Refactor away from global variables, use defaultdict/iteration instead.

pull/209/head
William Ting 11 years ago
parent f5ff5a126f
commit 2582ad6421

@ -21,117 +21,87 @@
from __future__ import division, print_function
import sys
import os
try:
import argparse
except ImportError:
# Python 2.6 support
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import autojump_argparse as argparse
sys.path.pop()
from operator import itemgetter
import collections
import math
import operator
import os
import re
import shutil
import sys
from tempfile import NamedTemporaryFile
VERSION = 'release-v21.5.9'
MAX_KEYWEIGHT = 1000
MAX_STORED_PATHS = 1000
COMPLETION_SEPARATOR = '__'
ARGS = None
CONFIG_DIR = None
DB_FILE = None
TESTING = False
# load config from environmental variables
if 'AUTOJUMP_DATA_DIR' in os.environ:
CONFIG_DIR = os.environ.get('AUTOJUMP_DATA_DIR')
else:
xdg_data_dir = os.environ.get('XDG_DATA_HOME') or \
os.path.join(os.environ['HOME'], '.local', 'share')
CONFIG_DIR = os.path.join(xdg_data_dir, 'autojump')
KEEP_ALL_ENTRIES = False
if 'AUTOJUMP_KEEP_ALL_ENTRIES' in os.environ and \
os.environ.get('AUTOJUMP_KEEP_ALL_ENTRIES') == '1':
KEEP_ALL_ENTRIES = True
ALWAYS_IGNORE_CASE = False
if 'AUTOJUMP_IGNORE_CASE' in os.environ and \
os.environ.get('AUTOJUMP_IGNORE_CASE') == '1':
ALWAYS_IGNORE_CASE = True
KEEP_SYMLINKS = False
if 'AUTOJUMP_KEEP_SYMLINKS' in os.environ and \
os.environ.get('AUTOJUMP_KEEP_SYMLINKS') == '1':
KEEP_SYMLINKS = True
if CONFIG_DIR == os.path.expanduser('~'):
DB_FILE = CONFIG_DIR + '/.autojump.txt'
else:
DB_FILE = CONFIG_DIR + '/autojump.txt'
class Database:
"""
Object for interfacing with autojump database.
Object for interfacing with autojump database file.
"""
def __init__(self, filename):
self.filename = filename
self.data = {}
def __init__(self, config):
self.config = config
self.filename = config['db']
self.data = collections.defaultdict(int)
self.load()
def __len__(self):
return len(self.data)
def add(self, path, increment = 10):
def add(self, path, increment=10):
"""
Increase weight of existing paths or initialize new ones to 10.
"""
path = path.rstrip(os.sep)
if path not in self.data:
self.data[path] = increment
else:
import math
if self.data[path]:
self.data[path] = math.sqrt((self.data[path]**2) + (increment**2))
else:
self.data[path] = increment
self.save()
def decrease(self, path, increment = 15):
def decrease(self, path, increment=15):
"""
Decrease weight of existing path. Unknown ones are ignored.
"""
if path in self.data:
if self.data[path] < increment:
self.data[path] = 0
else:
self.data[path] -= increment
self.save()
if self.data[path] < increment:
self.data[path] = 0
else:
self.data[path] -= increment
self.save()
def decay(self):
"""
Decay database entries.
"""
for k in self.data.keys():
self.data[k] *= 0.9
try:
items = self.data.iteritems()
except AttributeError:
items = self.data.items()
for path, _ in items:
self.data[path] *= 0.9
def get_weight(self, path):
"""
Return path weight.
"""
if path in self.data:
return self.data[path]
else:
return 0
return self.data[path]
def load(self, error_recovery = False):
"""
Try to open the database file, recovering from backup if needed.
Open database file, recovering from backup if needed.
"""
if os.path.exists(self.filename):
try:
if sys.version_info >= (3, 0):
with open(self.filename, 'r', encoding = 'utf-8') as f:
with open(self.filename, 'r', encoding='utf-8') as f:
for line in f.readlines():
weight, path = line[:-1].split("\t", 1)
path = decode(path, 'utf-8')
@ -162,10 +132,11 @@ class Database:
"""
Trims and decays database entries when exceeding settings.
"""
if sum(self.data.values()) > MAX_KEYWEIGHT:
if sum(self.data.values()) > self.config['max_weight']:
self.decay()
if len(self.data) > MAX_STORED_PATHS:
if len(self.data) > self.config['max_paths']:
self.trim()
self.save()
def purge(self):
@ -173,10 +144,12 @@ class Database:
Deletes all entries that no longer exist on system.
"""
removed = []
for path in list(self.data.keys()):
if not os.path.exists(path):
removed.append(path)
del self.data[path]
self.save()
return removed
@ -189,9 +162,9 @@ class Database:
if ((not os.path.exists(self.filename)) or
os.name == 'nt' or
os.getuid() == os.stat(self.filename)[4]):
temp = NamedTemporaryFile(dir = CONFIG_DIR, delete = False)
temp = NamedTemporaryFile(dir=self.config['data'], delete=False)
for path, weight in sorted(self.data.items(),
key=itemgetter(1),
key=operator.itemgetter(1),
reverse=True):
temp.write((unico("%s\t%s\n" % (weight, path)).encode("utf-8")))
@ -223,18 +196,57 @@ class Database:
If database has exceeded MAX_STORED_PATHS, removes bottom 10%.
"""
dirs = list(self.data.items())
dirs.sort(key=itemgetter(1))
dirs.sort(key=operator.itemgetter(1))
remove_cnt = int(percent * len(dirs))
for path, _ in dirs[:remove_cnt]:
del self.data[path]
def set_defaults():
config = {}
def options():
"""
Parse command line options.
"""
global ARGS
config['version'] = 'release-v21.6.0'
config['max_weight'] = 1000
config['max_paths'] = 1000
config['separator'] = '__'
config['ignore_case'] = False
config['keep_entries'] = False
config['keep_symlinks'] = False
config['debug'] = False
home = os.path.expanduser('HOME')
xdg_data = os.environ.get('XDG_DATA_HOME') or \
os.path.join(home, '.local', 'share')
config['data'] = os.path.join(xdg_data, 'autojump')
config['db'] = config['data'] + '/autojump.txt'
return config
def parse_env(config):
home = os.path.expanduser('HOME')
if 'AUTOJUMP_DATA_DIR' in os.environ:
config['data'] = os.environ.get('AUTOJUMP_DATA_DIR')
config['db'] = config['data'] + '/autojump.txt'
if config['data'] == home:
config['db'] = config['data'] + '/.autojump.txt'
if 'AUTOJUMP_KEEP_ALL_ENTRIES' in os.environ and \
os.environ.get('AUTOJUMP_KEEP_ALL_ENTRIES') == '1':
config['keep_entries'] = True
if 'AUTOJUMP_IGNORE_CASE' in os.environ and \
os.environ.get('AUTOJUMP_IGNORE_CASE') == '1':
config['ignore_case'] = True
if 'AUTOJUMP_KEEP_SYMLINKS' in os.environ and \
os.environ.get('AUTOJUMP_KEEP_SYMLINKS') == '1':
config['keep_symlinks'] = True
return config
def parse_arg(config):
parser = argparse.ArgumentParser(
description='Automatically jump to \
directory passed as an argument.',
@ -263,48 +275,48 @@ def options():
'-s', '--stat', action="store_true", default=False,
help='show database entries and their key weights')
parser.add_argument(
'-v', '--version', action="version", version="%(prog)s " + VERSION,
help='show version information and exit')
ARGS = parser.parse_args()
# The home dir can be reached quickly by "cd" and may interfere with other
# directories
if (ARGS.add):
if(ARGS.add != os.path.expanduser("~")):
db = Database(DB_FILE)
db.add(decode(ARGS.add))
return True
if (ARGS.decrease):
if(ARGS.decrease != os.path.expanduser("~")):
db = Database(DB_FILE)
# FIXME: handle symlinks?
db.decrease(os.getcwd(), ARGS.decrease)
return True
if (ARGS.purge):
db = Database(DB_FILE)
'-v', '--version', action="version", version="%(prog)s " +
config['version'], help='show version information and exit')
args = parser.parse_args()
db = Database(config)
if (args.add):
if (args.add != os.path.expanduser("~")):
db.add(decode(args.add))
sys.exit(0)
if (args.decrease):
if (args.decrease != os.path.expanduser("~")):
db.decrease(os.getcwd(), args.decrease)
sys.exit(0)
if (args.purge):
removed = db.purge()
if len(removed) > 0:
if len(removed):
for dir in removed:
output(unico(dir))
print("Number of database entries removed: %d" % len(removed))
return True
if (ARGS.stat):
db = Database(DB_FILE)
dirs = list(db.data.items())
dirs.sort(key=itemgetter(1))
for path, count in dirs[-100:]:
output(unico("%.1f:\t%s") % (count, path))
sys.exit(0)
if (args.stat):
for path, weight in sorted(db.data, key=db.data.get)[-100:]:
output(unico("%.1f:\t%s") % (weight, path))
print("________________________________________\n")
print("%d:\t total key weight" % sum(db.data.values()))
print("%d:\t stored directories" % len(dirs))
print("db file: %s" % DB_FILE)
return True
return False
sys.exit(0)
config['args'] = args
return config
def decode(text, encoding=None, errors="strict"):
"""
@ -385,7 +397,7 @@ def find_matches(db, patterns, max_matches=1, ignore_case=False, fuzzy=False):
current_dir = None
dirs = list(db.data.items())
dirs.sort(key=itemgetter(1), reverse=True)
dirs.sort(key=operator.itemgetter(1), reverse=True)
results = []
if fuzzy:
from difflib import get_close_matches
@ -450,12 +462,9 @@ def find_matches(db, patterns, max_matches=1, ignore_case=False, fuzzy=False):
return results
def shell_utility():
"""
Run this when autojump is called as a shell utility.
"""
if options(): return True
db = Database(DB_FILE)
def main():
config = parse_arg(parse_env(set_defaults()))
db = Database(config['db'], config)
# if no directories, add empty string
if (ARGS.directory == ''):
@ -515,8 +524,7 @@ def shell_utility():
if not KEEP_ALL_ENTRIES:
db.maintenance()
return True
return 0
if __name__ == "__main__":
if not shell_utility():
sys.exit(1)
sys.exit(main())

Loading…
Cancel
Save