From b82eec714a2e3e34f55627f3c94f5ebef11dede8 Mon Sep 17 00:00:00 2001 From: Paul Fitzpatrick Date: Mon, 27 Jul 2020 14:57:36 -0400 Subject: [PATCH] (core) move data engine code to core Summary: this moves sandbox/grist to core, and adds a requirements.txt file for reconstructing the content of sandbox/thirdparty. Test Plan: existing tests pass. Tested core functionality manually. Tested docker build manually. Reviewers: dsagal Reviewed By: dsagal Differential Revision: https://phab.getgrist.com/D2563 --- app/server/lib/NSandbox.ts | 13 +- buildtools/prepare_python.sh | 8 + package.json | 1 + sandbox/gen_js_schema.py | 52 + sandbox/grist/acl.py | 404 +++ sandbox/grist/action_obj.py | 79 + sandbox/grist/actions.py | 204 ++ sandbox/grist/codebuilder.py | 323 ++ sandbox/grist/column.py | 391 +++ sandbox/grist/csv_patch.py | 84 + sandbox/grist/depend.py | 155 + sandbox/grist/docactions.py | 240 ++ sandbox/grist/docmodel.py | 352 ++ sandbox/grist/engine.py | 1257 +++++++ sandbox/grist/functions/__init__.py | 12 + sandbox/grist/functions/date.py | 773 +++++ sandbox/grist/functions/info.py | 520 +++ sandbox/grist/functions/logical.py | 165 + sandbox/grist/functions/lookup.py | 80 + sandbox/grist/functions/math.py | 830 +++++ sandbox/grist/functions/schedule.py | 329 ++ sandbox/grist/functions/stats.py | 615 ++++ sandbox/grist/functions/test_schedule.py | 270 ++ sandbox/grist/functions/text.py | 590 ++++ sandbox/grist/gencode.py | 189 ++ sandbox/grist/gpath.py | 141 + sandbox/grist/grist.py | 15 + sandbox/grist/identifiers.py | 117 + sandbox/grist/import_actions.py | 309 ++ sandbox/grist/imports/__init__.py | 0 .../nyc_schools_progress_report_ec_2013.xlsx | Bin 0 -> 285213 bytes sandbox/grist/imports/main.py | 60 + sandbox/grist/imports/test_messytables.py | 22 + sandbox/grist/logger.py | 74 + sandbox/grist/lookup.py | 216 ++ sandbox/grist/main.py | 119 + sandbox/grist/match_counter.py | 31 + sandbox/grist/migrations.py | 750 +++++ sandbox/grist/moment.py | 258 ++ sandbox/grist/moment_parse.py | 159 + sandbox/grist/objtypes.py | 375 +++ sandbox/grist/records.py | 167 + sandbox/grist/relabeling.py | 331 ++ sandbox/grist/relation.py | 122 + sandbox/grist/repl.py | 87 + sandbox/grist/runtests.py | 33 + sandbox/grist/sandbox.py | 100 + sandbox/grist/schema.py | 354 ++ sandbox/grist/summary.py | 319 ++ sandbox/grist/table.py | 482 +++ sandbox/grist/table_data_set.py | 133 + sandbox/grist/test_acl.py | 512 +++ sandbox/grist/test_actions.py | 79 + sandbox/grist/test_codebuilder.py | 183 + sandbox/grist/test_column_actions.py | 453 +++ sandbox/grist/test_completion.py | 98 + sandbox/grist/test_derived.py | 294 ++ sandbox/grist/test_display_cols.py | 492 +++ sandbox/grist/test_docmodel.py | 249 ++ sandbox/grist/test_engine.py | 559 ++++ sandbox/grist/test_find_col.py | 47 + sandbox/grist/test_formula_error.py | 646 ++++ sandbox/grist/test_functions.py | 29 + sandbox/grist/test_gencode.py | 204 ++ sandbox/grist/test_gpath.py | 159 + sandbox/grist/test_import_actions.py | 150 + sandbox/grist/test_import_transform.py | 174 + sandbox/grist/test_logger.py | 38 + sandbox/grist/test_lookups.py | 646 ++++ sandbox/grist/test_match_counter.py | 147 + sandbox/grist/test_migrations.py | 195 ++ sandbox/grist/test_moment.py | 328 ++ sandbox/grist/test_relabeling.py | 361 ++ sandbox/grist/test_renames.py | 390 +++ sandbox/grist/test_renames2.py | 396 +++ sandbox/grist/test_side_effects.py | 119 + sandbox/grist/test_summary.py | 843 +++++ sandbox/grist/test_summary2.py | 1022 ++++++ sandbox/grist/test_table_actions.py | 317 ++ sandbox/grist/test_table_data_set.py | 174 + sandbox/grist/test_textbuilder.py | 85 + sandbox/grist/test_treeview.py | 59 + sandbox/grist/test_twowaymap.py | 154 + sandbox/grist/test_types.py | 595 ++++ sandbox/grist/test_useractions.py | 914 +++++ sandbox/grist/testsamples.py | 43 + sandbox/grist/testscript.json | 2951 +++++++++++++++++ sandbox/grist/testutil.py | 149 + sandbox/grist/textbuilder.py | 179 + sandbox/grist/treeview.py | 32 + sandbox/grist/twowaymap.py | 252 ++ sandbox/grist/tzdata.data | Bin 0 -> 909997 bytes sandbox/grist/useractions.py | 1549 +++++++++ sandbox/grist/usercode.py | 68 + sandbox/grist/usertypes.py | 461 +++ sandbox/install_tz.js | 31 + sandbox/requirements.txt | 17 + 97 files changed, 29551 insertions(+), 2 deletions(-) create mode 100755 buildtools/prepare_python.sh create mode 100644 sandbox/gen_js_schema.py create mode 100644 sandbox/grist/acl.py create mode 100644 sandbox/grist/action_obj.py create mode 100644 sandbox/grist/actions.py create mode 100644 sandbox/grist/codebuilder.py create mode 100644 sandbox/grist/column.py create mode 100644 sandbox/grist/csv_patch.py create mode 100644 sandbox/grist/depend.py create mode 100644 sandbox/grist/docactions.py create mode 100644 sandbox/grist/docmodel.py create mode 100644 sandbox/grist/engine.py create mode 100644 sandbox/grist/functions/__init__.py create mode 100644 sandbox/grist/functions/date.py create mode 100644 sandbox/grist/functions/info.py create mode 100644 sandbox/grist/functions/logical.py create mode 100644 sandbox/grist/functions/lookup.py create mode 100644 sandbox/grist/functions/math.py create mode 100644 sandbox/grist/functions/schedule.py create mode 100644 sandbox/grist/functions/stats.py create mode 100644 sandbox/grist/functions/test_schedule.py create mode 100644 sandbox/grist/functions/text.py create mode 100644 sandbox/grist/gencode.py create mode 100644 sandbox/grist/gpath.py create mode 100644 sandbox/grist/grist.py create mode 100644 sandbox/grist/identifiers.py create mode 100644 sandbox/grist/import_actions.py create mode 100644 sandbox/grist/imports/__init__.py create mode 100644 sandbox/grist/imports/fixtures/nyc_schools_progress_report_ec_2013.xlsx create mode 100644 sandbox/grist/imports/main.py create mode 100644 sandbox/grist/imports/test_messytables.py create mode 100644 sandbox/grist/logger.py create mode 100644 sandbox/grist/lookup.py create mode 100644 sandbox/grist/main.py create mode 100644 sandbox/grist/match_counter.py create mode 100644 sandbox/grist/migrations.py create mode 100644 sandbox/grist/moment.py create mode 100644 sandbox/grist/moment_parse.py create mode 100644 sandbox/grist/objtypes.py create mode 100644 sandbox/grist/records.py create mode 100644 sandbox/grist/relabeling.py create mode 100644 sandbox/grist/relation.py create mode 100644 sandbox/grist/repl.py create mode 100644 sandbox/grist/runtests.py create mode 100644 sandbox/grist/sandbox.py create mode 100644 sandbox/grist/schema.py create mode 100644 sandbox/grist/summary.py create mode 100644 sandbox/grist/table.py create mode 100644 sandbox/grist/table_data_set.py create mode 100644 sandbox/grist/test_acl.py create mode 100644 sandbox/grist/test_actions.py create mode 100644 sandbox/grist/test_codebuilder.py create mode 100644 sandbox/grist/test_column_actions.py create mode 100644 sandbox/grist/test_completion.py create mode 100644 sandbox/grist/test_derived.py create mode 100644 sandbox/grist/test_display_cols.py create mode 100644 sandbox/grist/test_docmodel.py create mode 100644 sandbox/grist/test_engine.py create mode 100644 sandbox/grist/test_find_col.py create mode 100644 sandbox/grist/test_formula_error.py create mode 100644 sandbox/grist/test_functions.py create mode 100644 sandbox/grist/test_gencode.py create mode 100644 sandbox/grist/test_gpath.py create mode 100644 sandbox/grist/test_import_actions.py create mode 100644 sandbox/grist/test_import_transform.py create mode 100644 sandbox/grist/test_logger.py create mode 100644 sandbox/grist/test_lookups.py create mode 100644 sandbox/grist/test_match_counter.py create mode 100644 sandbox/grist/test_migrations.py create mode 100644 sandbox/grist/test_moment.py create mode 100644 sandbox/grist/test_relabeling.py create mode 100644 sandbox/grist/test_renames.py create mode 100644 sandbox/grist/test_renames2.py create mode 100644 sandbox/grist/test_side_effects.py create mode 100644 sandbox/grist/test_summary.py create mode 100644 sandbox/grist/test_summary2.py create mode 100644 sandbox/grist/test_table_actions.py create mode 100644 sandbox/grist/test_table_data_set.py create mode 100644 sandbox/grist/test_textbuilder.py create mode 100644 sandbox/grist/test_treeview.py create mode 100644 sandbox/grist/test_twowaymap.py create mode 100644 sandbox/grist/test_types.py create mode 100644 sandbox/grist/test_useractions.py create mode 100644 sandbox/grist/testsamples.py create mode 100644 sandbox/grist/testscript.json create mode 100644 sandbox/grist/testutil.py create mode 100644 sandbox/grist/textbuilder.py create mode 100644 sandbox/grist/treeview.py create mode 100644 sandbox/grist/twowaymap.py create mode 100644 sandbox/grist/tzdata.data create mode 100644 sandbox/grist/useractions.py create mode 100644 sandbox/grist/usercode.py create mode 100644 sandbox/grist/usertypes.py create mode 100644 sandbox/install_tz.js create mode 100644 sandbox/requirements.txt diff --git a/app/server/lib/NSandbox.ts b/app/server/lib/NSandbox.ts index 159929e3..34ba33df 100644 --- a/app/server/lib/NSandbox.ts +++ b/app/server/lib/NSandbox.ts @@ -16,6 +16,7 @@ type SandboxMethod = (...args: any[]) => any; export interface ISandboxCommand { process: string; + libraryPath: string; } export interface ISandboxOptions { @@ -59,7 +60,7 @@ export class NSandbox implements ISandbox { if (command) { return spawn(command.process, pythonArgs, - {env: {PYTHONPATH: 'grist:thirdparty'}, + {env: {PYTHONPATH: command.libraryPath}, cwd: path.join(process.cwd(), 'sandbox'), ...spawnOptions}); } @@ -319,7 +320,13 @@ export class NSandboxCreator implements ISandboxCreator { } public create(options: ISandboxCreationOptions): ISandbox { + // Main script to run. const defaultEntryPoint = this._flavor === 'pynbox' ? 'grist/main.pyc' : 'grist/main.py'; + // Python library path is only configurable when flavor is unsandboxed. + // In this case, expect to find library files in a virtualenv built by core + // buildtools/prepare_python.sh + const pythonVersion = 'python2.7'; + const libraryPath = `grist:../venv/lib/${pythonVersion}/site-packages`; const args = [options.entryPoint || defaultEntryPoint]; if (!options.entryPoint && options.comment) { // When using default entry point, we can add on a comment as an argument - it isn't @@ -332,6 +339,7 @@ export class NSandboxCreator implements ISandboxCreator { selLdrArgs.push( // TODO: Only modules that we share with plugins should be mounted. They could be gathered in // a "$APPROOT/sandbox/plugin" folder, only which get mounted. + // TODO: These settings only make sense for pynbox flavor. '-E', 'PYTHONPATH=grist:thirdparty', '-m', `${options.sandboxMount}:/sandbox:ro`); } @@ -346,7 +354,8 @@ export class NSandboxCreator implements ISandboxCreator { selLdrArgs, ...(this._flavor === 'pynbox' ? {} : { command: { - process: "python2.7" + process: pythonVersion, + libraryPath } }) }); diff --git a/buildtools/prepare_python.sh b/buildtools/prepare_python.sh new file mode 100755 index 00000000..4f0cc408 --- /dev/null +++ b/buildtools/prepare_python.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +if [ ! -e venv ]; then + virtualenv -ppython2.7 venv +fi + +. venv/bin/activate +pip install --no-deps -r sandbox/requirements.txt diff --git a/package.json b/package.json index f9adaf63..a3b27aac 100644 --- a/package.json +++ b/package.json @@ -5,6 +5,7 @@ "main": "index.js", "scripts": { "start": "tsc --build -w --preserveWatchOutput & webpack --config buildtools/webpack.config.js --mode development --watch --hide-modules & NODE_PATH=_build:_build/stubs nodemon -w _build/app/server -w _build/app/common _build/stubs/app/server/server.js & wait", + "install:python": "buildtools/prepare_python.sh", "build:prod": "tsc --build && webpack --config buildtools/webpack.config.js --mode production", "start:prod": "node _build/stubs/app/server/server" }, diff --git a/sandbox/gen_js_schema.py b/sandbox/gen_js_schema.py new file mode 100644 index 00000000..4154ec48 --- /dev/null +++ b/sandbox/gen_js_schema.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python -B +""" +Generates a JS schema file from sandbox/grist/schema.py. +""" + +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), 'grist')) + +import schema # pylint: disable=import-error,wrong-import-position + +# These are the types that appear in Grist metadata columns. +_ts_types = { + "Bool": "boolean", + "DateTime": "number", + "Int": "number", + "PositionNumber": "number", + "Ref": "number", + "Text": "string", +} + +def get_ts_type(col_type): + col_type = col_type.split(':', 1)[0] # Strip suffix for Ref:, DateTime:, etc. + return _ts_types.get(col_type, "CellValue") + +def main(): + print """ +/*** THIS FILE IS AUTO-GENERATED BY %s ***/ +// tslint:disable:object-literal-key-quotes + +export const schema = { +""" % __file__ + + for table in schema.schema_create_actions(): + print ' "%s": {' % table.table_id + for column in table.columns: + print ' %-20s: "%s",' % (column['id'], column['type']) + print ' },\n' + + print """}; + +export interface SchemaTypes { +""" + for table in schema.schema_create_actions(): + print ' "%s": {' % table.table_id + for column in table.columns: + print ' %s: %s;' % (column['id'], get_ts_type(column['type'])) + print ' };\n' + print "}" + +if __name__ == '__main__': + main() diff --git a/sandbox/grist/acl.py b/sandbox/grist/acl.py new file mode 100644 index 00000000..e2a09a43 --- /dev/null +++ b/sandbox/grist/acl.py @@ -0,0 +1,404 @@ +# Access Control Lists. +# +# This modules is used by engine.py to split actions according to recipient, as well as to +# validate whether an action received from a peer is allowed by the rules. + +# Where are ACLs applied? +# ----------------------- +# Read ACLs (which control who can see data) are implemented by "acl_read_split" operation, which +# takes an action group and returns an action bundle, which is a list of ActionEnvelopes, each +# containing a smaller action group associated with a set of recipients who should get it. Note +# that the order of ActionEnvelopes matters, and actions should be applied in that order. +# +# In principle, this operation can be done either in the Python data engine or in Node. We do it +# in Python. The clearest reason is the need to apply ACL formulas. Not that it's impossible to do +# on the Node side, but currently formula values are only maintained on the Python side. + +# UserActions and ACLs +# -------------------- +# Each actions starts with a UserAction, which is turned by the data engine into a number of +# DocActions. We then split DocActions by recipient according to ACL rules. But should recipients +# receive UserActions too? +# +# If UserAction is shared, we need to split it similarly to docactions, because it will often +# contain data that some recipients should not see (e.g. a BulkUpdateRecord user-action generated +# by a copy-paste). An additional difficulty is that splitting by recipient may sometimes require +# creating multiple actions. Further, trimmed UserActions aren't enough for purposes (2) or (3). +# +# Our solution will be not to send around UserActions at all, since DocActions are sufficient to +# update a document. But UserActions are needed for some things: +# (1) Present a meaningful description to users for the action log. This should be possible, and +# may in fact be better, to do from docactions only. +# (2) Redo actions. We currently use UserActions for this, but we can treat this as "undo of the +# undo", relying on docactions, which is in fact more general. Any difficulties with that +# are the same as for Undo, and are not specific to Redo anyway. +# (3) Rebase pending actions after getting peers' actions from the hub. This only needs to be +# done by the authoring instance, which will keep its original UserAction. We don't need to +# share the UserAction for this purpose. + +# Initial state +# ------------- +# With sharing enabled, the ACL rules have particular defaults (in particular, with the current +# user included in the Owners group, and that group having full access in the default rule). +# Before sharing is enabled, this cannot be completely set up, because the current user is +# unknown, nor are the user's instances, and the initial instance may not even have an instanceId. +# +# Our approach is that default rules and groups are created immediately, before sharing is +# enabled. The Owners group stays empty, and action bundles end up destined for an empty list of +# recipients. Node handles empty list of recipients as its own instanceId when sharing is off. +# +# When sharing is enabled, actions are sent to add a user with the user's instances, including the +# current instance's real instanceId, to the Owners group, and Node stops handling empty list of +# recipients as special, relying on the presence of the actual instanceId instead. + +# Identifying tables and columns +# ------------------------------ +# If we use tableId and colId in rules, then rules need to be adjusted when tables or columns are +# renamed or removed. If we use tableRef and colRef, then such rules cannot apply to metadata +# tables (which don't have refs at all). +# +# Additionally, a benefit of using tableId and colId is that this is how actions identify tables +# and columns, so rules can be applied to actions without additional lookups. +# +# For these reasons, we use tableId and colId, rather than refs (row-ids). + +# Summary tables +# -------------- +# It's not sufficient to identify summary tables by their actual tableId or by tableRef, since +# both may change when summaries are removed and recreated. They should instead be identified +# by a value similar to tableTitle() in DocModel.js, specifically by the combination of source +# tableId and colIds of all group-by columns. + +# Actions on Permission Changes +# ----------------------------- +# When VIEW principals are added or removed for a table/column, or a VIEW ACLFormula adds or +# removes principals for a row, those principals need to receive AddRecord or RemoveRecord +# doc-actions (or equivalent). TODO: This needs to be handled. + +from collections import OrderedDict +import action_obj +import logger +log = logger.Logger(__name__, logger.INFO) + +class Permissions(object): + # Permission types and their combination are represented as bits of a single integer. + VIEW = 0x1 + UPDATE = 0x2 + ADD = 0x4 + REMOVE = 0x8 + SCHEMA_EDIT = 0x10 + ACL_EDIT = 0x20 + EDITOR = VIEW | UPDATE | ADD | REMOVE + ADMIN = EDITOR | SCHEMA_EDIT + OWNER = ADMIN | ACL_EDIT + + @classmethod + def includes(cls, superset, subset): + return (superset & subset) == subset + + @classmethod + def includes_view(cls, permissions): + return cls.includes(permissions, cls.VIEW) + + +# Sentinel object to represent "all rows", for internal use in this file. +_ALL_ROWS = "ALL_ROWS" + + +# Representation of ACL resources that's used by the ACL class. An instance of this class becomes +# the value of DocInfo.acl_resources formula. +# Note that the default ruleset (for tableId None, colId None) must exist. +# TODO: ensure that the default ruleset is created, and cannot be deleted. +class ResourceMap(object): + def __init__(self, resource_records): + self._col_resources = {} # Maps table_id to [(resource, col_id_set), ...] + self._default_resources = {} # Maps table_id (or None for global default) to resource record. + for resource in resource_records: + # Note that resource.tableId is the empty string ('') for the default table (represented as + # None in self._default_resources), and resource.colIds is '' for the table's default rule. + table_id = resource.tableId or None + if not resource.colIds: + self._default_resources[table_id] = resource + else: + col_id_set = set(resource.colIds.split(',')) + self._col_resources.setdefault(table_id, []).append((resource, col_id_set)) + + def get_col_resources(self, table_id): + """ + Returns a list of (resource, col_id_set) pairs, where resource is a record in ACLResources. + """ + return self._col_resources.get(table_id, []) + + def get_default_resource(self, table_id): + """ + Returns the "default" resource record for the given table. + """ + return self._default_resources.get(table_id) or self._default_resources.get(None) + +# Used by docmodel.py for DocInfo.acl_resources formula. +def build_resources(resource_records): + return ResourceMap(resource_records) + + +class ACL(object): + # Special recipients, or instanceIds. ALL is the special recipient for schema actions that + # should be shared with all collaborators of the document. + ALL = '#ALL' + ALL_SET = frozenset([ALL]) + EMPTY_SET = frozenset([]) + + def __init__(self, docmodel): + self._docmodel = docmodel + + def get_acl_resources(self): + try: + return self._docmodel.doc_info.table.get_record(1).acl_resources + except KeyError: + return None + + def _find_resources(self, table_id, col_ids): + """ + Yields tuples (resource, col_id_set) where each col_id_set represents the intersection of the + resouces's columns with col_ids. These intersections may be empty. + + If col_ids is None, then it's treated as "all columns", and each col_id_set represents all of + the resource's columns. For the default resource then, it yields (resource, None) + """ + resource_map = self.get_acl_resources() + + if col_ids is None: + for resource, col_id_set in resource_map.get_col_resources(table_id): + yield resource, col_id_set + resource = resource_map.get_default_resource(table_id) + yield resource, None + + else: + seen = set() + for resource, col_id_set in resource_map.get_col_resources(table_id): + seen.update(col_id_set) + yield resource, col_id_set.intersection(col_ids) + + resource = resource_map.get_default_resource(table_id) + yield resource, set(c for c in col_ids if c not in seen) + + @classmethod + def _acl_read_split_rows(cls, resource, row_ids): + """ + Scans through ACL rules for the resouce, yielding tuples of the form (rule, row_id, + instances), to say which rowId should be sent to each set of instances according to the rule. + """ + for rule in resource.ruleset: + if not Permissions.includes_view(rule.permissions): + continue + common_instances = _get_instances(rule.principalsList) + if rule.aclColumn and row_ids is not None: + for (row_id, principals) in get_row_principals(rule.aclColumn, row_ids): + yield (rule, row_id, common_instances | _get_instances(principals)) + else: + yield (rule, _ALL_ROWS, common_instances) + + @classmethod + def _acl_read_split_instance_sets(cls, resource, row_ids): + """ + Yields tuples of the form (instances, rowset, rules) for different sets of instances, to say + which rowIds for the given resource should be sent to each set of instances, and which rules + enabled that. When a set of instances should get all rows, rowset is None. + """ + for instances, group in _group((instances, (rule, row_id)) for (rule, row_id, instances) + in cls._acl_read_split_rows(resource, row_ids)): + rules = set(item[0] for item in group) + rowset = frozenset(item[1] for item in group) + yield (instances, _ALL_ROWS if _ALL_ROWS in rowset else rowset, rules) + + @classmethod + def _acl_read_split_resource(cls, resource, row_ids, docaction, output): + """ + Given an ACLResource record and optionally row_ids (which may be None), appends to output + tuples of the form `(instances, rules, action)`, where `action` is docaction itself or a part + of it that should be sent to the corresponding set of instances. + """ + if docaction is None: + return + + # Different rules may produce different recipients for the same set of rows. We group outputs + # by sets of rows (which determine a subaction), and take a union of all the recipients. + for rowset, group in _group((rowset, (instances, rules)) for (instances, rowset, rules) + in cls._acl_read_split_instance_sets(resource, row_ids)): + da = docaction if rowset is _ALL_ROWS else _subaction(docaction, row_ids=rowset) + if da is not None: + all_instances = frozenset(i for item in group for i in item[0]) + all_rules = set(r for item in group for r in item[1]) + output.append((all_instances, all_rules, da)) + + def _acl_read_split_docaction(self, docaction, output): + """ + Given just a docaction, appends to output tuples of the form `(instances, rules, action)`, + where `action` is docaction itself or a part of it that should be sent to `instances`, and + `rules` is the set of ACLRules that allowed that (empty set for schema actions). + """ + parts = _get_docaction_parts(docaction) + if parts is None: # This is a schema action, to send to everyone. + # We want to send schema actions to everyone on the document, represented by None. + output.append((ACL.ALL_SET, set(), docaction)) + return + + table_id, row_ids, col_ids = parts + for resource, col_id_set in self._find_resources(table_id, col_ids): + da = _subaction(docaction, col_ids=col_id_set) + if da is not None: + self._acl_read_split_resource(resource, row_ids, da, output) + + def _acl_read_split_docactions(self, docactions): + """ + Returns a list of tuples `(instances, rules, action)`. See _acl_read_split_docaction. + """ + if not self.get_acl_resources(): + return [(ACL.EMPTY_SET, None, da) for da in docactions] + + output = [] + for da in docactions: + self._acl_read_split_docaction(da, output) + return output + + def acl_read_split(self, action_group): + """ + Returns an ActionBundle, containing actions from the given action_group, split by the sets of + instances to which actions should be sent. + """ + bundle = action_obj.ActionBundle() + envelopeIndices = {} # Maps instance-sets to envelope indices. + + def getEnvIndex(instances): + envIndex = envelopeIndices.setdefault(instances, len(bundle.envelopes)) + if envIndex == len(bundle.envelopes): + bundle.envelopes.append(action_obj.Envelope(instances)) + return envIndex + + def split_into_envelopes(docactions, out_rules, output): + for (instances, rules, action) in self._acl_read_split_docactions(docactions): + output.append((getEnvIndex(instances), action)) + if rules: + out_rules.update(r.id for r in rules) + + split_into_envelopes(action_group.stored, bundle.rules, bundle.stored) + split_into_envelopes(action_group.calc, bundle.rules, bundle.calc) + split_into_envelopes(action_group.undo, bundle.rules, bundle.undo) + bundle.retValues = action_group.retValues + return bundle + + +class OrderedDefaultListDict(OrderedDict): + def __missing__(self, key): + self[key] = value = [] + return value + +def _group(iterable_of_pairs): + """ + Group iterable of pairs (a, b), returning pairs (a, [list of b]). The order of the groups, and + of items within a group, is according to the first seen. + """ + groups = OrderedDefaultListDict() + for key, value in iterable_of_pairs: + groups[key].append(value) + return groups.iteritems() + + +def _get_instances(principals): + """ + Returns a frozenset of all instances for all passed-in principals. + """ + instances = set() + for p in principals: + instances.update(i.instanceId for i in p.allInstances) + return frozenset(instances) + + +def get_row_principals(_acl_column, _rows): + # TODO TBD. Need to implement this (with tests) for acl-formulas for row-level access control. + return [] + + +#---------------------------------------------------------------------- + +def _get_docaction_parts(docaction): + """ + Returns a tuple of (table_id, row_ids, col_ids), any of whose members may be None, or None if + this action should not get split. + """ + return _docaction_part_helpers[docaction.__class__.__name__](docaction) + +# Helpers for _get_docaction_parts to extract for each action type the table, rows, and columns +# that a docaction of that type affects. Note that we are only talking here about the data +# affected. Schema actions do not get trimmed, since we decided against having a separate +# (and confusing) "SCHEMA_VIEW" permission. All peers will know the schema. +_docaction_part_helpers = { + 'AddRecord' : lambda a: (a.table_id, [a.row_id], a.columns.keys()), + 'BulkAddRecord' : lambda a: (a.table_id, a.row_ids, a.columns.keys()), + 'RemoveRecord' : lambda a: (a.table_id, [a.row_id], None), + 'BulkRemoveRecord' : lambda a: (a.table_id, a.row_ids, None), + 'UpdateRecord' : lambda a: (a.table_id, [a.row_id], a.columns.keys()), + 'BulkUpdateRecord' : lambda a: (a.table_id, a.row_ids, a.columns.keys()), + 'ReplaceTableData' : lambda a: (a.table_id, a.row_ids, a.columns.keys()), + 'AddColumn' : lambda a: None, + 'RemoveColumn' : lambda a: None, + 'RenameColumn' : lambda a: None, + 'ModifyColumn' : lambda a: None, + 'AddTable' : lambda a: None, + 'RemoveTable' : lambda a: None, + 'RenameTable' : lambda a: None, +} + + +#---------------------------------------------------------------------- + +def _subaction(docaction, row_ids=None, col_ids=None): + """ + For data actions, extracts and returns a part of docaction that applies only to the given + row_ids and/or col_ids, if given. If the part of the action is empty, returns None. + """ + helper = _subaction_helpers[docaction.__class__.__name__] + try: + return docaction.__class__._make(helper(docaction, row_ids, col_ids)) + except _NoMatch: + return None + +# Helpers for _subaction(), one for each action type, which return the tuple of values for the +# trimmed action. From this tuple a new action is automatically created by _subaction. If any part +# of the action becomes empty, the helpers raise _NoMatch exception. +_subaction_helpers = { + # pylint: disable=line-too-long + 'AddRecord' : lambda a, r, c: (a.table_id, match(r, a.row_id), match_keys_keep_empty(c, a.columns)), + 'BulkAddRecord' : lambda a, r, c: (a.table_id, match_list(r, a.row_ids), match_keys_keep_empty(c, a.columns)), + 'RemoveRecord' : lambda a, r, c: (a.table_id, match(r, a.row_id)), + 'BulkRemoveRecord' : lambda a, r, c: (a.table_id, match_list(r, a.row_ids)), + 'UpdateRecord' : lambda a, r, c: (a.table_id, match(r, a.row_id), match_keys_skip_empty(c, a.columns)), + 'BulkUpdateRecord' : lambda a, r, c: (a.table_id, match_list(r, a.row_ids), match_keys_skip_empty(c, a.columns)), + 'ReplaceTableData' : lambda a, r, c: (a.table_id, match_list(r, a.row_ids), match_keys_keep_empty(c, a.columns)), + 'AddColumn' : lambda a, r, c: a, + 'RemoveColumn' : lambda a, r, c: a, + 'RenameColumn' : lambda a, r, c: a, + 'ModifyColumn' : lambda a, r, c: a, + 'AddTable' : lambda a, r, c: a, + 'RemoveTable' : lambda a, r, c: a, + 'RenameTable' : lambda a, r, c: a, +} + +def match(subset, item): + return item if (subset is None or item in subset) else no_match() + +def match_list(subset, items): + return items if subset is None else ([i for i in items if i in subset] or no_match()) + +def match_keys_keep_empty(subset, items): + return items if subset is None else ( + {k: v for (k, v) in items.iteritems() if k in subset}) + +def match_keys_skip_empty(subset, items): + return items if subset is None else ( + {k: v for (k, v) in items.iteritems() if k in subset} or no_match()) + +class _NoMatch(Exception): + pass + +def no_match(): + raise _NoMatch() diff --git a/sandbox/grist/action_obj.py b/sandbox/grist/action_obj.py new file mode 100644 index 00000000..0bff6576 --- /dev/null +++ b/sandbox/grist/action_obj.py @@ -0,0 +1,79 @@ +""" +This module defines ActionGroup, ActionEnvelope, and ActionBundle -- classes that together +represent the result of applying a UserAction to a document. + +In general, UserActions refer to logical actions performed by the user. DocActions are the +individual steps to which UserActions translate. + +A list of UserActions applied together translates to multiple DocActions, packaged into an +ActionGroup. In a separate step, this ActionGroup is split up according to ACL rules into and +ActionBundle consisting of ActionEnvelopes, each containing a smaller set of actions associated +with the set of recipients who should receive them. +""" + +import actions + + +class ActionGroup(object): + """ + ActionGroup packages different types of doc actions for returning them to the instance. + + The ActionGroup stores actions produced by the engine in the course of processing one or more + UserActions, plus an array of return values, one for each UserAction. + """ + def __init__(self): + self.calc = [] + self.stored = [] + self.undo = [] + self.retValues = [] + + def get_repr(self): + return { + "calc": map(actions.get_action_repr, self.calc), + "stored": map(actions.get_action_repr, self.stored), + "undo": map(actions.get_action_repr, self.undo), + "retValues": self.retValues + } + + @classmethod + def from_json_obj(cls, data): + ag = ActionGroup() + ag.calc = map(actions.action_from_repr, data.get('calc', [])) + ag.stored = map(actions.action_from_repr, data.get('stored', [])) + ag.undo = map(actions.action_from_repr, data.get('undo', [])) + ag.retValues = data.get('retValues', []) + return ag + + +class Envelope(object): + """ + Envelope contains information about recipients as a set (or frozenset) of instanceIds. + """ + def __init__(self, recipient_set): + self.recipients = recipient_set + + def to_json_obj(self): + return {"recipients": sorted(self.recipients)} + +class ActionBundle(object): + """ + ActionBundle contains actions arranged into envelopes, i.e. split up by sets of recipients. + Note that different Envelopes contain different sets of recipients (which may overlap however). + """ + def __init__(self): + self.envelopes = [] + self.stored = [] # Pairs of (envIndex, docAction) + self.calc = [] # Pairs of (envIndex, docAction) + self.undo = [] # Pairs of (envIndex, docAction) + self.retValues = [] + self.rules = set() # RowIds of ACLRule records used to construct this ActionBundle. + + def to_json_obj(self): + return { + "envelopes": [e.to_json_obj() for e in self.envelopes], + "stored": [(env, actions.get_action_repr(a)) for (env, a) in self.stored], + "calc": [(env, actions.get_action_repr(a)) for (env, a) in self.calc], + "undo": [(env, actions.get_action_repr(a)) for (env, a) in self.undo], + "retValues": self.retValues, + "rules": sorted(self.rules) + } diff --git a/sandbox/grist/actions.py b/sandbox/grist/actions.py new file mode 100644 index 00000000..14cfccb1 --- /dev/null +++ b/sandbox/grist/actions.py @@ -0,0 +1,204 @@ +""" +actions.py defines the action objects used in the Python code, and functions to convert between +them and the serializable docActions objects used to communicate with the outside. + +When communicating with Node, docActions are represented as arrays [actionName, arguments...]. +""" + +from collections import namedtuple +import inspect + +import objtypes + +def _eq_with_type(self, other): + # pylint: disable=unidiomatic-typecheck + return tuple(self) == tuple(other) and type(self) == type(other) + +def _ne_with_type(self, other): + return not _eq_with_type(self, other) + +def namedtuple_eq(typename, field_names): + """ + Just like namedtuple, but these objects are only considered equal to other objects of the same + type (not just to any tuple with the same values). + """ + n = namedtuple(typename, field_names) + n.__eq__ = _eq_with_type + n.__ne__ = _ne_with_type + return n + +# For Record actions, the parameters are as follows: +# table_id: string name of the table. +# row_id: numeric row identifier +# row_ids: list of row identifiers +# columns: dictionary mapping col_id (string name of column) to the value for the given +# row_id, or an array of values parallel to the array of row_ids. +AddRecord = namedtuple_eq('AddRecord', ('table_id', 'row_id', 'columns')) +BulkAddRecord = namedtuple_eq('BulkAddRecord', ('table_id', 'row_ids', 'columns')) +RemoveRecord = namedtuple_eq('RemoveRecord', ('table_id', 'row_id')) +BulkRemoveRecord = namedtuple_eq('BulkRemoveRecord', ('table_id', 'row_ids')) +UpdateRecord = namedtuple_eq('UpdateRecord', ('table_id', 'row_id', 'columns')) +BulkUpdateRecord = namedtuple_eq('BulkUpdateRecord', ('table_id', 'row_ids', 'columns')) + +# Identical to BulkAddRecord, but implies emptying out the table first. +ReplaceTableData = namedtuple_eq('ReplaceTableData', BulkAddRecord._fields) + +# For Column actions, the parameters are: +# table_id: string name of the table. +# col_id: string name of column +# col_info: dictionary with particular keys +# type: string type of the column +# isFormula: bool, whether it is a formula column +# formula: string text of the formula, or empty string +# Other keys may be set in col_info (e.g. widgetOptions, label) but are not currently used in +# the schema (only such values from the metadata tables is used). +AddColumn = namedtuple_eq('AddColumn', ('table_id', 'col_id', 'col_info')) +RemoveColumn = namedtuple_eq('RemoveColumn', ('table_id', 'col_id')) +RenameColumn = namedtuple_eq('RenameColumn', ('table_id', 'old_col_id', 'new_col_id')) +ModifyColumn = namedtuple_eq('ModifyColumn', ('table_id', 'col_id', 'col_info')) + +# For Table actions, the parameters are: +# table_id: string name of the table. +# columns: array of col_info objects, as described for Column actions above, containing also: +# id: string name of the column (aka col_id in Column actions) +AddTable = namedtuple_eq('AddTable', ('table_id', 'columns')) +RemoveTable = namedtuple_eq('RemoveTable', ('table_id',)) +RenameTable = namedtuple_eq('RenameTable', ('old_table_id', 'new_table_id')) + +# Identical to BulkAddRecord, just a clearer type name for loading or fetching data. +TableData = namedtuple_eq('TableData', BulkAddRecord._fields) + +action_types = dict((key, val) for (key, val) in globals().items() + if inspect.isclass(val) and issubclass(val, tuple)) + +# This is the set of names of all the actions that affect the schema. +schema_actions = {name for name in action_types + if name.endswith("Column") or name.endswith("Table")} + +def _add_simplify(SingleActionType, BulkActionType): + """ + Add .simplify method to "Bulk" actions, which returns None for no rows, non-Bulk version for a + single row, and the original action otherwise. + """ + if len(SingleActionType._fields) < 3: + def get_first(self): + return SingleActionType(self.table_id, self.row_ids[0]) + else: + def get_first(self): + return SingleActionType(self.table_id, self.row_ids[0], + { key: col[0] for key, col in self.columns.iteritems()}) + def simplify(self): + return None if not self.row_ids else (get_first(self) if len(self.row_ids) == 1 else self) + + BulkActionType.simplify = simplify + +_add_simplify(AddRecord, BulkAddRecord) +_add_simplify(RemoveRecord, BulkRemoveRecord) +_add_simplify(UpdateRecord, BulkUpdateRecord) + + +def get_action_repr(action_obj): + """ + Converts an action object, such as UpdateRecord into a docAction array. + """ + return [action_obj.__class__.__name__] + list(encode_objects(action_obj)) + +def action_from_repr(doc_action): + """ + Converts a docAction array into an object such as UpdateRecord. + """ + action_type = action_types.get(doc_action[0]) + if not action_type: + raise ValueError('Unknown action %s' % (doc_action[0],)) + + try: + return decode_objects(action_type(*doc_action[1:])) + except TypeError as e: + raise TypeError("%s: %s" % (doc_action[0], e.message)) + + +def convert_recursive_helper(converter, data): + """ + Given JSON-like data (a nested collection of lists or arrays), which may include Action tuples, + replaces all primitive values with converter(value). It should be used as follows: + + def my_convert(data): + if data needs conversion: + return converted_value + return convert_recursive_helper(my_convert, data) + """ + if isinstance(data, dict): + return {converter(k): converter(v) for k, v in data.iteritems()} + elif isinstance(data, list): + return [converter(el) for el in data] + elif isinstance(data, tuple): + return type(data)(*[converter(el) for el in data]) + else: + return data + +def convert_action_values(converter, action): + """ + Replaces all data values in an action that includes actual data with converter(value). + """ + if isinstance(action, (AddRecord, UpdateRecord)): + return type(action)(action.table_id, action.row_id, + {k: converter(v) for k, v in action.columns.iteritems()}) + if isinstance(action, (BulkAddRecord, BulkUpdateRecord, ReplaceTableData, TableData)): + return type(action)(action.table_id, action.row_ids, + {k: map(converter, v) for k, v in action.columns.iteritems()}) + return action + +def convert_recursive_in_action(converter, data): + """ + Like convert_recursive_helper, but only values of Grist cells (i.e. individual values in data + columns) get passed through converter. + """ + def inner(data): + if isinstance(data, tuple): + return convert_action_values(converter, data) + return convert_recursive_helper(inner, data) + return inner(data) + +def encode_objects(data): + return convert_recursive_in_action(objtypes.encode_object, data) + +def decode_objects(data, decoder=objtypes.decode_object): + """ + Decode objects in values of a DocAction or a data structure containing DocActions. + """ + return convert_recursive_in_action(decoder, data) + +def decode_bulk_values(bulk_values, decoder=objtypes.decode_object): + """ + Decode objects in values of the form {col_id: array_of_values}, as present in bulk DocActions + and UserActions. + """ + return {k: map(decoder, v) for (k, v) in bulk_values.iteritems()} + +def transpose_bulk_action(bulk_action): + """ + Generates namedtuples for records in a bulk action such as BulkAddRecord. Such actions store + values by columns, so in effect this transposes them, yielding them by rows. + """ + items = sorted(bulk_action.columns.items()) + RecordType = namedtuple('Record', ['id'] + [col_id for (col_id, values) in items]) + for row in zip(bulk_action.row_ids, *[values for (col_id, values) in items]): + yield RecordType(*row) + + +def prune_actions(action_list, table_id, col_id): + """ + Modifies action_list in-place, removing any mention of (table_id, col_id). Both must be given + and not None in this implementation. + """ + keep = [] + for a in action_list: + if getattr(a, 'table_id', None) == table_id: + if hasattr(a, 'columns'): + a.columns.pop(col_id, None) + if not a.columns: + continue + if getattr(a, 'col_id', None) == col_id: + continue + keep.append(a) + action_list[:] = keep diff --git a/sandbox/grist/codebuilder.py b/sandbox/grist/codebuilder.py new file mode 100644 index 00000000..ec73ade6 --- /dev/null +++ b/sandbox/grist/codebuilder.py @@ -0,0 +1,323 @@ +import ast +import contextlib +import re +import six + +import astroid +import asttokens +import textbuilder +import logger +log = logger.Logger(__name__, logger.INFO) + + +DOLLAR_REGEX = re.compile(r'\$(?=[a-zA-Z_][a-zA-Z_0-9]*)') + +# For functions needing lazy evaluation, the slice for which arguments to wrap in a lambda. +LAZY_ARG_FUNCTIONS = { + 'IF': slice(1, 3), + 'ISERR': slice(0, 1), + 'ISERROR': slice(0, 1), + 'IFERROR': slice(0, 1), +} + +def make_formula_body(formula, default_value, assoc_value=None): + """ + Given a formula, returns a textbuilder.Builder object suitable to be the body of a function, + with the formula transformed to replace `$foo` with `rec.foo`, and to insert `return` if + appropriate. Assoc_value is associated with textbuilder.Text() to be returned by map_back_patch. + """ + if isinstance(formula, six.binary_type): + formula = formula.decode('utf8') + + if not formula.strip(): + return textbuilder.Text('return ' + repr(default_value), assoc_value) + + formula_builder_text = textbuilder.Text(formula, assoc_value) + + # Start with a temporary builder, since we need to translate "$" before we can parse the code at + # all (namely, we turn '$foo' into 'DOLLARfoo' first). Once we can parse the code, we'll create + # a proper set of patches. Note that we initially translate into 'DOLLARfoo' rather than + # 'rec.foo', so that the translated entity is a single token: this makes for more precisely + # reported errors if there are any. + tmp_patches = textbuilder.make_regexp_patches(formula, DOLLAR_REGEX, 'DOLLAR') + tmp_formula = textbuilder.Replacer(formula_builder_text, tmp_patches) + + # Parse the formula into an abstract syntax tree (AST), catching syntax errors. + try: + atok = asttokens.ASTTokens(tmp_formula.get_text(), parse=True) + except SyntaxError as e: + return textbuilder.Text(_create_syntax_error_code(tmp_formula, formula, e)) + + # Parse formula and generate error code on assignment to rec + with use_inferences(InferRecAssignment): + try: + astroid.parse(tmp_formula.get_text()) + except SyntaxError as e: + return textbuilder.Text(_create_syntax_error_code(tmp_formula, formula, e)) + + # Once we have a tree, go through it and create a subset of the dollar patches that are actually + # relevant. E.g. this is where we'll skip the "$foo" patches that appear in strings or comments. + patches = [] + for node in ast.walk(atok.tree): + if isinstance(node, ast.Name) and node.id.startswith('DOLLAR'): + input_pos = tmp_formula.map_back_offset(node.first_token.startpos) + m = DOLLAR_REGEX.match(formula, input_pos) + # If there is no match, then we must have had a "DOLLARblah" identifier that didn't come + # from translating a "$" prefix. + if m: + patches.append(textbuilder.make_patch(formula, m.start(0), m.end(0), 'rec.')) + + # Wrap arguments to the top-level "IF()" function into lambdas, for lazy evaluation. This is + # to ensure it's not affected by an exception in the unused value, to match Excel behavior. + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + lazy_args_slice = LAZY_ARG_FUNCTIONS.get(node.func.id) + if lazy_args_slice: + for arg in node.args[lazy_args_slice]: + start, end = map(tmp_formula.map_back_offset, atok.get_text_range(arg)) + patches.append(textbuilder.make_patch(formula, start, start, 'lambda: (')) + patches.append(textbuilder.make_patch(formula, end, end, ')')) + + # If the last statement is an expression that has its result unused (an ast.Expr node), + # then insert a "return" keyword. + last_statement = atok.tree.body[-1] if atok.tree.body else None + if isinstance(last_statement, ast.Expr): + input_pos = tmp_formula.map_back_offset(last_statement.first_token.startpos) + patches.append(textbuilder.make_patch(formula, input_pos, input_pos, "return ")) + elif last_statement is None: + # If we have an empty body (e.g. just a comment), add a 'pass' at the end. + patches.append(textbuilder.make_patch(formula, len(formula), len(formula), '\npass')) + + # Apply the new set of patches to the original formula to get the real output. + final_formula = textbuilder.Replacer(formula_builder_text, patches) + + # Try parsing again before returning it just in case we have new syntax errors. These are + # possible in cases when a single token ('DOLLARfoo') is valid but an expression ('rec.foo') is + # not, e.g. `foo($bar=1)` or `def $foo()`. + try: + atok = asttokens.ASTTokens(final_formula.get_text(), parse=True) + except SyntaxError as e: + return textbuilder.Text(_create_syntax_error_code(final_formula, formula, e)) + + # We return the text-builder object whose .get_text() is the final formula. + return final_formula + + +def _create_syntax_error_code(builder, input_text, err): + """ + Returns the text for a function that raises the given SyntaxError and includes the offending + code in a commented-out form. In addition, it translates the error's position from builder's + output to input_text. + """ + output_ln = asttokens.LineNumbers(builder.get_text()) + input_ln = asttokens.LineNumbers(input_text) + # A SyntaxError contains .lineno and .offset (1-based), which we need to translate to offset + # within the transformed text, so that it can be mapped back to an offset in the original text, + # and finally translated back into a line number and 1-based position to report to the user. An + # example is that "$x*" is translated to "return x*", and the syntax error in the transformed + # python code (line 2 offset 9) needs to be translated to be in line 2 offset 3. + output_offset = output_ln.line_to_offset(err.lineno, err.offset - 1 if err.offset else 0) + input_offset = builder.map_back_offset(output_offset) + line, col = input_ln.offset_to_line(input_offset) + return "%s\nraise %s('%s on line %d col %d')" % ( + textbuilder.line_start_re.sub('# ', input_text.rstrip()), + type(err).__name__, err.args[0], line, col + 1) + +#---------------------------------------------------------------------- + +def infer(node): + try: + return next(node.infer(), None) + except astroid.exceptions.InferenceError, e: + return "InferenceError on %r: %r" % (node, e) + + +_lookup_method_names = ('lookupOne', 'lookupRecords') + +def _is_table(node): + """ + Return true if obj is a class defining a user table. + """ + return (isinstance(node, astroid.nodes.Class) and node.decorators and + node.decorators.nodes[0].as_string() == 'grist.UserTable') + + + +@contextlib.contextmanager +def use_inferences(*inference_tips): + transform_args = [(cls.node_class, astroid.inference_tip(cls.infer), cls.filter) + for cls in inference_tips] + for args in transform_args: + astroid.MANAGER.register_transform(*args) + yield + for args in transform_args: + astroid.MANAGER.unregister_transform(*args) + + +class InferenceTip(object): + """ + Base class for inference tips. A derived class can implement the filter() and infer() class + methods, and then register() will put that inference helper into use. + """ + node_class = None + + @classmethod + def filter(cls, node): + raise NotImplementedError() + + @classmethod + def infer(cls, node, context): + raise NotImplementedError() + + +class InferReferenceColumn(InferenceTip): + """ + Inference helper to treat the return value of `grist.Reference("Foo")` as an instance of the + table `Foo`. + """ + node_class = astroid.nodes.Call + + @classmethod + def filter(cls, node): + return (isinstance(node.func, astroid.nodes.Attribute) and + node.func.as_string() in ('grist.Reference', 'grist.ReferenceList')) + + @classmethod + def infer(cls, node, context=None): + table_id = node.args[0].value + table_class = next(node.root().igetattr(table_id)) + yield astroid.bases.Instance(table_class) + + +def _get_formula_type(function_node): + decorators = function_node.decorators.nodes if function_node.decorators else () + for dec in decorators: + if (isinstance(dec, astroid.nodes.Call) and + dec.func.as_string() == 'grist.formulaType'): + return dec.args[0] + return None + + +class InferReferenceFormula(InferenceTip): + """ + Inference helper to treat functions decorated with `grist.formulaType(grist.Reference("Foo"))` + as returning instances of table `Foo`. + """ + node_class = astroid.nodes.Function + + @classmethod + def filter(cls, node): + # All methods on tables are really used as properties. + return _is_table(node.parent.frame()) + + @classmethod + def infer(cls, node, context=None): + ftype = _get_formula_type(node) + if ftype and InferReferenceColumn.filter(ftype): + return InferReferenceColumn.infer(ftype, context) + return node.infer_call_result(node.parent.frame(), context) + + +class InferLookupReference(InferenceTip): + """ + Inference helper to treat the return value of `Table.lookupRecords(...)` as returning instances + of table `Table`. + """ + node_class = astroid.nodes.Call + + @classmethod + def filter(cls, node): + return (isinstance(node.func, astroid.nodes.Attribute) and + node.func.attrname in _lookup_method_names and + _is_table(infer(node.func.expr))) + + @classmethod + def infer(cls, node, context=None): + yield astroid.bases.Instance(infer(node.func.expr)) + +class InferLookupComprehension(InferenceTip): + node_class = astroid.nodes.AssignName + + @classmethod + def filter(cls, node): + compr = node.parent + if not isinstance(compr, astroid.nodes.Comprehension): + return False + if not isinstance(compr.iter, astroid.nodes.Call): + return False + return InferLookupReference.filter(compr.iter) + + @classmethod + def infer(cls, node, context=None): + return InferLookupReference.infer(node.parent.iter) + +class InferRecAssignment(InferenceTip): + """ + Inference helper to raise exception on assignment to `rec`. + """ + node_class = astroid.nodes.AssignName + + @classmethod + def filter(cls, node): + if node.name == 'rec': + raise SyntaxError('Grist disallows assignment to the special variable "rec"', + ('', node.lineno, node.col_offset, "")) + + @classmethod + def infer(cls, node, context): + raise NotImplementedError() + +#---------------------------------------------------------------------- + +def parse_grist_names(builder): + """ + Returns a list of tuples (col_info, start_pos, table_id, col_id): + col_info: (table_id, col_id) for the formula the name is found in. It is the value passed + in by gencode.py to codebuilder.make_formula_body(). + start_pos: Index of the start character of the name in col_info.formula + table_id: Parsed name when the tuple is for a table name; the name of the column's table + when the tuple is for a column name. + col_id: None when tuple is for a table name; col_id when the tuple is for a column name. + """ + code_text = builder.get_text() + + with use_inferences(InferReferenceColumn, InferReferenceFormula, InferLookupReference, + InferLookupComprehension): + atok = asttokens.ASTTokens(code_text, tree=astroid.builder.parse(code_text)) + + def make_tuple(start, end, table_id, col_id): + name = col_id or table_id + assert end - start == len(name) + patch = textbuilder.Patch(start, end, name, name) + assert code_text[start:end] == name + patch_source = builder.map_back_patch(patch) + if not patch_source: + return None + in_text, in_value, in_patch = patch_source + return (in_value, in_patch.start, table_id, col_id) + + parsed_names = [] + for node in asttokens.util.walk(atok.tree): + if isinstance(node, astroid.nodes.Name): + obj = infer(node) + if _is_table(obj): + start, end = atok.get_text_range(node) + parsed_names.append(make_tuple(start, end, obj.name, None)) + + elif isinstance(node, astroid.nodes.Attribute): + obj = infer(node.expr) + if isinstance(obj, astroid.bases.Instance): + cls = obj._proxied + if _is_table(cls): + tok = node.last_token + start, end = tok.startpos, tok.endpos + parsed_names.append(make_tuple(start, end, cls.name, node.attrname)) + elif isinstance(node, astroid.nodes.Keyword): + func = node.parent.func + if isinstance(func, astroid.nodes.Attribute) and func.attrname in _lookup_method_names: + obj = infer(func.expr) + if _is_table(obj): + tok = node.first_token + start, end = tok.startpos, tok.endpos + parsed_names.append(make_tuple(start, end, obj.name, node.arg)) + + return filter(None, parsed_names) diff --git a/sandbox/grist/column.py b/sandbox/grist/column.py new file mode 100644 index 00000000..4b28011c --- /dev/null +++ b/sandbox/grist/column.py @@ -0,0 +1,391 @@ +import types +from collections import namedtuple + +import depend +import objtypes +import usertypes +import relabeling +import relation +import moment +import logger +from sortedcontainers import SortedListWithKey + +log = logger.Logger(__name__, logger.INFO) + +MANUAL_SORT = 'manualSort' +MANUAL_SORT_COL_INFO = { + 'id': MANUAL_SORT, + 'type': 'ManualSortPos', + 'formula': '', + 'isFormula': False +} +MANUAL_SORT_DEFAULT = 2147483647.0 + +SPECIAL_COL_IDS = {'id', MANUAL_SORT} + +def is_user_column(col_id): + """ + Returns whether the col_id is of a user column (as opposed to special columns that can't be used + for user data). + """ + return col_id not in SPECIAL_COL_IDS and not col_id.startswith('#') + +def is_visible_column(col_id): + """ + Returns whether this is an id of a column that's intended to be shown to the user. + """ + return is_user_column(col_id) and not col_id.startswith('gristHelper_') + +def is_virtual_column(col_id): + """ + Returns whether col_id is of a special column that does not get communicated outside of the + sandbox. Lookup maps are an example. + """ + return col_id.startswith('#') + +def is_validation_column_name(name): + return name.startswith("validation___") + +ColInfo = namedtuple('ColInfo', ('type_obj', 'is_formula', 'method')) + +def get_col_info(col_model, default_func=None): + if isinstance(col_model, types.FunctionType): + type_obj = getattr(col_model, 'grist_type', usertypes.Any()) + return ColInfo(type_obj, is_formula=True, method=col_model) + else: + return ColInfo(col_model, is_formula=False, method=col_model.default_func or default_func) + + +class BaseColumn(object): + """ + BaseColumn holds a column of data, whether raw or computed. + """ + def __init__(self, table, col_id, col_info): + self.type_obj = col_info.type_obj + self._data = [] + self.col_id = col_id + self.table_id = table.table_id + self.node = depend.Node(self.table_id, col_id) + self._is_formula = col_info.is_formula + self._is_private = bool(col_info.method) and getattr(col_info.method, 'is_private', False) + self.method = col_info.method + + # Always initialize to include the special empty record at index 0. + self.growto(1) + + def update_method(self, method): + """ + After rebuilding user code, we reuse existing column objects, but need to replace their + 'method' function. The method may refer to variables in the generated "usercode" module, and + it's important that all such references are to the rebuilt "usercode" module. + """ + self.method = method + + def is_formula(self): + """ + Whether this is a formula column. Note that a non-formula column may have an associated + method, which is used to fill in defaults when a record is added. + """ + return self._is_formula + + def is_private(self): + """ + Returns whether this method is private to the sandbox. If so, changes to this column do not + get communicated to outside the sandbox via actions. + """ + return self._is_private + + def has_formula(self): + """ + has_formula is true if formula is set, whether or not this is a formula column. + """ + return self.method is not None + + def clear(self): + self._data = [] + self.growto(1) # Always include the special empty record at index 0. + + def destroy(self): + """ + Called when the column is deleted. + """ + del self._data[:] + + def growto(self, size): + if len(self._data) < size: + self._data.extend([self.getdefault()] * (size - len(self._data))) + + def size(self): + return len(self._data) + + def set(self, row_id, value): + """ + Sets the value of this column for the given row_id. Value should be as returned by convert(), + i.e. of the right type, or alttext, or error (but should NOT be random wrong types). + """ + try: + self._data[row_id] = value + except IndexError: + self.growto(row_id + 1) + self._data[row_id] = value + + def unset(self, row_id): + """ + Sets the value for the given row_id to the default value. + """ + self.set(row_id, self.getdefault()) + + def get_cell_value(self, row_id): + """ + Returns the "rich" value for the given row_id, i.e. the value that would be seen by formulas. + E.g. for ReferenceColumns it'll be the referred-to Record object. For cells containing + alttext, this will be an AltText object. For RaisedException objects that represent a thrown + error, this will re-raise that error. + """ + raw = self.raw_get(row_id) + if isinstance(raw, objtypes.RaisedException): + raise raw.error + if self.type_obj.is_right_type(raw): + return self._make_rich_value(raw) + return usertypes.AltText(str(raw), self.type_obj.typename()) + + def _make_rich_value(self, typed_value): + """ + Called by get_cell_value() with a value of the right type for this column. Should be + implemented by derived classes to produce a "rich" version of the value. + """ + # pylint: disable=no-self-use + return typed_value + + def raw_get(self, row_id): + """ + Returns the value stored for the given row_id. This may be an error or alttext, and it does + not convert to a richer object. + """ + try: + return self._data[row_id] + except IndexError: + return self.getdefault() + + def safe_get(self, row_id): + """ + Returns a value of the right type, or the default value if the stored value had a wrong type. + """ + raw = self.raw_get(row_id) + return raw if self.type_obj.is_right_type(raw) else self.getdefault() + + def getdefault(self): + """ + Returns the default value for this column. This is a static default; the implementation of + "default formula" logic is separate. + """ + return self.type_obj.default + + def sample_value(self): + """ + Returns a sample value for this column, used for auto-completions. E.g. for a date, this + returns an actual datetime object rather than None (only its attributes should matter). + """ + return self.type_obj.default + + def copy_from_column(self, other_column): + """ + Replace this column's data entirely with data from another column of the same exact type. + """ + self._data[:] = other_column._data + + def convert(self, value_to_convert): + """ + Converts a value of any type to this column's type, returning either the converted value (for + which is_right_type is true), or an alttext string, or an error object. + """ + return self.type_obj.convert(value_to_convert) + + def prepare_new_values(self, values, ignore_data=False): + """ + This allows us to modify values and also produce adjustments to existing records. This + currently is only used by PositionColumn. Returns two lists: new_values, and + [(row_id, new_value)] list of adjustments to existing records. + If ignore_data is True, makes adjustments without regard to the existing data; this is used + for processing ReplaceTableData actions. + """ + # pylint: disable=no-self-use, unused-argument + return values, [] + + +class DataColumn(BaseColumn): + """ + DataColumn describes a column of raw data, and holds it. + """ + pass + +class BoolColumn(BaseColumn): + def set(self, row_id, value): + # When 1 or 1.0 is loaded, we should see it as True, and similarly 0 as False. This is similar + # to how, after loading a number into a DateColumn, we should see a date, except we adjust + # booleans at set() time. + bool_value = True if value == 1 else (False if value == 0 else value) + super(BoolColumn, self).set(row_id, bool_value) + +class NumericColumn(BaseColumn): + def set(self, row_id, value): + # Make sure any integers are treated as floats to avoid truncation. + # Uses `type(value) == int` rather than `isintance(value, int)` to specifically target + # ints and not bools (which are singleton instances the class int in python). But + # perhaps something should be done about bools also? + # pylint: disable=unidiomatic-typecheck + super(NumericColumn, self).set(row_id, float(value) if type(value) == int else value) + +_sample_date = moment.ts_to_date(0) +_sample_datetime = moment.ts_to_dt(0, None, moment.TZ_UTC) + +class DateColumn(BaseColumn): + """ + DateColumn contains numerical timestamps represented as seconds since epoch, in type float, + to midnight of specific UTC dates. Accessing them yields date objects. + """ + def _make_rich_value(self, typed_value): + return typed_value and moment.ts_to_date(typed_value) + + def sample_value(self): + return _sample_date + +class DateTimeColumn(BaseColumn): + """ + DateTimeColumn contains numerical timestamps represented as seconds since epoch, in type float, + and a timestamp associated with the column. Accessing them yields datetime objects. + """ + def __init__(self, table, col_id, col_info): + super(DateTimeColumn, self).__init__(table, col_id, col_info) + self._timezone = col_info.type_obj.timezone + + def _make_rich_value(self, typed_value): + return typed_value and moment.ts_to_dt(typed_value, self._timezone) + + def sample_value(self): + return _sample_datetime + +class PositionColumn(BaseColumn): + def __init__(self, table, col_id, col_info): + super(PositionColumn, self).__init__(table, col_id, col_info) + # This is a list of row_ids, ordered by the position. + self._sorted_rows = SortedListWithKey(key=self.raw_get) + + def set(self, row_id, value): + self._sorted_rows.discard(row_id) + super(PositionColumn, self).set(row_id, value) + if value != self.getdefault(): + self._sorted_rows.add(row_id) + + def copy_from_column(self, other_column): + super(PositionColumn, self).copy_from_column(other_column) + self._sorted_rows = SortedListWithKey(other_column._sorted_rows[:], key=self.raw_get) + + def prepare_new_values(self, values, ignore_data=False): + # This does the work of adjusting positions and relabeling existing rows with new position + # (without changing sort order) to make space for the new positions. Note that this is also + # used for updating a position for an existing row: we'll find a new value for it; later when + # this value is set, the old position will be removed and the new one added. + if ignore_data: + rows = SortedListWithKey([], key=self.raw_get) + else: + rows = self._sorted_rows + adjustments, new_values = relabeling.prepare_inserts(rows, values) + return new_values, [(self._sorted_rows[i], pos) for (i, pos) in adjustments] + + +class BaseReferenceColumn(BaseColumn): + """ + Base class for ReferenceColumn and ReferenceListColumn. + """ + def __init__(self, table, col_id, col_info): + super(BaseReferenceColumn, self).__init__(table, col_id, col_info) + # We can assume that all tables have been instantiated, but not all initialized. + target_table_id = self.type_obj.table_id + self._target_table = table._engine.tables.get(target_table_id, None) + self._relation = relation.ReferenceRelation(table.table_id, target_table_id, col_id) + # Note that we need to remove these back-references when the column is removed. + if self._target_table: + self._target_table._back_references.add(self) + + def destroy(self): + # Destroy the column and remove the back-reference we created in the constructor. + super(BaseReferenceColumn, self).destroy() + if self._target_table: + self._target_table._back_references.remove(self) + + def _update_references(self, row_id, old_value, new_value): + raise NotImplementedError() + + def set(self, row_id, value): + old = self.safe_get(row_id) + super(BaseReferenceColumn, self).set(row_id, value) + new = self.safe_get(row_id) + self._update_references(row_id, old, new) + + def copy_from_column(self, other_column): + super(BaseReferenceColumn, self).copy_from_column(other_column) + self._relation.clear() + # This is hacky: we should have an interface to iterate through values of a column. (As it is, + # self._data may include values for non-existent rows; it works here because those values are + # falsy, which makes them ignored by self._update_references). + for row_id, value in enumerate(self._data): + if isinstance(value, int): + self._update_references(row_id, None, value) + + def sample_value(self): + return self._target_table.sample_record + + +class ReferenceColumn(BaseReferenceColumn): + """ + ReferenceColumn contains IDs of rows in another table. Accessing them yields the records in the + other table. + """ + def _make_rich_value(self, typed_value): + # If we refer to an invalid table, return integers rather than fail completely. + if not self._target_table: + return typed_value + # For a Reference, values must either refer to an existing record, or be 0. In all tables, + # the 0 index will contain the all-defaults record. + return self._target_table.Record(self._target_table, typed_value, self._relation) + + def _update_references(self, row_id, old_value, new_value): + if old_value: + self._relation.remove_reference(row_id, old_value) + if new_value: + self._relation.add_reference(row_id, new_value) + + +class ReferenceListColumn(BaseReferenceColumn): + """ + ReferenceListColumn maintains for each row a list of references (row IDs) into another table. + Accessing them yields RecordSets. + """ + def _update_references(self, row_id, old_list, new_list): + for old_value in old_list or (): + self._relation.remove_reference(row_id, old_value) + for new_value in new_list or (): + self._relation.add_reference(row_id, new_value) + + def _make_rich_value(self, typed_value): + if typed_value is None: + typed_value = [] + # If we refer to an invalid table, return integers rather than fail completely. + if not self._target_table: + return typed_value + return self._target_table.RecordSet(self._target_table, typed_value, self._relation) + + +# Set up the relationship between usertypes objects and column objects. +usertypes.BaseColumnType.ColType = DataColumn +usertypes.Reference.ColType = ReferenceColumn +usertypes.ReferenceList.ColType = ReferenceListColumn +usertypes.DateTime.ColType = DateTimeColumn +usertypes.Date.ColType = DateColumn +usertypes.PositionNumber.ColType = PositionColumn +usertypes.Bool.ColType = BoolColumn +usertypes.Numeric.ColType = NumericColumn + +def create_column(table, col_id, col_info): + return col_info.type_obj.ColType(table, col_id, col_info) diff --git a/sandbox/grist/csv_patch.py b/sandbox/grist/csv_patch.py new file mode 100644 index 00000000..ef04f196 --- /dev/null +++ b/sandbox/grist/csv_patch.py @@ -0,0 +1,84 @@ +import re +import csv + +# Monkey-patch csv.Sniffer class, in which the quote/delimiter detection has silly bugs in the +# regexp that it uses. It also seems poorly-implemented in other ways. We can probably do better +# by not using csv.Sniffer at all. +# The method below is a modified copy of the same-named method in the standard csv.Sniffer class. +def _guess_quote_and_delimiter(_self, data, delimiters): + """ + Looks for text enclosed between two identical quotes + (the probable quotechar) which are preceded and followed + by the same character (the probable delimiter). + For example: + ,'some text', + The quote with the most wins, same with the delimiter. + If there is no quotechar the delimiter can't be determined + this way. + """ + + regexp = re.compile( + r""" + (?:(?P[^\w\n"\'])|^|\n) # delimiter or start-of-line + (?P\ ?) # optional initial space + (?P["\']).*?(?P=quote) # quote-surrounded field + (?:(?P=delim)|$|\r?\n) # delimiter or end-of-line + """, re.VERBOSE | re.DOTALL | re.MULTILINE) + matches = regexp.findall(data) + + if not matches: + # (quotechar, doublequote, delimiter, skipinitialspace) + return ('', False, None, 0) + quotes = {} + delims = {} + spaces = 0 + for m in matches: + n = regexp.groupindex['quote'] - 1 + key = m[n] + if key: + quotes[key] = quotes.get(key, 0) + 1 + try: + n = regexp.groupindex['delim'] - 1 + key = m[n] + except KeyError: + continue + if key and (delimiters is None or key in delimiters): + delims[key] = delims.get(key, 0) + 1 + try: + n = regexp.groupindex['space'] - 1 + except KeyError: + continue + if m[n]: + spaces += 1 + + quotechar = reduce(lambda a, b, _quotes = quotes: + (_quotes[a] > _quotes[b]) and a or b, quotes.keys()) + + if delims: + delim = reduce(lambda a, b, _delims = delims: + (_delims[a] > _delims[b]) and a or b, delims.keys()) + skipinitialspace = delims[delim] == spaces + if delim == '\n': # most likely a file with a single column + delim = '' + else: + # there is *no* delimiter, it's a single column of quoted data + delim = '' + skipinitialspace = 0 + + # if we see an extra quote between delimiters, we've got a + # double quoted format + dq_regexp = re.compile( + (r"((%(delim)s)|^)\W*%(quote)s[^%(delim)s\n]*%(quote)" + + r"s[^%(delim)s\n]*%(quote)s\W*((%(delim)s)|$)") % \ + {'delim':re.escape(delim), 'quote':quotechar}, re.MULTILINE) + + + + if dq_regexp.search(data): + doublequote = True + else: + doublequote = False + + return (quotechar, doublequote, delim, skipinitialspace) + +csv.Sniffer._guess_quote_and_delimiter = _guess_quote_and_delimiter diff --git a/sandbox/grist/depend.py b/sandbox/grist/depend.py new file mode 100644 index 00000000..41be2e1e --- /dev/null +++ b/sandbox/grist/depend.py @@ -0,0 +1,155 @@ +""" +depend.py provides classes and functions to manage the dependency graph for grist formulas. + +Conceptually, all dependency relationships are the Edges (Node1, Relation, Node2), meaning that +Node1 depends on Node2. Each Node represents a column in a particular table (could be a derived +table, such as for subtotals). The Relation determines the row mapping, i.e. which rows in Node1 +column need to be recomputed when a row changes in Node2 column. + +When a formula is evaluated, the Record and RecordSet objects maintain a reference to the Relation +in use, while property access determines which Nodes (or columns) depend on one another. +""" + +# Note: this is partly inspired by the implementation of the ninja build system, see +# https://github.com/martine/ninja/blob/master/src/graph.h + +# Idea for the future: we can consider the concept from ninja of "order-only deps", which are +# needed before we can build the outputs, but which don't cause the outputs to rebuild. Support +# for this (with computed values properly persisted) could allow some cool use cases, like columns +# that recompute manually rather than automatically. + +from collections import namedtuple +from sortedcontainers import SortedSet + +class Node(namedtuple('Node', ('table_id', 'col_id'))): + """ + Each Node in the dependency graph represents a column in a table. + """ + __slots__ = () # This is a memory-saving device to keep these objects small + + def __str__(self): + return '[%s.%s]' % (self.table_id, self.col_id) + + +class Edge(namedtuple('Edge', ('out_node', 'in_node', 'relation'))): + """ + Each Edge connects two Nodes using a Relation. It says that out_node depends on in_node, so that + a change to in_node should trigger a recomputation of out_node. + """ + __slots__ = () # This is a memory-saving device to keep these objects small + + def __str__(self): + return '[%s.%s: %s.%s @ %s]' % (self.out_node.table_id, self.out_node.col_id, + self.in_node.table_id, self.in_node.col_id, self.relation) + + +class CircularRefError(RuntimeError): + """ + Exception thrown when a formula column references itself, directly or indirectly. + """ + pass + + +class _AllRows(object): + """ + Special constant that indicates to `invalidate_deps` that all rows are affected and an entire + column is to be invalidated. + """ + pass + +ALL_ROWS = _AllRows() + +class Graph(object): + """ + Represents the dependency graph for all data in a grist document. + """ + def __init__(self): + # The set of all Edges, i.e. the complete dependency graph. + self._all_edges = set() + + # Map from node to the set of edges having it as the in_node (i.e. edges to dependents). + self._in_node_map = {} + + # Map from node to the set of edges having it as the out_node (i.e. edges to dependencies). + self._out_node_map = {} + + def dump_graph(self): + """ + Print out the graph to stdout, for debugging. + """ + print "Dependency graph (%d edges):" % len(self._all_edges) + for edge in sorted(self._all_edges): + print " %s" % (edge,) + + def add_edge(self, out_node, in_node, relation): + """ + Adds an edge to the global dependency graph: out_node depends on in_node, i.e. a change to + in_node should trigger a recomputation of out_node. + """ + edge = Edge(out_node, in_node, relation) + self._all_edges.add(edge) + self._in_node_map.setdefault(edge.in_node, set()).add(edge) + self._out_node_map.setdefault(edge.out_node, set()).add(edge) + + def clear_dependencies(self, out_node): + """ + Removes all edges which affect the given out_node, i.e. all of its dependencies. + """ + remove_edges = self._out_node_map.pop(out_node, ()) + for edge in remove_edges: + self._all_edges.remove(edge) + self._in_node_map.get(edge.in_node, set()).remove(edge) + edge.relation.reset_all() + + def reset_dependencies(self, node, dirty_rows): + """ + For edges the given node depends on, reset the given output rows. This is called just before + the rows get recomputed, to allow the relations to clear out state for those rows if needed. + """ + in_edges = self._out_node_map.get(node, ()) + for edge in in_edges: + edge.relation.reset_rows(dirty_rows) + + def remove_node_if_unused(self, node): + """ + Removes the given node if it has no dependents. Returns True if the node is gone, False if the + node has dependents. + """ + if self._in_node_map.get(node, None): + return False + self.clear_dependencies(node) + self._in_node_map.pop(node, None) + return True + + def invalidate_deps(self, dirty_node, dirty_rows, recompute_map, include_self=True): + """ + Invalidates the given rows in the given node, and all of its dependents, i.e. all the nodes + that recursively depend on dirty_node. If include_self is False, then skips the given node + (e.g. if the node is raw data rather than formula). Results are added to recompute_map, which + is a dict mapping Nodes to sets of rows that need to be recomputed. + + If dirty_rows is ALL_ROWS, the whole column is affected, and dependencies get recomputed from + scratch. ALL_ROWS propagates to all dependent columns, so those also get recomputed in full. + """ + if include_self: + if recompute_map.get(dirty_node) == ALL_ROWS: + return + if dirty_rows == ALL_ROWS: + recompute_map[dirty_node] = ALL_ROWS + # If all rows are being recomputed, clear the dependencies of the affected column. (We add + # dependencies in the course of recomputing, but we can only start from an empty set of + # dependencies if we are about to recompute all rows.) + self.clear_dependencies(dirty_node) + else: + out_rows = recompute_map.setdefault(dirty_node, SortedSet()) + prev_count = len(out_rows) + out_rows.update(dirty_rows) + # Don't bother recursing into dependencies if we didn't actually update anything. + if len(out_rows) <= prev_count: + return + + # Iterate through a copy of _in_node_map, because recursive clear_dependencies may modify it. + for edge in list(self._in_node_map.get(dirty_node, ())): + affected_rows = (ALL_ROWS if dirty_rows == ALL_ROWS else + edge.relation.get_affected_rows(dirty_rows)) + self.invalidate_deps(edge.out_node, affected_rows, recompute_map, include_self=True) diff --git a/sandbox/grist/docactions.py b/sandbox/grist/docactions.py new file mode 100644 index 00000000..fffb9f3d --- /dev/null +++ b/sandbox/grist/docactions.py @@ -0,0 +1,240 @@ +import actions +import schema +import logger +from usertypes import strict_equal + +log = logger.Logger(__name__, logger.INFO) + +class DocActions(object): + def __init__(self, engine): + self._engine = engine + + #---------------------------------------- + # Actions on records. + #---------------------------------------- + + def AddRecord(self, table_id, row_id, column_values): + self.BulkAddRecord( + table_id, [row_id], {key: [val] for key, val in column_values.iteritems()}) + + def BulkAddRecord(self, table_id, row_ids, column_values): + table = self._engine.tables[table_id] + for row_id in row_ids: + assert row_id not in table.row_ids, \ + "docactions.[Bulk]AddRecord for existing record #%s" % row_id + + self._engine.out_actions.undo.append(actions.BulkRemoveRecord(table_id, row_ids).simplify()) + + self._engine.add_records(table_id, row_ids, column_values) + + def RemoveRecord(self, table_id, row_id): + return self.BulkRemoveRecord(table_id, [row_id]) + + def BulkRemoveRecord(self, table_id, row_ids): + table = self._engine.tables[table_id] + + # Collect the undo values, and unset all values in the column (i.e. set to defaults), just to + # make sure we don't have stale values hanging around. + undo_values = {} + for column in table.all_columns.itervalues(): + if not column.is_formula() and column.col_id != "id": + col_values = map(column.raw_get, row_ids) + default = column.getdefault() + # If this column had all default values, don't include it into the undo BulkAddRecord. + if not all(strict_equal(val, default) for val in col_values): + undo_values[column.col_id] = col_values + for row_id in row_ids: + column.unset(row_id) + + # Generate the undo action. + self._engine.out_actions.undo.append( + actions.BulkAddRecord(table_id, row_ids, undo_values).simplify()) + + # Invalidate the deleted rows, so that anything that depends on them gets recomputed. + self._engine.invalidate_records(table_id, row_ids) + + def UpdateRecord(self, table_id, row_id, columns): + self.BulkUpdateRecord( + table_id, [row_id], {key: [val] for key, val in columns.iteritems()}) + + def BulkUpdateRecord(self, table_id, row_ids, columns): + table = self._engine.tables[table_id] + for row_id in row_ids: + assert row_id in table.row_ids, \ + "docactions.[Bulk]UpdateRecord for non-existent record #%s" % row_id + + # Load the updated values. + undo_values = {} + for col_id, values in columns.iteritems(): + col = table.get_column(col_id) + undo_values[col_id] = map(col.raw_get, row_ids) + for (row_id, value) in zip(row_ids, values): + col.set(row_id, value) + + # Generate the undo action. + self._engine.out_actions.undo.append( + actions.BulkUpdateRecord(table_id, row_ids, undo_values).simplify()) + + # Invalidate the updated rows, just for the columns that got changed (and, as always, + # anything that depends on them). + self._engine.invalidate_records(table_id, row_ids, col_ids=columns.keys()) + + + def ReplaceTableData(self, table_id, row_ids, column_values): + old_data = self._engine.fetch_table(table_id, formulas=False) + self._engine.out_actions.undo.append(actions.ReplaceTableData(*old_data)) + self._engine.load_table(actions.TableData(table_id, row_ids, column_values)) + + #---------------------------------------- + # Actions on columns. + #---------------------------------------- + + def AddColumn(self, table_id, col_id, col_info): + table = self._engine.tables[table_id] + assert not table.has_column(col_id), "Column %s already exists in %s" % (col_id, table_id) + + # Add the new column to the schema object maintained in the engine. + self._engine.schema[table_id].columns[col_id] = schema.dict_to_col(col_info, col_id=col_id) + self._engine.rebuild_usercode() + self._engine.new_column_name(table) + + # Generate the undo action. + self._engine.out_actions.undo.append(actions.RemoveColumn(table_id, col_id)) + + def RemoveColumn(self, table_id, col_id): + table = self._engine.tables[table_id] + assert table.has_column(col_id), "Column %s not in table %s" % (col_id, table_id) + + # Generate (if needed) the undo action to restore the data. + undo_action = None + column = table.get_column(col_id) + if not column.is_formula(): + default = column.getdefault() + # Add to undo a BulkUpdateRecord for non-default values in the column being removed. + row_ids = [r for r in table.row_ids if not strict_equal(column.raw_get(r), default)] + undo_action = actions.BulkUpdateRecord(table_id, row_ids, { + column.col_id: map(column.raw_get, row_ids) + }).simplify() + + # Remove the specified column from the schema object. + colinfo = self._engine.schema[table_id].columns.pop(col_id) + self._engine.rebuild_usercode() + + # Generate the undo action(s). + if undo_action: + self._engine.out_actions.undo.append(undo_action) + self._engine.out_actions.undo.append(actions.AddColumn( + table_id, col_id, schema.col_to_dict(colinfo, include_id=False))) + + def RenameColumn(self, table_id, old_col_id, new_col_id): + table = self._engine.tables[table_id] + + assert table.has_column(old_col_id), "Column %s not in table %s" % (old_col_id, table_id) + assert not table.has_column(new_col_id), \ + "Column %s already exists in %s" % (new_col_id, table_id) + old_column = table.get_column(old_col_id) + + # Replace the renamed column in the schema object. + schema_table_info = self._engine.schema[table_id] + colinfo = schema_table_info.columns.pop(old_col_id) + schema_table_info.columns[new_col_id] = schema.SchemaColumn( + new_col_id, colinfo.type, colinfo.isFormula, colinfo.formula) + + self._engine.rebuild_usercode() + self._engine.new_column_name(table) + + # We replaced the old column with a new Column object (not strictly necessary, but simpler). + # For a raw data column, we need to copy over the data from the old column object. + new_column = table.get_column(new_col_id) + new_column.copy_from_column(old_column) + + # Generate the undo action. + self._engine.out_actions.undo.append(actions.RenameColumn(table_id, new_col_id, old_col_id)) + + def ModifyColumn(self, table_id, col_id, col_info): + table = self._engine.tables[table_id] + assert table.has_column(col_id), "Column %s not in table %s" % (col_id, table_id) + old_column = table.get_column(col_id) + + # Modify the specified column in the schema object. + schema_table_info = self._engine.schema[table_id] + old = schema_table_info.columns[col_id] + new = schema.SchemaColumn(col_id, + col_info.get('type', old.type), + bool(col_info.get('isFormula', old.isFormula)), + col_info.get('formula', old.formula)) + if new == old: + log.info("ModifyColumn called which was a noop") + return + + undo_col_info = {k: v for k, v in schema.col_to_dict(old, include_id=False).iteritems() + if k in col_info} + + # Remove the column from the schema, then re-add it, to force creation of a new column object. + schema_table_info.columns.pop(col_id) + self._engine.rebuild_usercode() + + schema_table_info.columns[col_id] = new + self._engine.rebuild_usercode() + + # Fill in the new column with the values from the old column. + new_column = table.get_column(col_id) + for row_id in table.row_ids: + new_column.set(row_id, old_column.raw_get(row_id)) + + # Generate the undo action. + self._engine.out_actions.undo.append(actions.ModifyColumn(table_id, col_id, undo_col_info)) + + #---------------------------------------- + # Actions on tables. + #---------------------------------------- + + def AddTable(self, table_id, columns): + assert table_id not in self._engine.tables, "Table %s already exists" % table_id + + # Update schema, and re-generate the module code. + self._engine.schema[table_id] = schema.SchemaTable(table_id, schema.dict_list_to_cols(columns)) + self._engine.rebuild_usercode() + + # Generate the undo action. + self._engine.out_actions.undo.append(actions.RemoveTable(table_id)) + + def RemoveTable(self, table_id): + assert table_id in self._engine.tables, "Table %s doesn't exist" % table_id + + # Create undo actions to restore all the data records of this table. + table_data = self._engine.fetch_table(table_id, formulas=False) + undo_action = actions.BulkAddRecord(*table_data).simplify() + if undo_action: + self._engine.out_actions.undo.append(undo_action) + + # Update schema, and re-generate the module code. + schema_table = self._engine.schema.pop(table_id) + self._engine.rebuild_usercode() + + # Generate the undo action. + self._engine.out_actions.undo.append(actions.AddTable( + table_id, schema.cols_to_dict_list(schema_table.columns))) + + def RenameTable(self, old_table_id, new_table_id): + assert old_table_id in self._engine.tables, "Table %s doesn't exist" % old_table_id + assert new_table_id not in self._engine.tables, "Table %s already exists" % new_table_id + + old_table = self._engine.tables[old_table_id] + + # Update schema, and re-generate the module code. + old = self._engine.schema.pop(old_table_id) + self._engine.schema[new_table_id] = schema.SchemaTable(new_table_id, old.columns) + self._engine.rebuild_usercode() + + # Copy over all columns from the old table to the new. + new_table = self._engine.tables[new_table_id] + for new_column in new_table.all_columns.itervalues(): + if not new_column.is_formula(): + new_column.copy_from_column(old_table.get_column(new_column.col_id)) + new_table.grow_to_max() # We need to bring formula columns to the right size too. + + # Generate the undo action. + self._engine.out_actions.undo.append(actions.RenameTable(new_table_id, old_table_id)) + +# end diff --git a/sandbox/grist/docmodel.py b/sandbox/grist/docmodel.py new file mode 100644 index 00000000..639c801e --- /dev/null +++ b/sandbox/grist/docmodel.py @@ -0,0 +1,352 @@ +""" +This file provides convenient access to document metadata that is internal to the sandbox. +Specifically, it has handles to the metadata tables, and adds helpful formula columns to tables +which exist only in the sandbox and are not communicated to the client. + +It is similar in purpose to DocModel.js on the client side. +""" +import itertools +import json + +import acl +import records +import usertypes +import relabeling +import table +import moment + +def _record_set(table_id, group_by, sort_by=None): + @usertypes.formulaType(usertypes.ReferenceList(table_id)) + def func(rec, table): + lookup_table = table.docmodel.get_table(table_id) + return lookup_table.lookupRecords(sort_by=sort_by, **{group_by: rec.id}) + return func + + +def _record_inverse(table_id, ref_col): + @usertypes.formulaType(usertypes.Reference(table_id)) + def func(rec, table): + lookup_table = table.docmodel.get_table(table_id) + return lookup_table.lookupOne(**{ref_col: rec.id}) + return func + + +class MetaTableExtras(object): + """ + Container class for enhancements to metadata table models. The members (formula methods) defined + for a nested class here will automatically be added as members to same-named metadata table. + """ + # pylint: disable=no-self-argument,no-member,unused-argument,not-an-iterable + class _grist_DocInfo(object): + def acl_resources(rec, table): + """ + Returns a map of ACL resources for use by acl.py. It is done in a formula so that it + automatically recomputes when anything changes in _grist_ACLResources table. + """ + # pylint: disable=no-self-use + return acl.build_resources(table.docmodel.get_table('_grist_ACLResources').lookupRecords()) + + @usertypes.formulaType(usertypes.Any()) + def tzinfo(rec, table): + # pylint: disable=no-self-use + try: + return moment.tzinfo(rec.timezone) + except KeyError: + return moment.TZ_UTC + + class _grist_Tables(object): + columns = _record_set('_grist_Tables_column', 'parentId', sort_by='parentPos') + viewSections = _record_set('_grist_Views_section', 'tableRef') + tableViews = _record_set('_grist_TableViews', 'tableRef') + summaryTables = _record_set('_grist_Tables', 'summarySourceTable') + + def summaryKey(rec, table): + """ + Returns the tuple of sorted colRefs for summary columns. This uniquely identifies a summary + table among other summary tables for the same source table. + """ + # pylint: disable=not-an-iterable + return (tuple(sorted(int(c.summarySourceCol) for c in rec.columns if c.summarySourceCol)) + if rec.summarySourceTable else None) + + def setAutoRemove(rec, table): + """Marks the table for removal if it's a summary table with no more view sections.""" + table.docmodel.setAutoRemove(rec, rec.summarySourceTable and not rec.viewSections) + + + class _grist_Tables_column(object): + viewFields = _record_set('_grist_Views_section_field', 'colRef') + summaryGroupByColumns = _record_set('_grist_Tables_column', 'summarySourceCol') + usedByCols = _record_set('_grist_Tables_column', 'displayCol') + usedByFields = _record_set('_grist_Views_section_field', 'displayCol') + + def tableId(rec, table): + return rec.parentId.tableId + + def numDisplayColUsers(rec, table): + """ + Returns the number of cols and fields using this col as a display col + """ + return len(rec.usedByCols) + len(rec.usedByFields) + + def setAutoRemove(rec, table): + """Marks the col for removal if it's a display helper col with no more users.""" + table.docmodel.setAutoRemove(rec, + rec.colId.startswith('gristHelper_Display') and rec.numDisplayColUsers == 0) + + + class _grist_Views(object): + viewSections = _record_set('_grist_Views_section', 'parentId') + tabBarItems = _record_set('_grist_TabBar', 'viewRef') + tableViewItems = _record_set('_grist_TableViews', 'viewRef') + primaryViewTable = _record_inverse('_grist_Tables', 'primaryViewId') + pageItems = _record_set('_grist_Pages', 'viewRef') + + class _grist_Views_section(object): + fields = _record_set('_grist_Views_section_field', 'parentId', sort_by='parentPos') + + class _grist_ACLRules(object): + # The set of rules that applies to this resource + @usertypes.formulaType(usertypes.ReferenceList('_grist_ACLPrincipals')) + def principalsList(rec, table): + return json.loads(rec.principals) + + class _grist_ACLResources(object): + # The set of rules that applies to this resource + ruleset = _record_set('_grist_ACLRules', 'resource') + + class _grist_ACLPrincipals(object): + # Memberships table maintains containment relationships between principals. + memberships = _record_set('_grist_ACLMemberships', 'parent') + + # Children of a User principal are Instances. Children of a Group are Users or other Groups. + @usertypes.formulaType(usertypes.ReferenceList('_grist_ACLPrincipals')) + def children(rec, table): + return [m.child for m in rec.memberships] + + @usertypes.formulaType(usertypes.ReferenceList('_grist_ACLPrincipals')) + def descendants(rec, table): + """ + Descendants through great-grandchildren. (We don't support fully recursive descendants yet, + which may be cleaner.) The max supported level is a group containing subgroups (children), + which contain users (grandchildren), which contain instances (great-grandchildren). + """ + # Include direct children. + ret = set(rec.children) + ret.add(rec) + for c1 in rec.children: + # Include grandchildren (children of each child) + ret.update(c1.children) + for c2 in c1.children: + # Include great-grandchildren (children of each grandchild). + ret.update(c2.children) + return ret + + @usertypes.formulaType(usertypes.ReferenceList('_grist_ACLPrincipals')) + def allInstances(rec, table): + return sorted(r for r in rec.descendants if r.instanceId) + + @usertypes.formulaType(usertypes.Text()) + def name(rec, table): + return ('User:' + rec.userEmail if rec.type == 'user' else + 'Group:' + rec.groupName if rec.type == 'group' else + 'Inst:' + rec.instanceId if rec.type == 'instance' else '') + + +def enhance_model(model_class): + """ + Given a metadata model class, add all members (formula methods) to it from the same-named inner + class of MetaTableExtras. The added members are marked as private; the resulting Column objects + will have col.is_private() as true. + """ + extras_class = getattr(MetaTableExtras, model_class.__name__, None) + if not extras_class: + return + for name, member in extras_class.__dict__.iteritems(): + if not name.startswith("__"): + member.__name__ = name + member.is_private = True + setattr(model_class, name, member) + +# There is a single instance of DocModel per sandbox process and +# global_docmodel is a reference to it +global_docmodel = None + +class DocModel(object): + """ + This class defines more convenient handles to all metadata tables. In addition, it sets + table.docmodel member for each of these tables to itself. Note that it deals with + table.UserTable objects (rather than the lower-level table.Table objects). + """ + def __init__(self, engine): + self._engine = engine + global global_docmodel # pylint: disable=global-statement + if not global_docmodel: + global_docmodel = self + + # Set of records scheduled for automatic removal. + self._auto_remove_set = set() + + def update_tables(self): + """ + Update the table handles we maintain to correspond to the current Engine tables. + """ + self.doc_info = self._prep_table("_grist_DocInfo") + self.tables = self._prep_table("_grist_Tables") + self.columns = self._prep_table("_grist_Tables_column") + self.table_views = self._prep_table("_grist_TableViews") + self.tab_bar = self._prep_table("_grist_TabBar") + self.views = self._prep_table("_grist_Views") + self.view_sections = self._prep_table("_grist_Views_section") + self.view_fields = self._prep_table("_grist_Views_section_field") + self.validations = self._prep_table("_grist_Validations") + self.repl_hist = self._prep_table("_grist_REPL_Hist") + self.attachments = self._prep_table("_grist_Attachments") + self.acl_rules = self._prep_table("_grist_ACLRules") + self.acl_resources = self._prep_table("_grist_ACLResources") + self.acl_principals = self._prep_table("_grist_ACLPrincipals") + self.acl_memberships = self._prep_table("_grist_ACLMemberships") + self.pages = self._prep_table("_grist_Pages") + + def _prep_table(self, name): + """ + Helper that gets the table with the given name, and sets its .doc attribute to DocModel. + """ + user_table = self._engine.tables[name].user_table + user_table.docmodel = self + return user_table + + def get_table(self, table_id): + return self._engine.tables[table_id].user_table + + + def get_table_rec(self, table_id): + """Returns the table record for the given table name, or raises ValueError.""" + table_rec = self.tables.lookupOne(tableId=table_id) + if not table_rec: + raise ValueError("No such table: %s" % table_id) + return table_rec + + def get_column_rec(self, table_id, col_id): + """Returns the column record for the given table and column names, or raises ValueError.""" + col_rec = self.columns.lookupOne(tableId=table_id, colId=col_id) + if not col_rec: + raise ValueError("No such column: %s.%s" % (table_id, col_id)) + return col_rec + + + def setAutoRemove(self, record, yes_or_no): + """ + Marks a record for automatic removal. To use, create a formula in your table, e.g. + 'setAutoRemove', which calls `table.docmodel.setAutoRemove(boolean_value)`. Whenever it gets + reevaluated and the boolean_value is true, the record will be automatically removed. + For now, it is only usable in metadata tables, although we could extend to user tables. + """ + if yes_or_no: + self._auto_remove_set.add(record) + else: + self._auto_remove_set.discard(record) + + def apply_auto_removes(self): + """ + Remove the records marked for removal. + """ + # Sort to make sure removals are done in deterministic order. + gone_records = sorted(self._auto_remove_set) + self._auto_remove_set.clear() + self.remove(gone_records) + return bool(gone_records) + + def remove(self, records): + """ + Removes all records in the given iterable of Records. + """ + for table_id, group in itertools.groupby(records, lambda r: r._table.table_id): + self._engine.user_actions.BulkRemoveRecord(table_id, [int(r) for r in group]) + + def update(self, records, **col_values): + """ + Updates all records in the given list of Records or a RecordSet; col_values maps column ids to + values. The values may either be a list of the length len(records), or a non-list value that + will be used for all records. + """ + record_list = list(records) + if not record_list: + return + table_id = record_list[0]._table.table_id + # Make sure these are all records from the same table. + assert all(r._table.table_id == table_id for r in record_list) + row_ids = [int(r) for r in record_list] + values = _unify_col_values(col_values, len(record_list)) + self._engine.user_actions.BulkUpdateRecord(table_id, row_ids, values) + + def add(self, record_set_or_table, **col_values): + """ + Add new records for the given table; col_values maps column ids to values. Values may either + be lists (all of the same length), or non-list values that will be used for all added records. + Either a UserTable or a RecordSet may used as the first argument. If it is a RecordSet created + with lookupRecords, it may set additional col_values. + Returns a list of inserted records. + """ + assert isinstance(record_set_or_table, (records.RecordSet, table.UserTable)) + count = _get_col_values_count(col_values) + values = _unify_col_values(col_values, count) + + if isinstance(record_set_or_table, records.RecordSet): + table_obj = record_set_or_table._table + group_by = record_set_or_table._group_by + if group_by: + values.update((k, [v] * count) for k, v in group_by.iteritems() if k not in values) + else: + table_obj = record_set_or_table.table + + row_ids = self._engine.user_actions.BulkAddRecord(table_obj.table_id, [None] * count, values) + return [table_obj.Record(table_obj, r, None) for r in row_ids] + + def insert(self, record_set, position, **col_values): + """ + Add new records using col_values, inserting them into record_set according to position. + This may only be used when record_set is sorted by a field of type PositionNumber; in + particular it must be the result of lookupRecords() with 'sort_by' parameter. + Position may be numeric (to compare to other sort_by values), or None to insert at the end. + Returns a list of inserted records. + """ + assert isinstance(record_set, records.RecordSet), \ + "docmodel.insert() may only be used on a RecordSet, not %s" % type(record_set) + sort_by = getattr(record_set, '_sort_by', None) + assert sort_by, \ + "docmodel.insert() may only be used on a sorted RecordSet" + column = record_set._table.get_column(sort_by) + assert isinstance(column.type_obj, usertypes.PositionNumber), \ + "docmodel.insert() may only be used on a RecordSet sorted by PositionNumber type column" + + col_values[sort_by] = float('inf') if position is None else position + return self.add(record_set, **col_values) + + def insert_after(self, record_set, position, **col_values): + """ + Same as insert, but when position is equal to the position of an existing record, inserts + after that record; and when position is None, inserts at the beginning. + """ + # We can reuse insert() by just using the next float for position. As long as positions of + # existing records are different, that would necessarily place the new records correctly. + pos = float('-inf') if position is None else relabeling.nextfloat(position) + return self.insert(record_set, pos, **col_values) + + +def _unify_col_values(col_values, count): + """ + Helper that converts a dict mapping keys to values or lists of values to all lists. Non-list + values get turned into lists by repeating them count times. + """ + assert all(len(v) == count for v in col_values.itervalues() if isinstance(v, list)) + return {k: (v if isinstance(v, list) else [v] * count) + for k, v in col_values.iteritems()} + +def _get_col_values_count(col_values): + """ + Helper that returns the length of the first list in among the values of col_values. If none of + the values is a list, returns 1. + """ + first_list = next((v for v in col_values.itervalues() if isinstance(v, list)), None) + return len(first_list) if first_list is not None else 1 diff --git a/sandbox/grist/engine.py b/sandbox/grist/engine.py new file mode 100644 index 00000000..9065e152 --- /dev/null +++ b/sandbox/grist/engine.py @@ -0,0 +1,1257 @@ +""" +The data engine ties the code generated from the schema with the document data, and with +dependency tracking. +""" +import contextlib +import itertools +import re +import sys +import traceback +from collections import namedtuple, OrderedDict, Hashable +from sortedcontainers import SortedSet + +import time +import rlcompleter +import acl +import actions +import action_obj +from codebuilder import DOLLAR_REGEX +import depend +import docactions +import docmodel +import gencode +import logger +import match_counter +import objtypes +import schema +import table as table_module +import useractions +import column +import repl + +log = logger.Logger(__name__, logger.INFO) + +class OrderError(Exception): + """ + An exception thrown and handled internally, representing when + evaluating a formula for a cell requires a value from another cell + (or lookup) that has not yet itself been evaluated. Formulas used + to be evaluated recursively, on the program stack, but now ordering + is organized explicitly by watching for this exception and adapting + evaluation order appropriately. + """ + def __init__(self, message, node, row_id): + super(OrderError, self).__init__(message) + self.node = node # The column of the cell evaluated out of order. + self.row_id = row_id # The row_id of the cell evaluated out of order. + self.requiring_node = None # The column of the original cell being evaluated. + # Added later since not known at point of exception. + self.requiring_row_id = None # the row_id of the original cell being evaluated + + def set_requirer(self, node, row_id): + self.requiring_node = node + self.requiring_row_id = row_id + +# An item of work to be done by Engine._update +WorkItem = namedtuple('WorkItem', ('node', 'row_ids', 'locks')) + +# Needed because some comparisons may fail (e.g. datetimes with different tzinfo objects) +def _equal_values(a, b): + try: + return a == b + except Exception: + return False + +# Returns an AddTable action which can be used to reproduce the given docmodel table +def _get_table_actions(table): + schema_cols = [schema.make_column(c.colId, c.type, formula=c.formula, isFormula=c.isFormula) + for c in table.columns] + return actions.AddTable(table.tableId, schema_cols) + + +# skip private members, and methods we don't want to expose to users. +skipped_completions = re.compile(r'\.(_|lookupOrAddDerived|getSummarySourceGroup)') + + +# Unique sentinel values with which columns are initialized before any data is loaded into them. +# For formula columns, it ensures that any calculation produces an unequal value and gets included +# into a calc action. +_pending_sentinel = object() + +# The schema for the data is documented in gencode.py. + +# There is a general process by which values get recomputed. There are two stages: +# (1) when raw data is loaded or changed by an action, it marks things as "dirty". +# This is done using engine.recompute_map, which maps Nodes to sets of dirty rows. +# (2) when up-to-date data is needed, _recompute is called, and updates the dirty rows. +# Up-to-date data is needed when it's required externally (e.g. to send to client), and +# may be needed recursively when other data is being recomputed. + +# In this implementation, rows are identified by a row_id, which functions like an index, so that +# data may be stored in lists and typed arrays. This is very memory-efficient when row_ids are +# dense, but bad when they get too sparse. TODO The proposed solution is to have a condense +# operation which renumbers row_ids when they get too sparse. + +# TODO: +# We should support types SubRecord, SubRecordList, and SubRecordMap. Original thought was to +# represent them as derived tables with special names, such as "Foo.field". This breaks several +# assumptions about how to organize generated code. Instead, we can use derived tables with valid +# names (such as "Foo_field"), and add an actual column "field" with an appropriate type. This +# column may refer to derived tables or independent tables. Derived tables would have an extra +# property, marking them as derived, which would affect certain UI decisions. + + +class Engine(object): + """ + The Engine is the core of the grist per-document logic. Some of its methods form the API exposed + to the Node controller. These are: + + Initialization: + + load_empty() + Initializes an empty document; useful for newly-created documents. + + load_meta_tables(meta_tables, meta_columns) + load_table(table_data) + load_done() + These three must be called in-order to initialize a non-empty document. + - First, load_meta_tables() must be called with data for the two special metadata tables + containing the schema. It returns the list of other table names the data engine expects. + - Then load_table() must be called once for each of the other tables (both special tables, + and user tables), with that table's data (no need to call it for empty tables). + - Finally, load_done() must be called once to finish initialization. + + Other methods: + + fetch_table(table_id, formulas) + Returns a TableData object containing the full data for the table. Formula columns + are included only if formulas is True. + + apply_user_actions(user_actions) + Applies a list of UserActions, which are tuples consisting of the name of the action + method (as defind in useractions.py) and the arguments to it. Returns ActionGroup tuple, + containing several categories of DocActions, including the results of computations. + """ + + class ComputeFrame(object): + """ + Represents the node and ID of the value currently being recomputed. There is a stack of + ComputeFrames, because during computation we may access other out-of-date nodes, and need to + recompute those first. + compute_frame.current_row_id gets set to each record ID as we go through them. + """ + def __init__(self, node): + self.node = node + self.current_row_id = None + + + def __init__(self): + # The document data, incuding logic (formulas), and metadata (tables prefixed with "_grist_"). + self.tables = {} # Maps table IDs (or names) to Table objects. + + # Schema contains information about tables and columns, needed in particular to generate the + # code, from which in turn we create all the Table and Column objects. Schema is an + # OrderedDict of tableIds to schema.SchemaTable objects. Each of those contains a .columns + # OrderedDict of colId to schema.SchemaColumns objects. Order is used when generating code. + self.schema = OrderedDict() + + # A more convenient interface to the document metadata. + self.docmodel = docmodel.DocModel(self) + + # The module containing the compiled user code generated from the schema. + self.gencode = gencode.GenCode() + + # Maintain the dependency graph of what Nodes (columns) depend on what other Nodes. + self.dep_graph = depend.Graph() + + # Maps Nodes to sets of dirty rows (that need to be recomputed). + self.recompute_map = {} + + # Maps Nodes to sets of done rows (to avoid recomputing in an infinite loop). + self._recompute_done_map = {} + + # Contains Nodes once an exception value has been seen for them. + self._is_node_exception_reported = set() + + # Contains Edges (node1, node2, relation) already seen during formula accesses. + self._recompute_edge_set = set() + + # Sanity-check counter to check if we are making progress. + self._recompute_done_counter = 0 + + # Maps Nodes to a list of [rowId, value] pairs for cells that have been changed. + # Ordered to preserve the order in which first change was made to a column. + # This allows actions to be emitted in a legacy order that a lot of tests depend + # on. Not necessary to functioning, just a convenience. + self._changes_map = OrderedDict() + + # This is set when we are running engine._update_loop, which has the ability to + # evaluate dependencies. We check this flag in engine._recompute_in_order, which will + # start an update loop if called without one already in place. + self._in_update_loop = False + + # A set of (node, row_id) cell references. When evaluating a formula, a dependency + # on any of these cells implies a circular dependency. + self._locked_cells = set() + + # The lists of actions of different kinds, built up while applying an action. + self.out_actions = action_obj.ActionGroup() + + # Stack of compute frames. + self._compute_stack = [] + + # Certain recomputations are triggered by a particular doc action. This keep track of it. + self._triggering_doc_action = None + + # The list of columns that got deleted while applying an action. + self._gone_columns = [] + + # The set of potentially unused LookupMapColumns. + self._unused_lookups = set() + + # Create the formula tracer that can be overridden to trace formula evaluations. It is called + # with the Column and Record object for the formula about to be evaluated. It's used in tests. + self.formula_tracer = lambda col, record: None + + # Create the object that knows how to interpret UserActions. + self.doc_actions = docactions.DocActions(self) + + # Create the object that knows how to interpret UserActions. + self.user_actions = useractions.UserActions(self) + + # A flag for when a useraction causes a schema change, to verify consistency afterwards. + self._schema_updated = False + + # Locals dict for recently executed code in the REPL + self._repl = repl.REPLInterpreter() + + # The single ACL instance for breaking up and validating actions according to permissions. + self._acl = acl.ACL(self.docmodel) + + # Stores an exception representing the first unevaluated cell met while recomputing the + # current cell. + self._cell_required_error = None + + def load_empty(self): + """ + Initialize an empty document, e.g. a newly-created one. + """ + self.load_meta_tables(actions.TableData('_grist_Tables', [], {}), + actions.TableData('_grist_Tables_column', [], {})) + self.load_done() + + def load_meta_tables(self, meta_tables, meta_columns): + """ + Must be the first method to call for this Engine. The arguments must contain the data for the + _grist_Tables and _grist_Tables_column tables, in the form of actions.TableData. + Returns the list of all the other table names that data engine expects to be loaded. + """ + self.schema = schema.build_schema(meta_tables, meta_columns) + + # Compile the user-defined module code (containing all formulas in particular). + self.rebuild_usercode() + + # Load the data into the now-existing metadata tables. This isn't used directly, it's just a + # mirror of the schema for storage and for looking at. + self.load_table(meta_tables) + self.load_table(meta_columns) + return sorted(table_id for table_id in self.tables + if table_id not in (meta_tables.table_id, meta_columns.table_id)) + + def load_table(self, data): + """ + Must be called for each of the metadata tables (except the ones given to load_meta), and for + each user-defined table. The argument is an actions.TableData object. + """ + table = self.tables[data.table_id] + + # Clear all columns, whether or not they are present in the data. + for column in table.all_columns.itervalues(): + column.clear() + + # Only load non-formula columns + columns = {col_id: data for (col_id, data) in data.columns.iteritems() + if table.has_column(col_id) and not table.get_column(col_id).is_formula()} + + # Add the records. + self.add_records(data.table_id, data.row_ids, columns) + + def load_done(self): + """ + Finalizes the loading of data into this Engine. + """ + self._bring_all_up_to_date() + + def add_records(self, table_id, row_ids, column_values): + """ + Helper to add records to the given table, with row_ids and column_values having the same + interpretation as in TableData or BulkAddRecords. It's used both for the initial loading of + data, and for BulkAddRecords itself. + """ + table = self.tables[table_id] + + growto_size = (max(row_ids) + 1) if row_ids else 1 + + # Create the new records. + id_column = table.get_column('id') + id_column.growto(growto_size) + for row_id in row_ids: + id_column.set(row_id, row_id) + + # Resize all columns to the full table size. + table.grow_to_max() + + # Load the new values. + for col_id, values in column_values.iteritems(): + column = table.get_column(col_id) + column.growto(growto_size) + for row_id, value in itertools.izip(row_ids, values): + column.set(row_id, value) + + # Set all values in formula columns to a special "pending" sentinel value, so that when they + # are calculated, they are considered changed and included into the produced calc actions. + # This matters because the client starts off seeing formula columns as "pending" values too. + for column in table.all_columns.itervalues(): + if not column.is_formula(): + continue + for row_id in row_ids: + column.set(row_id, _pending_sentinel) + + # Invalidate new records to cause the formula columns to get recomputed. + self.invalidate_records(table_id, row_ids) + + def fetch_table(self, table_id, formulas=True, private=False, query=None): + """ + Returns TableData object representing all data in this table. + """ + table = self.tables[table_id] + column_values = {} + + query_cols = [] + if query: + query_cols = [(table.get_column(col_id), values) for (col_id, values) in query.iteritems()] + row_ids = [r for r in table.row_ids + if all((c.raw_get(r) in values) for (c, values) in query_cols)] + + for c in table.all_columns.itervalues(): + # pylint: disable=too-many-boolean-expressions + if ((formulas or not c.is_formula()) + and (private or not c.is_private()) + and c.col_id != "id" and not column.is_virtual_column(c.col_id)): + column_values[c.col_id] = map(c.raw_get, row_ids) + + return actions.TableData(table_id, row_ids, column_values) + + def fetch_table_schema(self): + return self.gencode.get_user_text() + + def fetch_meta_tables(self, formulas=True): + """ + Returns {table_id: TableData} mapping for all metadata tables (those starting with '_grist_'). + + Note the slight naming difference with load_meta_tables: that one expects just two + extra-special tables, whereas fetch_meta_tables returns all special tables. + """ + return {table_id: self.fetch_table(table_id, formulas=formulas) + for table_id in self.tables if table_id.startswith('_grist_')} + + def fetch_snapshot(self): + """ + Returns a full list of actions which when applied sequentially recreate the doc database to + its current state. + """ + action_group = action_obj.ActionGroup() + action_group.stored = self._get_snapshot_actions() + return action_group + + def _get_snapshot_actions(self): + """ + Returns a list of action objects which recreate the document database when applied. + """ + schema_actions = schema.schema_create_actions() + table_actions = [_get_table_actions(table) for table in self.docmodel.tables.all] + record_actions = [self._get_record_actions(table_id) for (table_id,t) in self.tables.iteritems() + if t.next_row_id() > 1] + return schema_actions + table_actions + record_actions + + # Returns a BulkAddRecord action which can be used to add the currently existing data to an empty + # version of the table with the given table_id. + def _get_record_actions(self, table_id): + table_data = self.fetch_table(table_id, formulas=False) + return actions.BulkAddRecord(table_id, table_data.row_ids, table_data.columns) + + def find_col_from_values(self, values, n, opt_table_id=None): + """ + Returns a list of colRefs for columns whose values match a given list. The results are ordered + from best to worst according to the number of matches of distinct values. + + If n is non-zero, limits the results to that number. If opt_table_id is given, search only + that table for matching columns. + """ + start_time = time.time() + # Exclude default values, since these will often result in matching new/incomplete columns. + # If a value is unhashable, set() will fail, so we check for that. + sample = set(v for v in values if isinstance(v, Hashable)) + matched_cols = [] + + # If the column has no values, return + if not sample: + return [] + + search_cols = (self.docmodel.get_table_rec(opt_table_id).columns + if opt_table_id in self.tables else self.docmodel.columns.all) + + m = match_counter.MatchCounter(sample) + # Iterates through each valid column in the document, counting matches. + for c in search_cols: + if (not gencode._is_special_table(c.tableId) and + column.is_visible_column(c.colId) and + not c.type.startswith('Ref')): + table = self.tables[c.tableId] + col = table.get_column(c.colId) + matches = m.count_unique(col.raw_get(r) for r in itertools.islice(table.row_ids, 1000)) + if matches > 0: + matched_cols.append((matches, c.id)) + + # Sorts the matched columns by the matches, then select the best-matching columns + matched_cols.sort(reverse=True) + if n: + matched_cols = matched_cols[:n] + + log.info('Found column from values in %.3fs' % (time.time() - start_time)) + return [c[1] for c in matched_cols] + + def assert_schema_consistent(self): + """ + Asserts that the internally-stored schema is equivalent to the schema as represented by the + special tables of metadata. + """ + meta_tables = self.fetch_table('_grist_Tables') + meta_columns = self.fetch_table('_grist_Tables_column') + gen_schema = schema.build_schema(meta_tables, meta_columns) + gen_schema_dicts = {k: (t.tableId, dict(t.columns.iteritems())) + for k, t in gen_schema.iteritems()} + cur_schema_dicts = {k: (t.tableId, dict(t.columns.iteritems())) + for k, t in self.schema.iteritems()} + if cur_schema_dicts != gen_schema_dicts: + import pprint + import difflib + a = (pprint.pformat(cur_schema_dicts) + "\n").splitlines(True) + b = (pprint.pformat(gen_schema_dicts) + "\n").splitlines(True) + raise AssertionError("Internal schema different from that in metadata:\n" + + "".join(difflib.unified_diff(a, b, fromfile="internal", tofile="metadata"))) + + def dump_state(self): + self.dep_graph.dump_graph() + self.dump_recompute_map() + + def dump_recompute_map(self): + log.debug("Recompute map (%d nodes):" % len(self.recompute_map)) + for node, dirty_rows in self.recompute_map.iteritems(): + log.debug(" Node %s: %s" % (node, dirty_rows)) + + @contextlib.contextmanager + def open_compute_frame(self, node): + """ + Use as: `with open_compute_frame(node) as frame:`. This automatically maintains the stack of + ComputeFrames, pushing and popping reliably. + """ + frame = Engine.ComputeFrame(node) + self._compute_stack.append(frame) + try: + yield frame + finally: + self._compute_stack.pop() + + def get_current_frame(self): + """ + Returns the compute frame currently being computed, or None if there isn't one. + """ + return self._compute_stack[-1] if self._compute_stack else None + + def _use_node(self, node, relation, row_ids=[]): + # This is used whenever a formula accesses any part of any record. It's hot code, and + # it's worth optimizing. + + if self._compute_stack and self._compute_stack[-1].node: + # Add an edge to indicate that the node being computed depends on the node passed in. + # Note that during evaluation, we only *add* dependencies. We *remove* them by clearing them + # whenever ALL rows for a node are invalidated (on schema changes and reloads). + current_node = self._compute_stack[-1].node + edge = (current_node, node, relation) + if edge not in self._recompute_edge_set: + self.dep_graph.add_edge(*edge) + self._recompute_edge_set.add(edge) + + # This check is not essential here, but is an optimization that saves cycles. + if self.recompute_map.get(node) is None: + return + + self._recompute(node, row_ids) + + def _pre_update(self): + """ + Called at beginning of _bring_all_up_to_date or _bring_lookups_up_to_date. + Makes sure cell change accumulation is reset. + """ + self._changes_map = OrderedDict() + self._recompute_done_map = {} + self._locked_cells = set() + self._is_node_exception_reported = set() + self._recompute_edge_set = set() + self._cell_required_error = None + + def _post_update(self): + """ + Called at end of _bring_all_up_to_date or _bring_lookups_up_to_date. + Issues actions for any accumulated cell changes. + """ + for node, changes in self._changes_map.iteritems(): + if not changes: + continue + table = self.tables[node.table_id] + col = table.get_column(node.col_id) + # If there are changes, create and add a BulkUpdateRecord either to 'calc' or 'stored' + # actions, as appropriate. + changed_rows = [c[0] for c in changes] + changed_values = [c[1] for c in changes] + action = (actions.BulkUpdateRecord(col.table_id, changed_rows, {col.col_id: changed_values}) + .simplify()) + if action and not col.is_private(): + if col.is_formula(): + self.out_actions.calc.append(action) + else: + # We may compute values for non-formula columns (e.g. for a newly-added record), in which + # case we need a stored action. TODO: If this code path occurs during anything other than + # an AddRecord, we also need an undo action. + self.out_actions.stored.append(action) + self._pre_update() # empty lists/sets/maps + + def _update_loop(self, work_items, ignore_other_changes=False): + """ + Called to compute the specified cells, including any nested dependencies. + Consumes OrderError exceptions, and reacts to them with a strategy for + reordering cell evaluation. That strategy is currently simple: + * Maintain a stack of work item triplets. Each work item has: + - A node (table/column pair). + - A list of row_ids to compute (this can be None, meaning "all"). + - A list of row_ids to "unlock" once finished. + * Until stack is empty, take a work item off the stack and attempt to + _recompute the specified rows of the specified node. + - If an OrderError is received, first check it is for a cell we + requested (_recompute will opportunistically try to compute + other cells we haven't asked for, and it is important for the + purposes of cycle detection to discount that). + - If so, "lock" that cell, push the current work item back on the + stack (remembering which cell to unlock later), and add a new + work item for the cell that threw the OrderError. + + The "lock" serves only for cycle detection. + + The order of stack placement means that the cell that threw + the OrderError will now be evaluated before the cell that + depends on it. + - If not, ignore the OrderError. If we actually need that cell, + We'll get back to it later as we work up the work_items stack. + * The _recompute method, as mentioned, will attempt to compute not + just the requested rows of a particular column, but any other dirty + cells in that column. This is an important optimization for the + common case of columns with non-self-referring dependencies. + """ + self._in_update_loop = True + while self.recompute_map: + self._recompute_done_counter = 0 + self._expected_done_counter = 0 + while work_items: + node, row_ids, locks = work_items.pop() + try: + self._recompute_step(node, require_rows=row_ids) + except OrderError as e: + # Need to schedule re-ordered evaluation + assert node == e.requiring_node + assert (not row_ids) or (e.requiring_row_id in row_ids) + # Put current work item back on stack, and don't dispose its locks + work_items.append(WorkItem(node, row_ids, locks)) + locks = [] + # Add a new work item for the cell we are following up, and lock + # it to forbid circular dependencies + lock = (node, e.requiring_row_id) + work_items.append(WorkItem(e.node, [e.row_id], [lock])) + self._locked_cells.add(lock) + # Discard any locks once work item is complete + for lock in locks: + if lock not in self._locked_cells: + # If cell is already unlocked, don't double-count it. + continue + self._locked_cells.discard(lock) + # Sanity check: make sure we've computed at least one more cell + self._expected_done_counter += 1 + if self._recompute_done_counter < self._expected_done_counter: + raise Exception('data engine not making progress updating dependencies') + if ignore_other_changes: + # For _bring_lookups_up_to_date, we should only wait for the work items + # explicitly requested. + break + # Sanity check that we computed at least one cell. + if self.recompute_map and self._recompute_done_counter == 0: + raise Exception('data engine not making progress updating formulas') + # Figure out remaining work to do, maintaining classic Grist ordering. + nodes = sorted(self.recompute_map.keys(), reverse=True) + work_items = [WorkItem(node, None, []) for node in nodes] + self._in_update_loop = False + + def _bring_all_up_to_date(self): + # Bring all nodes up to date. We iterate in sorted order of the keys so that the order is + # deterministic (which is helpful for tests in particular). + self._pre_update() + try: + # Figure out remaining work to do, maintaining classic Grist ordering. + nodes = sorted(self.recompute_map.keys(), reverse=True) + work_items = [WorkItem(node, None, []) for node in nodes] + self._update_loop(work_items) + # Check if any potentially unused LookupMaps are still unused, and if so, delete them. + for lookup_map in self._unused_lookups: + if self.dep_graph.remove_node_if_unused(lookup_map.node): + self.delete_column(lookup_map) + finally: + self._unused_lookups.clear() + self._post_update() + + def _bring_lookups_up_to_date(self, triggering_doc_action): + # Just bring the lookup nodes up to date. This is part of a somewhat hacky solution in + # apply_doc_action: lookup nodes don't know exactly what depends on them until they are + # recomputed. So invalidating lookup nodes doesn't complete all invalidation; further + # invalidations may be generated in the course of recomputing the lookup nodes. So we force + # recomputation of lookup nodes to ensure that we see up-to-date results between applying doc + # actions. + # + # This matters for private formulas used internally; it isn't needed for external use, since + # all nodes are brought up to date before responding to a user action anyway. + # + # In addition, we expose the triggering doc_action so that lookupOrAddDerived can avoid adding + # a record to a derived table when the trigger itself is a change to the derived table. This + # currently only happens on undo, and is admittedly an ugly workaround. + self._pre_update() + try: + self._triggering_doc_action = triggering_doc_action + nodes = sorted(self.recompute_map.keys(), reverse=True) + nodes = [node for node in nodes if node.col_id.startswith('#lookup')] + work_items = [WorkItem(node, None, []) for node in nodes] + self._update_loop(work_items, ignore_other_changes=True) + finally: + self._triggering_doc_action = None + self._post_update() + + def is_triggered_by_table_action(self, table_id): + # Workaround for lookupOrAddDerived that prevents AddRecord from being created when the + # trigger is itself an action for the same table. See comments for _bring_lookups_up_to_date. + a = self._triggering_doc_action + return a and getattr(a, 'table_id', None) == table_id + + def bring_col_up_to_date(self, col_obj): + """ + Public interface to recompute a column if it is dirty. It also generates a calc or stored + action and adds it into self.out_actions object. + """ + self._recompute_done_map.pop(col_obj.node, None) + self._recompute(col_obj.node) + + def get_formula_error(self, table_id, col_id, row_id): + """ + Returns an error message (traceback) for one concrete cell which user clicked. + It is sufficient in case when we want to get traceback for only one formula cell with error, + not recomputing the whole column and dependent columns as well. So it recomputes the formula + for this cell and returns error message with details. + """ + table = self.tables[table_id] + col = table.get_column(col_id) + checkpoint = self._get_undo_checkpoint() + try: + return self._recompute_one_cell(None, table, col, row_id) + finally: + # It is possible for formula evaluation to have side-effects that produce DocActions (e.g. + # lookupOrAddDerived() creates those). In case of get_formula_error(), these aren't fully + # processed (e.g. don't get applied to DocStorage), so it's important to reverse them. + self._undo_to_checkpoint(checkpoint) + + def _recompute(self, node, row_ids=None): + """ + Make sure cells of a node are up to date, recomputing as necessary. Can optionally + be limited to a list of rows that are of interest. + """ + if self._in_update_loop: + # This is a nested evaluation. If there are in fact any cells to evaluate, + # this must result in an OrderError. We let engine._recompute_step + # take care of figuring this out. + self._recompute_step(node, allow_evaluation=False, require_rows=row_ids) + else: + # Sometimes _use_node is called from outside _update_loop. In this case, + # we start an _update_loop to compute whatever is required. Otherwise + # nested dependencies would not get computed. + self._update_loop([WorkItem(node, row_ids, [])], ignore_other_changes=True) + + + def _recompute_step(self, node, allow_evaluation=True, require_rows=None): # pylint: disable=too-many-statements + """ + Recomputes a node (i.e. column), evaluating the appropriate formula for the given rows + to get new values. Only columns whose .has_formula() is true should ever have invalidated rows + in recompute_map (this includes data columns with a default formula, for newly-added records). + + If `allow_evaluation` is false, any time we would recompute a node, we instead throw + an OrderError exception. This is used to "flatten" computation - instead of evaluating + nested dependencies on the program stack, an external loop will evaluate them in an + unnested order. Remember that formulas may access other columns, and column access calls + engine._use_node, which calls _recompute to bring those nodes up to date. + + Recompute records changes in _changes_map, which is used later to generate appropriate + BulkUpdateRecord actions, either calc (for formulas) or stored (for non-formula columns). + """ + + dirty_rows = self.recompute_map.get(node, None) + if dirty_rows is None: + return + + table = self.tables[node.table_id] + col = table.get_column(node.col_id) + assert col.has_formula(), "Engine._recompute: called on no-formula node %s" % (node,) + + # Get a sorted list of row IDs, excluding deleted rows (they will sometimes end up in + # recompute_map) and rows already done (since _recompute_done_map got cleared). + if node not in self._recompute_done_map: + # Before starting to evaluate a formula, call reset_rows() + # on all relations with nodes we depend on. E.g. this is + # used for lookups, so that we can reset stored lookup + # information for rows that are about to get reevaluated. + self.dep_graph.reset_dependencies(node, dirty_rows) + self._recompute_done_map[node] = set() + + exclude = self._recompute_done_map[node] + if dirty_rows == depend.ALL_ROWS: + dirty_rows = SortedSet(r for r in table.row_ids if r not in exclude) + self.recompute_map[node] = dirty_rows + require_rows = sorted(require_rows or []) + + # Prevents dependency creation for non-formula nodes. A non-formula column may include a + # formula to eval for a newly-added record. Those shouldn't create dependencies. + formula_node = node if col.is_formula() else None + + changes = None + cleaned = [] # this lists row_ids that can be removed from dirty_rows once we are no + # longer iterating on it. + with self.open_compute_frame(formula_node) as frame: + try: + require_count = len(require_rows) + for i, row_id in enumerate(itertools.chain(require_rows, dirty_rows)): + required = i < require_count or require_count == 0 + if require_count and row_id not in dirty_rows: + # Nothing need be done for required rows that are already up to date. + continue + if row_id not in table.row_ids or row_id in exclude: + # We can declare victory for absent or excluded rows. + cleaned.append(row_id) + continue + if not allow_evaluation: + # We're not actually in a position to evaluate this cell, we need to just + # report that we needed an _update_loop will arrange for us to be called + # again in a better order. + if required: + msg = 'Cell value not available yet' + err = OrderError(msg, node, row_id) + if not self._cell_required_error: + # Cache the exception in case user consumes it or modifies it in their formula. + self._cell_required_error = OrderError(msg, node, row_id) + raise err + # For common-case formulas, all cells in a column are likely to fail in the same way, + # so don't bother trying more from this column until we've reordered. + return + try: + # We figure out if we've hit a cycle here. If so, we just let _recompute_on_cell + # know, so it can set the cell value appropriately and do some other bookkeeping. + cycle = required and (node, row_id) in self._locked_cells + value = self._recompute_one_cell(frame, table, col, row_id, cycle=cycle, node=node) + except OrderError as e: + if not required: + # We're out of order, but for a cell we were evaluating opportunistically. + # Don't throw an exception, since it could lead us off on a wild goose + # chase - let _update_loop focus on one path at a time. + return + # Keep track of why this cell was needed. + e.requiring_node = node + e.requiring_row_id = row_id + raise e + + # Successfully evaluated a cell! Unlock it if it was locked, so other cells can + # use it without triggering a cyclic dependency error. + self._locked_cells.discard((node, row_id)) + + if isinstance(value, objtypes.RaisedException): + is_first = node not in self._is_node_exception_reported + if is_first: + self._is_node_exception_reported.add(node) + log.info(value.details) + value = objtypes.RaisedException(value.error) # strip out details after logging + + # TODO: validation columns should be wrapped to always return True/False (catching + # exceptions), so that we don't need special handling here. + if column.is_validation_column_name(col.col_id): + value = (value in (True, None)) + + # Convert the value, and if needed, set, and include into the returned action. + value = col.convert(value) + if not _equal_values(value, col.raw_get(row_id)): + if not changes: + changes = self._changes_map.setdefault(node, []) + changes.append([row_id, value]) + col.set(row_id, value) + exclude.add(row_id) + cleaned.append(row_id) + self._recompute_done_counter += 1 + # If no particular rows were requested, and we arrive here, + # that means we made it through the whole column! For long + # columns, it is worth deleting dirty_rows in one step rather + # than discarding one cell at a time. + if require_rows is None: + cleaned = [] + dirty_rows = None + + finally: + for row_id in cleaned: + # this modifies self.recompute_map[node], to which dirty_rows is a reference + dirty_rows.discard(row_id) + if not dirty_rows: + self.recompute_map.pop(node) + + def _recompute_one_cell(self, frame, table, col, row_id, cycle=False, node=None): + """ + Recomputes an one formula cell and returns a value. + The value can be: + - the recomputed value in case there are no errors + - exception + - exception with details if flag include_details is set + """ + if frame: + frame.current_row_id = row_id + + # Baffling, but keeping a reference to current generated "usercode" module protects against a + # seeming garbage-collection bug: if during formula evaluation the module gets regenerated + # (e.g. a side-effect causes a formula column to change to non-formula), the stale-module + # formula code that's still running will see None values in the usermodule's module-dictionary; + # just keeping this extra reference allows stale formulas to see valid values. + usercode_reference = self.gencode.usercode + + checkpoint = self._get_undo_checkpoint() + record = table.Record(table, row_id, table._identity_relation) + try: + if cycle: + raise depend.CircularRefError("Circular Reference") + result = col.method(record, table.user_table) + if self._cell_required_error: + raise self._cell_required_error # pylint: disable=raising-bad-type + self.formula_tracer(col, record) + return result + except: # pylint: disable=bare-except + # Since col.method runs untrusted user code, we use a bare except to catch all + # exceptions (even those not derived from BaseException). + + # Before storing the exception value, make sure there isn't an OrderError pending. + # If there is, we will raise it after undoing any side effects. + order_error = self._cell_required_error + + # Otherwise, we use sys.exc_info to recover the raised exception object. + regular_error = sys.exc_info()[1] if not order_error else None + + # It is possible for formula evaluation to have side-effects that produce DocActions (e.g. + # lookupOrAddDerived() creates those). If there is an error, undo any such side-effects. + self._undo_to_checkpoint(checkpoint) + + # Now we can raise the order error, if there was one. Cell evaluation will be reordered + # in response. + if order_error: + self._cell_required_error = None + raise order_error # pylint: disable=raising-bad-type + + self.formula_tracer(col, record) + + include_details = (node not in self._is_node_exception_reported) if node else True + return objtypes.RaisedException(regular_error, include_details) + + def convert_action_values(self, action): + """ + Given a BulkUpdateRecord or BulkAddRecord action, convert the values using the appropriate + Column objects, replacing them with the right-type value, alttext, or error objects. + """ + table_id, row_ids, column_values = action + table = self.tables[action.table_id] + new_values = {} + extra_actions = [] + for col_id, values in column_values.iteritems(): + col_obj = table.get_column(col_id) + values = [col_obj.convert(val) for val in values] + + # If there are values for any PositionNumber columns, ensure PositionNumbers are ordered as + # intended but are all unique, which may require updating other positions. + nvalues, adjustments = col_obj.prepare_new_values(values) + if adjustments: + extra_actions.append(actions.BulkUpdateRecord( + action.table_id, [r for r,v in adjustments], {col_id: [v for r,v in adjustments]})) + + new_values[col_id] = nvalues + + if isinstance(action, (actions.BulkAddRecord, actions.ReplaceTableData)): + # Make sure we call prepare_new_values() for ALL columns when adding rows. The for-loop + # above does it for columns explicitly mentioned; this section does it for the other + # columns, using their default values as input to prepare_new_values(). + ignore_data = isinstance(action, actions.ReplaceTableData) + for col_id, col_obj in table.all_columns.iteritems(): + if col_id in column_values or column.is_virtual_column(col_id) or col_obj.is_formula(): + continue + defaults = [col_obj.getdefault() for r in row_ids] + # We use defaults to get new values or adjustments. If we are replacing data, we'll make + # the adjustments without regard to the existing data. + nvalues, adjustments = col_obj.prepare_new_values(defaults, ignore_data=ignore_data) + if adjustments: + extra_actions.append(actions.BulkUpdateRecord( + action.table_id, [r for r,v in adjustments], {col_id: [v for r,v in adjustments]})) + if nvalues != defaults: + new_values[col_id] = nvalues + + # Return action of the same type (e.g. BulkUpdateAction, BulkAddAction), but with new values, + # as well as any extra actions that were generated (as could happen for position adjustments). + return (type(action)(table_id, row_ids, new_values), extra_actions) + + def trim_update_action(self, action): + """ + Takes a BulkUpdateAction, and returns a new BulkUpdateAction with only those rows that + actually cause any changes. + """ + table_id, row_ids, column_values = action + table = self.tables[action.table_id] + + # Collect for each column the Column object and a list of new values. + cols = [(table.get_column(col_id), values) for (col_id, values) in column_values.iteritems()] + + # In comparisons below, we rely here on Python's "==" operator to check for equality. After a + # type conversion, it may compare the new type to the old, e.g. 1 == 1.0 == True. It's + # important that such equality is acceptable also to JS and to DocStorage. So far, it seems + # just right. + + # Find columns for which any value actually changed. + cols = [(col_obj, values) for (col_obj, values) in cols + if any(values[i] != col_obj.raw_get(row_id) for (i, row_id) in enumerate(row_ids))] + + # Now find the indices of rows for which any value actually changed from what's in its Column. + row_subset = [i for i, row_id in enumerate(row_ids) + if any(values[i] != col_obj.raw_get(row_id) for (col_obj, values) in cols)] + + # Create and return a new action with just the selected subset of rows. + return actions.BulkUpdateRecord( + action.table_id, + [row_ids[i] for i in row_subset], + {col_obj.col_id: [values[i] for i in row_subset] + for (col_obj, values) in cols} + ) + + def eval_user_code(self, src): + ret = self._repl.runsource(src) + self.gencode.usercode.__dict__.update(self._repl.locals) + return ret + + def invalidate_records(self, table_id, row_ids=depend.ALL_ROWS, col_ids=None, + data_cols_to_recompute=frozenset()): + """ + Invalidate the records with the given row_ids. If col_ids is given, only those columns are + invalidated (otherwise all columns). If data_cols_to_recompute is given, then non-formula + col_ids that have an associated formula will get invalidated too, to cause recomputation. + + Note that it's not just about formula columns; pure data columns need to cause invalidation of + formula columns that depend on them. Those data columns that have an associated formula may + additionally (typically on AddRecord) be themselves invalidated, to cause recomputation. + """ + table = self.tables[table_id] + columns = (table.all_columns.values() + if col_ids is None else [table.get_column(c) for c in col_ids]) + for column in columns: + # If data_cols_to_recompute includes this column, compute its default formula. This + # flag is set on AddRecord and BulkAddRecord, when a default formula needs to be computed. + self.invalidate_column(column, row_ids, column.col_id in data_cols_to_recompute) + + def invalidate_column(self, col_obj, row_ids=depend.ALL_ROWS, recompute_data_col=False): + # Normally, only formula columns use include_self (to recompute themselves). However, if + # recompute_data_col is set, default formulas will also be computed. + include_self = col_obj.is_formula() or (col_obj.has_formula() and recompute_data_col) + self.dep_graph.invalidate_deps(col_obj.node, row_ids, self.recompute_map, + include_self=include_self) + + def rebuild_usercode(self): + """ + Compiles the usercode from the schema, and updates all tables and columns to match. + Also, keeps the locals in the repl in sync with the user code, so that the repl has access to + usercode and vice-versa. + """ + self.gencode.make_module(self.schema) + + # Re-populate self.tables, reusing existing tables whenever possible. + old_tables = self.tables + + self.tables = {} + for table_id, user_table in self.gencode.usercode.__dict__.iteritems(): + if isinstance(user_table, table_module.UserTable): + self.tables[table_id] = (old_tables.get(table_id) or table_module.Table(table_id, self)) + + # Now update the table model for each table, and tie it to its UserTable object. + for table_id, table in self.tables.iteritems(): + user_table = getattr(self.gencode.usercode, table_id) + self._update_table_model(table, user_table) + user_table._set_table_impl(table) + + # For any tables that are gone, use self._update_table_model to clean them up. + for table_id, table in old_tables.iteritems(): + if table_id not in self.tables: + self._update_table_model(table, None) + self._repl.locals.pop(table_id, None) + + # Update docmodel with references to the updated metadata tables. + self.docmodel.update_tables() + + # The order here is important to make sure that when we update the usercode, + # we don't overwrite with outdated usercode entries + self._repl.locals.update(self.gencode.usercode.__dict__) + self.gencode.usercode.__dict__.update(self._repl.locals) + + # TODO: Whenever schema changes, we need to adjust the ACL resources to remove or rename + # tableIds and colIds. + + + def _update_table_model(self, table, user_table): + """ + Updates the given Table object to match the given user_table (from usercode module). This + builds new columns as needed, and cleans up. To clean up state for a table getting removed, + pass in user_table of None. + """ + # Save the dict of columns before the update. + old_columns = table.all_columns.copy() + + if user_table is None: + new_columns = {} + else: + # Update the table's model. This also builds new columns if needed. + table._rebuild_model(user_table) + new_columns = table.all_columns + + added_col_ids = new_columns.viewkeys() - old_columns.viewkeys() + deleted_col_ids = old_columns.viewkeys() - new_columns.viewkeys() + + # Invalidate the columns that got added and anything that depends on them. + if added_col_ids: + self.invalidate_records(table.table_id, col_ids=added_col_ids) + + for col_id in deleted_col_ids: + self.invalidate_column(old_columns[col_id]) + + # Schedule deleted columns for clean-up. + for c in deleted_col_ids: + self.delete_column(old_columns[c]) + + if user_table is None: + for c in table.get_helper_columns(): + self.delete_column(c) + + + def delete_column(self, col_obj): + # Remove the column from its table. + if col_obj.table_id in self.tables: + self.tables[col_obj.table_id].delete_column(col_obj) + + # Invalidate anything that depends on the column being deleted. The column may be gone from + # the table itself, so we use invalidate_column directly. + self.invalidate_column(col_obj) + # Remove reference to the column from the recompute_map. + self.recompute_map.pop(col_obj.node, None) + # Mark the column to be destroyed at the end of applying this docaction. + self._gone_columns.append(col_obj) + + + def new_column_name(self, table): + """ + Invalidate anything that referenced unknown columns, in case the newly-added name fixes the + broken reference. + """ + self.dep_graph.invalidate_deps(table._new_columns_node, depend.ALL_ROWS, self.recompute_map, + include_self=False) + + def mark_lookupmap_for_cleanup(self, lookup_map_column): + """ + Once a LookupMapColumn seems no longer used, it's added here. We'll check after recomputing + everything, and if still unused, will clean it up. + """ + self._unused_lookups.add(lookup_map_column) + + def apply_user_actions(self, user_actions): + """ + Applies the list of user_actions. Returns an ActionGroup. + """ + # We currently recompute everything and send all calc actions back on every change. If clients + # only need a subset of data loaded, it would be better to filter calc actions, and + # include only those the clients care about. For side-effects, we might want to recompute + # everything, and only filter what we send. + + self.out_actions = action_obj.ActionGroup() + + checkpoint = self._get_undo_checkpoint() + try: + for user_action in user_actions: + self._schema_updated = False + self.out_actions.retValues.append(self._apply_one_user_action(user_action)) + + # If the UserAction touched the schema, check that it is now consistent with metadata. + if self._schema_updated: + self.assert_schema_consistent() + + except Exception, e: + # Save full exception info, so that we can rethrow accurately even if undo also fails. + exc_info = sys.exc_info() + # If we get an exception, we should revert all changes applied so far, to keep things + # consistent internally as well as with the clients and database outside of the sandbox + # (which won't see any changes in case of an error). + log.info("Failed to apply useractions; reverting: %r" % (e,)) + self._undo_to_checkpoint(checkpoint) + + # Check schema consistency again. If this fails, something is really wrong (we tried to go + # back to a good state but failed). We'll just report it loudly. + try: + if self._schema_updated: + self.assert_schema_consistent() + except Exception: + log.error("Inconsistent schema after revert on failure: %s" % traceback.format_exc()) + + # Re-raise the original exception (simple `raise` wouldn't do if undo also fails above). + raise exc_info[0], exc_info[1], exc_info[2] + + # Note that recalculations and auto-removals get included after processing all useractions. + self._bring_all_up_to_date() + + # Apply any triggered record removals. If anything does get removed, recalculate what's needed. + while self.docmodel.apply_auto_removes(): + self._bring_all_up_to_date() + + return self.out_actions + + def acl_split(self, action_group): + """ + Splits ActionGroups, as returned e.g. from apply_user_actions, by permissions. Returns a + single ActionBundle containing of all of the original action_groups. + """ + return self._acl.acl_read_split(action_group) + + def _apply_one_user_action(self, user_action): + """ + Applies a single user action to the document, without running any triggered updates. + A UserAction is a tuple whose first element is the name of the action. + """ + log.debug("applying user_action %s" % (user_action,)) + return getattr(self.user_actions, user_action.__class__.__name__)(*user_action) + + def apply_doc_action(self, doc_action): + """ + Applies a doc action, which is a step of a user action. It is represented by an Action object + as defined in actions.py. + """ + #log.warn("Engine.apply_doc_action %s" % (doc_action,)) + self._gone_columns = [] + + action_name = doc_action.__class__.__name__ + saved_schema = None + if action_name in actions.schema_actions: + self._schema_updated = True + # Make a copy of the schema. If a bug causes a docaction to fail after modifying schema, we + # restore it, or we'll end up with mismatching schema and metadata. + saved_schema = schema.clone_schema(self.schema) + + try: + getattr(self.doc_actions, action_name)(*doc_action) + except Exception: + # Save full exception info, so that we can rethrow accurately even if this clause also fails. + exc_info = sys.exc_info() + if saved_schema: + log.info("Restoring schema and usercode on exception") + self.schema = saved_schema + try: + self.rebuild_usercode() + except Exception: + log.error("Error rebuilding usercode after restoring schema: %s" % traceback.format_exc()) + # Re-raise the original exception (simple `raise` wouldn't do if rebuild also fails above). + raise exc_info[0], exc_info[1], exc_info[2] + + # If any columns got deleted, destroy them to clear _back_references in other tables, and to + # force errors if anything still uses them. Also clear them from calc actions if needed. + for col in self._gone_columns: + # Calc actions may already be generated if the column deletion was triggered by auto-removal. + actions.prune_actions(self.out_actions.calc, col.table_id, col.col_id) + col.destroy() + + # We normally recompute formulas before returning to the user; but some formulas are also used + # internally in-between applying doc actions. We have this workaround to ensure that those are + # up-to-date after each doc action. See more in comments for _bring_lookups_up_to_date. + # We check _compute_stack to avoid a recursive call (happens when a formula produces an + # action, as for derived/summary tables). + if not self._compute_stack: + self._bring_lookups_up_to_date(doc_action) + + def autocomplete(self, txt, table_id): + """ + Return a list of suggested completions of the python fragment supplied. + """ + # replace $ with rec. and add a dummy rec object + tweaked_txt = DOLLAR_REGEX.sub(r'rec.', txt) + # convert a bare $ with nothing after it also + if txt == '$': + tweaked_txt = 'rec.' + table = self.tables[table_id] + context = {'rec': table.sample_record} + context.update(self.gencode.usercode.__dict__) + + completer = rlcompleter.Completer(context) + results = [] + at = 0 + while True: + # Get a possible completion. Result will be None or "" + result = completer.complete(tweaked_txt, at) + at += 1 + if not result: + break + if skipped_completions.search(result): + continue + results.append(result) + # If we changed the prefix (expanding the $ symbol) we now need to change it back. + if tweaked_txt != txt: + results = [txt + result[len(tweaked_txt):] for result in results] + results.sort() + return results + + def _get_undo_checkpoint(self): + """ + You may call _get_undo_checkpoint() and pass its result into _undo_to_checkpoint() to undo + DocActions saved since the first call; but only while in a single apply_user_actions() call. + """ + # We produce a tuple of lengths: one for each of the properties of out_actions ActionObj. + aobj = self.out_actions + return (len(aobj.calc), len(aobj.stored), len(aobj.undo), len(aobj.retValues)) + + def _undo_to_checkpoint(self, checkpoint): + """ + See _get_undo_checkpoint() above. + """ + # Check if out_actions ActionObj grew at all since _get_undo_checkpoint(). If yes, revert by + # applying any undo actions, and trim it back to original state (if we don't trim it, it will + # only grow further, with undo actions themselves getting applied as new doc actions). + new_checkpoint = self._get_undo_checkpoint() + if new_checkpoint != checkpoint: + (len_calc, len_stored, len_undo, len_ret) = checkpoint + undo_actions = self.out_actions.undo[len_undo:] + log.info("Reverting %d doc actions" % len(undo_actions)) + self.user_actions.ApplyUndoActions(map(actions.get_action_repr, undo_actions)) + del self.out_actions.calc[len_calc:] + del self.out_actions.stored[len_stored:] + del self.out_actions.undo[len_undo:] + del self.out_actions.retValues[len_ret:] + + +# end diff --git a/sandbox/grist/functions/__init__.py b/sandbox/grist/functions/__init__.py new file mode 100644 index 00000000..485f4dda --- /dev/null +++ b/sandbox/grist/functions/__init__.py @@ -0,0 +1,12 @@ +# pylint: disable=wildcard-import +from date import * +from info import * +from logical import * +from lookup import * +from math import * +from stats import * +from text import * +from schedule import * + +# Export all uppercase names, for use with `from functions import *`. +__all__ = [k for k in dir() if not k.startswith('_') and k.isupper()] diff --git a/sandbox/grist/functions/date.py b/sandbox/grist/functions/date.py new file mode 100644 index 00000000..7262c09c --- /dev/null +++ b/sandbox/grist/functions/date.py @@ -0,0 +1,773 @@ +import calendar +import datetime +import dateutil.parser +import moment +import docmodel + +# pylint: disable=no-member + +_excel_date_zero = datetime.datetime(1899, 12, 30) + + +def _make_datetime(value): + if isinstance(value, datetime.datetime): + return value + elif isinstance(value, datetime.date): + return datetime.datetime.combine(value, datetime.time()) + elif isinstance(value, datetime.time): + return datetime.datetime.combine(datetime.date.today(), value) + elif isinstance(value, basestring): + return dateutil.parser.parse(value) + else: + raise ValueError('Invalid date %r' % (value,)) + +def _get_global_tz(): + if docmodel.global_docmodel: + return docmodel.global_docmodel.doc_info.lookupOne(id=1).tzinfo + return moment.TZ_UTC + +def _get_tzinfo(zonelabel): + """ + A helper that returns a `datetime.tzinfo` instance for zonelabel. Returns the global + document timezone if zonelabel is None. + """ + return moment.tzinfo(zonelabel) if zonelabel else _get_global_tz() + +def DTIME(value, tz=None): + """ + Returns the value converted to a python `datetime` object. The value may be a + `string`, `date` (interpreted as midnight on that day), `time` (interpreted as a + time-of-day today), or an existing `datetime`. + + The returned `datetime` will have its timezone set to the `tz` argument, or the + document's default timezone when `tz` is omitted or None. If the input is itself a + `datetime` with the timezone set, it is returned unchanged (no changes to its timezone). + + >>> DTIME(datetime.date(2017, 1, 1)) + datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> DTIME(datetime.date(2017, 1, 1), 'Europe/Paris') + datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('Europe/Paris')) + >>> DTIME(datetime.datetime(2017, 1, 1)) + datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> DTIME(datetime.datetime(2017, 1, 1, tzinfo=moment.tzinfo('UTC'))) + datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('UTC')) + >>> DTIME(datetime.datetime(2017, 1, 1, tzinfo=moment.tzinfo('UTC')), 'Europe/Paris') + datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('UTC')) + >>> DTIME("1/1/2008") + datetime.datetime(2008, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + """ + value = _make_datetime(value) + return value if value.tzinfo else value.replace(tzinfo=_get_tzinfo(tz)) + + +def XL_TO_DATE(value, tz=None): + """ + Converts a provided Excel serial number representing a date into a `datetime` object. + Value is interpreted as the number of days since December 30, 1899. + + (This corresponds to Google Sheets interpretation. Excel starts with Dec. 31, 1899 but wrongly + considers 1900 to be a leap year. Excel for Mac should be configured to use 1900 date system, + i.e. uncheck "Use the 1904 date system" option.) + + The returned `datetime` will have its timezone set to the `tz` argument, or the + document's default timezone when `tz` is omitted or None. + + >>> XL_TO_DATE(41100.1875) + datetime.datetime(2012, 7, 10, 4, 30, tzinfo=moment.tzinfo('America/New_York')) + >>> XL_TO_DATE(39448) + datetime.datetime(2008, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> XL_TO_DATE(40982.0625) + datetime.datetime(2012, 3, 14, 1, 30, tzinfo=moment.tzinfo('America/New_York')) + + More tests: + >>> XL_TO_DATE(0) + datetime.datetime(1899, 12, 30, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> XL_TO_DATE(-1) + datetime.datetime(1899, 12, 29, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> XL_TO_DATE(1) + datetime.datetime(1899, 12, 31, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> XL_TO_DATE(1.5) + datetime.datetime(1899, 12, 31, 12, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> XL_TO_DATE(61.0) + datetime.datetime(1900, 3, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + """ + return DTIME(_excel_date_zero, tz) + datetime.timedelta(days=value) + + +def DATE_TO_XL(date_value): + """ + Converts a Python `date` or `datetime` object to the serial number as used by + Excel, with December 30, 1899 as serial number 1. + + See XL_TO_DATE for more explanation. + + >>> DATE_TO_XL(datetime.date(2008, 1, 1)) + 39448.0 + >>> DATE_TO_XL(datetime.date(2012, 3, 14)) + 40982.0 + >>> DATE_TO_XL(datetime.datetime(2012, 3, 14, 1, 30)) + 40982.0625 + + More tests: + >>> DATE_TO_XL(datetime.date(1900, 1, 1)) + 2.0 + >>> DATE_TO_XL(datetime.datetime(1900, 1, 1)) + 2.0 + >>> DATE_TO_XL(datetime.datetime(1900, 1, 1, 12, 0)) + 2.5 + >>> DATE_TO_XL(datetime.datetime(1900, 1, 1, 12, 0, tzinfo=moment.tzinfo('America/New_York'))) + 2.5 + >>> DATE_TO_XL(datetime.date(1900, 3, 1)) + 61.0 + >>> DATE_TO_XL(datetime.datetime(2008, 1, 1)) + 39448.0 + >>> DATE_TO_XL(XL_TO_DATE(39488)) + 39488.0 + >>> dt_ny = XL_TO_DATE(39488) + >>> dt_paris = moment.tz(dt_ny, 'America/New_York').tz('Europe/Paris').datetime() + >>> DATE_TO_XL(dt_paris) + 39488.0 + """ + # If date_value is `naive` it's ok to pass tz to both DTIME as it won't affect the + # result. + return (DTIME(date_value) - DTIME(_excel_date_zero)).total_seconds() / 86400. + + +def DATE(year, month, day): + """ + Returns the `datetime.datetime` object that represents a particular date. + The DATE function is most useful in formulas where year, month, and day are formulas, not + constants. + + If year is between 0 and 1899 (inclusive), adds 1900 to calculate the year. + >>> DATE(108, 1, 2) + datetime.date(2008, 1, 2) + >>> DATE(2008, 1, 2) + datetime.date(2008, 1, 2) + + If month is greater than 12, rolls into the following year. + >>> DATE(2008, 14, 2) + datetime.date(2009, 2, 2) + + If month is less than 1, subtracts that many months plus 1, from the first month in the year. + >>> DATE(2008, -3, 2) + datetime.date(2007, 9, 2) + + If day is greater than the number of days in the given month, rolls into the following months. + >>> DATE(2008, 1, 35) + datetime.date(2008, 2, 4) + + If day is less than 1, subtracts that many days plus 1, from the first day of the given month. + >>> DATE(2008, 1, -15) + datetime.date(2007, 12, 16) + + More tests: + >>> DATE(1900, 1, 1) + datetime.date(1900, 1, 1) + >>> DATE(1900, 0, 0) + datetime.date(1899, 11, 30) + """ + if year < 1900: + year += 1900 + norm_month = (month - 1) % 12 + 1 + norm_year = year + (month - 1) // 12 + return datetime.date(norm_year, norm_month, 1) + datetime.timedelta(days=day - 1) + + +def DATEDIF(start_date, end_date, unit): + """ + Calculates the number of days, months, or years between two dates. + Unit indicates the type of information that you want returned: + + - "Y": The number of complete years in the period. + - "M": The number of complete months in the period. + - "D": The number of days in the period. + - "MD": The difference between the days in start_date and end_date. The months and years of the + dates are ignored. + - "YM": The difference between the months in start_date and end_date. The days and years of the + dates are ignored. + - "YD": The difference between the days of start_date and end_date. The years of the dates are + ignored. + + Two complete years in the period (2) + >>> DATEDIF(DATE(2001, 1, 1), DATE(2003, 1, 1), "Y") + 2 + + 440 days between June 1, 2001, and August 15, 2002 (440) + >>> DATEDIF(DATE(2001, 6, 1), DATE(2002, 8, 15), "D") + 440 + + 75 days between June 1 and August 15, ignoring the years of the dates (75) + >>> DATEDIF(DATE(2001, 6, 1), DATE(2012, 8, 15), "YD") + 75 + + The difference between 1 and 15, ignoring the months and the years of the dates (14) + >>> DATEDIF(DATE(2001, 6, 1), DATE(2002, 8, 15), "MD") + 14 + + More tests: + >>> DATEDIF(DATE(1969, 7, 16), DATE(1969, 7, 24), "D") + 8 + >>> DATEDIF(DATE(2014, 1, 1), DATE(2015, 1, 1), "M") + 12 + >>> DATEDIF(DATE(2014, 1, 2), DATE(2015, 1, 1), "M") + 11 + >>> DATEDIF(DATE(2014, 1, 1), DATE(2024, 1, 1), "Y") + 10 + >>> DATEDIF(DATE(2014, 1, 2), DATE(2024, 1, 1), "Y") + 9 + >>> DATEDIF(DATE(1906, 10, 16), DATE(2004, 2, 3), "YM") + 3 + >>> DATEDIF(DATE(2016, 2, 14), DATE(2016, 3, 14), "YM") + 1 + >>> DATEDIF(DATE(2016, 2, 14), DATE(2016, 3, 13), "YM") + 0 + >>> DATEDIF(DATE(2008, 10, 16), DATE(2019, 12, 3), "MD") + 17 + >>> DATEDIF(DATE(2008, 11, 16), DATE(2019, 1, 3), "MD") + 18 + >>> DATEDIF(DATE(2016, 2, 29), DATE(2017, 2, 28), "Y") + 0 + >>> DATEDIF(DATE(2016, 2, 29), DATE(2017, 2, 29), "Y") + 1 + """ + if isinstance(start_date, datetime.datetime): + start_date = start_date.date() + if isinstance(end_date, datetime.datetime): + end_date = end_date.date() + if unit == 'D': + return (end_date - start_date).days + elif unit == 'M': + months = (end_date.year - start_date.year) * 12 + (end_date.month - start_date.month) + month_delta = 0 if start_date.day <= end_date.day else 1 + return months - month_delta + elif unit == 'Y': + years = end_date.year - start_date.year + year_delta = 0 if (start_date.month, start_date.day) <= (end_date.month, end_date.day) else 1 + return years - year_delta + elif unit == 'MD': + month_delta = 0 if start_date.day <= end_date.day else 1 + return (end_date - DATE(end_date.year, end_date.month - month_delta, start_date.day)).days + elif unit == 'YM': + month_delta = 0 if start_date.day <= end_date.day else 1 + return (end_date.month - start_date.month - month_delta) % 12 + elif unit == 'YD': + year_delta = 0 if (start_date.month, start_date.day) <= (end_date.month, end_date.day) else 1 + return (end_date - DATE(end_date.year - year_delta, start_date.month, start_date.day)).days + else: + raise ValueError('Invalid unit %s' % (unit,)) + + +def DATEVALUE(date_string, tz=None): + """ + Converts a date that is stored as text to a `datetime` object. + + >>> DATEVALUE("1/1/2008") + datetime.datetime(2008, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> DATEVALUE("30-Jan-2008") + datetime.datetime(2008, 1, 30, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> DATEVALUE("2008-12-11") + datetime.datetime(2008, 12, 11, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> DATEVALUE("5-JUL").replace(year=2000) + datetime.datetime(2000, 7, 5, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + + In case of ambiguity, prefer M/D/Y format. + >>> DATEVALUE("1/2/3") + datetime.datetime(2003, 1, 2, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + + More tests: + >>> DATEVALUE("8/22/2011") + datetime.datetime(2011, 8, 22, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> DATEVALUE("22-MAY-2011") + datetime.datetime(2011, 5, 22, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> DATEVALUE("2011/02/23") + datetime.datetime(2011, 2, 23, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> DATEVALUE("11/3/2011") + datetime.datetime(2011, 11, 3, 0, 0, tzinfo=moment.tzinfo('America/New_York')) + >>> DATE_TO_XL(DATEVALUE("11/3/2011")) + 40850.0 + >>> DATEVALUE("asdf") + Traceback (most recent call last): + ... + ValueError: Unknown string format + """ + return dateutil.parser.parse(date_string).replace(tzinfo=_get_tzinfo(tz)) + + +def DAY(date): + """ + Returns the day of a date, as an integer ranging from 1 to 31. Same as `date.day`. + + >>> DAY(DATE(2011, 4, 15)) + 15 + >>> DAY("5/31/2012") + 31 + >>> DAY(datetime.datetime(1900, 1, 1)) + 1 + """ + return _make_datetime(date).day + + +def DAYS(end_date, start_date): + """ + Returns the number of days between two dates. Same as `(end_date - start_date).days`. + + >>> DAYS("3/15/11","2/1/11") + 42 + >>> DAYS(DATE(2011, 12, 31), DATE(2011, 1, 1)) + 364 + >>> DAYS("2/1/11", "3/15/11") + -42 + """ + return (_make_datetime(end_date) - _make_datetime(start_date)).days + + +def EDATE(start_date, months): + """ + Returns the date that is the given number of months before or after `start_date`. Use + EDATE to calculate maturity dates or due dates that fall on the same day of the month as the + date of issue. + + >>> EDATE(DATE(2011, 1, 15), 1) + datetime.date(2011, 2, 15) + >>> EDATE(DATE(2011, 1, 15), -1) + datetime.date(2010, 12, 15) + >>> EDATE(DATE(2011, 1, 15), 2) + datetime.date(2011, 3, 15) + >>> EDATE(DATE(2012, 3, 1), 10) + datetime.date(2013, 1, 1) + >>> EDATE(DATE(2012, 5, 1), -2) + datetime.date(2012, 3, 1) + """ + return DATE(start_date.year, start_date.month + months, start_date.day) + + +def DATEADD(start_date, days=0, months=0, years=0, weeks=0): + """ + Returns the date a given number of days, months, years, or weeks away from `start_date`. You may + specify arguments in any order if you specify argument names. Use negative values to subtract. + + For example, `DATEADD(date, 1)` is the same as `DATEADD(date, days=1)`, ands adds one day to + `date`. `DATEADD(date, years=1, days=-1)` adds one year minus one day. + + >>> DATEADD(DATE(2011, 1, 15), 1) + datetime.date(2011, 1, 16) + >>> DATEADD(DATE(2011, 1, 15), months=1, days=-1) + datetime.date(2011, 2, 14) + >>> DATEADD(DATE(2011, 1, 15), years=-2, months=1, days=3, weeks=2) + datetime.date(2009, 3, 4) + >>> DATEADD(DATE(1975, 4, 30), years=50, weeks=-5) + datetime.date(2025, 3, 26) + """ + return DATE(start_date.year + years, start_date.month + months, + start_date.day + days + weeks * 7) + + +def EOMONTH(start_date, months): + """ + Returns the date for the last day of the month that is the indicated number of months before or + after start_date. Use EOMONTH to calculate maturity dates or due dates that fall on the last day + of the month. + + >>> EOMONTH(DATE(2011, 1, 1), 1) + datetime.date(2011, 2, 28) + >>> EOMONTH(DATE(2011, 1, 15), -3) + datetime.date(2010, 10, 31) + >>> EOMONTH(DATE(2012, 3, 1), 10) + datetime.date(2013, 1, 31) + >>> EOMONTH(DATE(2012, 5, 1), -2) + datetime.date(2012, 3, 31) + """ + return DATE(start_date.year, start_date.month + months + 1, 1) - datetime.timedelta(days=1) + + +def HOUR(time): + """ + Returns the hour of a `datetime`, as an integer from 0 (12:00 A.M.) to 23 (11:00 P.M.). + Same as `time.hour`. + + >>> HOUR(XL_TO_DATE(0.75)) + 18 + >>> HOUR("7/18/2011 7:45") + 7 + >>> HOUR("4/21/2012") + 0 + """ + return _make_datetime(time).hour + + +def ISOWEEKNUM(date): + """ + Returns the ISO week number of the year for a given date. + + >>> ISOWEEKNUM("3/9/2012") + 10 + >>> [ISOWEEKNUM(DATE(2000 + y, 1, 1)) for y in [0,1,2,3,4,5,6,7,8]] + [52, 1, 1, 1, 1, 53, 52, 1, 1] + """ + return _make_datetime(date).isocalendar()[1] + + +def MINUTE(time): + """ + Returns the minutes of `datetime`, as an integer from 0 to 59. + Same as `time.minute`. + + >>> MINUTE(XL_TO_DATE(0.75)) + 0 + >>> MINUTE("7/18/2011 7:45") + 45 + >>> MINUTE("12:59:00 PM") + 59 + >>> MINUTE(datetime.time(12, 58, 59)) + 58 + """ + return _make_datetime(time).minute + + +def MONTH(date): + """ + Returns the month of a date represented, as an integer from from 1 (January) to 12 (December). + Same as `date.month`. + + >>> MONTH(DATE(2011, 4, 15)) + 4 + >>> MONTH("5/31/2012") + 5 + >>> MONTH(datetime.datetime(1900, 1, 1)) + 1 + """ + return _make_datetime(date).month + + +def NOW(tz=None): + """ + Returns the `datetime` object for the current time. + """ + return datetime.datetime.now(_get_tzinfo(tz)) + + + +def SECOND(time): + """ + Returns the seconds of `datetime`, as an integer from 0 to 59. + Same as `time.second`. + + >>> SECOND(XL_TO_DATE(0.75)) + 0 + >>> SECOND("7/18/2011 7:45:13") + 13 + >>> SECOND(datetime.time(12, 58, 59)) + 59 + """ + + return _make_datetime(time).second + + +def TODAY(): + """ + Returns the `date` object for the current date. + """ + return datetime.date.today() + + +_weekday_type_map = { + # type: (first day of week (according to date.weekday()), number to return for it) + 1: (6, 1), + 2: (0, 1), + 3: (0, 0), + 11: (0, 1), + 12: (1, 1), + 13: (2, 1), + 14: (3, 1), + 15: (4, 1), + 16: (5, 1), + 17: (6, 1), +} + +def WEEKDAY(date, return_type=1): + """ + Returns the day of the week corresponding to a date. The day is given as an integer, ranging + from 1 (Sunday) to 7 (Saturday), by default. + + Return_type determines the type of the returned value. + + - 1 (default) - Returns 1 (Sunday) through 7 (Saturday). + - 2 - Returns 1 (Monday) through 7 (Sunday). + - 3 - Returns 0 (Monday) through 6 (Sunday). + - 11 - Returns 1 (Monday) through 7 (Sunday). + - 12 - Returns 1 (Tuesday) through 7 (Monday). + - 13 - Returns 1 (Wednesday) through 7 (Tuesday). + - 14 - Returns 1 (Thursday) through 7 (Wednesday). + - 15 - Returns 1 (Friday) through 7 (Thursday). + - 16 - Returns 1 (Saturday) through 7 (Friday). + - 17 - Returns 1 (Sunday) through 7 (Saturday). + + >>> WEEKDAY(DATE(2008, 2, 14)) + 5 + >>> WEEKDAY(DATE(2012, 3, 1)) + 5 + >>> WEEKDAY(DATE(2012, 3, 1), 1) + 5 + >>> WEEKDAY(DATE(2012, 3, 1), 2) + 4 + >>> WEEKDAY("3/1/2012", 3) + 3 + + More tests: + >>> WEEKDAY(XL_TO_DATE(10000), 1) + 4 + >>> WEEKDAY(DATE(1901, 1, 1)) + 3 + >>> WEEKDAY(DATE(1901, 1, 1), 2) + 2 + >>> [WEEKDAY(DATE(2008, 2, d)) for d in [10, 11, 12, 13, 14, 15, 16, 17]] + [1, 2, 3, 4, 5, 6, 7, 1] + >>> [WEEKDAY(DATE(2008, 2, d), 1) for d in [10, 11, 12, 13, 14, 15, 16, 17]] + [1, 2, 3, 4, 5, 6, 7, 1] + >>> [WEEKDAY(DATE(2008, 2, d), 17) for d in [10, 11, 12, 13, 14, 15, 16, 17]] + [1, 2, 3, 4, 5, 6, 7, 1] + >>> [WEEKDAY(DATE(2008, 2, d), 2) for d in [10, 11, 12, 13, 14, 15, 16, 17]] + [7, 1, 2, 3, 4, 5, 6, 7] + >>> [WEEKDAY(DATE(2008, 2, d), 3) for d in [10, 11, 12, 13, 14, 15, 16, 17]] + [6, 0, 1, 2, 3, 4, 5, 6] + """ + if return_type not in _weekday_type_map: + raise ValueError("Invalid return type %s" % (return_type,)) + (first, index) = _weekday_type_map[return_type] + return (_make_datetime(date).weekday() - first) % 7 + index + + +def WEEKNUM(date, return_type=1): + """ + Returns the week number of a specific date. For example, the week containing January 1 is the + first week of the year, and is numbered week 1. + + Return_type determines which week is considered the first week of the year. + + - 1 (default) - Week 1 is the first week starting Sunday that contains January 1. + - 2 - Week 1 is the first week starting Monday that contains January 1. + - 11 - Week 1 is the first week starting Monday that contains January 1. + - 12 - Week 1 is the first week starting Tuesday that contains January 1. + - 13 - Week 1 is the first week starting Wednesday that contains January 1. + - 14 - Week 1 is the first week starting Thursday that contains January 1. + - 15 - Week 1 is the first week starting Friday that contains January 1. + - 16 - Week 1 is the first week starting Saturday that contains January 1. + - 17 - Week 1 is the first week starting Sunday that contains January 1. + - 21 - ISO 8601 Approach: Week 1 is the first week starting Monday that contains January 4. + Equivalently, it is the week that contains the first Thursday of the year. + + >>> WEEKNUM(DATE(2012, 3, 9)) + 10 + >>> WEEKNUM(DATE(2012, 3, 9), 2) + 11 + >>> WEEKNUM('1/1/1900') + 1 + >>> WEEKNUM('2/1/1900') + 5 + + More tests: + >>> WEEKNUM('2/1/1909', 2) + 6 + >>> WEEKNUM('1/1/1901', 21) + 1 + >>> [WEEKNUM(DATE(2012, 3, 9), t) for t in [1,2,11,12,13,14,15,16,17,21]] + [10, 11, 11, 11, 11, 11, 11, 10, 10, 10] + """ + if return_type == 21: + return ISOWEEKNUM(date) + if return_type not in _weekday_type_map: + raise ValueError("Invalid return type %s" % (return_type,)) + (first, index) = _weekday_type_map[return_type] + date = _make_datetime(date) + jan1 = datetime.datetime(date.year, 1, 1) + week1_start = jan1 - datetime.timedelta(days=(jan1.weekday() - first) % 7) + return (date - week1_start).days // 7 + 1 + + +def YEAR(date): + """ + Returns the year corresponding to a date as an integer. + Same as `date.year`. + + >>> YEAR(DATE(2011, 4, 15)) + 2011 + >>> YEAR("5/31/2030") + 2030 + >>> YEAR(datetime.datetime(1900, 1, 1)) + 1900 + """ + return _make_datetime(date).year + + +def _date_360(y, m, d): + return y * 360 + m * 30 + d + +def _last_of_feb(date): + return date.month == 2 and (date + datetime.timedelta(days=1)).month == 3 + +def YEARFRAC(start_date, end_date, basis=0): + """ + Calculates the fraction of the year represented by the number of whole days between two dates. + + Basis is the type of day count basis to use. + + * `0` (default) - US (NASD) 30/360 + * `1` - Actual/actual + * `2` - Actual/360 + * `3` - Actual/365 + * `4` - European 30/360 + * `-1` - Actual/actual (Google Sheets variation) + + This function is useful for financial calculations. For compatibility with Excel, it defaults to + using the NASD standard calendar. For use in non-financial settings, option `-1` is + likely the best choice. + + See for explanation of + the US 30/360 and European 30/360 methods. See for analysis of + Excel's particular implementation. + + Basis `-1` is similar to `1`, but differs from Excel when dates span both leap and non-leap years. + It matches the calculation in Google Sheets, counting the days in each year as a fraction of + that year's length. + + Fraction of the year between 1/1/2012 and 7/30/12, omitting the Basis argument. + >>> "%.8f" % YEARFRAC(DATE(2012, 1, 1), DATE(2012, 7, 30)) + '0.58055556' + + Fraction between same dates, using the Actual/Actual basis argument. Because 2012 is a Leap + year, it has a 366 day basis. + >>> "%.8f" % YEARFRAC(DATE(2012, 1, 1), DATE(2012, 7, 30), 1) + '0.57650273' + + Fraction between same dates, using the Actual/365 basis argument. Uses a 365 day basis. + >>> "%.8f" % YEARFRAC(DATE(2012, 1, 1), DATE(2012, 7, 30), 3) + '0.57808219' + + More tests: + >>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30)), 10) + 0.4972222222 + >>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 0), 10) + 0.4972222222 + >>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 1), 10) + 0.4945355191 + >>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 2), 10) + 0.5027777778 + >>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 3), 10) + 0.495890411 + >>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 4), 10) + 0.4972222222 + >>> [YEARFRAC(DATE(2012, 1, 1), DATE(2012, 1, 1), t) for t in [0, 1, -1, 2, 3, 4]] + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + >>> [round(YEARFRAC(DATE(1985, 3, 15), DATE(2016, 2, 29), t), 6) for t in [0, 1, -1, 2, 3, 4]] + [30.955556, 30.959617, 30.961202, 31.411111, 30.980822, 30.955556] + >>> [round(YEARFRAC(DATE(2001, 2, 28), DATE(2016, 3, 31), t), 6) for t in [0, 1, -1, 2, 3, 4]] + [15.086111, 15.085558, 15.086998, 15.305556, 15.09589, 15.088889] + >>> [round(YEARFRAC(DATE(1968, 4, 7), DATE(2011, 2, 14), t), 6) for t in [0, 1, -1, 2, 3, 4]] + [42.852778, 42.855578, 42.855521, 43.480556, 42.884932, 42.852778] + + Here we test "basis 1" on leap and non-leap years. + >>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2015, 3, 1), t), 6) for t in [1, -1]] + [0.161644, 0.161644] + >>> [round(YEARFRAC(DATE(2016, 1, 1), DATE(2016, 3, 1), t), 6) for t in [1, -1]] + [0.163934, 0.163934] + >>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2016, 1, 1), t), 6) for t in [1, -1]] + [1.0, 1.0] + >>> [round(YEARFRAC(DATE(2016, 1, 1), DATE(2017, 1, 1), t), 6) for t in [1, -1]] + [1.0, 1.0] + >>> [round(YEARFRAC(DATE(2016, 2, 29), DATE(2017, 1, 1), t), 6) for t in [1, -1]] + [0.838798, 0.838798] + >>> [round(YEARFRAC(DATE(2014, 12, 15), DATE(2015, 3, 15), t), 6) for t in [1, -1]] + [0.246575, 0.246575] + + For these examples, Google Sheets differs from Excel, and we match Excel here. + >>> [round(YEARFRAC(DATE(2015, 12, 15), DATE(2016, 3, 15), t), 6) for t in [1, -1]] + [0.248634, 0.248761] + >>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2016, 2, 29), t), 6) for t in [1, -1]] + [1.160055, 1.161202] + >>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2016, 2, 28), t), 6) for t in [1, -1]] + [1.157319, 1.15847] + >>> [round(YEARFRAC(DATE(2015, 3, 1), DATE(2016, 2, 29), t), 6) for t in [1, -1]] + [0.997268, 0.999558] + >>> [round(YEARFRAC(DATE(2015, 3, 1), DATE(2016, 2, 28), t), 6) for t in [1, -1]] + [0.99726, 0.996826] + >>> [round(YEARFRAC(DATE(2016, 3, 1), DATE(2017, 1, 1), t), 6) for t in [1, -1]] + [0.838356, 0.836066] + >>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2017, 1, 1), t), 6) for t in [1, -1]] + [2.000912, 2.0] + """ + # pylint: disable=too-many-return-statements + # This function is actually completely crazy. The rules are strange too. We'll follow the logic + # in http://www.dwheeler.com/yearfrac/excel-ooxml-yearfrac.pdf + if start_date == end_date: + return 0.0 + if start_date > end_date: + start_date, end_date = end_date, start_date + + d1, m1, y1 = start_date.day, start_date.month, start_date.year + d2, m2, y2 = end_date.day, end_date.month, end_date.year + + if basis == 0: + if d1 == 31: + d1 = 30 + if d1 == 30 and d2 == 31: + d2 = 30 + if _last_of_feb(start_date): + d1 = 30 + if _last_of_feb(end_date): + d2 = 30 + return (_date_360(y2, m2, d2) - _date_360(y1, m1, d1)) / 360.0 + + elif basis == 1: + # This implements Excel's convoluted logic. + if (y1 + 1, m1, d1) >= (y2, m2, d2): + # Less than or equal to one year. + if y1 == y2 and calendar.isleap(y1): + year_length = 366.0 + elif (y1, m1, d1) < (y2, 2, 29) <= (y2, m2, d2) and calendar.isleap(y2): + year_length = 366.0 + elif (y1, m1, d1) <= (y1, 2, 29) < (y2, m2, d2) and calendar.isleap(y1): + year_length = 366.0 + else: + year_length = 365.0 + else: + year_length = (datetime.date(y2 + 1, 1, 1) - datetime.date(y1, 1, 1)).days / (y2 + 1.0 - y1) + return (end_date - start_date).days / year_length + + elif basis == -1: + # This is Google Sheets implementation. Call it an overkill, but I think it's more sensible. + # + # Excel's logic has the unfortunate property that YEARFRAC(a, b) + YEARFRAC(b, c) is not + # always equal to YEARFRAC(a, c). Google Sheets implements a variation that does have this + # property, counting the days in each year as a fraction of that year's length (as if each day + # is counted as 1/365 or 1/366 depending on the year). + # + # The one redeeming quality of Excel's logic is that YEARFRAC for two days that differ by + # exactly one year is 1.0 (not always true for GS). But in GS version, YEARFRAC between any + # two Jan 1 is always a whole number (not always true in Excel). + if y1 == y2: + return _one_year_frac(start_date, end_date) + return ( + + _one_year_frac(start_date, datetime.date(y1 + 1, 1, 1)) + + (y2 - y1 - 1) + + _one_year_frac(datetime.date(y2, 1, 1), end_date) + ) + + elif basis == 2: + return (end_date - start_date).days / 360.0 + + elif basis == 3: + return (end_date - start_date).days / 365.0 + + elif basis == 4: + if d1 == 31: + d1 = 30 + if d2 == 31: + d2 = 30 + return (_date_360(y2, m2, d2) - _date_360(y1, m1, d1)) / 360.0 + + raise ValueError('Invalid basis argument %r' % (basis,)) + +def _one_year_frac(start_date, end_date): + year_length = 366.0 if calendar.isleap(start_date.year) else 365.0 + return (end_date - start_date).days / year_length diff --git a/sandbox/grist/functions/info.py b/sandbox/grist/functions/info.py new file mode 100644 index 00000000..0d01466c --- /dev/null +++ b/sandbox/grist/functions/info.py @@ -0,0 +1,520 @@ +# -*- coding: UTF-8 -*- +# pylint: disable=unused-argument + +from __future__ import absolute_import +import datetime +import math +import numbers +import re + +from functions import date # pylint: disable=import-error +from usertypes import AltText # pylint: disable=import-error +from records import Record # pylint: disable=import-error + +def ISBLANK(value): + """ + Returns whether a value refers to an empty cell. It isn't implemented in Grist. To check for an + empty string, use `value == ""`. + """ + raise NotImplementedError() + + +def ISERR(value): + """ + Checks whether a value is an error. In other words, it returns true + if using `value` directly would raise an exception. + + NOTE: Grist implements this by automatically wrapping the argument to use lazy evaluation. + + A more Pythonic approach to checking for errors is: + ``` + try: + ... value ... + except Exception, err: + ... do something about the error ... + ``` + + For example: + + >>> ISERR("Hello") + False + + More tests: + >>> ISERR(lambda: (1/0.1)) + False + >>> ISERR(lambda: (1/0.0)) + True + >>> ISERR(lambda: "test".bar()) + True + >>> ISERR(lambda: "test".upper()) + False + >>> ISERR(lambda: AltText("A")) + False + >>> ISERR(lambda: float('nan')) + False + >>> ISERR(lambda: None) + False + """ + return lazy_value_or_error(value) is _error_sentinel + + +def ISERROR(value): + """ + Checks whether a value is an error or an invalid value. It is similar to `ISERR`, but also + returns true for an invalid value such as NaN or a text value in a Numeric column. + + NOTE: Grist implements this by automatically wrapping the argument to use lazy evaluation. + + >>> ISERROR("Hello") + False + >>> ISERROR(AltText("fail")) + True + >>> ISERROR(float('nan')) + True + + More tests: + >>> ISERROR(AltText("")) + True + >>> [ISERROR(v) for v in [0, None, "", "Test", 17.0]] + [False, False, False, False, False] + >>> ISERROR(lambda: (1/0.1)) + False + >>> ISERROR(lambda: (1/0.0)) + True + >>> ISERROR(lambda: "test".bar()) + True + >>> ISERROR(lambda: "test".upper()) + False + >>> ISERROR(lambda: AltText("A")) + True + >>> ISERROR(lambda: float('nan')) + True + >>> ISERROR(lambda: None) + False + """ + return is_error(lazy_value_or_error(value)) + + +def ISLOGICAL(value): + """ + Checks whether a value is `True` or `False`. + + >>> ISLOGICAL(True) + True + >>> ISLOGICAL(False) + True + >>> ISLOGICAL(0) + False + >>> ISLOGICAL(None) + False + >>> ISLOGICAL("Test") + False + """ + return isinstance(value, bool) + + +def ISNA(value): + """ + Checks whether a value is the error `#N/A`. + + >>> ISNA(float('nan')) + True + >>> ISNA(0.0) + False + >>> ISNA('text') + False + >>> ISNA(float('-inf')) + False + """ + return isinstance(value, float) and math.isnan(value) + + +def ISNONTEXT(value): + """ + Checks whether a value is non-textual. + + >>> ISNONTEXT("asdf") + False + >>> ISNONTEXT("") + False + >>> ISNONTEXT(AltText("text")) + False + >>> ISNONTEXT(17.0) + True + >>> ISNONTEXT(None) + True + >>> ISNONTEXT(datetime.date(2011, 1, 1)) + True + """ + return not ISTEXT(value) + + +def ISNUMBER(value): + """ + Checks whether a value is a number. + + >>> ISNUMBER(17) + True + >>> ISNUMBER(-123.123423) + True + >>> ISNUMBER(False) + True + >>> ISNUMBER(float('nan')) + True + >>> ISNUMBER(float('inf')) + True + >>> ISNUMBER('17') + False + >>> ISNUMBER(None) + False + >>> ISNUMBER(datetime.date(2011, 1, 1)) + False + + More tests: + >>> ISNUMBER(AltText("text")) + False + >>> ISNUMBER('') + False + """ + return isinstance(value, numbers.Number) + + +def ISREF(value): + """ + Checks whether a value is a table record. + + For example, if a column person is of type Reference to the People table, then ISREF($person) + is True. + Similarly, ISREF(People.lookupOne(name=$name)) is True. For any other type of value, + ISREF() would evaluate to False. + + >>> ISREF(17) + False + >>> ISREF("Roger") + False + + """ + return isinstance(value, Record) + + +def ISTEXT(value): + """ + Checks whether a value is text. + + >>> ISTEXT("asdf") + True + >>> ISTEXT("") + True + >>> ISTEXT(AltText("text")) + True + >>> ISTEXT(17.0) + False + >>> ISTEXT(None) + False + >>> ISTEXT(datetime.date(2011, 1, 1)) + False + """ + return isinstance(value, (basestring, AltText)) + + +# Regexp for matching email. See ISEMAIL for justification. +_email_regexp = re.compile( + r""" + ^\w # Start with an alphanumeric character + [\w%+/='-]* (\.[\w%+/='-]+)* # Elsewhere allow also a few other special characters + # But no two consecutive periods + @ + ([A-Za-z0-9] # Each part of hostname must start with alphanumeric + ([A-Za-z0-9-]*[A-Za-z0-9])?\. # May have dashes inside, but end in alphanumeric + )+ + [A-Za-z]{2,6}$ # Restrict top-level domain to length {2,6}. Google seems + # to use a whitelist for TLDs longer than 2 characters. + """, re.UNICODE | re.VERBOSE) + + +# Regexp for matching hostname part of URLs (see also ISURL). Duplicates part of _email_regexp. +_hostname_regexp = re.compile( + r"""^ + ([A-Za-z0-9] # Each part of hostname must start with alphanumeric + ([A-Za-z0-9-]*[A-Za-z0-9])?\. # May have dashes inside, but end in alphanumeric + )+ + [A-Za-z]{2,6}$ # Restrict top-level domain to length {2,6}. Google seems + """, re.VERBOSE) + + +def ISEMAIL(value): + u""" + Returns whether a value is a valid email address. + + Note that checking email validity is not an exact science. The technical standard considers many + email addresses valid that are not used in practice, and would not be considered valid by most + users. Instead, we follow Google Sheets implementation, with some differences, noted below. + + >>> ISEMAIL("Abc.123@example.com") + True + >>> ISEMAIL("Bob_O-Reilly+tag@example.com") + True + >>> ISEMAIL("John Doe") + False + >>> ISEMAIL("john@aol...com") + False + + More tests: + >>> ISEMAIL("Abc@example.com") # True, True + True + >>> ISEMAIL("Abc.123@example.com") # True, True + True + >>> ISEMAIL("foo@bar.com") # True, True + True + >>> ISEMAIL("asdf@com.zt") # True, True + True + >>> ISEMAIL("Bob_O-Reilly+tag@example.com") # True, True + True + >>> ISEMAIL("john@server.department.company.com") # True, True + True + >>> ISEMAIL("asdf@mail.ru") # True, True + True + >>> ISEMAIL("fabio@foo.qwer.COM") # True, True + True + >>> ISEMAIL("user+mailbox/department=shipping@example.com") # False, True + True + >>> ISEMAIL(u"user+mailbox/department=shipping@example.com") # False, True + True + >>> ISEMAIL("customer/department=shipping@example.com") # False, True + True + >>> ISEMAIL("Bob_O'Reilly+tag@example.com") # False, True + True + >>> ISEMAIL(u"фыва@mail.ru") # False, True + True + >>> ISEMAIL("my@baddash.-.com") # True, False + False + >>> ISEMAIL("my@baddash.-a.com") # True, False + False + >>> ISEMAIL("my@baddash.b-.com") # True, False + False + >>> ISEMAIL("john@-.com") # True, False + False + >>> ISEMAIL("fabio@disapproved.solutions") # False, False + False + >>> ISEMAIL("!def!xyz%abc@example.com") # False, False + False + >>> ISEMAIL("!#$%&'*+-/=?^_`.{|}~@example.com") # False, False + False + >>> ISEMAIL(u"伊昭傑@郵件.商務") # False, False + False + >>> ISEMAIL(u"राम@मोहन.ईन्फो") # False, Fale + False + >>> ISEMAIL(u"юзер@екзампл.ком") # False, False + False + >>> ISEMAIL(u"θσερ@εχαμπλε.ψομ") # False, False + False + >>> ISEMAIL(u"葉士豪@臺網中心.tw") # False, False + False + >>> ISEMAIL(u"jeff@臺網中心.tw") # False, False + False + >>> ISEMAIL(u"葉士豪@臺網中心.台灣") # False, False + False + >>> ISEMAIL(u"jeff葉@臺網中心.tw") # False, False + False + >>> ISEMAIL("my.name@domain.com") # False, False + False + >>> ISEMAIL("my.name@domain.com") # False, False + False + >>> ISEMAIL("my@.leadingdot.com") # False, False + False + >>> ISEMAIL("my@..leadingfwdot.com") # False, False + False + >>> ISEMAIL("my@..twodots.com") # False, False + False + >>> ISEMAIL("my@twodots..com") # False, False + False + >>> ISEMAIL(".leadingdot@domain.com") # False, False + False + >>> ISEMAIL("..twodots@domain.com") # False, False + False + >>> ISEMAIL("twodots..here@domain.com") # False, False + False + >>> ISEMAIL("me@⒈wouldbeinvalid.com") # False, False + False + >>> ISEMAIL("Foo Bar ") # False, False + False + >>> ISEMAIL("Abc\\@def@example.com") # False, False + False + >>> ISEMAIL("foo@bar@google.com") # False, False + False + >>> ISEMAIL("john@aol...com") # False, False + False + >>> ISEMAIL("x@ทีเอชนิค.ไทย") # False, False + False + >>> ISEMAIL("asdf@mail") # False, False + False + >>> ISEMAIL("example@良好Mail.中国") # False, False + False + """ + return bool(_email_regexp.match(value)) + + +_url_regexp = re.compile( + r"""^ + ((ftp|http|https|gopher|mailto|news|telnet|aim)://)? + (\w+@)? # Allow 'user@' part, esp. useful for mailto: URLs. + ([A-Za-z0-9] # Each part of hostname must start with alphanumeric + ([A-Za-z0-9-]*[A-Za-z0-9])?\. # May have dashes inside, but end in alphanumeric + )+ + [A-Za-z]{2,6} # Restrict top-level domain to length {2,6}. Google seems + # to use a whitelist for TLDs longer than 2 characters. + ([/?][-\w!#$%&'()*+,./:;=?@~]*)?$ # Notably, this excludes <, >, and ". + """, re.VERBOSE) + + +def ISURL(value): + """ + Checks whether a value is a valid URL. It does not need to be fully qualified, or to include + "http://" and "www". It does not follow a standard, but attempts to work similarly to ISURL in + Google Sheets, and to return True for text that is likely a URL. + + Valid protocols include ftp, http, https, gopher, mailto, news, telnet, and aim. + + >>> ISURL("http://www.getgrist.com") + True + >>> ISURL("https://foo.com/test_(wikipedia)#cite-1") + True + >>> ISURL("mailto://user@example.com") + True + >>> ISURL("http:///a") + False + + More tests: + >>> ISURL("http://www.google.com") + True + >>> ISURL("www.google.com/") + True + >>> ISURL("google.com") + True + >>> ISURL("http://a.b-c.de") + True + >>> ISURL("a.b-c.de") + True + >>> ISURL("http://j.mp/---") + True + >>> ISURL("ftp://foo.bar/baz") + True + >>> ISURL("https://foo.com/blah_(wikipedia)#cite-1") + True + >>> ISURL("mailto://user@google.com") + True + >>> ISURL("http://user@www.google.com") + True + >>> ISURL("http://foo.com/!#$%25&'()*+,-./=?@_~") + True + >>> ISURL("http://../") + False + >>> ISURL("http://??/") + False + >>> ISURL("a.-b.cd") + False + >>> ISURL("http://foo.bar?q=Spaces should be encoded ") + False + >>> ISURL("//") + False + >>> ISURL("///a") + False + >>> ISURL("http:///a") + False + >>> ISURL("bar://www.google.com") + False + >>> ISURL("http:// shouldfail.com") + False + >>> ISURL("ftps://foo.bar/") + False + >>> ISURL("http://-error-.invalid/") + False + >>> ISURL("http://0.0.0.0") + False + >>> ISURL("http://.www.foo.bar/") + False + >>> ISURL("http://.www.foo.bar./") + False + >>> ISURL("example.com/file[/].html") + False + >>> ISURL("http://example.com/file[/].html") + False + >>> ISURL("http://mw1.google.com/kml-samples/gp/seattle/gigapxl/$[level]/r$[y]_c$[x].jpg") + False + >>> ISURL("http://foo.com/>") + False + """ + value = value.strip() + if ' ' in value: # Disallow spaces inside value. + return False + return bool(_url_regexp.match(value)) + + +def N(value): + """ + Returns the value converted to a number. True/False are converted to 1/0. A date is converted to + Excel-style serial number of the date. Anything else is converted to 0. + + >>> N(7) + 7 + >>> N(7.1) + 7.1 + >>> N("Even") + 0 + >>> N("7") + 0 + >>> N(True) + 1 + >>> N(datetime.datetime(2011, 4, 17)) + 40650.0 + """ + if ISNUMBER(value): + return value + if isinstance(value, datetime.date): + return date.DATE_TO_XL(value) + return 0 + + +def NA(): + """ + Returns the "value not available" error, `#N/A`. + + >>> math.isnan(NA()) + True + """ + return float('nan') + + +def TYPE(value): + """ + Returns a number associated with the type of data passed into the function. This is not + implemented in Grist. Use `isinstance(value, type)` or `type(value)`. + """ + raise NotImplementedError() + +def CELL(info_type, reference): + """ + Returns the requested information about the specified cell. This is not implemented in Grist + """ + raise NotImplementedError() + + +# Unique sentinel value to represent that a lazy value evaluates with an exception. +_error_sentinel = object() + +def lazy_value_or_error(value): + """ + Evaluates a value like lazy_value(), but returns _error_sentinel on exception. + """ + try: + return value() if callable(value) else value + except Exception: + return _error_sentinel + +def is_error(value): + """ + Checks whether a value is an invalid value or _error_sentinel. + """ + return ((value is _error_sentinel) + or isinstance(value, AltText) + or (isinstance(value, float) and math.isnan(value))) diff --git a/sandbox/grist/functions/logical.py b/sandbox/grist/functions/logical.py new file mode 100644 index 00000000..966278b2 --- /dev/null +++ b/sandbox/grist/functions/logical.py @@ -0,0 +1,165 @@ +from info import lazy_value_or_error, is_error +from usertypes import AltText # pylint: disable=unused-import,import-error + + +def AND(logical_expression, *logical_expressions): + """ + Returns True if all of the arguments are logically true, and False if any are false. + Same as `all([value1, value2, ...])`. + + >>> AND(1) + True + >>> AND(0) + False + >>> AND(1, 1) + True + >>> AND(1,2,3,4) + True + >>> AND(1,2,3,4,0) + False + """ + return all((logical_expression,) + logical_expressions) + + +def FALSE(): + """ + Returns the logical value `False`. You may also use the value `False` directly. This + function is provided primarily for compatibility with other spreadsheet programs. + + >>> FALSE() + False + """ + return False + + +def IF(logical_expression, value_if_true, value_if_false): + """ + Returns one value if a logical expression is `True` and another if it is `False`. + + The equivalent Python expression is: + ``` + value_if_true if logical_expression else value_if_false + ``` + + Since Grist supports multi-line formulas, you may also use Python blocks such as: + ``` + if logical_expression: + return value_if_true + else: + return value_if_false + ``` + + NOTE: Grist follows Excel model by only evaluating one of the value expressions, by + automatically wrapping the expressions to use lazy evaluation. This allows `IF(False, 1/0, 1)` + to evaluate to `1` rather than raise an exception. + + >>> IF(12, "Yes", "No") + 'Yes' + >>> IF(None, "Yes", "No") + 'No' + >>> IF(True, 0.85, 0.0) + 0.85 + >>> IF(False, 0.85, 0.0) + 0.0 + + More tests: + >>> IF(True, lambda: (1/0), lambda: (17)) + Traceback (most recent call last): + ... + ZeroDivisionError: integer division or modulo by zero + >>> IF(False, lambda: (1/0), lambda: (17)) + 17 + """ + return lazy_value(value_if_true) if logical_expression else lazy_value(value_if_false) + + +def IFERROR(value, value_if_error=""): + """ + Returns the first argument if it is not an error value, otherwise returns the second argument if + present, or a blank if the second argument is absent. + + NOTE: Grist handles values that raise an exception by wrapping them to use lazy evaluation. + + >>> IFERROR(float('nan'), "**NAN**") + '**NAN**' + >>> IFERROR(17.17, "**NAN**") + 17.17 + >>> IFERROR("Text") + 'Text' + >>> IFERROR(AltText("hello")) + '' + + More tests: + >>> IFERROR(lambda: (1/0.1), "X") + 10.0 + >>> IFERROR(lambda: (1/0.0), "X") + 'X' + >>> IFERROR(lambda: AltText("A"), "err") + 'err' + >>> IFERROR(lambda: None, "err") + + >>> IFERROR(lambda: foo.bar, 123) + 123 + >>> IFERROR(lambda: "test".bar(), 123) + 123 + >>> IFERROR(lambda: "test".bar()) + '' + >>> IFERROR(lambda: "test".upper(), 123) + 'TEST' + """ + value = lazy_value_or_error(value) + return value if not is_error(value) else value_if_error + + +def NOT(logical_expression): + """ + Returns the opposite of a logical value: `NOT(True)` returns `False`; `NOT(False)` returns + `True`. Same as `not logical_expression`. + + >>> NOT(123) + False + >>> NOT(0) + True + """ + return not logical_expression + + +def OR(logical_expression, *logical_expressions): + """ + Returns True if any of the arguments is logically true, and false if all of the + arguments are false. + Same as `any([value1, value2, ...])`. + + >>> OR(1) + True + >>> OR(0) + False + >>> OR(1, 1) + True + >>> OR(0, 1) + True + >>> OR(0, 0) + False + >>> OR(0,False,0.0,"",None) + False + >>> OR(0,None,3,0) + True + """ + return any((logical_expression,) + logical_expressions) + + +def TRUE(): + """ + Returns the logical value `True`. You may also use the value `True` directly. This + function is provided primarily for compatibility with other spreadsheet programs. + + >>> TRUE() + True + """ + return True + +def lazy_value(value): + """ + Evaluates a lazy value by calling it when it's a callable, or returns it unchanged otherwise. + """ + return value() if callable(value) else value diff --git a/sandbox/grist/functions/lookup.py b/sandbox/grist/functions/lookup.py new file mode 100644 index 00000000..796a48e0 --- /dev/null +++ b/sandbox/grist/functions/lookup.py @@ -0,0 +1,80 @@ +# pylint: disable=redefined-builtin, line-too-long + +def ADDRESS(row, column, absolute_relative_mode, use_a1_notation, sheet): + """Returns a cell reference as a string.""" + raise NotImplementedError() + +def CHOOSE(index, choice1, choice2): + """Returns an element from a list of choices based on index.""" + raise NotImplementedError() + +def COLUMN(cell_reference=None): + """Returns the column number of a specified cell, with `A=1`.""" + raise NotImplementedError() + +def COLUMNS(range): + """Returns the number of columns in a specified array or range.""" + raise NotImplementedError() + +def GETPIVOTDATA(value_name, any_pivot_table_cell, original_column_1, pivot_item_1=None, *args): + """Extracts an aggregated value from a pivot table that corresponds to the specified row and column headings.""" + raise NotImplementedError() + +def HLOOKUP(search_key, range, index, is_sorted): + """Horizontal lookup. Searches across the first row of a range for a key and returns the value of a specified cell in the column found.""" + raise NotImplementedError() + +def HYPERLINK(url, link_label): + """Creates a hyperlink inside a cell.""" + raise NotImplementedError() + +def INDEX(reference, row, column): + """Returns the content of a cell, specified by row and column offset.""" + raise NotImplementedError() + +def INDIRECT(cell_reference_as_string): + """Returns a cell reference specified by a string.""" + raise NotImplementedError() + +def LOOKUP(search_key, search_range_or_search_result_array, result_range=None): + """Looks through a row or column for a key and returns the value of the cell in a result range located in the same position as the search row or column.""" + raise NotImplementedError() + +def MATCH(search_key, range, search_type): + """Returns the relative position of an item in a range that matches a specified value.""" + raise NotImplementedError() + +def OFFSET(cell_reference, offset_rows, offset_columns, height, width): + """Returns a range reference shifted a specified number of rows and columns from a starting cell reference.""" + raise NotImplementedError() + +def ROW(cell_reference): + """Returns the row number of a specified cell.""" + raise NotImplementedError() + +def ROWS(range): + """Returns the number of rows in a specified array or range.""" + raise NotImplementedError() + +def VLOOKUP(table, **field_value_pairs): + """ + Vertical lookup. Searches the given table for a record matching the given `field=value` + arguments. If multiple records match, returns one of them. If none match, returns the special + empty record. + + The returned object is a record whose fields are available using `.field` syntax. For example, + `VLOOKUP(Employees, EmployeeID=$EmpID).Salary`. + + Note that `VLOOKUP` isn't commonly needed in Grist, since [Reference columns](col-refs) are the + best way to link data between tables, and allow simple efficient usage such as `$Person.Age`. + + `VLOOKUP` is exactly quivalent to `table.lookupOne(**field_value_pairs)`. See + [lookupOne](#lookupone). + + For example: + ``` + VLOOKUP(People, First_Name="Lewis", Last_Name="Carroll") + VLOOKUP(People, First_Name="Lewis", Last_Name="Carroll").Age + ``` + """ + return table.lookupOne(**field_value_pairs) diff --git a/sandbox/grist/functions/math.py b/sandbox/grist/functions/math.py new file mode 100644 index 00000000..1970ec65 --- /dev/null +++ b/sandbox/grist/functions/math.py @@ -0,0 +1,830 @@ +# pylint: disable=unused-argument + +from __future__ import absolute_import +import itertools +import math as _math +import operator +import random + +from functions.info import ISNUMBER, ISLOGICAL +import roman + +# Iterates through elements of iterable arguments, or through individual args when not iterable. +def _chain(*values_or_iterables): + for v in values_or_iterables: + try: + for x in v: + yield x + except TypeError: + yield v + + +# Iterates through iterable or other arguments, skipping non-numeric ones. +def _chain_numeric(*values_or_iterables): + for v in _chain(*values_or_iterables): + if ISNUMBER(v) and not ISLOGICAL(v): + yield v + + +# Iterates through iterable or other arguments, replacing non-numeric ones with 0 (or True with 1). +def _chain_numeric_a(*values_or_iterables): + for v in _chain(*values_or_iterables): + yield int(v) if ISLOGICAL(v) else v if ISNUMBER(v) else 0 + + +def _round_toward_zero(value): + return _math.floor(value) if value >= 0 else _math.ceil(value) + +def _round_away_from_zero(value): + return _math.ceil(value) if value >= 0 else _math.floor(value) + +def ABS(value): + """ + Returns the absolute value of a number. + + >>> ABS(2) + 2 + >>> ABS(-2) + 2 + >>> ABS(-4) + 4 + """ + return abs(value) + +def ACOS(value): + """ + Returns the inverse cosine of a value, in radians. + + >>> round(ACOS(-0.5), 9) + 2.094395102 + >>> round(ACOS(-0.5)*180/PI(), 10) + 120.0 + """ + return _math.acos(value) + +def ACOSH(value): + """ + Returns the inverse hyperbolic cosine of a number. + + >>> ACOSH(1) + 0.0 + >>> round(ACOSH(10), 7) + 2.9932228 + """ + return _math.acosh(value) + +def ARABIC(roman_numeral): + """ + Computes the value of a Roman numeral. + + >>> ARABIC("LVII") + 57 + >>> ARABIC('mcmxii') + 1912 + """ + return roman.fromRoman(roman_numeral.upper()) + +def ASIN(value): + """ + Returns the inverse sine of a value, in radians. + + >>> round(ASIN(-0.5), 9) + -0.523598776 + >>> round(ASIN(-0.5)*180/PI(), 10) + -30.0 + >>> round(DEGREES(ASIN(-0.5)), 10) + -30.0 + """ + return _math.asin(value) + +def ASINH(value): + """ + Returns the inverse hyperbolic sine of a number. + + >>> round(ASINH(-2.5), 9) + -1.647231146 + >>> round(ASINH(10), 9) + 2.99822295 + """ + return _math.asinh(value) + +def ATAN(value): + """ + Returns the inverse tangent of a value, in radians. + + >>> round(ATAN(1), 9) + 0.785398163 + >>> ATAN(1)*180/PI() + 45.0 + >>> DEGREES(ATAN(1)) + 45.0 + """ + return _math.atan(value) + +def ATAN2(x, y): + """ + Returns the angle between the x-axis and a line segment from the origin (0,0) to specified + coordinate pair (`x`,`y`), in radians. + + >>> round(ATAN2(1, 1), 9) + 0.785398163 + >>> round(ATAN2(-1, -1), 9) + -2.35619449 + >>> ATAN2(-1, -1)*180/PI() + -135.0 + >>> DEGREES(ATAN2(-1, -1)) + -135.0 + >>> round(ATAN2(1,2), 9) + 1.107148718 + """ + return _math.atan2(y, x) + +def ATANH(value): + """ + Returns the inverse hyperbolic tangent of a number. + + >>> round(ATANH(0.76159416), 9) + 1.00000001 + >>> round(ATANH(-0.1), 9) + -0.100335348 + """ + return _math.atanh(value) + +def CEILING(value, factor=1): + """ + Rounds a number up to the nearest multiple of factor, or the nearest integer if the factor is + omitted or 1. + + >>> CEILING(2.5, 1) + 3 + >>> CEILING(-2.5, -2) + -4 + >>> CEILING(-2.5, 2) + -2 + >>> CEILING(1.5, 0.1) + 1.5 + >>> CEILING(0.234, 0.01) + 0.24 + """ + return int(_math.ceil(float(value) / factor)) * factor + +def COMBIN(n, k): + """ + Returns the number of ways to choose some number of objects from a pool of a given size of + objects. + + >>> COMBIN(8,2) + 28 + >>> COMBIN(4,2) + 6 + >>> COMBIN(10,7) + 120 + """ + # From http://stackoverflow.com/a/4941932/328565 + k = min(k, n-k) + if k == 0: + return 1 + numer = reduce(operator.mul, xrange(n, n-k, -1)) + denom = reduce(operator.mul, xrange(1, k+1)) + return numer//denom + +def COS(angle): + """ + Returns the cosine of an angle provided in radians. + + >>> round(COS(1.047), 7) + 0.5001711 + >>> round(COS(60*PI()/180), 10) + 0.5 + >>> round(COS(RADIANS(60)), 10) + 0.5 + """ + return _math.cos(angle) + +def COSH(value): + """ + Returns the hyperbolic cosine of any real number. + + >>> round(COSH(4), 6) + 27.308233 + >>> round(COSH(EXP(1)), 7) + 7.6101251 + """ + return _math.cosh(value) + +def DEGREES(angle): + """ + Converts an angle value in radians to degrees. + + >>> round(DEGREES(ACOS(-0.5)), 10) + 120.0 + >>> DEGREES(PI()) + 180.0 + """ + return _math.degrees(angle) + +def EVEN(value): + """ + Rounds a number up to the nearest even integer, rounding away from zero. + + >>> EVEN(1.5) + 2 + >>> EVEN(3) + 4 + >>> EVEN(2) + 2 + >>> EVEN(-1) + -2 + """ + return int(_round_away_from_zero(float(value) / 2)) * 2 + +def EXP(exponent): + """ + Returns Euler's number, e (~2.718) raised to a power. + + >>> round(EXP(1), 8) + 2.71828183 + >>> round(EXP(2), 7) + 7.3890561 + """ + return _math.exp(exponent) + +def FACT(value): + """ + Returns the factorial of a number. + + >>> FACT(5) + 120 + >>> FACT(1.9) + 1 + >>> FACT(0) + 1 + >>> FACT(1) + 1 + >>> FACT(-1) + Traceback (most recent call last): + ... + ValueError: factorial() not defined for negative values + """ + return _math.factorial(int(value)) + +def FACTDOUBLE(value): + """ + Returns the "double factorial" of a number. + + >>> FACTDOUBLE(6) + 48 + >>> FACTDOUBLE(7) + 105 + >>> FACTDOUBLE(3) + 3 + >>> FACTDOUBLE(4) + 8 + """ + return reduce(operator.mul, xrange(value, 1, -2)) + +def FLOOR(value, factor=1): + """ + Rounds a number down to the nearest integer multiple of specified significance. + + >>> FLOOR(3.7,2) + 2 + >>> FLOOR(-2.5,-2) + -2 + >>> FLOOR(2.5,-2) + Traceback (most recent call last): + ... + ValueError: factor argument invalid + >>> FLOOR(1.58,0.1) + 1.5 + >>> FLOOR(0.234,0.01) + 0.23 + """ + if (factor < 0) != (value < 0): + raise ValueError("factor argument invalid") + return int(_math.floor(float(value) / factor)) * factor + +def _gcd(a, b): + while a != 0: + if a > b: + a, b = b, a + a, b = b % a, a + return b + +def GCD(value1, *more_values): + """ + Returns the greatest common divisor of one or more integers. + + >>> GCD(5, 2) + 1 + >>> GCD(24, 36) + 12 + >>> GCD(7, 1) + 1 + >>> GCD(5, 0) + 5 + >>> GCD(0, 5) + 5 + >>> GCD(5) + 5 + >>> GCD(14, 42, 21) + 7 + """ + values = [v for v in (value1,) + more_values if v] + if not values: + return 0 + if any(v < 0 for v in values): + raise ValueError("gcd requires non-negative values") + return reduce(_gcd, map(int, values)) + +def INT(value): + """ + Rounds a number down to the nearest integer that is less than or equal to it. + + >>> INT(8.9) + 8 + >>> INT(-8.9) + -9 + >>> 19.5-INT(19.5) + 0.5 + """ + return int(_math.floor(value)) + +def _lcm(a, b): + return a * b / _gcd(a, b) + +def LCM(value1, *more_values): + """ + Returns the least common multiple of one or more integers. + + >>> LCM(5, 2) + 10 + >>> LCM(24, 36) + 72 + >>> LCM(0, 5) + 0 + >>> LCM(5) + 5 + >>> LCM(10, 100) + 100 + >>> LCM(12, 18) + 36 + >>> LCM(12, 18, 24) + 72 + """ + values = (value1,) + more_values + if any(v < 0 for v in values): + raise ValueError("gcd requires non-negative values") + if any(v == 0 for v in values): + return 0 + return reduce(_lcm, map(int, values)) + +def LN(value): + """ + Returns the the logarithm of a number, base e (Euler's number). + + >>> round(LN(86), 7) + 4.4543473 + >>> round(LN(2.7182818), 7) + 1.0 + >>> round(LN(EXP(3)), 10) + 3.0 + """ + return _math.log(value) + +def LOG(value, base=10): + """ + Returns the the logarithm of a number given a base. + + >>> LOG(10) + 1.0 + >>> LOG(8, 2) + 3.0 + >>> round(LOG(86, 2.7182818), 7) + 4.4543473 + """ + return _math.log(value, base) + +def LOG10(value): + """ + Returns the the logarithm of a number, base 10. + + >>> round(LOG10(86), 9) + 1.934498451 + >>> LOG10(10) + 1.0 + >>> LOG10(100000) + 5.0 + >>> LOG10(10**5) + 5.0 + """ + return _math.log10(value) + +def MOD(dividend, divisor): + """ + Returns the result of the modulo operator, the remainder after a division operation. + + >>> MOD(3, 2) + 1 + >>> MOD(-3, 2) + 1 + >>> MOD(3, -2) + -1 + >>> MOD(-3, -2) + -1 + """ + return dividend % divisor + +def MROUND(value, factor): + """ + Rounds one number to the nearest integer multiple of another. + + >>> MROUND(10, 3) + 9 + >>> MROUND(-10, -3) + -9 + >>> round(MROUND(1.3, 0.2), 10) + 1.4 + >>> MROUND(5, -2) + Traceback (most recent call last): + ... + ValueError: factor argument invalid + """ + if (factor < 0) != (value < 0): + raise ValueError("factor argument invalid") + return int(_round_toward_zero(float(value) / factor + 0.5)) * factor + +def MULTINOMIAL(value1, *more_values): + """ + Returns the factorial of the sum of values divided by the product of the values' factorials. + + >>> MULTINOMIAL(2, 3, 4) + 1260 + >>> MULTINOMIAL(3) + 1 + >>> MULTINOMIAL(1,2,3) + 60 + >>> MULTINOMIAL(0,2,4,6) + 13860 + """ + s = value1 + res = 1 + for v in more_values: + s += v + res *= COMBIN(s, v) + return res + +def ODD(value): + """ + Rounds a number up to the nearest odd integer. + + >>> ODD(1.5) + 3 + >>> ODD(3) + 3 + >>> ODD(2) + 3 + >>> ODD(-1) + -1 + >>> ODD(-2) + -3 + """ + return int(_round_away_from_zero(float(value + 1) / 2)) * 2 - 1 + +def PI(): + """ + Returns the value of Pi to 14 decimal places. + + >>> round(PI(), 9) + 3.141592654 + >>> round(PI()/2, 9) + 1.570796327 + >>> round(PI()*9, 8) + 28.27433388 + """ + return _math.pi + +def POWER(base, exponent): + """ + Returns a number raised to a power. + + >>> POWER(5,2) + 25.0 + >>> round(POWER(98.6,3.2), 3) + 2401077.222 + >>> round(POWER(4,5.0/4), 9) + 5.656854249 + """ + return _math.pow(base, exponent) + + +def PRODUCT(factor1, *more_factors): + """ + Returns the result of multiplying a series of numbers together. Each argument may be a number or + an array. + + >>> PRODUCT([5,15,30]) + 2250 + >>> PRODUCT([5,15,30], 2) + 4500 + >>> PRODUCT(5,15,[30],[2]) + 4500 + + More tests: + >>> PRODUCT([2, True, None, "", False, "0", 5]) + 10 + >>> PRODUCT([2, True, None, "", False, 0, 5]) + 0 + """ + return reduce(operator.mul, _chain_numeric(factor1, *more_factors)) + +def QUOTIENT(dividend, divisor): + """ + Returns one number divided by another. + + >>> QUOTIENT(5, 2) + 2 + >>> QUOTIENT(4.5, 3.1) + 1 + >>> QUOTIENT(-10, 3) + -3 + """ + return TRUNC(float(dividend) / divisor) + +def RADIANS(angle): + """ + Converts an angle value in degrees to radians. + + >>> round(RADIANS(270), 6) + 4.712389 + """ + return _math.radians(angle) + +def RAND(): + """ + Returns a random number between 0 inclusive and 1 exclusive. + """ + return random.random() + +def RANDBETWEEN(low, high): + """ + Returns a uniformly random integer between two values, inclusive. + """ + return random.randrange(low, high + 1) + +def ROMAN(number, form_unused=None): + """ + Formats a number in Roman numerals. The second argument is ignored in this implementation. + + >>> ROMAN(499,0) + 'CDXCIX' + >>> ROMAN(499.2,0) + 'CDXCIX' + >>> ROMAN(57) + 'LVII' + >>> ROMAN(1912) + 'MCMXII' + """ + # TODO: Maybe we should support the second argument. + return roman.toRoman(int(number)) + +def ROUND(value, places=0): + """ + Rounds a number to a certain number of decimal places according to standard rules. + + >>> ROUND(2.15, 1) # Excel actually gives the more correct 2.2 + 2.1 + >>> ROUND(2.149, 1) + 2.1 + >>> ROUND(-1.475, 2) + -1.48 + >>> ROUND(21.5, -1) + 20.0 + >>> ROUND(626.3,-3) + 1000.0 + >>> ROUND(1.98,-1) + 0.0 + >>> ROUND(-50.55,-2) + -100.0 + """ + # TODO: Excel manages to round 2.15 to 2.2, but Python sees 2.149999... and rounds to 2.1 + # (see Python notes in documentation of `round()`). + return round(value, places) + +def ROUNDDOWN(value, places=0): + """ + Rounds a number to a certain number of decimal places, always rounding down towards zero. + + >>> ROUNDDOWN(3.2, 0) + 3 + >>> ROUNDDOWN(76.9,0) + 76 + >>> ROUNDDOWN(3.14159, 3) + 3.141 + >>> ROUNDDOWN(-3.14159, 1) + -3.1 + >>> ROUNDDOWN(31415.92654, -2) + 31400 + """ + factor = 10**-places + return int(_round_toward_zero(float(value) / factor)) * factor + +def ROUNDUP(value, places=0): + """ + Rounds a number to a certain number of decimal places, always rounding up away from zero. + + >>> ROUNDUP(3.2,0) + 4 + >>> ROUNDUP(76.9,0) + 77 + >>> ROUNDUP(3.14159, 3) + 3.142 + >>> ROUNDUP(-3.14159, 1) + -3.2 + >>> ROUNDUP(31415.92654, -2) + 31500 + """ + factor = 10**-places + return int(_round_away_from_zero(float(value) / factor)) * factor + +def SERIESSUM(x, n, m, a): + """ + Given parameters x, n, m, and a, returns the power series sum a_1*x^n + a_2*x^(n+m) + + ... + a_i*x^(n+(i-1)m), where i is the number of entries in range `a`. + + >>> SERIESSUM(1,0,1,1) + 1 + >>> SERIESSUM(2,1,0,[1,2,3]) + 12 + >>> SERIESSUM(-3,1,1,[2,4,6]) + -132 + >>> round(SERIESSUM(PI()/4,0,2,[1,-1./FACT(2),1./FACT(4),-1./FACT(6)]), 6) + 0.707103 + """ + return sum(coef*pow(x, n+i*m) for i, coef in enumerate(_chain(a))) + +def SIGN(value): + """ + Given an input number, returns `-1` if it is negative, `1` if positive, and `0` if it is zero. + + >>> SIGN(10) + 1 + >>> SIGN(4.0-4.0) + 0 + >>> SIGN(-0.00001) + -1 + """ + return 0 if value == 0 else int(_math.copysign(1, value)) + +def SIN(angle): + """ + Returns the sine of an angle provided in radians. + + >>> round(SIN(PI()), 10) + 0.0 + >>> SIN(PI()/2) + 1.0 + >>> round(SIN(30*PI()/180), 10) + 0.5 + >>> round(SIN(RADIANS(30)), 10) + 0.5 + """ + return _math.sin(angle) + +def SINH(value): + """ + Returns the hyperbolic sine of any real number. + + >>> round(2.868*SINH(0.0342*1.03), 7) + 0.1010491 + """ + return _math.sinh(value) + +def SQRT(value): + """ + Returns the positive square root of a positive number. + + >>> SQRT(16) + 4.0 + >>> SQRT(-16) + Traceback (most recent call last): + ... + ValueError: math domain error + >>> SQRT(ABS(-16)) + 4.0 + """ + return _math.sqrt(value) + + +def SQRTPI(value): + """ + Returns the positive square root of the product of Pi and the given positive number. + + >>> round(SQRTPI(1), 6) + 1.772454 + >>> round(SQRTPI(2), 6) + 2.506628 + """ + return _math.sqrt(_math.pi * value) + +def SUBTOTAL(function_code, range1, range2): + """ + Returns a subtotal for a vertical range of cells using a specified aggregation function. + """ + raise NotImplementedError() + + +def SUM(value1, *more_values): + """ + Returns the sum of a series of numbers. Each argument may be a number or an array. + Non-numeric values are ignored. + + >>> SUM([5,15,30]) + 50 + >>> SUM([5.,15,30], 2) + 52.0 + >>> SUM(5,15,[30],[2]) + 52 + + More tests: + >>> SUM([10.25, None, "", False, "other", 20.5]) + 30.75 + >>> SUM([True, "3", 4], True) + 6 + """ + return sum(_chain_numeric_a(value1, *more_values)) + + +def SUMIF(records, criterion, sum_range): + """ + Returns a conditional sum across a range. + """ + raise NotImplementedError() + +def SUMIFS(sum_range, criteria_range1, criterion1, *args): + """ + Returns the sum of a range depending on multiple criteria. + """ + raise NotImplementedError() + +def SUMPRODUCT(array1, *more_arrays): + """ + Multiplies corresponding components in the given arrays, and returns the sum of those products. + + >>> SUMPRODUCT([3,8,1,4,6,9], [2,6,5,7,7,3]) + 156 + >>> SUMPRODUCT([], [], []) + 0 + >>> SUMPRODUCT([-0.25], [-2], [-3]) + -1.5 + >>> SUMPRODUCT([-0.25, -0.25], [-2, -2], [-3, -3]) + -3.0 + """ + return sum(reduce(operator.mul, values) for values in itertools.izip(array1, *more_arrays)) + +def SUMSQ(value1, value2): + """ + Returns the sum of the squares of a series of numbers and/or cells. + """ + raise NotImplementedError() + +def TAN(angle): + """ + Returns the tangent of an angle provided in radians. + + >>> round(TAN(0.785), 8) + 0.99920399 + >>> round(TAN(45*PI()/180), 10) + 1.0 + >>> round(TAN(RADIANS(45)), 10) + 1.0 + """ + return _math.tan(angle) + +def TANH(value): + """ + Returns the hyperbolic tangent of any real number. + + >>> round(TANH(-2), 6) + -0.964028 + >>> TANH(0) + 0.0 + >>> round(TANH(0.5), 6) + 0.462117 + """ + return _math.tanh(value) + +def TRUNC(value, places=0): + """ + Truncates a number to a certain number of significant digits by omitting less significant + digits. + + >>> TRUNC(8.9) + 8 + >>> TRUNC(-8.9) + -8 + >>> TRUNC(0.45) + 0 + """ + # TRUNC seems indistinguishable from ROUNDDOWN. + return ROUNDDOWN(value, places) diff --git a/sandbox/grist/functions/schedule.py b/sandbox/grist/functions/schedule.py new file mode 100644 index 00000000..30aa398f --- /dev/null +++ b/sandbox/grist/functions/schedule.py @@ -0,0 +1,329 @@ +from datetime import datetime, timedelta +import re +from date import DATEADD, NOW, DTIME +from moment_parse import MONTH_NAMES, DAY_NAMES + +# Limit exports to schedule, so that upper-case constants like MONTH_NAMES, DAY_NAMES don't end up +# exposed as if Excel-style functions (or break docs generation). +__all__ = ['SCHEDULE'] + +def SCHEDULE(schedule, start=None, count=10, end=None): + """ + Returns the list of `datetime` objects generated according to the `schedule` string. Starts at + `start`, which defaults to NOW(). Generates at most `count` results (10 by default). If `end` is + given, stops there. + + The schedule has the format "INTERVAL: SLOTS, ...". For example: + + annual: Jan-15, Apr-15, Jul-15 -- Three times a year on given dates at midnight. + annual: 1/15, 4/15, 7/15 -- Same as above. + monthly: /1 2pm, /15 2pm -- The 1st and the 15th of each month, at 2pm. + 3-months: /10, +1m /20 -- Every 3 months on the 10th of month 1, 20th of month 2. + weekly: Mo 9am, Tu 9am, Fr 2pm -- Three times a week at specified times. + 2-weeks: Mo, +1w Tu -- Every 2 weeks on Monday of week 1, Tuesday of week 2. + daily: 07:30, 21:00 -- Twice a day at specified times. + 2-day: 12am, 4pm, +1d 8am -- Three times every two days, evenly spaced. + hourly: :15, :45 -- 15 minutes before and after each hour. + 4-hour: :00, 1:20, 2:40 -- Three times every 4 hours, evenly spaced. + 10-minute: +0s -- Every 10 minutes on the minute. + + INTERVAL must be either of the form `N-unit` where `N` is a number and `unit` is one of `year`, + `month`, `week`, `day`, `hour`; or one of the aliases: `annual`, `monthly`, `weekly`, `daily`, + `hourly`, which mean `1-year`, `1-month`, etc. + + SLOTS support the following units: + + `Jan-15` or `1/15` -- Month and day of the month; available when INTERVAL is year-based. + `/15` -- Day of the month, available when INTERVAL is month-based. + `Mon`, `Mo`, `Friday` -- Day of the week (or abbreviation), when INTERVAL is week-based. + 10am, 1:30pm, 15:45 -- Time of day, available for day-based or longer intervals. + :45, :00 -- Minutes of the hour, available when INTERVAL is hour-based. + +1d, +15d -- How many days to add to start of INTERVAL. + +1w -- How many weeks to add to start of INTERVAL. + +1m -- How many months to add to start of INTERVAL. + + The SLOTS are always relative to the INTERVAL rather than to `start`. Week-based intervals start + on Sunday. E.g. `weekly: +1d, +4d` is the same as `weekly: Mon, Thu`, and generates times on + Mondays and Thursdays regardless of `start`. + + The first generated time is determined by the *unit* of the INTERVAL without regard to the + multiple. E.g. both "2-week: Mon" and "3-week: Mon" start on the first Monday after `start`, and + then generate either every second or every third Monday after that. Similarly, `24-hour: :00` + starts with the first top-of-the-hour after `start` (not with midnight), and then repeats every + 24 hours. To start with the midnight after `start`, use `daily: 0:00`. + + For interval units of a day or longer, if time-of-day is not specified, it defaults to midnight. + + The time zone of `start` determines the time zone of the generated times. + + >>> def show(dates): return [d.strftime("%Y-%m-%d %H:%M") for d in dates] + >>> start = datetime(2018, 9, 4, 14, 0); # 2pm on Tue, Sep 4 2018. + + >>> show(SCHEDULE('annual: Jan-15, Apr-15, Jul-15, Oct-15', start=start, count=4)) + ['2018-10-15 00:00', '2019-01-15 00:00', '2019-04-15 00:00', '2019-07-15 00:00'] + + >>> show(SCHEDULE('annual: 1/15, 4/15, 7/15', start=start, count=4)) + ['2019-01-15 00:00', '2019-04-15 00:00', '2019-07-15 00:00', '2020-01-15 00:00'] + + >>> show(SCHEDULE('monthly: /1 2pm, /15 5pm', start=start, count=4)) + ['2018-09-15 17:00', '2018-10-01 14:00', '2018-10-15 17:00', '2018-11-01 14:00'] + + >>> show(SCHEDULE('3-months: /10, +1m /20', start=start, count=4)) + ['2018-09-10 00:00', '2018-10-20 00:00', '2018-12-10 00:00', '2019-01-20 00:00'] + + >>> show(SCHEDULE('weekly: Mo 9am, Tu 9am, Fr 2pm', start=start, count=4)) + ['2018-09-07 14:00', '2018-09-10 09:00', '2018-09-11 09:00', '2018-09-14 14:00'] + + >>> show(SCHEDULE('2-weeks: Mo, +1w Tu', start=start, count=4)) + ['2018-09-11 00:00', '2018-09-17 00:00', '2018-09-25 00:00', '2018-10-01 00:00'] + + >>> show(SCHEDULE('daily: 07:30, 21:00', start=start, count=4)) + ['2018-09-04 21:00', '2018-09-05 07:30', '2018-09-05 21:00', '2018-09-06 07:30'] + + >>> show(SCHEDULE('2-day: 12am, 4pm, +1d 8am', start=start, count=4)) + ['2018-09-04 16:00', '2018-09-05 08:00', '2018-09-06 00:00', '2018-09-06 16:00'] + + >>> show(SCHEDULE('hourly: :15, :45', start=start, count=4)) + ['2018-09-04 14:15', '2018-09-04 14:45', '2018-09-04 15:15', '2018-09-04 15:45'] + + >>> show(SCHEDULE('4-hour: :00, +1H :20, +2H :40', start=start, count=4)) + ['2018-09-04 14:00', '2018-09-04 15:20', '2018-09-04 16:40', '2018-09-04 18:00'] + """ + return Schedule(schedule).series(start or NOW(), end, count=count) + +class Delta(object): + """ + Similar to timedelta, keeps intervals by unit. Specifically, this is needed for months + and years, since those can't be represented exactly with a timedelta. + """ + def __init__(self): + self._timedelta = timedelta(0) + self._months = 0 + + def add_interval(self, number, unit): + if unit == 'months': + self._months += number + elif unit == 'years': + self._months += number * 12 + else: + self._timedelta += timedelta(**{unit: number}) + return self + + def add_to(self, dtime): + return datetime.combine(DATEADD(dtime, months=self._months), dtime.timetz()) + self._timedelta + + +class Schedule(object): + """ + Schedule parses a schedule spec into an interval and slots in the constructor. Then the series() + method applies it to any start/end dates. + """ + def __init__(self, spec_string): + parts = spec_string.split(":", 1) + if len(parts) != 2: + raise ValueError("schedule must have the form INTERVAL: SLOTS, ...") + + count, unit = _parse_interval(parts[0].strip()) + self._interval_unit = unit + self._interval = Delta().add_interval(count, unit) + self._slots = [_parse_slot(t, self._interval_unit) for t in parts[1].split(",")] + + def series(self, start_dtime, end_dtime, count=10): + # Start with a preceding unit boundary, then check the slots within that unit and start with + # the first one that's at start_dtime or later. + start_dtime = DTIME(start_dtime) + end_dtime = end_dtime and DTIME(end_dtime) + dtime = _round_down_to_unit(start_dtime, self._interval_unit) + while True: + for slot in self._slots: + if count <= 0: + return + out = slot.add_to(dtime) + if out < start_dtime: + continue + if end_dtime is not None and out > end_dtime: + return + yield out + count -= 1 + dtime = self._interval.add_to(dtime) + +def _fail(message): + raise ValueError(message) + +def _round_down_to_unit(dtime, unit): + """ + Rounds datetime down to the given unit. Weeks are rounded to start of Sunday. + """ + tz = dtime.tzinfo + return ( datetime(dtime.year, 1, 1, tzinfo=tz) if unit == 'years' + else datetime(dtime.year, dtime.month, 1, tzinfo=tz) if unit == 'months' + else (dtime - timedelta(days=dtime.isoweekday() % 7)) + .replace(hour=0, minute=0, second=0, microsecond=0) if unit == 'weeks' + else dtime.replace(hour=0, minute=0, second=0, microsecond=0) if unit == 'days' + else dtime.replace(minute=0, second=0, microsecond=0) if unit == 'hours' + else dtime.replace(second=0, microsecond=0) if unit == 'minutes' + else dtime.replace(microsecond=0) if unit == 'seconds' + else _fail("Invalid unit %s" % unit) + ) + +_UNITS = ('years', 'months', 'weeks', 'days', 'hours', 'minutes', 'seconds') +_VALID_UNITS = set(_UNITS) +_SINGULAR_UNITS = dict(zip(('year', 'month', 'week', 'day', 'hour', 'minute', 'second'), _UNITS)) +_SHORT_UNITS = dict(zip(('y', 'm', 'w', 'd', 'H', 'M', 'S'), _UNITS)) + +_INTERVAL_ALIASES = { + 'annual': (1, 'years'), + 'monthly': (1, 'months'), + 'weekly': (1, 'weeks'), + 'daily': (1, 'days'), + 'hourly': (1, 'hours'), +} + +_INTERVAL_RE = re.compile(r'^(?P\d+)[-\s]+(?P[a-z]+)$', re.I) + +# Maps weekday names, including 2- and 3-letter abbreviations, to numbers 0 through 6. +WEEKDAY_OFFSETS = {} +for (i, name) in enumerate(DAY_NAMES): + WEEKDAY_OFFSETS[name] = i + WEEKDAY_OFFSETS[name[:3]] = i + WEEKDAY_OFFSETS[name[:2]] = i + +# Maps month names, including 3-letter abbreviations, to numbers 0 through 11. +MONTH_OFFSETS = {} +for (i, name) in enumerate(MONTH_NAMES): + MONTH_OFFSETS[name] = i + MONTH_OFFSETS[name[:3]] = i + + +def _parse_interval(interval_str): + """ + Given a spec like "daily" or "3-week", returns (N, unit), such as (1, "days") or (3, "weeks"). + """ + interval_str = interval_str.lower() + if interval_str in _INTERVAL_ALIASES: + return _INTERVAL_ALIASES[interval_str] + + m = _INTERVAL_RE.match(interval_str) + if not m: + raise ValueError("Not a valid interval '%s'" % interval_str) + num = int(m.group("num")) + unit = m.group("unit") + unit = _SINGULAR_UNITS.get(unit, unit) + if unit not in _VALID_UNITS: + raise ValueError("Unknown unit '%s' in interval '%s'" % (unit, interval_str)) + return (num, unit) + + +def _parse_slot(slot_str, parent_unit): + """ + Parses a slot in one of several recognized formats. Allowed formats depend on parent_unit, e.g. + 'Jan-15' is valid when parent_unit is 'years', but not when it is 'hours'. We also disallow + using the same unit more than once, which is confusing, e.g. "+1d +2d" or "9:30am +2H". + Returns a Delta object. + """ + parts = slot_str.split() + if not parts: + raise ValueError("At least one slot must be specified") + + delta = Delta() + seen_units = set() + allowed_slot_types = _ALLOWED_SLOTS_BY_UNIT.get(parent_unit) or ('delta',) + + # Slot parts go through parts like "Jan-15 16pm", collecting the offsets into a single Delta. + for part in parts: + m = _SLOT_RE.match(part) + if not m: + raise ValueError("Invalid slot '%s'" % part) + for slot_type in allowed_slot_types: + if m.group(slot_type): + # If there is a group for one slot type, that's the only group. We find and use the + # corresponding parser, then move on to the next slot part. + for count, unit in _SLOT_PARSERS[slot_type](m): + delta.add_interval(count, unit) + if unit in seen_units: + raise ValueError("Duplicate unit %s in '%s'" % (unit, slot_str)) + seen_units.add(unit) + break + else: + # If none of the allowed slot types was found, it must be a disallowed one. + raise ValueError("Invalid slot '%s' for unit '%s'" % (part, parent_unit)) + return delta + +# We parse all slot types using one big regex. The constants below define one part of the regex +# for each slot type (e.g. to match "Jan-15" or "5:30am" or "+1d"). Note that all group names +# (defined with (?P...)) must be distinct. +_DATE_RE = r'(?:(?P[a-z]+)-|(?P\d+)/)(?P\d+)' +_MDAY_RE = r'/(?P\d+)' +_WDAY_RE = r'(?P[a-z]+)' +_TIME_RE = r'(?P\d+)(?:\:(?P\d{2})(?Pam|pm)?|(?Pam|pm))' +_MINS_RE = r':(?P\d{2})' +_DELTA_RE = r'\+(?P\d+)(?P[a-z]+)' + +# The regex parts are combined and compiled here. Only one group will match, corresponding to one +# slot type. Different slot types depend on the unit of the overall interval. +_SLOT_RE = re.compile( + r'^(?:(?P%s)|(?P%s)|(?P%s)|(?P