#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
  Copyright © 2008-2012 Joel Schaerer
  Copyright © 2012-2013 William Ting

  *  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 3, or (at your option)
  any later version.

  *  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  *  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""

from __future__ import division, print_function

import collections
import difflib
import math
import operator
import os
import re
import shutil
import sys
import tempfile

try:
    import argparse
except ImportError:
    # Python 2.6 support
    sys.path.append(os.path.dirname(os.path.realpath(__file__)))
    import autojump_argparse as argparse
    sys.path.pop()

class Database:
    """
    Abstraction for interfacing with with autojump database file.
    """

    def __init__(self, config):
        self.config = config
        self.filename = config['db']
        self.data = collections.defaultdict(int)
        self.load()

    def __len__(self):
        return len(self.data)

    def load(self, error_recovery = False):
        """
        Open database file, recovering from backup if needed.
        """
        if os.path.exists(self.filename):
            try:
                if sys.version_info >= (3, 0):
                    with open(self.filename, 'r', encoding='utf-8') as f:
                        for line in f.readlines():
                            weight, path = line[:-1].split("\t", 1)
                            path = decode(path, 'utf-8')
                            self.data[path] = float(weight)
                else:
                    with open(self.filename, 'r') as f:
                        for line in f.readlines():
                            weight, path = line[:-1].split("\t", 1)
                            path = decode(path, 'utf-8')
                            self.data[path] = float(weight)
            except (IOError, EOFError):
                self.load_backup(error_recovery)
        else:
            self.load_backup(error_recovery)

    def load_backup(self, error_recovery = False):
        """
        Loads database from backup file.
        """
        if os.path.exists(self.filename + '.bak'):
            if not error_recovery:
                print('Problem with autojump database,\
                        trying to recover from backup...', file=sys.stderr)
                shutil.copy(self.filename + '.bak', self.filename)
                return self.load(True)

    def save(self):
        """
        Save database atomically and preserve backup, creating new database if
        needed.
        """
        # check file existence and permissions
        if ((not os.path.exists(self.filename)) or
                os.name == 'nt' or
                os.getuid() == os.stat(self.filename)[4]):
            temp = tempfile.NamedTemporaryFile(dir=self.config['data'],
                    delete=False)

            for path, weight in sorted(self.data.items(),
                    key=operator.itemgetter(1),
                    reverse=True):
                temp.write((unico("%s\t%s\n" % (weight, path)).encode("utf-8")))

            # catching disk errors and skipping save when file handle can't
            # be closed.
            try:
                # http://thunk.org/tytso/blog/2009/03/15/dont-fear-the-fsync/
                temp.flush()
                os.fsync(temp)
                temp.close()
            except IOError as ex:
                print("Error saving autojump database (disk full?)" %
                        ex, file=sys.stderr)
                return

            shutil.move(temp.name, self.filename)
            try: # backup file
                import time
                if (not os.path.exists(self.filename+".bak") or
                        time.time()-os.path.getmtime(self.filename+".bak") \
                                > 86400):
                    shutil.copy(self.filename, self.filename+".bak")
            except OSError as ex:
                print("Error while creating backup autojump file. (%s)" %
                        ex, file=sys.stderr)

    def add(self, path, increment=10):
        """
        Increase weight of existing paths or initialize new ones to 10.
        """
        if path == self.config['home']:
            return

        path = path.rstrip(os.sep)

        if self.data[path]:
            self.data[path] = math.sqrt((self.data[path]**2) + (increment**2))
        else:
            self.data[path] = increment

        self.save()

    def decrease(self, path, increment=15):
        """
        Decrease weight of existing path. Unknown paths are ignored.
        """
        if path == self.config['home']:
            return

        if self.data[path] < increment:
            self.data[path] = 0
        else:
            self.data[path] -= increment

        self.save()

    def get_weight(self, path):
        return self.data[path]

    def maintenance(self):
        """
        Decay weights by 10%, periodically remove bottom 10% entries.
        """
        try:
            items = self.data.iteritems()
        except AttributeError:
            items = self.data.items()

        for path, _ in items:
            self.data[path] *= 0.9

        if len(self.data) > self.config['max_paths']:
            remove_cnt = int(0.1 * len(self.data))
            for path in sorted(self.data, key=self.data.get)[:remove_cnt]:
                del self.data[path]

            self.save()

    def purge(self):
        """
        Remove non-existent paths.
        """
        removed = []

        for path in list(self.data.keys()):
            if not os.path.exists(path):
                removed.append(path)
                del self.data[path]

        self.save()
        return removed

def set_defaults():
    config = {}

    config['version'] = 'release-v21.7.0'
    config['max_paths'] = 1000
    config['separator'] = '__'
    config['home'] = os.path.expanduser('~')

    config['ignore_case'] = False
    config['keep_symlinks'] = False
    config['debug'] = False
    config['match_cnt'] = 1

    xdg_data = os.environ.get('XDG_DATA_HOME') or \
            os.path.join(config['home'], '.local', 'share')
    config['data'] = os.path.join(xdg_data, 'autojump')
    config['db'] = config['data'] + '/autojump.txt'

    return config

def parse_env(config):
    if 'AUTOJUMP_DATA_DIR' in os.environ:
        config['data'] = os.environ.get('AUTOJUMP_DATA_DIR')
        config['db'] = config['data'] + '/autojump.txt'

    if config['data'] == config['home']:
        config['db'] = config['data'] + '/.autojump.txt'

    if 'AUTOJUMP_IGNORE_CASE' in os.environ and \
            os.environ.get('AUTOJUMP_IGNORE_CASE') == '1':
        config['ignore_case'] = True

    if 'AUTOJUMP_KEEP_SYMLINKS' in os.environ and \
            os.environ.get('AUTOJUMP_KEEP_SYMLINKS') == '1':
        config['keep_symlinks'] = True

    return config

def parse_arg(config):
    parser = argparse.ArgumentParser(
            description='Automatically jump to directory passed as an argument.',
            epilog="Please see autojump(1) man pages for full documentation.")
    parser.add_argument(
            'directory', metavar='DIRECTORY', nargs='*', default='',
            help='directory to jump to')
    parser.add_argument(
            '-a', '--add', metavar='DIRECTORY',
            help='manually add path to database')
    parser.add_argument(
            '-i', '--increase', metavar='WEIGHT', nargs='?', type=int,
            const=20, default=False,
            help='manually increase path weight in database')
    parser.add_argument(
            '-d', '--decrease', metavar='WEIGHT', nargs='?', type=int,
            const=15, default=False,
            help='manually decrease path weight in database')
    parser.add_argument(
            '-b', '--bash', action="store_true", default=False,
            help='enclose directory quotes to prevent errors')
    parser.add_argument(
            '--complete', action="store_true", default=False,
            help='used for tab completion')
    parser.add_argument(
            '--purge', action="store_true", default=False,
            help='delete all database entries that no longer exist on system')
    parser.add_argument(
            '-s', '--stat', action="store_true", default=False,
            help='show database entries and their key weights')
    parser.add_argument(
            '-v', '--version', action="version", version="%(prog)s " +
            config['version'], help='show version information and exit')

    args = parser.parse_args()
    db = Database(config)

    if args.add:
        db.add(decode(args.add))
        sys.exit(0)

    if args.increase:
        print("%.2f:\t old directory weight" % db.get_weight(os.getcwd()))
        db.add(os.getcwd(), args.increase)
        print("%.2f:\t new directory weight" % db.get_weight(os.getcwd()))
        sys.exit(0)

    if args.decrease:
        print("%.2f:\t old directory weight" % db.get_weight(os.getcwd()))
        db.decrease(os.getcwd(), args.decrease)
        print("%.2f:\t new directory weight" % db.get_weight(os.getcwd()))
        sys.exit(0)

    if args.purge:
        removed = db.purge()

        if len(removed):
            for dir in removed:
                output(dir)

        print("Number of database entries removed: %d" % len(removed))

        sys.exit(0)

    if args.stat:
        for path, weight in sorted(db.data.items(),
                key=operator.itemgetter(1))[-100:]:
            output("%.1f:\t%s" % (weight, path))

        print("________________________________________\n")
        print("%d:\t total key weight" % sum(db.data.values()))
        print("%d:\t stored directories" % len(db.data))
        print("%.2f:\t current directory weight" % db.get_weight(os.getcwd()))

        print("\ndb file: %s" % config['db'])
        sys.exit(0)

    if args.complete:
        config['match_cnt'] = 9
        config['ignore_case'] = True

    config['args'] = args
    return config

def decode(text, encoding=None, errors="strict"):
    """
    Decoding step for Python 2 which does not default to unicode.
    """
    if sys.version_info[0] > 2:
        return text
    else:
        if encoding is None:
            encoding = sys.getfilesystemencoding()
        return text.decode(encoding, errors)

def output_quotes(config, text):
    quotes = ""
    if config['args'].complete and config['args'].bash:
        quotes = "'"

    output("%s%s%s" % (quotes, text, quotes))

def output(text, encoding=None):
    """
    Wrapper for the print function, using the filesystem encoding by default
    to minimize encoding mismatch problems in directory names.
    """
    if sys.version_info[0] > 2:
        print(text)
    else:
        if encoding is None:
            encoding = sys.getfilesystemencoding()
        print(unicode(text).encode(encoding))

def unico(text):
    """
    If Python 2, convert to a unicode object.
    """
    if sys.version_info[0] > 2:
        return text
    else:
        return unicode(text)

def match(path, pattern, only_end=False, ignore_case=False):
    """
    Check whether a path matches a particular pattern, and return
    the remaining part of the string.
    """
    if only_end:
        match_path = "/".join(path.split('/')[-1-pattern.count('/'):])
    else:
        match_path = path

    if ignore_case:
        match_path = match_path.lower()
        pattern = pattern.lower()

    find_idx = match_path.find(pattern)
    # truncate path to avoid matching a pattern multiple times
    if find_idx != -1:
        return (True, path)
    else:
        return (False, path[find_idx+len(pattern):])

def find_matches(config, db, patterns, ignore_case=False, fuzzy=False):
    """
    Find paths matching patterns up to max_matches.
    """
    try:
        current_dir = decode(os.path.realpath(os.curdir))
    except OSError:
        current_dir = None

    dirs = sorted(db.data.items(), key=operator.itemgetter(1), reverse=True)
    results = []

    if ignore_case:
        patterns = [p.lower() for p in patterns]

    if fuzzy:
        # create dictionary of end paths to compare against
        end_dirs = {}
        for d in dirs:
            if ignore_case:
                end = d[0].split('/')[-1].lower()
            else:
                end = d[0].split('/')[-1]

            # collisions: ignore lower weight paths
            if end not in end_dirs:
                end_dirs[end] = d[0]

        # find the first match (heighest weight)
        while True:
            found = difflib.get_close_matches(patterns[-1], end_dirs, n=1, cutoff=.6)
            if not found:
                break
            # avoid jumping to current directory
            if (os.path.exists(found[0]) or config['debug']) and \
                current_dir != os.path.realpath(found[0]):
                break
            # continue with the last found directory removed
            del end_dirs[found[0]]

        if found:
            found = found[0]
            results.append(end_dirs[found])
            return results
        else:
            return []

    current_dir_match = False
    for path, _ in dirs:
        found, tmp = True, path
        for n, p in enumerate(patterns):
            # for single/last pattern, only check end of path
            if n == len(patterns)-1:
                found, tmp = match(tmp, p, True, ignore_case)
            else:
                found, tmp = match(tmp, p, False, ignore_case)
            if not found: break

        if found and (os.path.exists(path) or config['debug']):
            # avoid jumping to current directory
            # (call out to realpath this late to not stat all dirs)
            if current_dir == os.path.realpath(path):
                current_dir_match = True
                continue

            if path not in results:
                results.append(path)

            if len(results) >= config['match_cnt']:
                break

    # if current directory is the only match, add it to results
    if len(results) == 0 and current_dir_match:
        results.append(current_dir)

    return results

def main():
    config = parse_arg(parse_env(set_defaults()))
    sep = config['separator']
    db = Database(config)

    # checking command line directory arguments
    if config['args'].directory:
        patterns = [decode(d) for d in config['args'].directory]
    else:
        patterns = [unico('')]

    # check for tab completion
    tab_choice = None
    tab_match = re.search(sep+r'([0-9]+)', patterns[-1])

    # user has selected a tab completion entry
    if tab_match:
        config['match_cnt'] = 9
        tab_choice = int(tab_match.group(1))
        patterns[-1] = re.sub(sep+r'[0-9]+.*', '', patterns[-1])
    else:
        tab_match = re.match(r'(.*)'+sep, patterns[-1])
        # partial tab match, display choices again
        if tab_match:
            config['match_cnt'] = 9
            patterns[-1] = tab_match.group(1)

    results = find_matches(config, db, patterns,
            ignore_case=config['ignore_case'])

    # if no results, try ignoring case
    if not results and not config['ignore_case']:
        results = find_matches(config, db, patterns, ignore_case=True)

    # if no results, try approximate matching
    if not results:
        results = find_matches(config, db, patterns, ignore_case=True,
                fuzzy=True)

    if tab_choice and len(results) > (tab_choice-1):
        output_quotes(config, results[tab_choice-1])
    elif len(results) > 1 and config['args'].complete:
        for n, r in enumerate(results[:9]):
            output_quotes(config, '%s%s%d%s%s' %
                    (patterns[-1], sep, n+1, sep, r))
    elif results:
        output_quotes(config, results[0])
    else:
        return 1

    db.maintenance()

    return 0

if __name__ == "__main__":
    sys.exit(main())