#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
  Copyright © 2008-2012 Joel Schaerer
  Copyright © 2012-2014 William Ting

  * This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 3, or (at your option)
  any later version.

  * This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  * You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software Foundation, Inc.,
  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""

from __future__ import print_function

from argparse import ArgumentParser
from difflib import SequenceMatcher
from itertools import chain
from math import sqrt
from operator import attrgetter
from operator import itemgetter
import os
import re
import sys

if sys.version_info[0] == 3:
    ifilter = filter
    imap = map
    os.getcwdu = os.getcwd
else:
    from itertools import ifilter
    from itertools import imap

from autojump_data import dictify
from autojump_data import entriefy
from autojump_data import Entry
from autojump_data import load
from autojump_data import save
from autojump_utils import decode
from autojump_utils import encode_local
from autojump_utils import first
from autojump_utils import get_tab_needle_and_path
from autojump_utils import get_pwd
from autojump_utils import has_uppercase
from autojump_utils import is_osx
from autojump_utils import last
from autojump_utils import print_entry
from autojump_utils import print_tab_menu
from autojump_utils import sanitize
from autojump_utils import take

VERSION = '22.0.0-alpha'
FUZZY_MATCH_THRESHOLD = 0.6
TAB_ENTRIES_COUNT = 9
TAB_SEPARATOR = '__'


def set_defaults():
    config = {}

    if is_osx():
        data_home = os.path.join(
                        os.path.expanduser('~'),
                        'Library',
                        'autojump')
    else:
        data_home = os.getenv(
                'XDG_DATA_HOME',
                os.path.join(
                        os.path.expanduser('~'),
                        '.local',
                        'share',
                        'autojump'))

    config['data_path'] = os.path.join(data_home, 'autojump.txt')
    config['backup_path'] = os.path.join(data_home, 'autojump.txt.bak')
    config['tmp_path'] = os.path.join(data_home, 'data.tmp')

    return config


def parse_arguments():
    parser = ArgumentParser(
            description='Automatically jump to directory passed as an \
                    argument.',
            epilog="Please see autojump(1) man pages for full documentation.")
    parser.add_argument(
            'directory', metavar='DIRECTORY', nargs='*', default='',
            help='directory to jump to')
    parser.add_argument(
            '-a', '--add', metavar='DIRECTORY',
            help='add path')
    parser.add_argument(
            '-i', '--increase', metavar='WEIGHT', nargs='?', type=int,
            const=10, default=False,
            help='increase current directory weight')
    parser.add_argument(
            '-d', '--decrease', metavar='WEIGHT', nargs='?', type=int,
            const=15, default=False,
            help='decrease current directory weight')
    parser.add_argument(
            '--complete', action="store_true", default=False,
            help='used for tab completion')
    parser.add_argument(
            '--purge', action="store_true", default=False,
            help='remove non-existent paths from database')
    parser.add_argument(
            '-s', '--stat', action="store_true", default=False,
            help='show database entries and their key weights')
    parser.add_argument(
            '-v', '--version', action="version", version="%(prog)s v" +
            VERSION, help='show version information')

    return parser.parse_args()


def add_path(data, path, weight=10):
    """
    Add a new path or increment an existing one.

    os.path.realpath() is not used because it's preferable to use symlinks
    with resulting duplicate entries in the database than a single canonical
    path.
    """
    path = decode(path).rstrip(os.sep)
    if path == os.path.expanduser('~'):
        return data, Entry(path, 0)

    data[path] = sqrt((data.get(path, 0) ** 2) + (weight ** 2))

    return data, Entry(path, data[path])


def decrease_path(data, path, weight=15):
    """Decrease or zero out a path."""
    path = decode(path).rstrip(os.sep)
    data[path] = max(0, data.get(path, 0) - weight)
    return data, Entry(path, data[path])


def detect_smartcase(needles):
    """
    If any needles contain an uppercase letter then use case sensitive
    searching. Otherwise use case insensitive searching.
    """
    return not any(imap(has_uppercase, needles))


def find_matches(entries, needles):
    """Return an iterator to matching entries."""
    try:
        not_cwd = lambda entry: entry.path != os.getcwdu()
    except OSError:
        # tautology if current working directory no longer exists
        not_cwd = lambda x: True

    data = sorted(
            ifilter(not_cwd, entries),
            key=attrgetter('weight'),
            reverse=True)

    ignore_case = detect_smartcase(needles)

    exists = lambda entry: os.path.exists(entry.path)
    return ifilter(
            exists,
            chain(
                match_consecutive(needles, data, ignore_case),
                match_fuzzy(needles, data, ignore_case),
                match_anywhere(needles, data, ignore_case),
                # default return value so calling shell functions have an
                # argument to `cd` to
                [Entry('.', 0)]))


def handle_tab_completion(needle, entries):
    if not needle:
        sys.exit(0)

    tab_needle, path = get_tab_needle_and_path(needle, TAB_SEPARATOR)

    if path:
        # found complete tab completion entry
        print(encode_local(path))
    elif tab_needle:
        # found partial tab completion entry
        print_tab_menu(
                tab_needle,
                take(TAB_ENTRIES_COUNT, find_matches(entries, tab_needle)),
                TAB_SEPARATOR)
    else:
        print_tab_menu(
                needle,
                take(TAB_ENTRIES_COUNT, find_matches(entries, needle)),
                TAB_SEPARATOR)



def match_anywhere(needles, haystack, ignore_case=False):
    """
    Matches needles anywhere in the path as long as they're in the same (but
    not necessary consecutive) order.

    For example:
        needles = ['foo', 'baz']
        regex needle = r'.*foo.*baz.*'
        haystack = [
            (path="/foo/bar/baz", weight=10),
            (path="/baz/foo/bar", weight=10),
            (path="/foo/baz", weight=10)]

        result = [
            (path="/moo/foo/baz", weight=10),
            (path="/foo/baz", weight=10)]
    """
    regex_needle = '.*' + '.*'.join(needles) + '.*'
    regex_flags = re.IGNORECASE | re.UNICODE if ignore_case else re.UNICODE
    found = lambda haystack: re.search(
            regex_needle,
            haystack.path,
            flags=regex_flags)
    return ifilter(found, haystack)


def match_consecutive(needles, haystack, ignore_case=False):
    """
    Matches consecutive needles at the end of a path.

    For example:
        needles = ['foo', 'baz']
        haystack = [
            (path="/foo/bar/baz", weight=10),
            (path="/foo/baz/moo", weight=10),
            (path="/moo/foo/baz", weight=10),
            (path="/foo/baz", weight=10)]

        regex_needle = re.compile(r'''
            foo     # needle #1
            [^/]*   # all characters except os.sep zero or more times
            /       # os.sep
            [^/]*   # all characters except os.sep zero or more times
            baz     # needle #2
            [^/]*   # all characters except os.sep zero or more times
            $       # end of string
            ''')

        result = [
            (path="/moo/foo/baz", weight=10),
            (path="/foo/baz", weight=10)]
    """
    regex_no_sep = '[^' + os.sep + ']*'
    regex_one_sep = regex_no_sep + os.sep + regex_no_sep
    regex_no_sep_end = regex_no_sep + '$'
    # can't use compiled regex because of flags
    regex_needle = regex_one_sep.join(needles) + regex_no_sep_end
    regex_flags = re.IGNORECASE | re.UNICODE if ignore_case else re.UNICODE

    found = lambda entry: re.search(
            regex_needle,
            entry.path,
            flags=regex_flags)
    return ifilter(found, haystack)


def match_fuzzy(needles, haystack, ignore_case=False):
    """
    Performs an approximate match with the last needle against the end of
    every path past an acceptable threshold (FUZZY_MATCH_THRESHOLD).

    For example:
        needles = ['foo', 'bar']
        haystack = [
            (path="/foo/bar/baz", weight=11),
            (path="/foo/baz/moo", weight=10),
            (path="/moo/foo/baz", weight=10),
            (path="/foo/baz", weight=10),
            (path="/foo/bar", weight=10)]

    result = [
            (path="/foo/bar/baz", weight=11),
            (path="/moo/foo/baz", weight=10),
            (path="/foo/baz", weight=10),
            (path="/foo/bar", weight=10)]

    This is a weak heuristic and used as a last resort to find matches.
    """
    end_dir = lambda path: last(os.path.split(path))
    if ignore_case:
        needle = last(needles).lower()
        match_percent = lambda entry: SequenceMatcher(
                a=needle,
                b=end_dir(entry.path.lower())).ratio()
    else:
        needle = last(needles)
        match_percent = lambda entry: SequenceMatcher(
                a=needle,
                b=end_dir(entry.path)).ratio()
    meets_threshold = lambda entry: match_percent(entry) >= \
        FUZZY_MATCH_THRESHOLD
    return ifilter(meets_threshold, haystack)


def purge_missing_paths(entries):
    """Remove non-existent paths from a list of entries."""
    exists = lambda entry: os.path.exists(entry.path)
    return ifilter(exists, entries)


def print_stats(data, data_path):
    for path, weight in sorted(data.items(), key=itemgetter(1)):
        print_entry(Entry(path, weight))

    print("________________________________________\n")
    print("%d:\t total weight" % sum(data.values()))
    print("%d:\t number of entries" % len(data))

    try:
        print("%.2f:\t current directory weight" % data.get(os.getcwdu(), 0))
    except OSError:
        # current directory no longer exists
        pass

    print("\ndata:\t %s" % data_path)


def main(args):
    config = set_defaults()

    # all arguments are mutually exclusive
    if args.add:
        save(config, first(add_path(load(config), args.add)))
    elif args.complete:
        handle_tab_completion(
                needle=first(sanitize(args.directory)),
                entries=entriefy(load(config)))
    elif args.decrease:
        data, entry = decrease_path(load(config), get_pwd(), args.decrease)
        save(config, data)
        print_entry(entry)
    elif args.increase:
        data, entry = add_path(load(config), get_pwd(), args.increase)
        save(config, data)
        print_entry(entry)
    elif args.purge:
        old_data = load(config)
        new_data = dictify(purge_missing_paths(entriefy(old_data)))
        save(config, new_data)
        print("Purged %d entries." % (len(old_data) - len(new_data)))
    elif args.stat:
        print_stats(load(config), config['data_path'])
    elif not args.directory:
        # default return value so calling shell functions have an argument
        # to `cd` to
        print(encode_local('.'))
    else:
        entries = entriefy(load(config))
        needles = sanitize(args.directory)
        _, path = get_tab_needle_and_path(first(needles), TAB_SEPARATOR)

        if path:
            # found complete tab completion entry
            print(encode_local(path))
        else:
            print(encode_local(first(find_matches(entries, needles)).path))

    return 0


if __name__ == "__main__":
    sys.exit(main(parse_arguments()))