mirror of
				https://github.com/wting/autojump
				synced 2025-06-13 12:54:07 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			508 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			508 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| #!/usr/bin/env python
 | |
| # -*- coding: utf-8 -*-
 | |
| """
 | |
|   Copyright © 2008-2012 Joel Schaerer
 | |
|   Copyright © 2012      William Ting
 | |
| 
 | |
|   *  This program is free software; you can redistribute it and/or modify
 | |
|   it under the terms of the GNU General Public License as published by
 | |
|   the Free Software Foundation; either version 3, or (at your option)
 | |
|   any later version.
 | |
| 
 | |
|   *  This program is distributed in the hope that it will be useful,
 | |
|   but WITHOUT ANY WARRANTY; without even the implied warranty of
 | |
|   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | |
|   GNU General Public License for more details.
 | |
| 
 | |
|   *  You should have received a copy of the GNU General Public License
 | |
|   along with this program; if not, write to the Free Software
 | |
|   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 | |
| """
 | |
| 
 | |
| from __future__ import division, print_function
 | |
| 
 | |
| try:
 | |
|     import argparse
 | |
| except ImportError:
 | |
|     # Python 2.6 support
 | |
|     sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 | |
|     import autojump_argparse as argparse
 | |
|     sys.path.pop()
 | |
| 
 | |
| import collections
 | |
| import difflib
 | |
| import math
 | |
| import operator
 | |
| import os
 | |
| import re
 | |
| import shutil
 | |
| import sys
 | |
| import tempfile
 | |
| 
 | |
| class Database:
 | |
|     """
 | |
|     Object for interfacing with autojump database file.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, config):
 | |
|         self.config = config
 | |
|         self.filename = config['db']
 | |
|         self.data = collections.defaultdict(int)
 | |
|         self.load()
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self.data)
 | |
| 
 | |
|     def add(self, path, increment=10):
 | |
|         """
 | |
|         Increase weight of existing paths or initialize new ones to 10.
 | |
|         """
 | |
|         if path == self.config['home']:
 | |
|             return
 | |
| 
 | |
|         path = path.rstrip(os.sep)
 | |
| 
 | |
|         if self.data[path]:
 | |
|             self.data[path] = math.sqrt((self.data[path]**2) + (increment**2))
 | |
|         else:
 | |
|             self.data[path] = increment
 | |
| 
 | |
|         self.save()
 | |
| 
 | |
|     def decrease(self, path, increment=15):
 | |
|         """
 | |
|         Decrease weight of existing path. Unknown ones are ignored.
 | |
|         """
 | |
|         if path == self.config['home']:
 | |
|             return
 | |
| 
 | |
|         if self.data[path] < increment:
 | |
|             self.data[path] = 0
 | |
|         else:
 | |
|             self.data[path] -= increment
 | |
| 
 | |
|         self.save()
 | |
| 
 | |
|     def get_weight(self, path):
 | |
|         """
 | |
|         Return path weight.
 | |
|         """
 | |
|         return self.data[path]
 | |
| 
 | |
|     def load(self, error_recovery = False):
 | |
|         """
 | |
|         Open database file, recovering from backup if needed.
 | |
|         """
 | |
|         if os.path.exists(self.filename):
 | |
|             try:
 | |
|                 if sys.version_info >= (3, 0):
 | |
|                     with open(self.filename, 'r', encoding='utf-8') as f:
 | |
|                         for line in f.readlines():
 | |
|                             weight, path = line[:-1].split("\t", 1)
 | |
|                             path = decode(path, 'utf-8')
 | |
|                             self.data[path] = float(weight)
 | |
|                 else:
 | |
|                     with open(self.filename, 'r') as f:
 | |
|                         for line in f.readlines():
 | |
|                             weight, path = line[:-1].split("\t", 1)
 | |
|                             path = decode(path, 'utf-8')
 | |
|                             self.data[path] = float(weight)
 | |
|             except (IOError, EOFError):
 | |
|                 self.load_backup(error_recovery)
 | |
|         else:
 | |
|             self.load_backup(error_recovery)
 | |
| 
 | |
|     def load_backup(self, error_recovery = False):
 | |
|         """
 | |
|         Loads database from backup file.
 | |
|         """
 | |
|         if os.path.exists(self.filename + '.bak'):
 | |
|             if not error_recovery:
 | |
|                 print('Problem with autojump database,\
 | |
|                         trying to recover from backup...', file=sys.stderr)
 | |
|                 shutil.copy(self.filename + '.bak', self.filename)
 | |
|                 return self.load(True)
 | |
| 
 | |
|     def maintenance(self):
 | |
|         """
 | |
|         Decay weights by 10%, periodically remove bottom 10% entries.
 | |
|         """
 | |
|         try:
 | |
|             items = self.data.iteritems()
 | |
|         except AttributeError:
 | |
|             items = self.data.items()
 | |
| 
 | |
|         for path, _ in items:
 | |
|             self.data[path] *= 0.9
 | |
| 
 | |
|         if len(self.data) > self.config['max_paths']:
 | |
|             remove_cnt = int(0.1 * len(self.data))
 | |
|             for path in sorted(self.data, key=self.data.get)[:remove_cnt]:
 | |
|                 del self.data[path]
 | |
| 
 | |
|             self.save()
 | |
| 
 | |
|     def purge(self):
 | |
|         """
 | |
|         Deletes all entries that no longer exist on system.
 | |
|         """
 | |
|         removed = []
 | |
| 
 | |
|         for path in list(self.data.keys()):
 | |
|             if not os.path.exists(path):
 | |
|                 removed.append(path)
 | |
|                 del self.data[path]
 | |
| 
 | |
|         self.save()
 | |
|         return removed
 | |
| 
 | |
|     def save(self):
 | |
|         """
 | |
|         Save database atomically and preserve backup, creating new database if
 | |
|         needed.
 | |
|         """
 | |
|         # check file existence and permissions
 | |
|         if ((not os.path.exists(self.filename)) or
 | |
|                 os.name == 'nt' or
 | |
|                 os.getuid() == os.stat(self.filename)[4]):
 | |
|             temp = tempfile.NamedTemporaryFile(dir=self.config['data'],
 | |
|                     delete=False)
 | |
| 
 | |
|             for path, weight in sorted(self.data.items(),
 | |
|                     key=operator.itemgetter(1),
 | |
|                     reverse=True):
 | |
|                 temp.write((unico("%s\t%s\n" % (weight, path)).encode("utf-8")))
 | |
| 
 | |
|             # catching disk errors and skipping save when file handle can't
 | |
|             # be closed.
 | |
|             try:
 | |
|                 # http://thunk.org/tytso/blog/2009/03/15/dont-fear-the-fsync/
 | |
|                 temp.flush()
 | |
|                 os.fsync(temp)
 | |
|                 temp.close()
 | |
|             except IOError as ex:
 | |
|                 print("Error saving autojump database (disk full?)" %
 | |
|                         ex, file=sys.stderr)
 | |
|                 return
 | |
| 
 | |
|             shutil.move(temp.name, self.filename)
 | |
|             try: # backup file
 | |
|                 import time
 | |
|                 if (not os.path.exists(self.filename+".bak") or
 | |
|                         time.time()-os.path.getmtime(self.filename+".bak") \
 | |
|                                 > 86400):
 | |
|                     shutil.copy(self.filename, self.filename+".bak")
 | |
|             except OSError as ex:
 | |
|                 print("Error while creating backup autojump file. (%s)" %
 | |
|                         ex, file=sys.stderr)
 | |
| 
 | |
| def set_defaults():
 | |
|     config = {}
 | |
| 
 | |
|     config['version'] = 'release-v21.6.5'
 | |
|     config['max_paths'] = 1000
 | |
|     config['separator'] = '__'
 | |
|     config['home'] = os.path.expanduser('HOME')
 | |
| 
 | |
|     config['ignore_case'] = False
 | |
|     config['keep_symlinks'] = False
 | |
|     config['debug'] = False
 | |
|     config['match_cnt'] = 1
 | |
| 
 | |
|     xdg_data = os.environ.get('XDG_DATA_HOME') or \
 | |
|             os.path.join(config['home'], '.local', 'share')
 | |
|     config['data'] = os.path.join(xdg_data, 'autojump')
 | |
|     config['db'] = config['data'] + '/autojump.txt'
 | |
| 
 | |
|     return config
 | |
| 
 | |
| def parse_env(config):
 | |
|     if 'AUTOJUMP_DATA_DIR' in os.environ:
 | |
|         config['data'] = os.environ.get('AUTOJUMP_DATA_DIR')
 | |
|         config['db'] = config['data'] + '/autojump.txt'
 | |
| 
 | |
|     if config['data'] == config['home']:
 | |
|         config['db'] = config['data'] + '/.autojump.txt'
 | |
| 
 | |
|     if 'AUTOJUMP_IGNORE_CASE' in os.environ and \
 | |
|             os.environ.get('AUTOJUMP_IGNORE_CASE') == '1':
 | |
|         config['ignore_case'] = True
 | |
| 
 | |
|     if 'AUTOJUMP_KEEP_SYMLINKS' in os.environ and \
 | |
|             os.environ.get('AUTOJUMP_KEEP_SYMLINKS') == '1':
 | |
|         config['keep_symlinks'] = True
 | |
| 
 | |
|     return config
 | |
| 
 | |
| def parse_arg(config):
 | |
|     parser = argparse.ArgumentParser(
 | |
|             description='Automatically jump to directory passed as an argument.',
 | |
|             epilog="Please see autojump(1) man pages for full documentation.")
 | |
|     parser.add_argument(
 | |
|             'directory', metavar='DIRECTORY', nargs='*', default='',
 | |
|             help='directory to jump to')
 | |
|     parser.add_argument(
 | |
|             '-a', '--add', metavar='DIRECTORY',
 | |
|             help='manually add path to database')
 | |
|     parser.add_argument(
 | |
|             '-i', '--increase', metavar='WEIGHT', nargs='?', type=int,
 | |
|             const=20, default=False,
 | |
|             help='manually increase path weight in database')
 | |
|     parser.add_argument(
 | |
|             '-d', '--decrease', metavar='WEIGHT', nargs='?', type=int,
 | |
|             const=15, default=False,
 | |
|             help='manually decrease path weight in database')
 | |
|     parser.add_argument(
 | |
|             '-b', '--bash', action="store_true", default=False,
 | |
|             help='enclose directory quotes to prevent errors')
 | |
|     parser.add_argument(
 | |
|             '--complete', action="store_true", default=False,
 | |
|             help='used for tab completion')
 | |
|     parser.add_argument(
 | |
|             '--purge', action="store_true", default=False,
 | |
|             help='delete all database entries that no longer exist on system')
 | |
|     parser.add_argument(
 | |
|             '-s', '--stat', action="store_true", default=False,
 | |
|             help='show database entries and their key weights')
 | |
|     parser.add_argument(
 | |
|             '-v', '--version', action="version", version="%(prog)s " +
 | |
|             config['version'], help='show version information and exit')
 | |
| 
 | |
|     args = parser.parse_args()
 | |
|     db = Database(config)
 | |
| 
 | |
|     if args.add:
 | |
|         db.add(decode(args.add))
 | |
|         sys.exit(0)
 | |
| 
 | |
|     if args.increase:
 | |
|         print("%.2f:\t old directory weight" % db.get_weight(os.getcwd()))
 | |
|         db.add(os.getcwd(), args.increase)
 | |
|         print("%.2f:\t new directory weight" % db.get_weight(os.getcwd()))
 | |
|         sys.exit(0)
 | |
| 
 | |
|     if args.decrease:
 | |
|         print("%.2f:\t old directory weight" % db.get_weight(os.getcwd()))
 | |
|         db.decrease(os.getcwd(), args.decrease)
 | |
|         print("%.2f:\t new directory weight" % db.get_weight(os.getcwd()))
 | |
|         sys.exit(0)
 | |
| 
 | |
|     if args.purge:
 | |
|         removed = db.purge()
 | |
| 
 | |
|         if len(removed):
 | |
|             for dir in removed:
 | |
|                 output(dir)
 | |
| 
 | |
|         print("Number of database entries removed: %d" % len(removed))
 | |
| 
 | |
|         sys.exit(0)
 | |
| 
 | |
|     if args.stat:
 | |
|         for path, weight in sorted(db.data.items(),
 | |
|                 key=operator.itemgetter(1))[-100:]:
 | |
|             output("%.1f:\t%s" % (weight, path))
 | |
| 
 | |
|         print("________________________________________\n")
 | |
|         print("%d:\t total key weight" % sum(db.data.values()))
 | |
|         print("%d:\t stored directories" % len(db.data))
 | |
|         print("%.2f:\t current directory weight" % db.get_weight(os.getcwd()))
 | |
| 
 | |
|         print("\ndb file: %s" % config['db'])
 | |
|         sys.exit(0)
 | |
| 
 | |
|     if args.complete:
 | |
|         config['match_cnt'] = 9
 | |
|         config['ignore_case'] = True
 | |
| 
 | |
|     config['args'] = args
 | |
|     return config
 | |
| 
 | |
| def decode(text, encoding=None, errors="strict"):
 | |
|     """
 | |
|     Decoding step for Python 2 which does not default to unicode.
 | |
|     """
 | |
|     if sys.version_info[0] > 2:
 | |
|         return text
 | |
|     else:
 | |
|         if encoding is None:
 | |
|             encoding = sys.getfilesystemencoding()
 | |
|         return text.decode(encoding, errors)
 | |
| 
 | |
| def output_quotes(config, text):
 | |
|     quotes = ""
 | |
|     if config['args'].complete and config['args'].bash:
 | |
|         quotes = "'"
 | |
| 
 | |
|     output("%s%s%s" % (quotes, text, quotes))
 | |
| 
 | |
| def output(text, encoding=None):
 | |
|     """
 | |
|     Wrapper for the print function, using the filesystem encoding by default
 | |
|     to minimize encoding mismatch problems in directory names.
 | |
|     """
 | |
|     if sys.version_info[0] > 2:
 | |
|         print(text)
 | |
|     else:
 | |
|         if encoding is None:
 | |
|             encoding = sys.getfilesystemencoding()
 | |
|         print(unicode(text).encode(encoding))
 | |
| 
 | |
| def unico(text):
 | |
|     """
 | |
|     If Python 2, convert to a unicode object.
 | |
|     """
 | |
|     if sys.version_info[0] > 2:
 | |
|         return text
 | |
|     else:
 | |
|         return unicode(text)
 | |
| 
 | |
| def match(path, pattern, only_end=False, ignore_case=False):
 | |
|     """
 | |
|     Check whether a path matches a particular pattern, and return
 | |
|     the remaining part of the string.
 | |
|     """
 | |
|     if only_end:
 | |
|         match_path = "/".join(path.split('/')[-1-pattern.count('/'):])
 | |
|     else:
 | |
|         match_path = path
 | |
| 
 | |
|     if ignore_case:
 | |
|         match_path = match_path.lower()
 | |
|         pattern = pattern.lower()
 | |
| 
 | |
|     find_idx = match_path.find(pattern)
 | |
|     # truncate path to avoid matching a pattern multiple times
 | |
|     if find_idx != -1:
 | |
|         return (True, path)
 | |
|     else:
 | |
|         return (False, path[find_idx+len(pattern):])
 | |
| 
 | |
| def find_matches(config, db, patterns, ignore_case=False, fuzzy=False):
 | |
|     """
 | |
|     Find paths matching patterns up to max_matches.
 | |
|     """
 | |
|     try:
 | |
|         current_dir = decode(os.path.realpath(os.curdir))
 | |
|     except OSError:
 | |
|         current_dir = None
 | |
| 
 | |
|     dirs = sorted(db.data.items(), key=operator.itemgetter(1), reverse=True)
 | |
|     results = []
 | |
| 
 | |
|     if ignore_case:
 | |
|         patterns = [p.lower() for p in patterns]
 | |
| 
 | |
|     if fuzzy:
 | |
|         # create dictionary of end paths to compare against
 | |
|         end_dirs = {}
 | |
|         for d in dirs:
 | |
|             if ignore_case:
 | |
|                 end = d[0].split('/')[-1].lower()
 | |
|             else:
 | |
|                 end = d[0].split('/')[-1]
 | |
| 
 | |
|             # collisions: ignore lower weight paths
 | |
|             if end not in end_dirs:
 | |
|                 end_dirs[end] = d[0]
 | |
| 
 | |
|         # find the first match (heighest weight)
 | |
|         while True:
 | |
|             found = difflib.get_close_matches(patterns[-1], end_dirs, n=1, cutoff=.6)
 | |
|             if not found:
 | |
|                 break
 | |
|             # avoid jumping to current directory
 | |
|             if (os.path.exists(found[0]) or config['debug']) and \
 | |
|                 current_dir != os.path.realpath(found[0]):
 | |
|                 break
 | |
|             # continue with the last found directory removed
 | |
|             del end_dirs[found[0]]
 | |
| 
 | |
|         if found:
 | |
|             found = found[0]
 | |
|             results.append(end_dirs[found])
 | |
|             return results
 | |
|         else:
 | |
|             return []
 | |
| 
 | |
|     current_dir_match = False
 | |
|     for path, _ in dirs:
 | |
|         found, tmp = True, path
 | |
|         for n, p in enumerate(patterns):
 | |
|             # for single/last pattern, only check end of path
 | |
|             if n == len(patterns)-1:
 | |
|                 found, tmp = match(tmp, p, True, ignore_case)
 | |
|             else:
 | |
|                 found, tmp = match(tmp, p, False, ignore_case)
 | |
|             if not found: break
 | |
| 
 | |
|         if found and (os.path.exists(path) or config['debug']):
 | |
|             # avoid jumping to current directory
 | |
|             # (call out to realpath this late to not stat all dirs)
 | |
|             if current_dir == os.path.realpath(path):
 | |
|                 current_dir_match = True
 | |
|                 continue
 | |
| 
 | |
|             if path not in results:
 | |
|                 results.append(path)
 | |
| 
 | |
|             if len(results) >= config['match_cnt']:
 | |
|                 break
 | |
| 
 | |
|     # if current directory is the only match, add it to results
 | |
|     if len(results) == 0 and current_dir_match:
 | |
|         results.append(current_dir)
 | |
| 
 | |
|     return results
 | |
| 
 | |
| def main():
 | |
|     config = parse_arg(parse_env(set_defaults()))
 | |
|     sep = config['separator']
 | |
|     db = Database(config)
 | |
| 
 | |
|     # checking command line directory arguments
 | |
|     if config['args'].directory:
 | |
|         patterns = [decode(d) for d in config['args'].directory]
 | |
|     else:
 | |
|         patterns = [unico('')]
 | |
| 
 | |
|     # check for tab completion
 | |
|     tab_choice = -1
 | |
|     tab_match = re.search(config['separator']+"([0-9]+)", patterns[-1])
 | |
|     if tab_match: # user has selected a tab completion entry
 | |
|         tab_choice = int(tab_match.group(1))
 | |
|         patterns[-1] = re.sub(config['separator']+"[0-9]+.*", "", patterns[-1])
 | |
|     else: # user hasn't selected a tab completion, display choices again
 | |
|         tab_match = re.match("(.*)"+config['separator'], patterns[-1])
 | |
|         if tab_match:
 | |
|             patterns[-1] = tab_match.group(1)
 | |
| 
 | |
|     results = find_matches(config, db, patterns,
 | |
|             ignore_case=config['ignore_case'])
 | |
| 
 | |
|     # if no results, try ignoring case
 | |
|     if not results and not config['ignore_case']:
 | |
|         results = find_matches(config, db, patterns, ignore_case=True)
 | |
| 
 | |
|     # if no results, try approximate matching
 | |
|     if not results:
 | |
|         results = find_matches(config, db, patterns, ignore_case=True,
 | |
|                 fuzzy=True)
 | |
| 
 | |
|     if tab_choice and len(results) > (tab_choice-1):
 | |
|         output_quotes(results[tab_choice-1])
 | |
|     elif len(results) > 1 and config['args'].complete:
 | |
|         for n, r in enumerate(results[:9]):
 | |
|             output_quotes('%s%s%d%s%s\n' % (patterns[-1], sep, n+1, sep, r))
 | |
|     elif results:
 | |
|         output_quotes(results[0])
 | |
|     else:
 | |
|         return 1
 | |
| 
 | |
|     db.maintenance()
 | |
| 
 | |
|     return 0
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     sys.exit(main())
 |