1
0
mirror of https://github.com/wting/autojump synced 2024-10-27 20:34:07 +00:00

Refactor away from global variables, use defaultdict/iteration instead.

This commit is contained in:
William Ting 2013-05-14 17:34:19 -05:00
parent f5ff5a126f
commit 2582ad6421

View File

@ -21,117 +21,87 @@
from __future__ import division, print_function from __future__ import division, print_function
import sys
import os
try: try:
import argparse import argparse
except ImportError: except ImportError:
# Python 2.6 support
sys.path.append(os.path.dirname(os.path.realpath(__file__))) sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import autojump_argparse as argparse import autojump_argparse as argparse
sys.path.pop() sys.path.pop()
from operator import itemgetter
import collections
import math
import operator
import os
import re import re
import shutil import shutil
import sys
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
VERSION = 'release-v21.5.9'
MAX_KEYWEIGHT = 1000
MAX_STORED_PATHS = 1000
COMPLETION_SEPARATOR = '__'
ARGS = None
CONFIG_DIR = None
DB_FILE = None
TESTING = False
# load config from environmental variables
if 'AUTOJUMP_DATA_DIR' in os.environ:
CONFIG_DIR = os.environ.get('AUTOJUMP_DATA_DIR')
else:
xdg_data_dir = os.environ.get('XDG_DATA_HOME') or \
os.path.join(os.environ['HOME'], '.local', 'share')
CONFIG_DIR = os.path.join(xdg_data_dir, 'autojump')
KEEP_ALL_ENTRIES = False
if 'AUTOJUMP_KEEP_ALL_ENTRIES' in os.environ and \
os.environ.get('AUTOJUMP_KEEP_ALL_ENTRIES') == '1':
KEEP_ALL_ENTRIES = True
ALWAYS_IGNORE_CASE = False
if 'AUTOJUMP_IGNORE_CASE' in os.environ and \
os.environ.get('AUTOJUMP_IGNORE_CASE') == '1':
ALWAYS_IGNORE_CASE = True
KEEP_SYMLINKS = False
if 'AUTOJUMP_KEEP_SYMLINKS' in os.environ and \
os.environ.get('AUTOJUMP_KEEP_SYMLINKS') == '1':
KEEP_SYMLINKS = True
if CONFIG_DIR == os.path.expanduser('~'):
DB_FILE = CONFIG_DIR + '/.autojump.txt'
else:
DB_FILE = CONFIG_DIR + '/autojump.txt'
class Database: class Database:
""" """
Object for interfacing with autojump database. Object for interfacing with autojump database file.
""" """
def __init__(self, filename): def __init__(self, config):
self.filename = filename self.config = config
self.data = {} self.filename = config['db']
self.data = collections.defaultdict(int)
self.load() self.load()
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def add(self, path, increment = 10): def add(self, path, increment=10):
""" """
Increase weight of existing paths or initialize new ones to 10. Increase weight of existing paths or initialize new ones to 10.
""" """
path = path.rstrip(os.sep) path = path.rstrip(os.sep)
if path not in self.data:
self.data[path] = increment if self.data[path]:
else:
import math
self.data[path] = math.sqrt((self.data[path]**2) + (increment**2)) self.data[path] = math.sqrt((self.data[path]**2) + (increment**2))
else:
self.data[path] = increment
self.save() self.save()
def decrease(self, path, increment = 15): def decrease(self, path, increment=15):
""" """
Decrease weight of existing path. Unknown ones are ignored. Decrease weight of existing path. Unknown ones are ignored.
""" """
if path in self.data:
if self.data[path] < increment: if self.data[path] < increment:
self.data[path] = 0 self.data[path] = 0
else: else:
self.data[path] -= increment self.data[path] -= increment
self.save() self.save()
def decay(self): def decay(self):
""" """
Decay database entries. Decay database entries.
""" """
for k in self.data.keys(): try:
self.data[k] *= 0.9 items = self.data.iteritems()
except AttributeError:
items = self.data.items()
for path, _ in items:
self.data[path] *= 0.9
def get_weight(self, path): def get_weight(self, path):
""" """
Return path weight. Return path weight.
""" """
if path in self.data:
return self.data[path] return self.data[path]
else:
return 0
def load(self, error_recovery = False): def load(self, error_recovery = False):
""" """
Try to open the database file, recovering from backup if needed. Open database file, recovering from backup if needed.
""" """
if os.path.exists(self.filename): if os.path.exists(self.filename):
try: try:
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
with open(self.filename, 'r', encoding = 'utf-8') as f: with open(self.filename, 'r', encoding='utf-8') as f:
for line in f.readlines(): for line in f.readlines():
weight, path = line[:-1].split("\t", 1) weight, path = line[:-1].split("\t", 1)
path = decode(path, 'utf-8') path = decode(path, 'utf-8')
@ -162,10 +132,11 @@ class Database:
""" """
Trims and decays database entries when exceeding settings. Trims and decays database entries when exceeding settings.
""" """
if sum(self.data.values()) > MAX_KEYWEIGHT: if sum(self.data.values()) > self.config['max_weight']:
self.decay() self.decay()
if len(self.data) > MAX_STORED_PATHS: if len(self.data) > self.config['max_paths']:
self.trim() self.trim()
self.save() self.save()
def purge(self): def purge(self):
@ -173,10 +144,12 @@ class Database:
Deletes all entries that no longer exist on system. Deletes all entries that no longer exist on system.
""" """
removed = [] removed = []
for path in list(self.data.keys()): for path in list(self.data.keys()):
if not os.path.exists(path): if not os.path.exists(path):
removed.append(path) removed.append(path)
del self.data[path] del self.data[path]
self.save() self.save()
return removed return removed
@ -189,9 +162,9 @@ class Database:
if ((not os.path.exists(self.filename)) or if ((not os.path.exists(self.filename)) or
os.name == 'nt' or os.name == 'nt' or
os.getuid() == os.stat(self.filename)[4]): os.getuid() == os.stat(self.filename)[4]):
temp = NamedTemporaryFile(dir = CONFIG_DIR, delete = False) temp = NamedTemporaryFile(dir=self.config['data'], delete=False)
for path, weight in sorted(self.data.items(), for path, weight in sorted(self.data.items(),
key=itemgetter(1), key=operator.itemgetter(1),
reverse=True): reverse=True):
temp.write((unico("%s\t%s\n" % (weight, path)).encode("utf-8"))) temp.write((unico("%s\t%s\n" % (weight, path)).encode("utf-8")))
@ -223,18 +196,57 @@ class Database:
If database has exceeded MAX_STORED_PATHS, removes bottom 10%. If database has exceeded MAX_STORED_PATHS, removes bottom 10%.
""" """
dirs = list(self.data.items()) dirs = list(self.data.items())
dirs.sort(key=itemgetter(1)) dirs.sort(key=operator.itemgetter(1))
remove_cnt = int(percent * len(dirs)) remove_cnt = int(percent * len(dirs))
for path, _ in dirs[:remove_cnt]: for path, _ in dirs[:remove_cnt]:
del self.data[path] del self.data[path]
def set_defaults():
config = {}
def options(): config['version'] = 'release-v21.6.0'
""" config['max_weight'] = 1000
Parse command line options. config['max_paths'] = 1000
""" config['separator'] = '__'
global ARGS
config['ignore_case'] = False
config['keep_entries'] = False
config['keep_symlinks'] = False
config['debug'] = False
home = os.path.expanduser('HOME')
xdg_data = os.environ.get('XDG_DATA_HOME') or \
os.path.join(home, '.local', 'share')
config['data'] = os.path.join(xdg_data, 'autojump')
config['db'] = config['data'] + '/autojump.txt'
return config
def parse_env(config):
home = os.path.expanduser('HOME')
if 'AUTOJUMP_DATA_DIR' in os.environ:
config['data'] = os.environ.get('AUTOJUMP_DATA_DIR')
config['db'] = config['data'] + '/autojump.txt'
if config['data'] == home:
config['db'] = config['data'] + '/.autojump.txt'
if 'AUTOJUMP_KEEP_ALL_ENTRIES' in os.environ and \
os.environ.get('AUTOJUMP_KEEP_ALL_ENTRIES') == '1':
config['keep_entries'] = True
if 'AUTOJUMP_IGNORE_CASE' in os.environ and \
os.environ.get('AUTOJUMP_IGNORE_CASE') == '1':
config['ignore_case'] = True
if 'AUTOJUMP_KEEP_SYMLINKS' in os.environ and \
os.environ.get('AUTOJUMP_KEEP_SYMLINKS') == '1':
config['keep_symlinks'] = True
return config
def parse_arg(config):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Automatically jump to \ description='Automatically jump to \
directory passed as an argument.', directory passed as an argument.',
@ -263,48 +275,48 @@ def options():
'-s', '--stat', action="store_true", default=False, '-s', '--stat', action="store_true", default=False,
help='show database entries and their key weights') help='show database entries and their key weights')
parser.add_argument( parser.add_argument(
'-v', '--version', action="version", version="%(prog)s " + VERSION, '-v', '--version', action="version", version="%(prog)s " +
help='show version information and exit') config['version'], help='show version information and exit')
ARGS = parser.parse_args() args = parser.parse_args()
db = Database(config)
# The home dir can be reached quickly by "cd" and may interfere with other if (args.add):
# directories if (args.add != os.path.expanduser("~")):
if (ARGS.add): db.add(decode(args.add))
if(ARGS.add != os.path.expanduser("~")):
db = Database(DB_FILE)
db.add(decode(ARGS.add))
return True
if (ARGS.decrease): sys.exit(0)
if(ARGS.decrease != os.path.expanduser("~")):
db = Database(DB_FILE)
# FIXME: handle symlinks?
db.decrease(os.getcwd(), ARGS.decrease)
return True
if (ARGS.purge): if (args.decrease):
db = Database(DB_FILE) if (args.decrease != os.path.expanduser("~")):
db.decrease(os.getcwd(), args.decrease)
sys.exit(0)
if (args.purge):
removed = db.purge() removed = db.purge()
if len(removed) > 0:
if len(removed):
for dir in removed: for dir in removed:
output(unico(dir)) output(unico(dir))
print("Number of database entries removed: %d" % len(removed))
return True
if (ARGS.stat): print("Number of database entries removed: %d" % len(removed))
db = Database(DB_FILE)
dirs = list(db.data.items()) sys.exit(0)
dirs.sort(key=itemgetter(1))
for path, count in dirs[-100:]: if (args.stat):
output(unico("%.1f:\t%s") % (count, path)) for path, weight in sorted(db.data, key=db.data.get)[-100:]:
output(unico("%.1f:\t%s") % (weight, path))
print("________________________________________\n") print("________________________________________\n")
print("%d:\t total key weight" % sum(db.data.values())) print("%d:\t total key weight" % sum(db.data.values()))
print("%d:\t stored directories" % len(dirs)) print("%d:\t stored directories" % len(dirs))
print("db file: %s" % DB_FILE) print("db file: %s" % DB_FILE)
return True
return False sys.exit(0)
config['args'] = args
return config
def decode(text, encoding=None, errors="strict"): def decode(text, encoding=None, errors="strict"):
""" """
@ -385,7 +397,7 @@ def find_matches(db, patterns, max_matches=1, ignore_case=False, fuzzy=False):
current_dir = None current_dir = None
dirs = list(db.data.items()) dirs = list(db.data.items())
dirs.sort(key=itemgetter(1), reverse=True) dirs.sort(key=operator.itemgetter(1), reverse=True)
results = [] results = []
if fuzzy: if fuzzy:
from difflib import get_close_matches from difflib import get_close_matches
@ -450,12 +462,9 @@ def find_matches(db, patterns, max_matches=1, ignore_case=False, fuzzy=False):
return results return results
def shell_utility(): def main():
""" config = parse_arg(parse_env(set_defaults()))
Run this when autojump is called as a shell utility. db = Database(config['db'], config)
"""
if options(): return True
db = Database(DB_FILE)
# if no directories, add empty string # if no directories, add empty string
if (ARGS.directory == ''): if (ARGS.directory == ''):
@ -515,8 +524,7 @@ def shell_utility():
if not KEEP_ALL_ENTRIES: if not KEEP_ALL_ENTRIES:
db.maintenance() db.maintenance()
return True return 0
if __name__ == "__main__": if __name__ == "__main__":
if not shell_utility(): sys.exit(main())
sys.exit(1)