(core) move data engine code to core

Summary:
this moves sandbox/grist to core, and adds a requirements.txt
file for reconstructing the content of sandbox/thirdparty.

Test Plan:
existing tests pass.
Tested core functionality manually.  Tested docker build manually.

Reviewers: dsagal

Reviewed By: dsagal

Differential Revision: https://phab.getgrist.com/D2563
pull/4/head
Paul Fitzpatrick 4 years ago
parent 2399baaca2
commit b82eec714a

@ -16,6 +16,7 @@ type SandboxMethod = (...args: any[]) => any;
export interface ISandboxCommand {
process: string;
libraryPath: string;
}
export interface ISandboxOptions {
@ -59,7 +60,7 @@ export class NSandbox implements ISandbox {
if (command) {
return spawn(command.process, pythonArgs,
{env: {PYTHONPATH: 'grist:thirdparty'},
{env: {PYTHONPATH: command.libraryPath},
cwd: path.join(process.cwd(), 'sandbox'), ...spawnOptions});
}
@ -319,7 +320,13 @@ export class NSandboxCreator implements ISandboxCreator {
}
public create(options: ISandboxCreationOptions): ISandbox {
// Main script to run.
const defaultEntryPoint = this._flavor === 'pynbox' ? 'grist/main.pyc' : 'grist/main.py';
// Python library path is only configurable when flavor is unsandboxed.
// In this case, expect to find library files in a virtualenv built by core
// buildtools/prepare_python.sh
const pythonVersion = 'python2.7';
const libraryPath = `grist:../venv/lib/${pythonVersion}/site-packages`;
const args = [options.entryPoint || defaultEntryPoint];
if (!options.entryPoint && options.comment) {
// When using default entry point, we can add on a comment as an argument - it isn't
@ -332,6 +339,7 @@ export class NSandboxCreator implements ISandboxCreator {
selLdrArgs.push(
// TODO: Only modules that we share with plugins should be mounted. They could be gathered in
// a "$APPROOT/sandbox/plugin" folder, only which get mounted.
// TODO: These settings only make sense for pynbox flavor.
'-E', 'PYTHONPATH=grist:thirdparty',
'-m', `${options.sandboxMount}:/sandbox:ro`);
}
@ -346,7 +354,8 @@ export class NSandboxCreator implements ISandboxCreator {
selLdrArgs,
...(this._flavor === 'pynbox' ? {} : {
command: {
process: "python2.7"
process: pythonVersion,
libraryPath
}
})
});

@ -0,0 +1,8 @@
#!/bin/bash
if [ ! -e venv ]; then
virtualenv -ppython2.7 venv
fi
. venv/bin/activate
pip install --no-deps -r sandbox/requirements.txt

@ -5,6 +5,7 @@
"main": "index.js",
"scripts": {
"start": "tsc --build -w --preserveWatchOutput & webpack --config buildtools/webpack.config.js --mode development --watch --hide-modules & NODE_PATH=_build:_build/stubs nodemon -w _build/app/server -w _build/app/common _build/stubs/app/server/server.js & wait",
"install:python": "buildtools/prepare_python.sh",
"build:prod": "tsc --build && webpack --config buildtools/webpack.config.js --mode production",
"start:prod": "node _build/stubs/app/server/server"
},

@ -0,0 +1,52 @@
#!/usr/bin/env python -B
"""
Generates a JS schema file from sandbox/grist/schema.py.
"""
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), 'grist'))
import schema # pylint: disable=import-error,wrong-import-position
# These are the types that appear in Grist metadata columns.
_ts_types = {
"Bool": "boolean",
"DateTime": "number",
"Int": "number",
"PositionNumber": "number",
"Ref": "number",
"Text": "string",
}
def get_ts_type(col_type):
col_type = col_type.split(':', 1)[0] # Strip suffix for Ref:, DateTime:, etc.
return _ts_types.get(col_type, "CellValue")
def main():
print """
/*** THIS FILE IS AUTO-GENERATED BY %s ***/
// tslint:disable:object-literal-key-quotes
export const schema = {
""" % __file__
for table in schema.schema_create_actions():
print ' "%s": {' % table.table_id
for column in table.columns:
print ' %-20s: "%s",' % (column['id'], column['type'])
print ' },\n'
print """};
export interface SchemaTypes {
"""
for table in schema.schema_create_actions():
print ' "%s": {' % table.table_id
for column in table.columns:
print ' %s: %s;' % (column['id'], get_ts_type(column['type']))
print ' };\n'
print "}"
if __name__ == '__main__':
main()

@ -0,0 +1,404 @@
# Access Control Lists.
#
# This modules is used by engine.py to split actions according to recipient, as well as to
# validate whether an action received from a peer is allowed by the rules.
# Where are ACLs applied?
# -----------------------
# Read ACLs (which control who can see data) are implemented by "acl_read_split" operation, which
# takes an action group and returns an action bundle, which is a list of ActionEnvelopes, each
# containing a smaller action group associated with a set of recipients who should get it. Note
# that the order of ActionEnvelopes matters, and actions should be applied in that order.
#
# In principle, this operation can be done either in the Python data engine or in Node. We do it
# in Python. The clearest reason is the need to apply ACL formulas. Not that it's impossible to do
# on the Node side, but currently formula values are only maintained on the Python side.
# UserActions and ACLs
# --------------------
# Each actions starts with a UserAction, which is turned by the data engine into a number of
# DocActions. We then split DocActions by recipient according to ACL rules. But should recipients
# receive UserActions too?
#
# If UserAction is shared, we need to split it similarly to docactions, because it will often
# contain data that some recipients should not see (e.g. a BulkUpdateRecord user-action generated
# by a copy-paste). An additional difficulty is that splitting by recipient may sometimes require
# creating multiple actions. Further, trimmed UserActions aren't enough for purposes (2) or (3).
#
# Our solution will be not to send around UserActions at all, since DocActions are sufficient to
# update a document. But UserActions are needed for some things:
# (1) Present a meaningful description to users for the action log. This should be possible, and
# may in fact be better, to do from docactions only.
# (2) Redo actions. We currently use UserActions for this, but we can treat this as "undo of the
# undo", relying on docactions, which is in fact more general. Any difficulties with that
# are the same as for Undo, and are not specific to Redo anyway.
# (3) Rebase pending actions after getting peers' actions from the hub. This only needs to be
# done by the authoring instance, which will keep its original UserAction. We don't need to
# share the UserAction for this purpose.
# Initial state
# -------------
# With sharing enabled, the ACL rules have particular defaults (in particular, with the current
# user included in the Owners group, and that group having full access in the default rule).
# Before sharing is enabled, this cannot be completely set up, because the current user is
# unknown, nor are the user's instances, and the initial instance may not even have an instanceId.
#
# Our approach is that default rules and groups are created immediately, before sharing is
# enabled. The Owners group stays empty, and action bundles end up destined for an empty list of
# recipients. Node handles empty list of recipients as its own instanceId when sharing is off.
#
# When sharing is enabled, actions are sent to add a user with the user's instances, including the
# current instance's real instanceId, to the Owners group, and Node stops handling empty list of
# recipients as special, relying on the presence of the actual instanceId instead.
# Identifying tables and columns
# ------------------------------
# If we use tableId and colId in rules, then rules need to be adjusted when tables or columns are
# renamed or removed. If we use tableRef and colRef, then such rules cannot apply to metadata
# tables (which don't have refs at all).
#
# Additionally, a benefit of using tableId and colId is that this is how actions identify tables
# and columns, so rules can be applied to actions without additional lookups.
#
# For these reasons, we use tableId and colId, rather than refs (row-ids).
# Summary tables
# --------------
# It's not sufficient to identify summary tables by their actual tableId or by tableRef, since
# both may change when summaries are removed and recreated. They should instead be identified
# by a value similar to tableTitle() in DocModel.js, specifically by the combination of source
# tableId and colIds of all group-by columns.
# Actions on Permission Changes
# -----------------------------
# When VIEW principals are added or removed for a table/column, or a VIEW ACLFormula adds or
# removes principals for a row, those principals need to receive AddRecord or RemoveRecord
# doc-actions (or equivalent). TODO: This needs to be handled.
from collections import OrderedDict
import action_obj
import logger
log = logger.Logger(__name__, logger.INFO)
class Permissions(object):
# Permission types and their combination are represented as bits of a single integer.
VIEW = 0x1
UPDATE = 0x2
ADD = 0x4
REMOVE = 0x8
SCHEMA_EDIT = 0x10
ACL_EDIT = 0x20
EDITOR = VIEW | UPDATE | ADD | REMOVE
ADMIN = EDITOR | SCHEMA_EDIT
OWNER = ADMIN | ACL_EDIT
@classmethod
def includes(cls, superset, subset):
return (superset & subset) == subset
@classmethod
def includes_view(cls, permissions):
return cls.includes(permissions, cls.VIEW)
# Sentinel object to represent "all rows", for internal use in this file.
_ALL_ROWS = "ALL_ROWS"
# Representation of ACL resources that's used by the ACL class. An instance of this class becomes
# the value of DocInfo.acl_resources formula.
# Note that the default ruleset (for tableId None, colId None) must exist.
# TODO: ensure that the default ruleset is created, and cannot be deleted.
class ResourceMap(object):
def __init__(self, resource_records):
self._col_resources = {} # Maps table_id to [(resource, col_id_set), ...]
self._default_resources = {} # Maps table_id (or None for global default) to resource record.
for resource in resource_records:
# Note that resource.tableId is the empty string ('') for the default table (represented as
# None in self._default_resources), and resource.colIds is '' for the table's default rule.
table_id = resource.tableId or None
if not resource.colIds:
self._default_resources[table_id] = resource
else:
col_id_set = set(resource.colIds.split(','))
self._col_resources.setdefault(table_id, []).append((resource, col_id_set))
def get_col_resources(self, table_id):
"""
Returns a list of (resource, col_id_set) pairs, where resource is a record in ACLResources.
"""
return self._col_resources.get(table_id, [])
def get_default_resource(self, table_id):
"""
Returns the "default" resource record for the given table.
"""
return self._default_resources.get(table_id) or self._default_resources.get(None)
# Used by docmodel.py for DocInfo.acl_resources formula.
def build_resources(resource_records):
return ResourceMap(resource_records)
class ACL(object):
# Special recipients, or instanceIds. ALL is the special recipient for schema actions that
# should be shared with all collaborators of the document.
ALL = '#ALL'
ALL_SET = frozenset([ALL])
EMPTY_SET = frozenset([])
def __init__(self, docmodel):
self._docmodel = docmodel
def get_acl_resources(self):
try:
return self._docmodel.doc_info.table.get_record(1).acl_resources
except KeyError:
return None
def _find_resources(self, table_id, col_ids):
"""
Yields tuples (resource, col_id_set) where each col_id_set represents the intersection of the
resouces's columns with col_ids. These intersections may be empty.
If col_ids is None, then it's treated as "all columns", and each col_id_set represents all of
the resource's columns. For the default resource then, it yields (resource, None)
"""
resource_map = self.get_acl_resources()
if col_ids is None:
for resource, col_id_set in resource_map.get_col_resources(table_id):
yield resource, col_id_set
resource = resource_map.get_default_resource(table_id)
yield resource, None
else:
seen = set()
for resource, col_id_set in resource_map.get_col_resources(table_id):
seen.update(col_id_set)
yield resource, col_id_set.intersection(col_ids)
resource = resource_map.get_default_resource(table_id)
yield resource, set(c for c in col_ids if c not in seen)
@classmethod
def _acl_read_split_rows(cls, resource, row_ids):
"""
Scans through ACL rules for the resouce, yielding tuples of the form (rule, row_id,
instances), to say which rowId should be sent to each set of instances according to the rule.
"""
for rule in resource.ruleset:
if not Permissions.includes_view(rule.permissions):
continue
common_instances = _get_instances(rule.principalsList)
if rule.aclColumn and row_ids is not None:
for (row_id, principals) in get_row_principals(rule.aclColumn, row_ids):
yield (rule, row_id, common_instances | _get_instances(principals))
else:
yield (rule, _ALL_ROWS, common_instances)
@classmethod
def _acl_read_split_instance_sets(cls, resource, row_ids):
"""
Yields tuples of the form (instances, rowset, rules) for different sets of instances, to say
which rowIds for the given resource should be sent to each set of instances, and which rules
enabled that. When a set of instances should get all rows, rowset is None.
"""
for instances, group in _group((instances, (rule, row_id)) for (rule, row_id, instances)
in cls._acl_read_split_rows(resource, row_ids)):
rules = set(item[0] for item in group)
rowset = frozenset(item[1] for item in group)
yield (instances, _ALL_ROWS if _ALL_ROWS in rowset else rowset, rules)
@classmethod
def _acl_read_split_resource(cls, resource, row_ids, docaction, output):
"""
Given an ACLResource record and optionally row_ids (which may be None), appends to output
tuples of the form `(instances, rules, action)`, where `action` is docaction itself or a part
of it that should be sent to the corresponding set of instances.
"""
if docaction is None:
return
# Different rules may produce different recipients for the same set of rows. We group outputs
# by sets of rows (which determine a subaction), and take a union of all the recipients.
for rowset, group in _group((rowset, (instances, rules)) for (instances, rowset, rules)
in cls._acl_read_split_instance_sets(resource, row_ids)):
da = docaction if rowset is _ALL_ROWS else _subaction(docaction, row_ids=rowset)
if da is not None:
all_instances = frozenset(i for item in group for i in item[0])
all_rules = set(r for item in group for r in item[1])
output.append((all_instances, all_rules, da))
def _acl_read_split_docaction(self, docaction, output):
"""
Given just a docaction, appends to output tuples of the form `(instances, rules, action)`,
where `action` is docaction itself or a part of it that should be sent to `instances`, and
`rules` is the set of ACLRules that allowed that (empty set for schema actions).
"""
parts = _get_docaction_parts(docaction)
if parts is None: # This is a schema action, to send to everyone.
# We want to send schema actions to everyone on the document, represented by None.
output.append((ACL.ALL_SET, set(), docaction))
return
table_id, row_ids, col_ids = parts
for resource, col_id_set in self._find_resources(table_id, col_ids):
da = _subaction(docaction, col_ids=col_id_set)
if da is not None:
self._acl_read_split_resource(resource, row_ids, da, output)
def _acl_read_split_docactions(self, docactions):
"""
Returns a list of tuples `(instances, rules, action)`. See _acl_read_split_docaction.
"""
if not self.get_acl_resources():
return [(ACL.EMPTY_SET, None, da) for da in docactions]
output = []
for da in docactions:
self._acl_read_split_docaction(da, output)
return output
def acl_read_split(self, action_group):
"""
Returns an ActionBundle, containing actions from the given action_group, split by the sets of
instances to which actions should be sent.
"""
bundle = action_obj.ActionBundle()
envelopeIndices = {} # Maps instance-sets to envelope indices.
def getEnvIndex(instances):
envIndex = envelopeIndices.setdefault(instances, len(bundle.envelopes))
if envIndex == len(bundle.envelopes):
bundle.envelopes.append(action_obj.Envelope(instances))
return envIndex
def split_into_envelopes(docactions, out_rules, output):
for (instances, rules, action) in self._acl_read_split_docactions(docactions):
output.append((getEnvIndex(instances), action))
if rules:
out_rules.update(r.id for r in rules)
split_into_envelopes(action_group.stored, bundle.rules, bundle.stored)
split_into_envelopes(action_group.calc, bundle.rules, bundle.calc)
split_into_envelopes(action_group.undo, bundle.rules, bundle.undo)
bundle.retValues = action_group.retValues
return bundle
class OrderedDefaultListDict(OrderedDict):
def __missing__(self, key):
self[key] = value = []
return value
def _group(iterable_of_pairs):
"""
Group iterable of pairs (a, b), returning pairs (a, [list of b]). The order of the groups, and
of items within a group, is according to the first seen.
"""
groups = OrderedDefaultListDict()
for key, value in iterable_of_pairs:
groups[key].append(value)
return groups.iteritems()
def _get_instances(principals):
"""
Returns a frozenset of all instances for all passed-in principals.
"""
instances = set()
for p in principals:
instances.update(i.instanceId for i in p.allInstances)
return frozenset(instances)
def get_row_principals(_acl_column, _rows):
# TODO TBD. Need to implement this (with tests) for acl-formulas for row-level access control.
return []
#----------------------------------------------------------------------
def _get_docaction_parts(docaction):
"""
Returns a tuple of (table_id, row_ids, col_ids), any of whose members may be None, or None if
this action should not get split.
"""
return _docaction_part_helpers[docaction.__class__.__name__](docaction)
# Helpers for _get_docaction_parts to extract for each action type the table, rows, and columns
# that a docaction of that type affects. Note that we are only talking here about the data
# affected. Schema actions do not get trimmed, since we decided against having a separate
# (and confusing) "SCHEMA_VIEW" permission. All peers will know the schema.
_docaction_part_helpers = {
'AddRecord' : lambda a: (a.table_id, [a.row_id], a.columns.keys()),
'BulkAddRecord' : lambda a: (a.table_id, a.row_ids, a.columns.keys()),
'RemoveRecord' : lambda a: (a.table_id, [a.row_id], None),
'BulkRemoveRecord' : lambda a: (a.table_id, a.row_ids, None),
'UpdateRecord' : lambda a: (a.table_id, [a.row_id], a.columns.keys()),
'BulkUpdateRecord' : lambda a: (a.table_id, a.row_ids, a.columns.keys()),
'ReplaceTableData' : lambda a: (a.table_id, a.row_ids, a.columns.keys()),
'AddColumn' : lambda a: None,
'RemoveColumn' : lambda a: None,
'RenameColumn' : lambda a: None,
'ModifyColumn' : lambda a: None,
'AddTable' : lambda a: None,
'RemoveTable' : lambda a: None,
'RenameTable' : lambda a: None,
}
#----------------------------------------------------------------------
def _subaction(docaction, row_ids=None, col_ids=None):
"""
For data actions, extracts and returns a part of docaction that applies only to the given
row_ids and/or col_ids, if given. If the part of the action is empty, returns None.
"""
helper = _subaction_helpers[docaction.__class__.__name__]
try:
return docaction.__class__._make(helper(docaction, row_ids, col_ids))
except _NoMatch:
return None
# Helpers for _subaction(), one for each action type, which return the tuple of values for the
# trimmed action. From this tuple a new action is automatically created by _subaction. If any part
# of the action becomes empty, the helpers raise _NoMatch exception.
_subaction_helpers = {
# pylint: disable=line-too-long
'AddRecord' : lambda a, r, c: (a.table_id, match(r, a.row_id), match_keys_keep_empty(c, a.columns)),
'BulkAddRecord' : lambda a, r, c: (a.table_id, match_list(r, a.row_ids), match_keys_keep_empty(c, a.columns)),
'RemoveRecord' : lambda a, r, c: (a.table_id, match(r, a.row_id)),
'BulkRemoveRecord' : lambda a, r, c: (a.table_id, match_list(r, a.row_ids)),
'UpdateRecord' : lambda a, r, c: (a.table_id, match(r, a.row_id), match_keys_skip_empty(c, a.columns)),
'BulkUpdateRecord' : lambda a, r, c: (a.table_id, match_list(r, a.row_ids), match_keys_skip_empty(c, a.columns)),
'ReplaceTableData' : lambda a, r, c: (a.table_id, match_list(r, a.row_ids), match_keys_keep_empty(c, a.columns)),
'AddColumn' : lambda a, r, c: a,
'RemoveColumn' : lambda a, r, c: a,
'RenameColumn' : lambda a, r, c: a,
'ModifyColumn' : lambda a, r, c: a,
'AddTable' : lambda a, r, c: a,
'RemoveTable' : lambda a, r, c: a,
'RenameTable' : lambda a, r, c: a,
}
def match(subset, item):
return item if (subset is None or item in subset) else no_match()
def match_list(subset, items):
return items if subset is None else ([i for i in items if i in subset] or no_match())
def match_keys_keep_empty(subset, items):
return items if subset is None else (
{k: v for (k, v) in items.iteritems() if k in subset})
def match_keys_skip_empty(subset, items):
return items if subset is None else (
{k: v for (k, v) in items.iteritems() if k in subset} or no_match())
class _NoMatch(Exception):
pass
def no_match():
raise _NoMatch()

@ -0,0 +1,79 @@
"""
This module defines ActionGroup, ActionEnvelope, and ActionBundle -- classes that together
represent the result of applying a UserAction to a document.
In general, UserActions refer to logical actions performed by the user. DocActions are the
individual steps to which UserActions translate.
A list of UserActions applied together translates to multiple DocActions, packaged into an
ActionGroup. In a separate step, this ActionGroup is split up according to ACL rules into and
ActionBundle consisting of ActionEnvelopes, each containing a smaller set of actions associated
with the set of recipients who should receive them.
"""
import actions
class ActionGroup(object):
"""
ActionGroup packages different types of doc actions for returning them to the instance.
The ActionGroup stores actions produced by the engine in the course of processing one or more
UserActions, plus an array of return values, one for each UserAction.
"""
def __init__(self):
self.calc = []
self.stored = []
self.undo = []
self.retValues = []
def get_repr(self):
return {
"calc": map(actions.get_action_repr, self.calc),
"stored": map(actions.get_action_repr, self.stored),
"undo": map(actions.get_action_repr, self.undo),
"retValues": self.retValues
}
@classmethod
def from_json_obj(cls, data):
ag = ActionGroup()
ag.calc = map(actions.action_from_repr, data.get('calc', []))
ag.stored = map(actions.action_from_repr, data.get('stored', []))
ag.undo = map(actions.action_from_repr, data.get('undo', []))
ag.retValues = data.get('retValues', [])
return ag
class Envelope(object):
"""
Envelope contains information about recipients as a set (or frozenset) of instanceIds.
"""
def __init__(self, recipient_set):
self.recipients = recipient_set
def to_json_obj(self):
return {"recipients": sorted(self.recipients)}
class ActionBundle(object):
"""
ActionBundle contains actions arranged into envelopes, i.e. split up by sets of recipients.
Note that different Envelopes contain different sets of recipients (which may overlap however).
"""
def __init__(self):
self.envelopes = []
self.stored = [] # Pairs of (envIndex, docAction)
self.calc = [] # Pairs of (envIndex, docAction)
self.undo = [] # Pairs of (envIndex, docAction)
self.retValues = []
self.rules = set() # RowIds of ACLRule records used to construct this ActionBundle.
def to_json_obj(self):
return {
"envelopes": [e.to_json_obj() for e in self.envelopes],
"stored": [(env, actions.get_action_repr(a)) for (env, a) in self.stored],
"calc": [(env, actions.get_action_repr(a)) for (env, a) in self.calc],
"undo": [(env, actions.get_action_repr(a)) for (env, a) in self.undo],
"retValues": self.retValues,
"rules": sorted(self.rules)
}

@ -0,0 +1,204 @@
"""
actions.py defines the action objects used in the Python code, and functions to convert between
them and the serializable docActions objects used to communicate with the outside.
When communicating with Node, docActions are represented as arrays [actionName, arguments...].
"""
from collections import namedtuple
import inspect
import objtypes
def _eq_with_type(self, other):
# pylint: disable=unidiomatic-typecheck
return tuple(self) == tuple(other) and type(self) == type(other)
def _ne_with_type(self, other):
return not _eq_with_type(self, other)
def namedtuple_eq(typename, field_names):
"""
Just like namedtuple, but these objects are only considered equal to other objects of the same
type (not just to any tuple with the same values).
"""
n = namedtuple(typename, field_names)
n.__eq__ = _eq_with_type
n.__ne__ = _ne_with_type
return n
# For Record actions, the parameters are as follows:
# table_id: string name of the table.
# row_id: numeric row identifier
# row_ids: list of row identifiers
# columns: dictionary mapping col_id (string name of column) to the value for the given
# row_id, or an array of values parallel to the array of row_ids.
AddRecord = namedtuple_eq('AddRecord', ('table_id', 'row_id', 'columns'))
BulkAddRecord = namedtuple_eq('BulkAddRecord', ('table_id', 'row_ids', 'columns'))
RemoveRecord = namedtuple_eq('RemoveRecord', ('table_id', 'row_id'))
BulkRemoveRecord = namedtuple_eq('BulkRemoveRecord', ('table_id', 'row_ids'))
UpdateRecord = namedtuple_eq('UpdateRecord', ('table_id', 'row_id', 'columns'))
BulkUpdateRecord = namedtuple_eq('BulkUpdateRecord', ('table_id', 'row_ids', 'columns'))
# Identical to BulkAddRecord, but implies emptying out the table first.
ReplaceTableData = namedtuple_eq('ReplaceTableData', BulkAddRecord._fields)
# For Column actions, the parameters are:
# table_id: string name of the table.
# col_id: string name of column
# col_info: dictionary with particular keys
# type: string type of the column
# isFormula: bool, whether it is a formula column
# formula: string text of the formula, or empty string
# Other keys may be set in col_info (e.g. widgetOptions, label) but are not currently used in
# the schema (only such values from the metadata tables is used).
AddColumn = namedtuple_eq('AddColumn', ('table_id', 'col_id', 'col_info'))
RemoveColumn = namedtuple_eq('RemoveColumn', ('table_id', 'col_id'))
RenameColumn = namedtuple_eq('RenameColumn', ('table_id', 'old_col_id', 'new_col_id'))
ModifyColumn = namedtuple_eq('ModifyColumn', ('table_id', 'col_id', 'col_info'))
# For Table actions, the parameters are:
# table_id: string name of the table.
# columns: array of col_info objects, as described for Column actions above, containing also:
# id: string name of the column (aka col_id in Column actions)
AddTable = namedtuple_eq('AddTable', ('table_id', 'columns'))
RemoveTable = namedtuple_eq('RemoveTable', ('table_id',))
RenameTable = namedtuple_eq('RenameTable', ('old_table_id', 'new_table_id'))
# Identical to BulkAddRecord, just a clearer type name for loading or fetching data.
TableData = namedtuple_eq('TableData', BulkAddRecord._fields)
action_types = dict((key, val) for (key, val) in globals().items()
if inspect.isclass(val) and issubclass(val, tuple))
# This is the set of names of all the actions that affect the schema.
schema_actions = {name for name in action_types
if name.endswith("Column") or name.endswith("Table")}
def _add_simplify(SingleActionType, BulkActionType):
"""
Add .simplify method to "Bulk" actions, which returns None for no rows, non-Bulk version for a
single row, and the original action otherwise.
"""
if len(SingleActionType._fields) < 3:
def get_first(self):
return SingleActionType(self.table_id, self.row_ids[0])
else:
def get_first(self):
return SingleActionType(self.table_id, self.row_ids[0],
{ key: col[0] for key, col in self.columns.iteritems()})
def simplify(self):
return None if not self.row_ids else (get_first(self) if len(self.row_ids) == 1 else self)
BulkActionType.simplify = simplify
_add_simplify(AddRecord, BulkAddRecord)
_add_simplify(RemoveRecord, BulkRemoveRecord)
_add_simplify(UpdateRecord, BulkUpdateRecord)
def get_action_repr(action_obj):
"""
Converts an action object, such as UpdateRecord into a docAction array.
"""
return [action_obj.__class__.__name__] + list(encode_objects(action_obj))
def action_from_repr(doc_action):
"""
Converts a docAction array into an object such as UpdateRecord.
"""
action_type = action_types.get(doc_action[0])
if not action_type:
raise ValueError('Unknown action %s' % (doc_action[0],))
try:
return decode_objects(action_type(*doc_action[1:]))
except TypeError as e:
raise TypeError("%s: %s" % (doc_action[0], e.message))
def convert_recursive_helper(converter, data):
"""
Given JSON-like data (a nested collection of lists or arrays), which may include Action tuples,
replaces all primitive values with converter(value). It should be used as follows:
def my_convert(data):
if data needs conversion:
return converted_value
return convert_recursive_helper(my_convert, data)
"""
if isinstance(data, dict):
return {converter(k): converter(v) for k, v in data.iteritems()}
elif isinstance(data, list):
return [converter(el) for el in data]
elif isinstance(data, tuple):
return type(data)(*[converter(el) for el in data])
else:
return data
def convert_action_values(converter, action):
"""
Replaces all data values in an action that includes actual data with converter(value).
"""
if isinstance(action, (AddRecord, UpdateRecord)):
return type(action)(action.table_id, action.row_id,
{k: converter(v) for k, v in action.columns.iteritems()})
if isinstance(action, (BulkAddRecord, BulkUpdateRecord, ReplaceTableData, TableData)):
return type(action)(action.table_id, action.row_ids,
{k: map(converter, v) for k, v in action.columns.iteritems()})
return action
def convert_recursive_in_action(converter, data):
"""
Like convert_recursive_helper, but only values of Grist cells (i.e. individual values in data
columns) get passed through converter.
"""
def inner(data):
if isinstance(data, tuple):
return convert_action_values(converter, data)
return convert_recursive_helper(inner, data)
return inner(data)
def encode_objects(data):
return convert_recursive_in_action(objtypes.encode_object, data)
def decode_objects(data, decoder=objtypes.decode_object):
"""
Decode objects in values of a DocAction or a data structure containing DocActions.
"""
return convert_recursive_in_action(decoder, data)
def decode_bulk_values(bulk_values, decoder=objtypes.decode_object):
"""
Decode objects in values of the form {col_id: array_of_values}, as present in bulk DocActions
and UserActions.
"""
return {k: map(decoder, v) for (k, v) in bulk_values.iteritems()}
def transpose_bulk_action(bulk_action):
"""
Generates namedtuples for records in a bulk action such as BulkAddRecord. Such actions store
values by columns, so in effect this transposes them, yielding them by rows.
"""
items = sorted(bulk_action.columns.items())
RecordType = namedtuple('Record', ['id'] + [col_id for (col_id, values) in items])
for row in zip(bulk_action.row_ids, *[values for (col_id, values) in items]):
yield RecordType(*row)
def prune_actions(action_list, table_id, col_id):
"""
Modifies action_list in-place, removing any mention of (table_id, col_id). Both must be given
and not None in this implementation.
"""
keep = []
for a in action_list:
if getattr(a, 'table_id', None) == table_id:
if hasattr(a, 'columns'):
a.columns.pop(col_id, None)
if not a.columns:
continue
if getattr(a, 'col_id', None) == col_id:
continue
keep.append(a)
action_list[:] = keep

@ -0,0 +1,323 @@
import ast
import contextlib
import re
import six
import astroid
import asttokens
import textbuilder
import logger
log = logger.Logger(__name__, logger.INFO)
DOLLAR_REGEX = re.compile(r'\$(?=[a-zA-Z_][a-zA-Z_0-9]*)')
# For functions needing lazy evaluation, the slice for which arguments to wrap in a lambda.
LAZY_ARG_FUNCTIONS = {
'IF': slice(1, 3),
'ISERR': slice(0, 1),
'ISERROR': slice(0, 1),
'IFERROR': slice(0, 1),
}
def make_formula_body(formula, default_value, assoc_value=None):
"""
Given a formula, returns a textbuilder.Builder object suitable to be the body of a function,
with the formula transformed to replace `$foo` with `rec.foo`, and to insert `return` if
appropriate. Assoc_value is associated with textbuilder.Text() to be returned by map_back_patch.
"""
if isinstance(formula, six.binary_type):
formula = formula.decode('utf8')
if not formula.strip():
return textbuilder.Text('return ' + repr(default_value), assoc_value)
formula_builder_text = textbuilder.Text(formula, assoc_value)
# Start with a temporary builder, since we need to translate "$" before we can parse the code at
# all (namely, we turn '$foo' into 'DOLLARfoo' first). Once we can parse the code, we'll create
# a proper set of patches. Note that we initially translate into 'DOLLARfoo' rather than
# 'rec.foo', so that the translated entity is a single token: this makes for more precisely
# reported errors if there are any.
tmp_patches = textbuilder.make_regexp_patches(formula, DOLLAR_REGEX, 'DOLLAR')
tmp_formula = textbuilder.Replacer(formula_builder_text, tmp_patches)
# Parse the formula into an abstract syntax tree (AST), catching syntax errors.
try:
atok = asttokens.ASTTokens(tmp_formula.get_text(), parse=True)
except SyntaxError as e:
return textbuilder.Text(_create_syntax_error_code(tmp_formula, formula, e))
# Parse formula and generate error code on assignment to rec
with use_inferences(InferRecAssignment):
try:
astroid.parse(tmp_formula.get_text())
except SyntaxError as e:
return textbuilder.Text(_create_syntax_error_code(tmp_formula, formula, e))
# Once we have a tree, go through it and create a subset of the dollar patches that are actually
# relevant. E.g. this is where we'll skip the "$foo" patches that appear in strings or comments.
patches = []
for node in ast.walk(atok.tree):
if isinstance(node, ast.Name) and node.id.startswith('DOLLAR'):
input_pos = tmp_formula.map_back_offset(node.first_token.startpos)
m = DOLLAR_REGEX.match(formula, input_pos)
# If there is no match, then we must have had a "DOLLARblah" identifier that didn't come
# from translating a "$" prefix.
if m:
patches.append(textbuilder.make_patch(formula, m.start(0), m.end(0), 'rec.'))
# Wrap arguments to the top-level "IF()" function into lambdas, for lazy evaluation. This is
# to ensure it's not affected by an exception in the unused value, to match Excel behavior.
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
lazy_args_slice = LAZY_ARG_FUNCTIONS.get(node.func.id)
if lazy_args_slice:
for arg in node.args[lazy_args_slice]:
start, end = map(tmp_formula.map_back_offset, atok.get_text_range(arg))
patches.append(textbuilder.make_patch(formula, start, start, 'lambda: ('))
patches.append(textbuilder.make_patch(formula, end, end, ')'))
# If the last statement is an expression that has its result unused (an ast.Expr node),
# then insert a "return" keyword.
last_statement = atok.tree.body[-1] if atok.tree.body else None
if isinstance(last_statement, ast.Expr):
input_pos = tmp_formula.map_back_offset(last_statement.first_token.startpos)
patches.append(textbuilder.make_patch(formula, input_pos, input_pos, "return "))
elif last_statement is None:
# If we have an empty body (e.g. just a comment), add a 'pass' at the end.
patches.append(textbuilder.make_patch(formula, len(formula), len(formula), '\npass'))
# Apply the new set of patches to the original formula to get the real output.
final_formula = textbuilder.Replacer(formula_builder_text, patches)
# Try parsing again before returning it just in case we have new syntax errors. These are
# possible in cases when a single token ('DOLLARfoo') is valid but an expression ('rec.foo') is
# not, e.g. `foo($bar=1)` or `def $foo()`.
try:
atok = asttokens.ASTTokens(final_formula.get_text(), parse=True)
except SyntaxError as e:
return textbuilder.Text(_create_syntax_error_code(final_formula, formula, e))
# We return the text-builder object whose .get_text() is the final formula.
return final_formula
def _create_syntax_error_code(builder, input_text, err):
"""
Returns the text for a function that raises the given SyntaxError and includes the offending
code in a commented-out form. In addition, it translates the error's position from builder's
output to input_text.
"""
output_ln = asttokens.LineNumbers(builder.get_text())
input_ln = asttokens.LineNumbers(input_text)
# A SyntaxError contains .lineno and .offset (1-based), which we need to translate to offset
# within the transformed text, so that it can be mapped back to an offset in the original text,
# and finally translated back into a line number and 1-based position to report to the user. An
# example is that "$x*" is translated to "return x*", and the syntax error in the transformed
# python code (line 2 offset 9) needs to be translated to be in line 2 offset 3.
output_offset = output_ln.line_to_offset(err.lineno, err.offset - 1 if err.offset else 0)
input_offset = builder.map_back_offset(output_offset)
line, col = input_ln.offset_to_line(input_offset)
return "%s\nraise %s('%s on line %d col %d')" % (
textbuilder.line_start_re.sub('# ', input_text.rstrip()),
type(err).__name__, err.args[0], line, col + 1)
#----------------------------------------------------------------------
def infer(node):
try:
return next(node.infer(), None)
except astroid.exceptions.InferenceError, e:
return "InferenceError on %r: %r" % (node, e)
_lookup_method_names = ('lookupOne', 'lookupRecords')
def _is_table(node):
"""
Return true if obj is a class defining a user table.
"""
return (isinstance(node, astroid.nodes.Class) and node.decorators and
node.decorators.nodes[0].as_string() == 'grist.UserTable')
@contextlib.contextmanager
def use_inferences(*inference_tips):
transform_args = [(cls.node_class, astroid.inference_tip(cls.infer), cls.filter)
for cls in inference_tips]
for args in transform_args:
astroid.MANAGER.register_transform(*args)
yield
for args in transform_args:
astroid.MANAGER.unregister_transform(*args)
class InferenceTip(object):
"""
Base class for inference tips. A derived class can implement the filter() and infer() class
methods, and then register() will put that inference helper into use.
"""
node_class = None
@classmethod
def filter(cls, node):
raise NotImplementedError()
@classmethod
def infer(cls, node, context):
raise NotImplementedError()
class InferReferenceColumn(InferenceTip):
"""
Inference helper to treat the return value of `grist.Reference("Foo")` as an instance of the
table `Foo`.
"""
node_class = astroid.nodes.Call
@classmethod
def filter(cls, node):
return (isinstance(node.func, astroid.nodes.Attribute) and
node.func.as_string() in ('grist.Reference', 'grist.ReferenceList'))
@classmethod
def infer(cls, node, context=None):
table_id = node.args[0].value
table_class = next(node.root().igetattr(table_id))
yield astroid.bases.Instance(table_class)
def _get_formula_type(function_node):
decorators = function_node.decorators.nodes if function_node.decorators else ()
for dec in decorators:
if (isinstance(dec, astroid.nodes.Call) and
dec.func.as_string() == 'grist.formulaType'):
return dec.args[0]
return None
class InferReferenceFormula(InferenceTip):
"""
Inference helper to treat functions decorated with `grist.formulaType(grist.Reference("Foo"))`
as returning instances of table `Foo`.
"""
node_class = astroid.nodes.Function
@classmethod
def filter(cls, node):
# All methods on tables are really used as properties.
return _is_table(node.parent.frame())
@classmethod
def infer(cls, node, context=None):
ftype = _get_formula_type(node)
if ftype and InferReferenceColumn.filter(ftype):
return InferReferenceColumn.infer(ftype, context)
return node.infer_call_result(node.parent.frame(), context)
class InferLookupReference(InferenceTip):
"""
Inference helper to treat the return value of `Table.lookupRecords(...)` as returning instances
of table `Table`.
"""
node_class = astroid.nodes.Call
@classmethod
def filter(cls, node):
return (isinstance(node.func, astroid.nodes.Attribute) and
node.func.attrname in _lookup_method_names and
_is_table(infer(node.func.expr)))
@classmethod
def infer(cls, node, context=None):
yield astroid.bases.Instance(infer(node.func.expr))
class InferLookupComprehension(InferenceTip):
node_class = astroid.nodes.AssignName
@classmethod
def filter(cls, node):
compr = node.parent
if not isinstance(compr, astroid.nodes.Comprehension):
return False
if not isinstance(compr.iter, astroid.nodes.Call):
return False
return InferLookupReference.filter(compr.iter)
@classmethod
def infer(cls, node, context=None):
return InferLookupReference.infer(node.parent.iter)
class InferRecAssignment(InferenceTip):
"""
Inference helper to raise exception on assignment to `rec`.
"""
node_class = astroid.nodes.AssignName
@classmethod
def filter(cls, node):
if node.name == 'rec':
raise SyntaxError('Grist disallows assignment to the special variable "rec"',
('<string>', node.lineno, node.col_offset, ""))
@classmethod
def infer(cls, node, context):
raise NotImplementedError()
#----------------------------------------------------------------------
def parse_grist_names(builder):
"""
Returns a list of tuples (col_info, start_pos, table_id, col_id):
col_info: (table_id, col_id) for the formula the name is found in. It is the value passed
in by gencode.py to codebuilder.make_formula_body().
start_pos: Index of the start character of the name in col_info.formula
table_id: Parsed name when the tuple is for a table name; the name of the column's table
when the tuple is for a column name.
col_id: None when tuple is for a table name; col_id when the tuple is for a column name.
"""
code_text = builder.get_text()
with use_inferences(InferReferenceColumn, InferReferenceFormula, InferLookupReference,
InferLookupComprehension):
atok = asttokens.ASTTokens(code_text, tree=astroid.builder.parse(code_text))
def make_tuple(start, end, table_id, col_id):
name = col_id or table_id
assert end - start == len(name)
patch = textbuilder.Patch(start, end, name, name)
assert code_text[start:end] == name
patch_source = builder.map_back_patch(patch)
if not patch_source:
return None
in_text, in_value, in_patch = patch_source
return (in_value, in_patch.start, table_id, col_id)
parsed_names = []
for node in asttokens.util.walk(atok.tree):
if isinstance(node, astroid.nodes.Name):
obj = infer(node)
if _is_table(obj):
start, end = atok.get_text_range(node)
parsed_names.append(make_tuple(start, end, obj.name, None))
elif isinstance(node, astroid.nodes.Attribute):
obj = infer(node.expr)
if isinstance(obj, astroid.bases.Instance):
cls = obj._proxied
if _is_table(cls):
tok = node.last_token
start, end = tok.startpos, tok.endpos
parsed_names.append(make_tuple(start, end, cls.name, node.attrname))
elif isinstance(node, astroid.nodes.Keyword):
func = node.parent.func
if isinstance(func, astroid.nodes.Attribute) and func.attrname in _lookup_method_names:
obj = infer(func.expr)
if _is_table(obj):
tok = node.first_token
start, end = tok.startpos, tok.endpos
parsed_names.append(make_tuple(start, end, obj.name, node.arg))
return filter(None, parsed_names)

@ -0,0 +1,391 @@
import types
from collections import namedtuple
import depend
import objtypes
import usertypes
import relabeling
import relation
import moment
import logger
from sortedcontainers import SortedListWithKey
log = logger.Logger(__name__, logger.INFO)
MANUAL_SORT = 'manualSort'
MANUAL_SORT_COL_INFO = {
'id': MANUAL_SORT,
'type': 'ManualSortPos',
'formula': '',
'isFormula': False
}
MANUAL_SORT_DEFAULT = 2147483647.0
SPECIAL_COL_IDS = {'id', MANUAL_SORT}
def is_user_column(col_id):
"""
Returns whether the col_id is of a user column (as opposed to special columns that can't be used
for user data).
"""
return col_id not in SPECIAL_COL_IDS and not col_id.startswith('#')
def is_visible_column(col_id):
"""
Returns whether this is an id of a column that's intended to be shown to the user.
"""
return is_user_column(col_id) and not col_id.startswith('gristHelper_')
def is_virtual_column(col_id):
"""
Returns whether col_id is of a special column that does not get communicated outside of the
sandbox. Lookup maps are an example.
"""
return col_id.startswith('#')
def is_validation_column_name(name):
return name.startswith("validation___")
ColInfo = namedtuple('ColInfo', ('type_obj', 'is_formula', 'method'))
def get_col_info(col_model, default_func=None):
if isinstance(col_model, types.FunctionType):
type_obj = getattr(col_model, 'grist_type', usertypes.Any())
return ColInfo(type_obj, is_formula=True, method=col_model)
else:
return ColInfo(col_model, is_formula=False, method=col_model.default_func or default_func)
class BaseColumn(object):
"""
BaseColumn holds a column of data, whether raw or computed.
"""
def __init__(self, table, col_id, col_info):
self.type_obj = col_info.type_obj
self._data = []
self.col_id = col_id
self.table_id = table.table_id
self.node = depend.Node(self.table_id, col_id)
self._is_formula = col_info.is_formula
self._is_private = bool(col_info.method) and getattr(col_info.method, 'is_private', False)
self.method = col_info.method
# Always initialize to include the special empty record at index 0.
self.growto(1)
def update_method(self, method):
"""
After rebuilding user code, we reuse existing column objects, but need to replace their
'method' function. The method may refer to variables in the generated "usercode" module, and
it's important that all such references are to the rebuilt "usercode" module.
"""
self.method = method
def is_formula(self):
"""
Whether this is a formula column. Note that a non-formula column may have an associated
method, which is used to fill in defaults when a record is added.
"""
return self._is_formula
def is_private(self):
"""
Returns whether this method is private to the sandbox. If so, changes to this column do not
get communicated to outside the sandbox via actions.
"""
return self._is_private
def has_formula(self):
"""
has_formula is true if formula is set, whether or not this is a formula column.
"""
return self.method is not None
def clear(self):
self._data = []
self.growto(1) # Always include the special empty record at index 0.
def destroy(self):
"""
Called when the column is deleted.
"""
del self._data[:]
def growto(self, size):
if len(self._data) < size:
self._data.extend([self.getdefault()] * (size - len(self._data)))
def size(self):
return len(self._data)
def set(self, row_id, value):
"""
Sets the value of this column for the given row_id. Value should be as returned by convert(),
i.e. of the right type, or alttext, or error (but should NOT be random wrong types).
"""
try:
self._data[row_id] = value
except IndexError:
self.growto(row_id + 1)
self._data[row_id] = value
def unset(self, row_id):
"""
Sets the value for the given row_id to the default value.
"""
self.set(row_id, self.getdefault())
def get_cell_value(self, row_id):
"""
Returns the "rich" value for the given row_id, i.e. the value that would be seen by formulas.
E.g. for ReferenceColumns it'll be the referred-to Record object. For cells containing
alttext, this will be an AltText object. For RaisedException objects that represent a thrown
error, this will re-raise that error.
"""
raw = self.raw_get(row_id)
if isinstance(raw, objtypes.RaisedException):
raise raw.error
if self.type_obj.is_right_type(raw):
return self._make_rich_value(raw)
return usertypes.AltText(str(raw), self.type_obj.typename())
def _make_rich_value(self, typed_value):
"""
Called by get_cell_value() with a value of the right type for this column. Should be
implemented by derived classes to produce a "rich" version of the value.
"""
# pylint: disable=no-self-use
return typed_value
def raw_get(self, row_id):
"""
Returns the value stored for the given row_id. This may be an error or alttext, and it does
not convert to a richer object.
"""
try:
return self._data[row_id]
except IndexError:
return self.getdefault()
def safe_get(self, row_id):
"""
Returns a value of the right type, or the default value if the stored value had a wrong type.
"""
raw = self.raw_get(row_id)
return raw if self.type_obj.is_right_type(raw) else self.getdefault()
def getdefault(self):
"""
Returns the default value for this column. This is a static default; the implementation of
"default formula" logic is separate.
"""
return self.type_obj.default
def sample_value(self):
"""
Returns a sample value for this column, used for auto-completions. E.g. for a date, this
returns an actual datetime object rather than None (only its attributes should matter).
"""
return self.type_obj.default
def copy_from_column(self, other_column):
"""
Replace this column's data entirely with data from another column of the same exact type.
"""
self._data[:] = other_column._data
def convert(self, value_to_convert):
"""
Converts a value of any type to this column's type, returning either the converted value (for
which is_right_type is true), or an alttext string, or an error object.
"""
return self.type_obj.convert(value_to_convert)
def prepare_new_values(self, values, ignore_data=False):
"""
This allows us to modify values and also produce adjustments to existing records. This
currently is only used by PositionColumn. Returns two lists: new_values, and
[(row_id, new_value)] list of adjustments to existing records.
If ignore_data is True, makes adjustments without regard to the existing data; this is used
for processing ReplaceTableData actions.
"""
# pylint: disable=no-self-use, unused-argument
return values, []
class DataColumn(BaseColumn):
"""
DataColumn describes a column of raw data, and holds it.
"""
pass
class BoolColumn(BaseColumn):
def set(self, row_id, value):
# When 1 or 1.0 is loaded, we should see it as True, and similarly 0 as False. This is similar
# to how, after loading a number into a DateColumn, we should see a date, except we adjust
# booleans at set() time.
bool_value = True if value == 1 else (False if value == 0 else value)
super(BoolColumn, self).set(row_id, bool_value)
class NumericColumn(BaseColumn):
def set(self, row_id, value):
# Make sure any integers are treated as floats to avoid truncation.
# Uses `type(value) == int` rather than `isintance(value, int)` to specifically target
# ints and not bools (which are singleton instances the class int in python). But
# perhaps something should be done about bools also?
# pylint: disable=unidiomatic-typecheck
super(NumericColumn, self).set(row_id, float(value) if type(value) == int else value)
_sample_date = moment.ts_to_date(0)
_sample_datetime = moment.ts_to_dt(0, None, moment.TZ_UTC)
class DateColumn(BaseColumn):
"""
DateColumn contains numerical timestamps represented as seconds since epoch, in type float,
to midnight of specific UTC dates. Accessing them yields date objects.
"""
def _make_rich_value(self, typed_value):
return typed_value and moment.ts_to_date(typed_value)
def sample_value(self):
return _sample_date
class DateTimeColumn(BaseColumn):
"""
DateTimeColumn contains numerical timestamps represented as seconds since epoch, in type float,
and a timestamp associated with the column. Accessing them yields datetime objects.
"""
def __init__(self, table, col_id, col_info):
super(DateTimeColumn, self).__init__(table, col_id, col_info)
self._timezone = col_info.type_obj.timezone
def _make_rich_value(self, typed_value):
return typed_value and moment.ts_to_dt(typed_value, self._timezone)
def sample_value(self):
return _sample_datetime
class PositionColumn(BaseColumn):
def __init__(self, table, col_id, col_info):
super(PositionColumn, self).__init__(table, col_id, col_info)
# This is a list of row_ids, ordered by the position.
self._sorted_rows = SortedListWithKey(key=self.raw_get)
def set(self, row_id, value):
self._sorted_rows.discard(row_id)
super(PositionColumn, self).set(row_id, value)
if value != self.getdefault():
self._sorted_rows.add(row_id)
def copy_from_column(self, other_column):
super(PositionColumn, self).copy_from_column(other_column)
self._sorted_rows = SortedListWithKey(other_column._sorted_rows[:], key=self.raw_get)
def prepare_new_values(self, values, ignore_data=False):
# This does the work of adjusting positions and relabeling existing rows with new position
# (without changing sort order) to make space for the new positions. Note that this is also
# used for updating a position for an existing row: we'll find a new value for it; later when
# this value is set, the old position will be removed and the new one added.
if ignore_data:
rows = SortedListWithKey([], key=self.raw_get)
else:
rows = self._sorted_rows
adjustments, new_values = relabeling.prepare_inserts(rows, values)
return new_values, [(self._sorted_rows[i], pos) for (i, pos) in adjustments]
class BaseReferenceColumn(BaseColumn):
"""
Base class for ReferenceColumn and ReferenceListColumn.
"""
def __init__(self, table, col_id, col_info):
super(BaseReferenceColumn, self).__init__(table, col_id, col_info)
# We can assume that all tables have been instantiated, but not all initialized.
target_table_id = self.type_obj.table_id
self._target_table = table._engine.tables.get(target_table_id, None)
self._relation = relation.ReferenceRelation(table.table_id, target_table_id, col_id)
# Note that we need to remove these back-references when the column is removed.
if self._target_table:
self._target_table._back_references.add(self)
def destroy(self):
# Destroy the column and remove the back-reference we created in the constructor.
super(BaseReferenceColumn, self).destroy()
if self._target_table:
self._target_table._back_references.remove(self)
def _update_references(self, row_id, old_value, new_value):
raise NotImplementedError()
def set(self, row_id, value):
old = self.safe_get(row_id)
super(BaseReferenceColumn, self).set(row_id, value)
new = self.safe_get(row_id)
self._update_references(row_id, old, new)
def copy_from_column(self, other_column):
super(BaseReferenceColumn, self).copy_from_column(other_column)
self._relation.clear()
# This is hacky: we should have an interface to iterate through values of a column. (As it is,
# self._data may include values for non-existent rows; it works here because those values are
# falsy, which makes them ignored by self._update_references).
for row_id, value in enumerate(self._data):
if isinstance(value, int):
self._update_references(row_id, None, value)
def sample_value(self):
return self._target_table.sample_record
class ReferenceColumn(BaseReferenceColumn):
"""
ReferenceColumn contains IDs of rows in another table. Accessing them yields the records in the
other table.
"""
def _make_rich_value(self, typed_value):
# If we refer to an invalid table, return integers rather than fail completely.
if not self._target_table:
return typed_value
# For a Reference, values must either refer to an existing record, or be 0. In all tables,
# the 0 index will contain the all-defaults record.
return self._target_table.Record(self._target_table, typed_value, self._relation)
def _update_references(self, row_id, old_value, new_value):
if old_value:
self._relation.remove_reference(row_id, old_value)
if new_value:
self._relation.add_reference(row_id, new_value)
class ReferenceListColumn(BaseReferenceColumn):
"""
ReferenceListColumn maintains for each row a list of references (row IDs) into another table.
Accessing them yields RecordSets.
"""
def _update_references(self, row_id, old_list, new_list):
for old_value in old_list or ():
self._relation.remove_reference(row_id, old_value)
for new_value in new_list or ():
self._relation.add_reference(row_id, new_value)
def _make_rich_value(self, typed_value):
if typed_value is None:
typed_value = []
# If we refer to an invalid table, return integers rather than fail completely.
if not self._target_table:
return typed_value
return self._target_table.RecordSet(self._target_table, typed_value, self._relation)
# Set up the relationship between usertypes objects and column objects.
usertypes.BaseColumnType.ColType = DataColumn
usertypes.Reference.ColType = ReferenceColumn
usertypes.ReferenceList.ColType = ReferenceListColumn
usertypes.DateTime.ColType = DateTimeColumn
usertypes.Date.ColType = DateColumn
usertypes.PositionNumber.ColType = PositionColumn
usertypes.Bool.ColType = BoolColumn
usertypes.Numeric.ColType = NumericColumn
def create_column(table, col_id, col_info):
return col_info.type_obj.ColType(table, col_id, col_info)

@ -0,0 +1,84 @@
import re
import csv
# Monkey-patch csv.Sniffer class, in which the quote/delimiter detection has silly bugs in the
# regexp that it uses. It also seems poorly-implemented in other ways. We can probably do better
# by not using csv.Sniffer at all.
# The method below is a modified copy of the same-named method in the standard csv.Sniffer class.
def _guess_quote_and_delimiter(_self, data, delimiters):
"""
Looks for text enclosed between two identical quotes
(the probable quotechar) which are preceded and followed
by the same character (the probable delimiter).
For example:
,'some text',
The quote with the most wins, same with the delimiter.
If there is no quotechar the delimiter can't be determined
this way.
"""
regexp = re.compile(
r"""
(?:(?P<delim>[^\w\n"\'])|^|\n) # delimiter or start-of-line
(?P<space>\ ?) # optional initial space
(?P<quote>["\']).*?(?P=quote) # quote-surrounded field
(?:(?P=delim)|$|\r?\n) # delimiter or end-of-line
""", re.VERBOSE | re.DOTALL | re.MULTILINE)
matches = regexp.findall(data)
if not matches:
# (quotechar, doublequote, delimiter, skipinitialspace)
return ('', False, None, 0)
quotes = {}
delims = {}
spaces = 0
for m in matches:
n = regexp.groupindex['quote'] - 1
key = m[n]
if key:
quotes[key] = quotes.get(key, 0) + 1
try:
n = regexp.groupindex['delim'] - 1
key = m[n]
except KeyError:
continue
if key and (delimiters is None or key in delimiters):
delims[key] = delims.get(key, 0) + 1
try:
n = regexp.groupindex['space'] - 1
except KeyError:
continue
if m[n]:
spaces += 1
quotechar = reduce(lambda a, b, _quotes = quotes:
(_quotes[a] > _quotes[b]) and a or b, quotes.keys())
if delims:
delim = reduce(lambda a, b, _delims = delims:
(_delims[a] > _delims[b]) and a or b, delims.keys())
skipinitialspace = delims[delim] == spaces
if delim == '\n': # most likely a file with a single column
delim = ''
else:
# there is *no* delimiter, it's a single column of quoted data
delim = ''
skipinitialspace = 0
# if we see an extra quote between delimiters, we've got a
# double quoted format
dq_regexp = re.compile(
(r"((%(delim)s)|^)\W*%(quote)s[^%(delim)s\n]*%(quote)" +
r"s[^%(delim)s\n]*%(quote)s\W*((%(delim)s)|$)") % \
{'delim':re.escape(delim), 'quote':quotechar}, re.MULTILINE)
if dq_regexp.search(data):
doublequote = True
else:
doublequote = False
return (quotechar, doublequote, delim, skipinitialspace)
csv.Sniffer._guess_quote_and_delimiter = _guess_quote_and_delimiter

@ -0,0 +1,155 @@
"""
depend.py provides classes and functions to manage the dependency graph for grist formulas.
Conceptually, all dependency relationships are the Edges (Node1, Relation, Node2), meaning that
Node1 depends on Node2. Each Node represents a column in a particular table (could be a derived
table, such as for subtotals). The Relation determines the row mapping, i.e. which rows in Node1
column need to be recomputed when a row changes in Node2 column.
When a formula is evaluated, the Record and RecordSet objects maintain a reference to the Relation
in use, while property access determines which Nodes (or columns) depend on one another.
"""
# Note: this is partly inspired by the implementation of the ninja build system, see
# https://github.com/martine/ninja/blob/master/src/graph.h
# Idea for the future: we can consider the concept from ninja of "order-only deps", which are
# needed before we can build the outputs, but which don't cause the outputs to rebuild. Support
# for this (with computed values properly persisted) could allow some cool use cases, like columns
# that recompute manually rather than automatically.
from collections import namedtuple
from sortedcontainers import SortedSet
class Node(namedtuple('Node', ('table_id', 'col_id'))):
"""
Each Node in the dependency graph represents a column in a table.
"""
__slots__ = () # This is a memory-saving device to keep these objects small
def __str__(self):
return '[%s.%s]' % (self.table_id, self.col_id)
class Edge(namedtuple('Edge', ('out_node', 'in_node', 'relation'))):
"""
Each Edge connects two Nodes using a Relation. It says that out_node depends on in_node, so that
a change to in_node should trigger a recomputation of out_node.
"""
__slots__ = () # This is a memory-saving device to keep these objects small
def __str__(self):
return '[%s.%s: %s.%s @ %s]' % (self.out_node.table_id, self.out_node.col_id,
self.in_node.table_id, self.in_node.col_id, self.relation)
class CircularRefError(RuntimeError):
"""
Exception thrown when a formula column references itself, directly or indirectly.
"""
pass
class _AllRows(object):
"""
Special constant that indicates to `invalidate_deps` that all rows are affected and an entire
column is to be invalidated.
"""
pass
ALL_ROWS = _AllRows()
class Graph(object):
"""
Represents the dependency graph for all data in a grist document.
"""
def __init__(self):
# The set of all Edges, i.e. the complete dependency graph.
self._all_edges = set()
# Map from node to the set of edges having it as the in_node (i.e. edges to dependents).
self._in_node_map = {}
# Map from node to the set of edges having it as the out_node (i.e. edges to dependencies).
self._out_node_map = {}
def dump_graph(self):
"""
Print out the graph to stdout, for debugging.
"""
print "Dependency graph (%d edges):" % len(self._all_edges)
for edge in sorted(self._all_edges):
print " %s" % (edge,)
def add_edge(self, out_node, in_node, relation):
"""
Adds an edge to the global dependency graph: out_node depends on in_node, i.e. a change to
in_node should trigger a recomputation of out_node.
"""
edge = Edge(out_node, in_node, relation)
self._all_edges.add(edge)
self._in_node_map.setdefault(edge.in_node, set()).add(edge)
self._out_node_map.setdefault(edge.out_node, set()).add(edge)
def clear_dependencies(self, out_node):
"""
Removes all edges which affect the given out_node, i.e. all of its dependencies.
"""
remove_edges = self._out_node_map.pop(out_node, ())
for edge in remove_edges:
self._all_edges.remove(edge)
self._in_node_map.get(edge.in_node, set()).remove(edge)
edge.relation.reset_all()
def reset_dependencies(self, node, dirty_rows):
"""
For edges the given node depends on, reset the given output rows. This is called just before
the rows get recomputed, to allow the relations to clear out state for those rows if needed.
"""
in_edges = self._out_node_map.get(node, ())
for edge in in_edges:
edge.relation.reset_rows(dirty_rows)
def remove_node_if_unused(self, node):
"""
Removes the given node if it has no dependents. Returns True if the node is gone, False if the
node has dependents.
"""
if self._in_node_map.get(node, None):
return False
self.clear_dependencies(node)
self._in_node_map.pop(node, None)
return True
def invalidate_deps(self, dirty_node, dirty_rows, recompute_map, include_self=True):
"""
Invalidates the given rows in the given node, and all of its dependents, i.e. all the nodes
that recursively depend on dirty_node. If include_self is False, then skips the given node
(e.g. if the node is raw data rather than formula). Results are added to recompute_map, which
is a dict mapping Nodes to sets of rows that need to be recomputed.
If dirty_rows is ALL_ROWS, the whole column is affected, and dependencies get recomputed from
scratch. ALL_ROWS propagates to all dependent columns, so those also get recomputed in full.
"""
if include_self:
if recompute_map.get(dirty_node) == ALL_ROWS:
return
if dirty_rows == ALL_ROWS:
recompute_map[dirty_node] = ALL_ROWS
# If all rows are being recomputed, clear the dependencies of the affected column. (We add
# dependencies in the course of recomputing, but we can only start from an empty set of
# dependencies if we are about to recompute all rows.)
self.clear_dependencies(dirty_node)
else:
out_rows = recompute_map.setdefault(dirty_node, SortedSet())
prev_count = len(out_rows)
out_rows.update(dirty_rows)
# Don't bother recursing into dependencies if we didn't actually update anything.
if len(out_rows) <= prev_count:
return
# Iterate through a copy of _in_node_map, because recursive clear_dependencies may modify it.
for edge in list(self._in_node_map.get(dirty_node, ())):
affected_rows = (ALL_ROWS if dirty_rows == ALL_ROWS else
edge.relation.get_affected_rows(dirty_rows))
self.invalidate_deps(edge.out_node, affected_rows, recompute_map, include_self=True)

@ -0,0 +1,240 @@
import actions
import schema
import logger
from usertypes import strict_equal
log = logger.Logger(__name__, logger.INFO)
class DocActions(object):
def __init__(self, engine):
self._engine = engine
#----------------------------------------
# Actions on records.
#----------------------------------------
def AddRecord(self, table_id, row_id, column_values):
self.BulkAddRecord(
table_id, [row_id], {key: [val] for key, val in column_values.iteritems()})
def BulkAddRecord(self, table_id, row_ids, column_values):
table = self._engine.tables[table_id]
for row_id in row_ids:
assert row_id not in table.row_ids, \
"docactions.[Bulk]AddRecord for existing record #%s" % row_id
self._engine.out_actions.undo.append(actions.BulkRemoveRecord(table_id, row_ids).simplify())
self._engine.add_records(table_id, row_ids, column_values)
def RemoveRecord(self, table_id, row_id):
return self.BulkRemoveRecord(table_id, [row_id])
def BulkRemoveRecord(self, table_id, row_ids):
table = self._engine.tables[table_id]
# Collect the undo values, and unset all values in the column (i.e. set to defaults), just to
# make sure we don't have stale values hanging around.
undo_values = {}
for column in table.all_columns.itervalues():
if not column.is_formula() and column.col_id != "id":
col_values = map(column.raw_get, row_ids)
default = column.getdefault()
# If this column had all default values, don't include it into the undo BulkAddRecord.
if not all(strict_equal(val, default) for val in col_values):
undo_values[column.col_id] = col_values
for row_id in row_ids:
column.unset(row_id)
# Generate the undo action.
self._engine.out_actions.undo.append(
actions.BulkAddRecord(table_id, row_ids, undo_values).simplify())
# Invalidate the deleted rows, so that anything that depends on them gets recomputed.
self._engine.invalidate_records(table_id, row_ids)
def UpdateRecord(self, table_id, row_id, columns):
self.BulkUpdateRecord(
table_id, [row_id], {key: [val] for key, val in columns.iteritems()})
def BulkUpdateRecord(self, table_id, row_ids, columns):
table = self._engine.tables[table_id]
for row_id in row_ids:
assert row_id in table.row_ids, \
"docactions.[Bulk]UpdateRecord for non-existent record #%s" % row_id
# Load the updated values.
undo_values = {}
for col_id, values in columns.iteritems():
col = table.get_column(col_id)
undo_values[col_id] = map(col.raw_get, row_ids)
for (row_id, value) in zip(row_ids, values):
col.set(row_id, value)
# Generate the undo action.
self._engine.out_actions.undo.append(
actions.BulkUpdateRecord(table_id, row_ids, undo_values).simplify())
# Invalidate the updated rows, just for the columns that got changed (and, as always,
# anything that depends on them).
self._engine.invalidate_records(table_id, row_ids, col_ids=columns.keys())
def ReplaceTableData(self, table_id, row_ids, column_values):
old_data = self._engine.fetch_table(table_id, formulas=False)
self._engine.out_actions.undo.append(actions.ReplaceTableData(*old_data))
self._engine.load_table(actions.TableData(table_id, row_ids, column_values))
#----------------------------------------
# Actions on columns.
#----------------------------------------
def AddColumn(self, table_id, col_id, col_info):
table = self._engine.tables[table_id]
assert not table.has_column(col_id), "Column %s already exists in %s" % (col_id, table_id)
# Add the new column to the schema object maintained in the engine.
self._engine.schema[table_id].columns[col_id] = schema.dict_to_col(col_info, col_id=col_id)
self._engine.rebuild_usercode()
self._engine.new_column_name(table)
# Generate the undo action.
self._engine.out_actions.undo.append(actions.RemoveColumn(table_id, col_id))
def RemoveColumn(self, table_id, col_id):
table = self._engine.tables[table_id]
assert table.has_column(col_id), "Column %s not in table %s" % (col_id, table_id)
# Generate (if needed) the undo action to restore the data.
undo_action = None
column = table.get_column(col_id)
if not column.is_formula():
default = column.getdefault()
# Add to undo a BulkUpdateRecord for non-default values in the column being removed.
row_ids = [r for r in table.row_ids if not strict_equal(column.raw_get(r), default)]
undo_action = actions.BulkUpdateRecord(table_id, row_ids, {
column.col_id: map(column.raw_get, row_ids)
}).simplify()
# Remove the specified column from the schema object.
colinfo = self._engine.schema[table_id].columns.pop(col_id)
self._engine.rebuild_usercode()
# Generate the undo action(s).
if undo_action:
self._engine.out_actions.undo.append(undo_action)
self._engine.out_actions.undo.append(actions.AddColumn(
table_id, col_id, schema.col_to_dict(colinfo, include_id=False)))
def RenameColumn(self, table_id, old_col_id, new_col_id):
table = self._engine.tables[table_id]
assert table.has_column(old_col_id), "Column %s not in table %s" % (old_col_id, table_id)
assert not table.has_column(new_col_id), \
"Column %s already exists in %s" % (new_col_id, table_id)
old_column = table.get_column(old_col_id)
# Replace the renamed column in the schema object.
schema_table_info = self._engine.schema[table_id]
colinfo = schema_table_info.columns.pop(old_col_id)
schema_table_info.columns[new_col_id] = schema.SchemaColumn(
new_col_id, colinfo.type, colinfo.isFormula, colinfo.formula)
self._engine.rebuild_usercode()
self._engine.new_column_name(table)
# We replaced the old column with a new Column object (not strictly necessary, but simpler).
# For a raw data column, we need to copy over the data from the old column object.
new_column = table.get_column(new_col_id)
new_column.copy_from_column(old_column)
# Generate the undo action.
self._engine.out_actions.undo.append(actions.RenameColumn(table_id, new_col_id, old_col_id))
def ModifyColumn(self, table_id, col_id, col_info):
table = self._engine.tables[table_id]
assert table.has_column(col_id), "Column %s not in table %s" % (col_id, table_id)
old_column = table.get_column(col_id)
# Modify the specified column in the schema object.
schema_table_info = self._engine.schema[table_id]
old = schema_table_info.columns[col_id]
new = schema.SchemaColumn(col_id,
col_info.get('type', old.type),
bool(col_info.get('isFormula', old.isFormula)),
col_info.get('formula', old.formula))
if new == old:
log.info("ModifyColumn called which was a noop")
return
undo_col_info = {k: v for k, v in schema.col_to_dict(old, include_id=False).iteritems()
if k in col_info}
# Remove the column from the schema, then re-add it, to force creation of a new column object.
schema_table_info.columns.pop(col_id)
self._engine.rebuild_usercode()
schema_table_info.columns[col_id] = new
self._engine.rebuild_usercode()
# Fill in the new column with the values from the old column.
new_column = table.get_column(col_id)
for row_id in table.row_ids:
new_column.set(row_id, old_column.raw_get(row_id))
# Generate the undo action.
self._engine.out_actions.undo.append(actions.ModifyColumn(table_id, col_id, undo_col_info))
#----------------------------------------
# Actions on tables.
#----------------------------------------
def AddTable(self, table_id, columns):
assert table_id not in self._engine.tables, "Table %s already exists" % table_id
# Update schema, and re-generate the module code.
self._engine.schema[table_id] = schema.SchemaTable(table_id, schema.dict_list_to_cols(columns))
self._engine.rebuild_usercode()
# Generate the undo action.
self._engine.out_actions.undo.append(actions.RemoveTable(table_id))
def RemoveTable(self, table_id):
assert table_id in self._engine.tables, "Table %s doesn't exist" % table_id
# Create undo actions to restore all the data records of this table.
table_data = self._engine.fetch_table(table_id, formulas=False)
undo_action = actions.BulkAddRecord(*table_data).simplify()
if undo_action:
self._engine.out_actions.undo.append(undo_action)
# Update schema, and re-generate the module code.
schema_table = self._engine.schema.pop(table_id)
self._engine.rebuild_usercode()
# Generate the undo action.
self._engine.out_actions.undo.append(actions.AddTable(
table_id, schema.cols_to_dict_list(schema_table.columns)))
def RenameTable(self, old_table_id, new_table_id):
assert old_table_id in self._engine.tables, "Table %s doesn't exist" % old_table_id
assert new_table_id not in self._engine.tables, "Table %s already exists" % new_table_id
old_table = self._engine.tables[old_table_id]
# Update schema, and re-generate the module code.
old = self._engine.schema.pop(old_table_id)
self._engine.schema[new_table_id] = schema.SchemaTable(new_table_id, old.columns)
self._engine.rebuild_usercode()
# Copy over all columns from the old table to the new.
new_table = self._engine.tables[new_table_id]
for new_column in new_table.all_columns.itervalues():
if not new_column.is_formula():
new_column.copy_from_column(old_table.get_column(new_column.col_id))
new_table.grow_to_max() # We need to bring formula columns to the right size too.
# Generate the undo action.
self._engine.out_actions.undo.append(actions.RenameTable(new_table_id, old_table_id))
# end

@ -0,0 +1,352 @@
"""
This file provides convenient access to document metadata that is internal to the sandbox.
Specifically, it has handles to the metadata tables, and adds helpful formula columns to tables
which exist only in the sandbox and are not communicated to the client.
It is similar in purpose to DocModel.js on the client side.
"""
import itertools
import json
import acl
import records
import usertypes
import relabeling
import table
import moment
def _record_set(table_id, group_by, sort_by=None):
@usertypes.formulaType(usertypes.ReferenceList(table_id))
def func(rec, table):
lookup_table = table.docmodel.get_table(table_id)
return lookup_table.lookupRecords(sort_by=sort_by, **{group_by: rec.id})
return func
def _record_inverse(table_id, ref_col):
@usertypes.formulaType(usertypes.Reference(table_id))
def func(rec, table):
lookup_table = table.docmodel.get_table(table_id)
return lookup_table.lookupOne(**{ref_col: rec.id})
return func
class MetaTableExtras(object):
"""
Container class for enhancements to metadata table models. The members (formula methods) defined
for a nested class here will automatically be added as members to same-named metadata table.
"""
# pylint: disable=no-self-argument,no-member,unused-argument,not-an-iterable
class _grist_DocInfo(object):
def acl_resources(rec, table):
"""
Returns a map of ACL resources for use by acl.py. It is done in a formula so that it
automatically recomputes when anything changes in _grist_ACLResources table.
"""
# pylint: disable=no-self-use
return acl.build_resources(table.docmodel.get_table('_grist_ACLResources').lookupRecords())
@usertypes.formulaType(usertypes.Any())
def tzinfo(rec, table):
# pylint: disable=no-self-use
try:
return moment.tzinfo(rec.timezone)
except KeyError:
return moment.TZ_UTC
class _grist_Tables(object):
columns = _record_set('_grist_Tables_column', 'parentId', sort_by='parentPos')
viewSections = _record_set('_grist_Views_section', 'tableRef')
tableViews = _record_set('_grist_TableViews', 'tableRef')
summaryTables = _record_set('_grist_Tables', 'summarySourceTable')
def summaryKey(rec, table):
"""
Returns the tuple of sorted colRefs for summary columns. This uniquely identifies a summary
table among other summary tables for the same source table.
"""
# pylint: disable=not-an-iterable
return (tuple(sorted(int(c.summarySourceCol) for c in rec.columns if c.summarySourceCol))
if rec.summarySourceTable else None)
def setAutoRemove(rec, table):
"""Marks the table for removal if it's a summary table with no more view sections."""
table.docmodel.setAutoRemove(rec, rec.summarySourceTable and not rec.viewSections)
class _grist_Tables_column(object):
viewFields = _record_set('_grist_Views_section_field', 'colRef')
summaryGroupByColumns = _record_set('_grist_Tables_column', 'summarySourceCol')
usedByCols = _record_set('_grist_Tables_column', 'displayCol')
usedByFields = _record_set('_grist_Views_section_field', 'displayCol')
def tableId(rec, table):
return rec.parentId.tableId
def numDisplayColUsers(rec, table):
"""
Returns the number of cols and fields using this col as a display col
"""
return len(rec.usedByCols) + len(rec.usedByFields)
def setAutoRemove(rec, table):
"""Marks the col for removal if it's a display helper col with no more users."""
table.docmodel.setAutoRemove(rec,
rec.colId.startswith('gristHelper_Display') and rec.numDisplayColUsers == 0)
class _grist_Views(object):
viewSections = _record_set('_grist_Views_section', 'parentId')
tabBarItems = _record_set('_grist_TabBar', 'viewRef')
tableViewItems = _record_set('_grist_TableViews', 'viewRef')
primaryViewTable = _record_inverse('_grist_Tables', 'primaryViewId')
pageItems = _record_set('_grist_Pages', 'viewRef')
class _grist_Views_section(object):
fields = _record_set('_grist_Views_section_field', 'parentId', sort_by='parentPos')
class _grist_ACLRules(object):
# The set of rules that applies to this resource
@usertypes.formulaType(usertypes.ReferenceList('_grist_ACLPrincipals'))
def principalsList(rec, table):
return json.loads(rec.principals)
class _grist_ACLResources(object):
# The set of rules that applies to this resource
ruleset = _record_set('_grist_ACLRules', 'resource')
class _grist_ACLPrincipals(object):
# Memberships table maintains containment relationships between principals.
memberships = _record_set('_grist_ACLMemberships', 'parent')
# Children of a User principal are Instances. Children of a Group are Users or other Groups.
@usertypes.formulaType(usertypes.ReferenceList('_grist_ACLPrincipals'))
def children(rec, table):
return [m.child for m in rec.memberships]
@usertypes.formulaType(usertypes.ReferenceList('_grist_ACLPrincipals'))
def descendants(rec, table):
"""
Descendants through great-grandchildren. (We don't support fully recursive descendants yet,
which may be cleaner.) The max supported level is a group containing subgroups (children),
which contain users (grandchildren), which contain instances (great-grandchildren).
"""
# Include direct children.
ret = set(rec.children)
ret.add(rec)
for c1 in rec.children:
# Include grandchildren (children of each child)
ret.update(c1.children)
for c2 in c1.children:
# Include great-grandchildren (children of each grandchild).
ret.update(c2.children)
return ret
@usertypes.formulaType(usertypes.ReferenceList('_grist_ACLPrincipals'))
def allInstances(rec, table):
return sorted(r for r in rec.descendants if r.instanceId)
@usertypes.formulaType(usertypes.Text())
def name(rec, table):
return ('User:' + rec.userEmail if rec.type == 'user' else
'Group:' + rec.groupName if rec.type == 'group' else
'Inst:' + rec.instanceId if rec.type == 'instance' else '')
def enhance_model(model_class):
"""
Given a metadata model class, add all members (formula methods) to it from the same-named inner
class of MetaTableExtras. The added members are marked as private; the resulting Column objects
will have col.is_private() as true.
"""
extras_class = getattr(MetaTableExtras, model_class.__name__, None)
if not extras_class:
return
for name, member in extras_class.__dict__.iteritems():
if not name.startswith("__"):
member.__name__ = name
member.is_private = True
setattr(model_class, name, member)
# There is a single instance of DocModel per sandbox process and
# global_docmodel is a reference to it
global_docmodel = None
class DocModel(object):
"""
This class defines more convenient handles to all metadata tables. In addition, it sets
table.docmodel member for each of these tables to itself. Note that it deals with
table.UserTable objects (rather than the lower-level table.Table objects).
"""
def __init__(self, engine):
self._engine = engine
global global_docmodel # pylint: disable=global-statement
if not global_docmodel:
global_docmodel = self
# Set of records scheduled for automatic removal.
self._auto_remove_set = set()
def update_tables(self):
"""
Update the table handles we maintain to correspond to the current Engine tables.
"""
self.doc_info = self._prep_table("_grist_DocInfo")
self.tables = self._prep_table("_grist_Tables")
self.columns = self._prep_table("_grist_Tables_column")
self.table_views = self._prep_table("_grist_TableViews")
self.tab_bar = self._prep_table("_grist_TabBar")
self.views = self._prep_table("_grist_Views")
self.view_sections = self._prep_table("_grist_Views_section")
self.view_fields = self._prep_table("_grist_Views_section_field")
self.validations = self._prep_table("_grist_Validations")
self.repl_hist = self._prep_table("_grist_REPL_Hist")
self.attachments = self._prep_table("_grist_Attachments")
self.acl_rules = self._prep_table("_grist_ACLRules")
self.acl_resources = self._prep_table("_grist_ACLResources")
self.acl_principals = self._prep_table("_grist_ACLPrincipals")
self.acl_memberships = self._prep_table("_grist_ACLMemberships")
self.pages = self._prep_table("_grist_Pages")
def _prep_table(self, name):
"""
Helper that gets the table with the given name, and sets its .doc attribute to DocModel.
"""
user_table = self._engine.tables[name].user_table
user_table.docmodel = self
return user_table
def get_table(self, table_id):
return self._engine.tables[table_id].user_table
def get_table_rec(self, table_id):
"""Returns the table record for the given table name, or raises ValueError."""
table_rec = self.tables.lookupOne(tableId=table_id)
if not table_rec:
raise ValueError("No such table: %s" % table_id)
return table_rec
def get_column_rec(self, table_id, col_id):
"""Returns the column record for the given table and column names, or raises ValueError."""
col_rec = self.columns.lookupOne(tableId=table_id, colId=col_id)
if not col_rec:
raise ValueError("No such column: %s.%s" % (table_id, col_id))
return col_rec
def setAutoRemove(self, record, yes_or_no):
"""
Marks a record for automatic removal. To use, create a formula in your table, e.g.
'setAutoRemove', which calls `table.docmodel.setAutoRemove(boolean_value)`. Whenever it gets
reevaluated and the boolean_value is true, the record will be automatically removed.
For now, it is only usable in metadata tables, although we could extend to user tables.
"""
if yes_or_no:
self._auto_remove_set.add(record)
else:
self._auto_remove_set.discard(record)
def apply_auto_removes(self):
"""
Remove the records marked for removal.
"""
# Sort to make sure removals are done in deterministic order.
gone_records = sorted(self._auto_remove_set)
self._auto_remove_set.clear()
self.remove(gone_records)
return bool(gone_records)
def remove(self, records):
"""
Removes all records in the given iterable of Records.
"""
for table_id, group in itertools.groupby(records, lambda r: r._table.table_id):
self._engine.user_actions.BulkRemoveRecord(table_id, [int(r) for r in group])
def update(self, records, **col_values):
"""
Updates all records in the given list of Records or a RecordSet; col_values maps column ids to
values. The values may either be a list of the length len(records), or a non-list value that
will be used for all records.
"""
record_list = list(records)
if not record_list:
return
table_id = record_list[0]._table.table_id
# Make sure these are all records from the same table.
assert all(r._table.table_id == table_id for r in record_list)
row_ids = [int(r) for r in record_list]
values = _unify_col_values(col_values, len(record_list))
self._engine.user_actions.BulkUpdateRecord(table_id, row_ids, values)
def add(self, record_set_or_table, **col_values):
"""
Add new records for the given table; col_values maps column ids to values. Values may either
be lists (all of the same length), or non-list values that will be used for all added records.
Either a UserTable or a RecordSet may used as the first argument. If it is a RecordSet created
with lookupRecords, it may set additional col_values.
Returns a list of inserted records.
"""
assert isinstance(record_set_or_table, (records.RecordSet, table.UserTable))
count = _get_col_values_count(col_values)
values = _unify_col_values(col_values, count)
if isinstance(record_set_or_table, records.RecordSet):
table_obj = record_set_or_table._table
group_by = record_set_or_table._group_by
if group_by:
values.update((k, [v] * count) for k, v in group_by.iteritems() if k not in values)
else:
table_obj = record_set_or_table.table
row_ids = self._engine.user_actions.BulkAddRecord(table_obj.table_id, [None] * count, values)
return [table_obj.Record(table_obj, r, None) for r in row_ids]
def insert(self, record_set, position, **col_values):
"""
Add new records using col_values, inserting them into record_set according to position.
This may only be used when record_set is sorted by a field of type PositionNumber; in
particular it must be the result of lookupRecords() with 'sort_by' parameter.
Position may be numeric (to compare to other sort_by values), or None to insert at the end.
Returns a list of inserted records.
"""
assert isinstance(record_set, records.RecordSet), \
"docmodel.insert() may only be used on a RecordSet, not %s" % type(record_set)
sort_by = getattr(record_set, '_sort_by', None)
assert sort_by, \
"docmodel.insert() may only be used on a sorted RecordSet"
column = record_set._table.get_column(sort_by)
assert isinstance(column.type_obj, usertypes.PositionNumber), \
"docmodel.insert() may only be used on a RecordSet sorted by PositionNumber type column"
col_values[sort_by] = float('inf') if position is None else position
return self.add(record_set, **col_values)
def insert_after(self, record_set, position, **col_values):
"""
Same as insert, but when position is equal to the position of an existing record, inserts
after that record; and when position is None, inserts at the beginning.
"""
# We can reuse insert() by just using the next float for position. As long as positions of
# existing records are different, that would necessarily place the new records correctly.
pos = float('-inf') if position is None else relabeling.nextfloat(position)
return self.insert(record_set, pos, **col_values)
def _unify_col_values(col_values, count):
"""
Helper that converts a dict mapping keys to values or lists of values to all lists. Non-list
values get turned into lists by repeating them count times.
"""
assert all(len(v) == count for v in col_values.itervalues() if isinstance(v, list))
return {k: (v if isinstance(v, list) else [v] * count)
for k, v in col_values.iteritems()}
def _get_col_values_count(col_values):
"""
Helper that returns the length of the first list in among the values of col_values. If none of
the values is a list, returns 1.
"""
first_list = next((v for v in col_values.itervalues() if isinstance(v, list)), None)
return len(first_list) if first_list is not None else 1

File diff suppressed because it is too large Load Diff

@ -0,0 +1,12 @@
# pylint: disable=wildcard-import
from date import *
from info import *
from logical import *
from lookup import *
from math import *
from stats import *
from text import *
from schedule import *
# Export all uppercase names, for use with `from functions import *`.
__all__ = [k for k in dir() if not k.startswith('_') and k.isupper()]

@ -0,0 +1,773 @@
import calendar
import datetime
import dateutil.parser
import moment
import docmodel
# pylint: disable=no-member
_excel_date_zero = datetime.datetime(1899, 12, 30)
def _make_datetime(value):
if isinstance(value, datetime.datetime):
return value
elif isinstance(value, datetime.date):
return datetime.datetime.combine(value, datetime.time())
elif isinstance(value, datetime.time):
return datetime.datetime.combine(datetime.date.today(), value)
elif isinstance(value, basestring):
return dateutil.parser.parse(value)
else:
raise ValueError('Invalid date %r' % (value,))
def _get_global_tz():
if docmodel.global_docmodel:
return docmodel.global_docmodel.doc_info.lookupOne(id=1).tzinfo
return moment.TZ_UTC
def _get_tzinfo(zonelabel):
"""
A helper that returns a `datetime.tzinfo` instance for zonelabel. Returns the global
document timezone if zonelabel is None.
"""
return moment.tzinfo(zonelabel) if zonelabel else _get_global_tz()
def DTIME(value, tz=None):
"""
Returns the value converted to a python `datetime` object. The value may be a
`string`, `date` (interpreted as midnight on that day), `time` (interpreted as a
time-of-day today), or an existing `datetime`.
The returned `datetime` will have its timezone set to the `tz` argument, or the
document's default timezone when `tz` is omitted or None. If the input is itself a
`datetime` with the timezone set, it is returned unchanged (no changes to its timezone).
>>> DTIME(datetime.date(2017, 1, 1))
datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> DTIME(datetime.date(2017, 1, 1), 'Europe/Paris')
datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('Europe/Paris'))
>>> DTIME(datetime.datetime(2017, 1, 1))
datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> DTIME(datetime.datetime(2017, 1, 1, tzinfo=moment.tzinfo('UTC')))
datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('UTC'))
>>> DTIME(datetime.datetime(2017, 1, 1, tzinfo=moment.tzinfo('UTC')), 'Europe/Paris')
datetime.datetime(2017, 1, 1, 0, 0, tzinfo=moment.tzinfo('UTC'))
>>> DTIME("1/1/2008")
datetime.datetime(2008, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
"""
value = _make_datetime(value)
return value if value.tzinfo else value.replace(tzinfo=_get_tzinfo(tz))
def XL_TO_DATE(value, tz=None):
"""
Converts a provided Excel serial number representing a date into a `datetime` object.
Value is interpreted as the number of days since December 30, 1899.
(This corresponds to Google Sheets interpretation. Excel starts with Dec. 31, 1899 but wrongly
considers 1900 to be a leap year. Excel for Mac should be configured to use 1900 date system,
i.e. uncheck "Use the 1904 date system" option.)
The returned `datetime` will have its timezone set to the `tz` argument, or the
document's default timezone when `tz` is omitted or None.
>>> XL_TO_DATE(41100.1875)
datetime.datetime(2012, 7, 10, 4, 30, tzinfo=moment.tzinfo('America/New_York'))
>>> XL_TO_DATE(39448)
datetime.datetime(2008, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> XL_TO_DATE(40982.0625)
datetime.datetime(2012, 3, 14, 1, 30, tzinfo=moment.tzinfo('America/New_York'))
More tests:
>>> XL_TO_DATE(0)
datetime.datetime(1899, 12, 30, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> XL_TO_DATE(-1)
datetime.datetime(1899, 12, 29, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> XL_TO_DATE(1)
datetime.datetime(1899, 12, 31, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> XL_TO_DATE(1.5)
datetime.datetime(1899, 12, 31, 12, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> XL_TO_DATE(61.0)
datetime.datetime(1900, 3, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
"""
return DTIME(_excel_date_zero, tz) + datetime.timedelta(days=value)
def DATE_TO_XL(date_value):
"""
Converts a Python `date` or `datetime` object to the serial number as used by
Excel, with December 30, 1899 as serial number 1.
See XL_TO_DATE for more explanation.
>>> DATE_TO_XL(datetime.date(2008, 1, 1))
39448.0
>>> DATE_TO_XL(datetime.date(2012, 3, 14))
40982.0
>>> DATE_TO_XL(datetime.datetime(2012, 3, 14, 1, 30))
40982.0625
More tests:
>>> DATE_TO_XL(datetime.date(1900, 1, 1))
2.0
>>> DATE_TO_XL(datetime.datetime(1900, 1, 1))
2.0
>>> DATE_TO_XL(datetime.datetime(1900, 1, 1, 12, 0))
2.5
>>> DATE_TO_XL(datetime.datetime(1900, 1, 1, 12, 0, tzinfo=moment.tzinfo('America/New_York')))
2.5
>>> DATE_TO_XL(datetime.date(1900, 3, 1))
61.0
>>> DATE_TO_XL(datetime.datetime(2008, 1, 1))
39448.0
>>> DATE_TO_XL(XL_TO_DATE(39488))
39488.0
>>> dt_ny = XL_TO_DATE(39488)
>>> dt_paris = moment.tz(dt_ny, 'America/New_York').tz('Europe/Paris').datetime()
>>> DATE_TO_XL(dt_paris)
39488.0
"""
# If date_value is `naive` it's ok to pass tz to both DTIME as it won't affect the
# result.
return (DTIME(date_value) - DTIME(_excel_date_zero)).total_seconds() / 86400.
def DATE(year, month, day):
"""
Returns the `datetime.datetime` object that represents a particular date.
The DATE function is most useful in formulas where year, month, and day are formulas, not
constants.
If year is between 0 and 1899 (inclusive), adds 1900 to calculate the year.
>>> DATE(108, 1, 2)
datetime.date(2008, 1, 2)
>>> DATE(2008, 1, 2)
datetime.date(2008, 1, 2)
If month is greater than 12, rolls into the following year.
>>> DATE(2008, 14, 2)
datetime.date(2009, 2, 2)
If month is less than 1, subtracts that many months plus 1, from the first month in the year.
>>> DATE(2008, -3, 2)
datetime.date(2007, 9, 2)
If day is greater than the number of days in the given month, rolls into the following months.
>>> DATE(2008, 1, 35)
datetime.date(2008, 2, 4)
If day is less than 1, subtracts that many days plus 1, from the first day of the given month.
>>> DATE(2008, 1, -15)
datetime.date(2007, 12, 16)
More tests:
>>> DATE(1900, 1, 1)
datetime.date(1900, 1, 1)
>>> DATE(1900, 0, 0)
datetime.date(1899, 11, 30)
"""
if year < 1900:
year += 1900
norm_month = (month - 1) % 12 + 1
norm_year = year + (month - 1) // 12
return datetime.date(norm_year, norm_month, 1) + datetime.timedelta(days=day - 1)
def DATEDIF(start_date, end_date, unit):
"""
Calculates the number of days, months, or years between two dates.
Unit indicates the type of information that you want returned:
- "Y": The number of complete years in the period.
- "M": The number of complete months in the period.
- "D": The number of days in the period.
- "MD": The difference between the days in start_date and end_date. The months and years of the
dates are ignored.
- "YM": The difference between the months in start_date and end_date. The days and years of the
dates are ignored.
- "YD": The difference between the days of start_date and end_date. The years of the dates are
ignored.
Two complete years in the period (2)
>>> DATEDIF(DATE(2001, 1, 1), DATE(2003, 1, 1), "Y")
2
440 days between June 1, 2001, and August 15, 2002 (440)
>>> DATEDIF(DATE(2001, 6, 1), DATE(2002, 8, 15), "D")
440
75 days between June 1 and August 15, ignoring the years of the dates (75)
>>> DATEDIF(DATE(2001, 6, 1), DATE(2012, 8, 15), "YD")
75
The difference between 1 and 15, ignoring the months and the years of the dates (14)
>>> DATEDIF(DATE(2001, 6, 1), DATE(2002, 8, 15), "MD")
14
More tests:
>>> DATEDIF(DATE(1969, 7, 16), DATE(1969, 7, 24), "D")
8
>>> DATEDIF(DATE(2014, 1, 1), DATE(2015, 1, 1), "M")
12
>>> DATEDIF(DATE(2014, 1, 2), DATE(2015, 1, 1), "M")
11
>>> DATEDIF(DATE(2014, 1, 1), DATE(2024, 1, 1), "Y")
10
>>> DATEDIF(DATE(2014, 1, 2), DATE(2024, 1, 1), "Y")
9
>>> DATEDIF(DATE(1906, 10, 16), DATE(2004, 2, 3), "YM")
3
>>> DATEDIF(DATE(2016, 2, 14), DATE(2016, 3, 14), "YM")
1
>>> DATEDIF(DATE(2016, 2, 14), DATE(2016, 3, 13), "YM")
0
>>> DATEDIF(DATE(2008, 10, 16), DATE(2019, 12, 3), "MD")
17
>>> DATEDIF(DATE(2008, 11, 16), DATE(2019, 1, 3), "MD")
18
>>> DATEDIF(DATE(2016, 2, 29), DATE(2017, 2, 28), "Y")
0
>>> DATEDIF(DATE(2016, 2, 29), DATE(2017, 2, 29), "Y")
1
"""
if isinstance(start_date, datetime.datetime):
start_date = start_date.date()
if isinstance(end_date, datetime.datetime):
end_date = end_date.date()
if unit == 'D':
return (end_date - start_date).days
elif unit == 'M':
months = (end_date.year - start_date.year) * 12 + (end_date.month - start_date.month)
month_delta = 0 if start_date.day <= end_date.day else 1
return months - month_delta
elif unit == 'Y':
years = end_date.year - start_date.year
year_delta = 0 if (start_date.month, start_date.day) <= (end_date.month, end_date.day) else 1
return years - year_delta
elif unit == 'MD':
month_delta = 0 if start_date.day <= end_date.day else 1
return (end_date - DATE(end_date.year, end_date.month - month_delta, start_date.day)).days
elif unit == 'YM':
month_delta = 0 if start_date.day <= end_date.day else 1
return (end_date.month - start_date.month - month_delta) % 12
elif unit == 'YD':
year_delta = 0 if (start_date.month, start_date.day) <= (end_date.month, end_date.day) else 1
return (end_date - DATE(end_date.year - year_delta, start_date.month, start_date.day)).days
else:
raise ValueError('Invalid unit %s' % (unit,))
def DATEVALUE(date_string, tz=None):
"""
Converts a date that is stored as text to a `datetime` object.
>>> DATEVALUE("1/1/2008")
datetime.datetime(2008, 1, 1, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> DATEVALUE("30-Jan-2008")
datetime.datetime(2008, 1, 30, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> DATEVALUE("2008-12-11")
datetime.datetime(2008, 12, 11, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> DATEVALUE("5-JUL").replace(year=2000)
datetime.datetime(2000, 7, 5, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
In case of ambiguity, prefer M/D/Y format.
>>> DATEVALUE("1/2/3")
datetime.datetime(2003, 1, 2, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
More tests:
>>> DATEVALUE("8/22/2011")
datetime.datetime(2011, 8, 22, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> DATEVALUE("22-MAY-2011")
datetime.datetime(2011, 5, 22, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> DATEVALUE("2011/02/23")
datetime.datetime(2011, 2, 23, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> DATEVALUE("11/3/2011")
datetime.datetime(2011, 11, 3, 0, 0, tzinfo=moment.tzinfo('America/New_York'))
>>> DATE_TO_XL(DATEVALUE("11/3/2011"))
40850.0
>>> DATEVALUE("asdf")
Traceback (most recent call last):
...
ValueError: Unknown string format
"""
return dateutil.parser.parse(date_string).replace(tzinfo=_get_tzinfo(tz))
def DAY(date):
"""
Returns the day of a date, as an integer ranging from 1 to 31. Same as `date.day`.
>>> DAY(DATE(2011, 4, 15))
15
>>> DAY("5/31/2012")
31
>>> DAY(datetime.datetime(1900, 1, 1))
1
"""
return _make_datetime(date).day
def DAYS(end_date, start_date):
"""
Returns the number of days between two dates. Same as `(end_date - start_date).days`.
>>> DAYS("3/15/11","2/1/11")
42
>>> DAYS(DATE(2011, 12, 31), DATE(2011, 1, 1))
364
>>> DAYS("2/1/11", "3/15/11")
-42
"""
return (_make_datetime(end_date) - _make_datetime(start_date)).days
def EDATE(start_date, months):
"""
Returns the date that is the given number of months before or after `start_date`. Use
EDATE to calculate maturity dates or due dates that fall on the same day of the month as the
date of issue.
>>> EDATE(DATE(2011, 1, 15), 1)
datetime.date(2011, 2, 15)
>>> EDATE(DATE(2011, 1, 15), -1)
datetime.date(2010, 12, 15)
>>> EDATE(DATE(2011, 1, 15), 2)
datetime.date(2011, 3, 15)
>>> EDATE(DATE(2012, 3, 1), 10)
datetime.date(2013, 1, 1)
>>> EDATE(DATE(2012, 5, 1), -2)
datetime.date(2012, 3, 1)
"""
return DATE(start_date.year, start_date.month + months, start_date.day)
def DATEADD(start_date, days=0, months=0, years=0, weeks=0):
"""
Returns the date a given number of days, months, years, or weeks away from `start_date`. You may
specify arguments in any order if you specify argument names. Use negative values to subtract.
For example, `DATEADD(date, 1)` is the same as `DATEADD(date, days=1)`, ands adds one day to
`date`. `DATEADD(date, years=1, days=-1)` adds one year minus one day.
>>> DATEADD(DATE(2011, 1, 15), 1)
datetime.date(2011, 1, 16)
>>> DATEADD(DATE(2011, 1, 15), months=1, days=-1)
datetime.date(2011, 2, 14)
>>> DATEADD(DATE(2011, 1, 15), years=-2, months=1, days=3, weeks=2)
datetime.date(2009, 3, 4)
>>> DATEADD(DATE(1975, 4, 30), years=50, weeks=-5)
datetime.date(2025, 3, 26)
"""
return DATE(start_date.year + years, start_date.month + months,
start_date.day + days + weeks * 7)
def EOMONTH(start_date, months):
"""
Returns the date for the last day of the month that is the indicated number of months before or
after start_date. Use EOMONTH to calculate maturity dates or due dates that fall on the last day
of the month.
>>> EOMONTH(DATE(2011, 1, 1), 1)
datetime.date(2011, 2, 28)
>>> EOMONTH(DATE(2011, 1, 15), -3)
datetime.date(2010, 10, 31)
>>> EOMONTH(DATE(2012, 3, 1), 10)
datetime.date(2013, 1, 31)
>>> EOMONTH(DATE(2012, 5, 1), -2)
datetime.date(2012, 3, 31)
"""
return DATE(start_date.year, start_date.month + months + 1, 1) - datetime.timedelta(days=1)
def HOUR(time):
"""
Returns the hour of a `datetime`, as an integer from 0 (12:00 A.M.) to 23 (11:00 P.M.).
Same as `time.hour`.
>>> HOUR(XL_TO_DATE(0.75))
18
>>> HOUR("7/18/2011 7:45")
7
>>> HOUR("4/21/2012")
0
"""
return _make_datetime(time).hour
def ISOWEEKNUM(date):
"""
Returns the ISO week number of the year for a given date.
>>> ISOWEEKNUM("3/9/2012")
10
>>> [ISOWEEKNUM(DATE(2000 + y, 1, 1)) for y in [0,1,2,3,4,5,6,7,8]]
[52, 1, 1, 1, 1, 53, 52, 1, 1]
"""
return _make_datetime(date).isocalendar()[1]
def MINUTE(time):
"""
Returns the minutes of `datetime`, as an integer from 0 to 59.
Same as `time.minute`.
>>> MINUTE(XL_TO_DATE(0.75))
0
>>> MINUTE("7/18/2011 7:45")
45
>>> MINUTE("12:59:00 PM")
59
>>> MINUTE(datetime.time(12, 58, 59))
58
"""
return _make_datetime(time).minute
def MONTH(date):
"""
Returns the month of a date represented, as an integer from from 1 (January) to 12 (December).
Same as `date.month`.
>>> MONTH(DATE(2011, 4, 15))
4
>>> MONTH("5/31/2012")
5
>>> MONTH(datetime.datetime(1900, 1, 1))
1
"""
return _make_datetime(date).month
def NOW(tz=None):
"""
Returns the `datetime` object for the current time.
"""
return datetime.datetime.now(_get_tzinfo(tz))
def SECOND(time):
"""
Returns the seconds of `datetime`, as an integer from 0 to 59.
Same as `time.second`.
>>> SECOND(XL_TO_DATE(0.75))
0
>>> SECOND("7/18/2011 7:45:13")
13
>>> SECOND(datetime.time(12, 58, 59))
59
"""
return _make_datetime(time).second
def TODAY():
"""
Returns the `date` object for the current date.
"""
return datetime.date.today()
_weekday_type_map = {
# type: (first day of week (according to date.weekday()), number to return for it)
1: (6, 1),
2: (0, 1),
3: (0, 0),
11: (0, 1),
12: (1, 1),
13: (2, 1),
14: (3, 1),
15: (4, 1),
16: (5, 1),
17: (6, 1),
}
def WEEKDAY(date, return_type=1):
"""
Returns the day of the week corresponding to a date. The day is given as an integer, ranging
from 1 (Sunday) to 7 (Saturday), by default.
Return_type determines the type of the returned value.
- 1 (default) - Returns 1 (Sunday) through 7 (Saturday).
- 2 - Returns 1 (Monday) through 7 (Sunday).
- 3 - Returns 0 (Monday) through 6 (Sunday).
- 11 - Returns 1 (Monday) through 7 (Sunday).
- 12 - Returns 1 (Tuesday) through 7 (Monday).
- 13 - Returns 1 (Wednesday) through 7 (Tuesday).
- 14 - Returns 1 (Thursday) through 7 (Wednesday).
- 15 - Returns 1 (Friday) through 7 (Thursday).
- 16 - Returns 1 (Saturday) through 7 (Friday).
- 17 - Returns 1 (Sunday) through 7 (Saturday).
>>> WEEKDAY(DATE(2008, 2, 14))
5
>>> WEEKDAY(DATE(2012, 3, 1))
5
>>> WEEKDAY(DATE(2012, 3, 1), 1)
5
>>> WEEKDAY(DATE(2012, 3, 1), 2)
4
>>> WEEKDAY("3/1/2012", 3)
3
More tests:
>>> WEEKDAY(XL_TO_DATE(10000), 1)
4
>>> WEEKDAY(DATE(1901, 1, 1))
3
>>> WEEKDAY(DATE(1901, 1, 1), 2)
2
>>> [WEEKDAY(DATE(2008, 2, d)) for d in [10, 11, 12, 13, 14, 15, 16, 17]]
[1, 2, 3, 4, 5, 6, 7, 1]
>>> [WEEKDAY(DATE(2008, 2, d), 1) for d in [10, 11, 12, 13, 14, 15, 16, 17]]
[1, 2, 3, 4, 5, 6, 7, 1]
>>> [WEEKDAY(DATE(2008, 2, d), 17) for d in [10, 11, 12, 13, 14, 15, 16, 17]]
[1, 2, 3, 4, 5, 6, 7, 1]
>>> [WEEKDAY(DATE(2008, 2, d), 2) for d in [10, 11, 12, 13, 14, 15, 16, 17]]
[7, 1, 2, 3, 4, 5, 6, 7]
>>> [WEEKDAY(DATE(2008, 2, d), 3) for d in [10, 11, 12, 13, 14, 15, 16, 17]]
[6, 0, 1, 2, 3, 4, 5, 6]
"""
if return_type not in _weekday_type_map:
raise ValueError("Invalid return type %s" % (return_type,))
(first, index) = _weekday_type_map[return_type]
return (_make_datetime(date).weekday() - first) % 7 + index
def WEEKNUM(date, return_type=1):
"""
Returns the week number of a specific date. For example, the week containing January 1 is the
first week of the year, and is numbered week 1.
Return_type determines which week is considered the first week of the year.
- 1 (default) - Week 1 is the first week starting Sunday that contains January 1.
- 2 - Week 1 is the first week starting Monday that contains January 1.
- 11 - Week 1 is the first week starting Monday that contains January 1.
- 12 - Week 1 is the first week starting Tuesday that contains January 1.
- 13 - Week 1 is the first week starting Wednesday that contains January 1.
- 14 - Week 1 is the first week starting Thursday that contains January 1.
- 15 - Week 1 is the first week starting Friday that contains January 1.
- 16 - Week 1 is the first week starting Saturday that contains January 1.
- 17 - Week 1 is the first week starting Sunday that contains January 1.
- 21 - ISO 8601 Approach: Week 1 is the first week starting Monday that contains January 4.
Equivalently, it is the week that contains the first Thursday of the year.
>>> WEEKNUM(DATE(2012, 3, 9))
10
>>> WEEKNUM(DATE(2012, 3, 9), 2)
11
>>> WEEKNUM('1/1/1900')
1
>>> WEEKNUM('2/1/1900')
5
More tests:
>>> WEEKNUM('2/1/1909', 2)
6
>>> WEEKNUM('1/1/1901', 21)
1
>>> [WEEKNUM(DATE(2012, 3, 9), t) for t in [1,2,11,12,13,14,15,16,17,21]]
[10, 11, 11, 11, 11, 11, 11, 10, 10, 10]
"""
if return_type == 21:
return ISOWEEKNUM(date)
if return_type not in _weekday_type_map:
raise ValueError("Invalid return type %s" % (return_type,))
(first, index) = _weekday_type_map[return_type]
date = _make_datetime(date)
jan1 = datetime.datetime(date.year, 1, 1)
week1_start = jan1 - datetime.timedelta(days=(jan1.weekday() - first) % 7)
return (date - week1_start).days // 7 + 1
def YEAR(date):
"""
Returns the year corresponding to a date as an integer.
Same as `date.year`.
>>> YEAR(DATE(2011, 4, 15))
2011
>>> YEAR("5/31/2030")
2030
>>> YEAR(datetime.datetime(1900, 1, 1))
1900
"""
return _make_datetime(date).year
def _date_360(y, m, d):
return y * 360 + m * 30 + d
def _last_of_feb(date):
return date.month == 2 and (date + datetime.timedelta(days=1)).month == 3
def YEARFRAC(start_date, end_date, basis=0):
"""
Calculates the fraction of the year represented by the number of whole days between two dates.
Basis is the type of day count basis to use.
* `0` (default) - US (NASD) 30/360
* `1` - Actual/actual
* `2` - Actual/360
* `3` - Actual/365
* `4` - European 30/360
* `-1` - Actual/actual (Google Sheets variation)
This function is useful for financial calculations. For compatibility with Excel, it defaults to
using the NASD standard calendar. For use in non-financial settings, option `-1` is
likely the best choice.
See <https://en.wikipedia.org/wiki/360-day_calendar> for explanation of
the US 30/360 and European 30/360 methods. See <http://www.dwheeler.com/yearfrac/> for analysis of
Excel's particular implementation.
Basis `-1` is similar to `1`, but differs from Excel when dates span both leap and non-leap years.
It matches the calculation in Google Sheets, counting the days in each year as a fraction of
that year's length.
Fraction of the year between 1/1/2012 and 7/30/12, omitting the Basis argument.
>>> "%.8f" % YEARFRAC(DATE(2012, 1, 1), DATE(2012, 7, 30))
'0.58055556'
Fraction between same dates, using the Actual/Actual basis argument. Because 2012 is a Leap
year, it has a 366 day basis.
>>> "%.8f" % YEARFRAC(DATE(2012, 1, 1), DATE(2012, 7, 30), 1)
'0.57650273'
Fraction between same dates, using the Actual/365 basis argument. Uses a 365 day basis.
>>> "%.8f" % YEARFRAC(DATE(2012, 1, 1), DATE(2012, 7, 30), 3)
'0.57808219'
More tests:
>>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30)), 10)
0.4972222222
>>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 0), 10)
0.4972222222
>>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 1), 10)
0.4945355191
>>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 2), 10)
0.5027777778
>>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 3), 10)
0.495890411
>>> round(YEARFRAC(DATE(2012, 1, 1), DATE(2012, 6, 30), 4), 10)
0.4972222222
>>> [YEARFRAC(DATE(2012, 1, 1), DATE(2012, 1, 1), t) for t in [0, 1, -1, 2, 3, 4]]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>>> [round(YEARFRAC(DATE(1985, 3, 15), DATE(2016, 2, 29), t), 6) for t in [0, 1, -1, 2, 3, 4]]
[30.955556, 30.959617, 30.961202, 31.411111, 30.980822, 30.955556]
>>> [round(YEARFRAC(DATE(2001, 2, 28), DATE(2016, 3, 31), t), 6) for t in [0, 1, -1, 2, 3, 4]]
[15.086111, 15.085558, 15.086998, 15.305556, 15.09589, 15.088889]
>>> [round(YEARFRAC(DATE(1968, 4, 7), DATE(2011, 2, 14), t), 6) for t in [0, 1, -1, 2, 3, 4]]
[42.852778, 42.855578, 42.855521, 43.480556, 42.884932, 42.852778]
Here we test "basis 1" on leap and non-leap years.
>>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2015, 3, 1), t), 6) for t in [1, -1]]
[0.161644, 0.161644]
>>> [round(YEARFRAC(DATE(2016, 1, 1), DATE(2016, 3, 1), t), 6) for t in [1, -1]]
[0.163934, 0.163934]
>>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2016, 1, 1), t), 6) for t in [1, -1]]
[1.0, 1.0]
>>> [round(YEARFRAC(DATE(2016, 1, 1), DATE(2017, 1, 1), t), 6) for t in [1, -1]]
[1.0, 1.0]
>>> [round(YEARFRAC(DATE(2016, 2, 29), DATE(2017, 1, 1), t), 6) for t in [1, -1]]
[0.838798, 0.838798]
>>> [round(YEARFRAC(DATE(2014, 12, 15), DATE(2015, 3, 15), t), 6) for t in [1, -1]]
[0.246575, 0.246575]
For these examples, Google Sheets differs from Excel, and we match Excel here.
>>> [round(YEARFRAC(DATE(2015, 12, 15), DATE(2016, 3, 15), t), 6) for t in [1, -1]]
[0.248634, 0.248761]
>>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2016, 2, 29), t), 6) for t in [1, -1]]
[1.160055, 1.161202]
>>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2016, 2, 28), t), 6) for t in [1, -1]]
[1.157319, 1.15847]
>>> [round(YEARFRAC(DATE(2015, 3, 1), DATE(2016, 2, 29), t), 6) for t in [1, -1]]
[0.997268, 0.999558]
>>> [round(YEARFRAC(DATE(2015, 3, 1), DATE(2016, 2, 28), t), 6) for t in [1, -1]]
[0.99726, 0.996826]
>>> [round(YEARFRAC(DATE(2016, 3, 1), DATE(2017, 1, 1), t), 6) for t in [1, -1]]
[0.838356, 0.836066]
>>> [round(YEARFRAC(DATE(2015, 1, 1), DATE(2017, 1, 1), t), 6) for t in [1, -1]]
[2.000912, 2.0]
"""
# pylint: disable=too-many-return-statements
# This function is actually completely crazy. The rules are strange too. We'll follow the logic
# in http://www.dwheeler.com/yearfrac/excel-ooxml-yearfrac.pdf
if start_date == end_date:
return 0.0
if start_date > end_date:
start_date, end_date = end_date, start_date
d1, m1, y1 = start_date.day, start_date.month, start_date.year
d2, m2, y2 = end_date.day, end_date.month, end_date.year
if basis == 0:
if d1 == 31:
d1 = 30
if d1 == 30 and d2 == 31:
d2 = 30
if _last_of_feb(start_date):
d1 = 30
if _last_of_feb(end_date):
d2 = 30
return (_date_360(y2, m2, d2) - _date_360(y1, m1, d1)) / 360.0
elif basis == 1:
# This implements Excel's convoluted logic.
if (y1 + 1, m1, d1) >= (y2, m2, d2):
# Less than or equal to one year.
if y1 == y2 and calendar.isleap(y1):
year_length = 366.0
elif (y1, m1, d1) < (y2, 2, 29) <= (y2, m2, d2) and calendar.isleap(y2):
year_length = 366.0
elif (y1, m1, d1) <= (y1, 2, 29) < (y2, m2, d2) and calendar.isleap(y1):
year_length = 366.0
else:
year_length = 365.0
else:
year_length = (datetime.date(y2 + 1, 1, 1) - datetime.date(y1, 1, 1)).days / (y2 + 1.0 - y1)
return (end_date - start_date).days / year_length
elif basis == -1:
# This is Google Sheets implementation. Call it an overkill, but I think it's more sensible.
#
# Excel's logic has the unfortunate property that YEARFRAC(a, b) + YEARFRAC(b, c) is not
# always equal to YEARFRAC(a, c). Google Sheets implements a variation that does have this
# property, counting the days in each year as a fraction of that year's length (as if each day
# is counted as 1/365 or 1/366 depending on the year).
#
# The one redeeming quality of Excel's logic is that YEARFRAC for two days that differ by
# exactly one year is 1.0 (not always true for GS). But in GS version, YEARFRAC between any
# two Jan 1 is always a whole number (not always true in Excel).
if y1 == y2:
return _one_year_frac(start_date, end_date)
return (
+ _one_year_frac(start_date, datetime.date(y1 + 1, 1, 1))
+ (y2 - y1 - 1)
+ _one_year_frac(datetime.date(y2, 1, 1), end_date)
)
elif basis == 2:
return (end_date - start_date).days / 360.0
elif basis == 3:
return (end_date - start_date).days / 365.0
elif basis == 4:
if d1 == 31:
d1 = 30
if d2 == 31:
d2 = 30
return (_date_360(y2, m2, d2) - _date_360(y1, m1, d1)) / 360.0
raise ValueError('Invalid basis argument %r' % (basis,))
def _one_year_frac(start_date, end_date):
year_length = 366.0 if calendar.isleap(start_date.year) else 365.0
return (end_date - start_date).days / year_length

@ -0,0 +1,520 @@
# -*- coding: UTF-8 -*-
# pylint: disable=unused-argument
from __future__ import absolute_import
import datetime
import math
import numbers
import re
from functions import date # pylint: disable=import-error
from usertypes import AltText # pylint: disable=import-error
from records import Record # pylint: disable=import-error
def ISBLANK(value):
"""
Returns whether a value refers to an empty cell. It isn't implemented in Grist. To check for an
empty string, use `value == ""`.
"""
raise NotImplementedError()
def ISERR(value):
"""
Checks whether a value is an error. In other words, it returns true
if using `value` directly would raise an exception.
NOTE: Grist implements this by automatically wrapping the argument to use lazy evaluation.
A more Pythonic approach to checking for errors is:
```
try:
... value ...
except Exception, err:
... do something about the error ...
```
For example:
>>> ISERR("Hello")
False
More tests:
>>> ISERR(lambda: (1/0.1))
False
>>> ISERR(lambda: (1/0.0))
True
>>> ISERR(lambda: "test".bar())
True
>>> ISERR(lambda: "test".upper())
False
>>> ISERR(lambda: AltText("A"))
False
>>> ISERR(lambda: float('nan'))
False
>>> ISERR(lambda: None)
False
"""
return lazy_value_or_error(value) is _error_sentinel
def ISERROR(value):
"""
Checks whether a value is an error or an invalid value. It is similar to `ISERR`, but also
returns true for an invalid value such as NaN or a text value in a Numeric column.
NOTE: Grist implements this by automatically wrapping the argument to use lazy evaluation.
>>> ISERROR("Hello")
False
>>> ISERROR(AltText("fail"))
True
>>> ISERROR(float('nan'))
True
More tests:
>>> ISERROR(AltText(""))
True
>>> [ISERROR(v) for v in [0, None, "", "Test", 17.0]]
[False, False, False, False, False]
>>> ISERROR(lambda: (1/0.1))
False
>>> ISERROR(lambda: (1/0.0))
True
>>> ISERROR(lambda: "test".bar())
True
>>> ISERROR(lambda: "test".upper())
False
>>> ISERROR(lambda: AltText("A"))
True
>>> ISERROR(lambda: float('nan'))
True
>>> ISERROR(lambda: None)
False
"""
return is_error(lazy_value_or_error(value))
def ISLOGICAL(value):
"""
Checks whether a value is `True` or `False`.
>>> ISLOGICAL(True)
True
>>> ISLOGICAL(False)
True
>>> ISLOGICAL(0)
False
>>> ISLOGICAL(None)
False
>>> ISLOGICAL("Test")
False
"""
return isinstance(value, bool)
def ISNA(value):
"""
Checks whether a value is the error `#N/A`.
>>> ISNA(float('nan'))
True
>>> ISNA(0.0)
False
>>> ISNA('text')
False
>>> ISNA(float('-inf'))
False
"""
return isinstance(value, float) and math.isnan(value)
def ISNONTEXT(value):
"""
Checks whether a value is non-textual.
>>> ISNONTEXT("asdf")
False
>>> ISNONTEXT("")
False
>>> ISNONTEXT(AltText("text"))
False
>>> ISNONTEXT(17.0)
True
>>> ISNONTEXT(None)
True
>>> ISNONTEXT(datetime.date(2011, 1, 1))
True
"""
return not ISTEXT(value)
def ISNUMBER(value):
"""
Checks whether a value is a number.
>>> ISNUMBER(17)
True
>>> ISNUMBER(-123.123423)
True
>>> ISNUMBER(False)
True
>>> ISNUMBER(float('nan'))
True
>>> ISNUMBER(float('inf'))
True
>>> ISNUMBER('17')
False
>>> ISNUMBER(None)
False
>>> ISNUMBER(datetime.date(2011, 1, 1))
False
More tests:
>>> ISNUMBER(AltText("text"))
False
>>> ISNUMBER('')
False
"""
return isinstance(value, numbers.Number)
def ISREF(value):
"""
Checks whether a value is a table record.
For example, if a column person is of type Reference to the People table, then ISREF($person)
is True.
Similarly, ISREF(People.lookupOne(name=$name)) is True. For any other type of value,
ISREF() would evaluate to False.
>>> ISREF(17)
False
>>> ISREF("Roger")
False
"""
return isinstance(value, Record)
def ISTEXT(value):
"""
Checks whether a value is text.
>>> ISTEXT("asdf")
True
>>> ISTEXT("")
True
>>> ISTEXT(AltText("text"))
True
>>> ISTEXT(17.0)
False
>>> ISTEXT(None)
False
>>> ISTEXT(datetime.date(2011, 1, 1))
False
"""
return isinstance(value, (basestring, AltText))
# Regexp for matching email. See ISEMAIL for justification.
_email_regexp = re.compile(
r"""
^\w # Start with an alphanumeric character
[\w%+/='-]* (\.[\w%+/='-]+)* # Elsewhere allow also a few other special characters
# But no two consecutive periods
@
([A-Za-z0-9] # Each part of hostname must start with alphanumeric
([A-Za-z0-9-]*[A-Za-z0-9])?\. # May have dashes inside, but end in alphanumeric
)+
[A-Za-z]{2,6}$ # Restrict top-level domain to length {2,6}. Google seems
# to use a whitelist for TLDs longer than 2 characters.
""", re.UNICODE | re.VERBOSE)
# Regexp for matching hostname part of URLs (see also ISURL). Duplicates part of _email_regexp.
_hostname_regexp = re.compile(
r"""^
([A-Za-z0-9] # Each part of hostname must start with alphanumeric
([A-Za-z0-9-]*[A-Za-z0-9])?\. # May have dashes inside, but end in alphanumeric
)+
[A-Za-z]{2,6}$ # Restrict top-level domain to length {2,6}. Google seems
""", re.VERBOSE)
def ISEMAIL(value):
u"""
Returns whether a value is a valid email address.
Note that checking email validity is not an exact science. The technical standard considers many
email addresses valid that are not used in practice, and would not be considered valid by most
users. Instead, we follow Google Sheets implementation, with some differences, noted below.
>>> ISEMAIL("Abc.123@example.com")
True
>>> ISEMAIL("Bob_O-Reilly+tag@example.com")
True
>>> ISEMAIL("John Doe")
False
>>> ISEMAIL("john@aol...com")
False
More tests:
>>> ISEMAIL("Abc@example.com") # True, True
True
>>> ISEMAIL("Abc.123@example.com") # True, True
True
>>> ISEMAIL("foo@bar.com") # True, True
True
>>> ISEMAIL("asdf@com.zt") # True, True
True
>>> ISEMAIL("Bob_O-Reilly+tag@example.com") # True, True
True
>>> ISEMAIL("john@server.department.company.com") # True, True
True
>>> ISEMAIL("asdf@mail.ru") # True, True
True
>>> ISEMAIL("fabio@foo.qwer.COM") # True, True
True
>>> ISEMAIL("user+mailbox/department=shipping@example.com") # False, True
True
>>> ISEMAIL(u"user+mailbox/department=shipping@example.com") # False, True
True
>>> ISEMAIL("customer/department=shipping@example.com") # False, True
True
>>> ISEMAIL("Bob_O'Reilly+tag@example.com") # False, True
True
>>> ISEMAIL(u"фыва@mail.ru") # False, True
True
>>> ISEMAIL("my@baddash.-.com") # True, False
False
>>> ISEMAIL("my@baddash.-a.com") # True, False
False
>>> ISEMAIL("my@baddash.b-.com") # True, False
False
>>> ISEMAIL("john@-.com") # True, False
False
>>> ISEMAIL("fabio@disapproved.solutions") # False, False
False
>>> ISEMAIL("!def!xyz%abc@example.com") # False, False
False
>>> ISEMAIL("!#$%&'*+-/=?^_`.{|}~@example.com") # False, False
False
>>> ISEMAIL(u"伊昭傑@郵件.商務") # False, False
False
>>> ISEMAIL(u"राम@मोहन.ईन्फो") # False, Fale
False
>>> ISEMAIL(u"юзер@екзампл.ком") # False, False
False
>>> ISEMAIL(u"θσερ@εχαμπλε.ψομ") # False, False
False
>>> ISEMAIL(u"葉士豪@臺網中心.tw") # False, False
False
>>> ISEMAIL(u"jeff@臺網中心.tw") # False, False
False
>>> ISEMAIL(u"葉士豪@臺網中心.台灣") # False, False
False
>>> ISEMAIL(u"jeff葉@臺網中心.tw") # False, False
False
>>> ISEMAIL("myname@domain.com") # False, False
False
>>> ISEMAIL("my.name@domaincom") # False, False
False
>>> ISEMAIL("my@.leadingdot.com") # False, False
False
>>> ISEMAIL("my@leadingfwdot.com") # False, False
False
>>> ISEMAIL("my@..twodots.com") # False, False
False
>>> ISEMAIL("my@twodots..com") # False, False
False
>>> ISEMAIL(".leadingdot@domain.com") # False, False
False
>>> ISEMAIL("..twodots@domain.com") # False, False
False
>>> ISEMAIL("twodots..here@domain.com") # False, False
False
>>> ISEMAIL("me@⒈wouldbeinvalid.com") # False, False
False
>>> ISEMAIL("Foo Bar <a+2asdf@qwer.bar.com>") # False, False
False
>>> ISEMAIL("Abc\\@def@example.com") # False, False
False
>>> ISEMAIL("foo@bar@google.com") # False, False
False
>>> ISEMAIL("john@aol...com") # False, False
False
>>> ISEMAIL("x@ทีเอชนิค.ไทย") # False, False
False
>>> ISEMAIL("asdf@mail") # False, False
False
>>> ISEMAIL("example@良好Mail.中国") # False, False
False
"""
return bool(_email_regexp.match(value))
_url_regexp = re.compile(
r"""^
((ftp|http|https|gopher|mailto|news|telnet|aim)://)?
(\w+@)? # Allow 'user@' part, esp. useful for mailto: URLs.
([A-Za-z0-9] # Each part of hostname must start with alphanumeric
([A-Za-z0-9-]*[A-Za-z0-9])?\. # May have dashes inside, but end in alphanumeric
)+
[A-Za-z]{2,6} # Restrict top-level domain to length {2,6}. Google seems
# to use a whitelist for TLDs longer than 2 characters.
([/?][-\w!#$%&'()*+,./:;=?@~]*)?$ # Notably, this excludes <, >, and ".
""", re.VERBOSE)
def ISURL(value):
"""
Checks whether a value is a valid URL. It does not need to be fully qualified, or to include
"http://" and "www". It does not follow a standard, but attempts to work similarly to ISURL in
Google Sheets, and to return True for text that is likely a URL.
Valid protocols include ftp, http, https, gopher, mailto, news, telnet, and aim.
>>> ISURL("http://www.getgrist.com")
True
>>> ISURL("https://foo.com/test_(wikipedia)#cite-1")
True
>>> ISURL("mailto://user@example.com")
True
>>> ISURL("http:///a")
False
More tests:
>>> ISURL("http://www.google.com")
True
>>> ISURL("www.google.com/")
True
>>> ISURL("google.com")
True
>>> ISURL("http://a.b-c.de")
True
>>> ISURL("a.b-c.de")
True
>>> ISURL("http://j.mp/---")
True
>>> ISURL("ftp://foo.bar/baz")
True
>>> ISURL("https://foo.com/blah_(wikipedia)#cite-1")
True
>>> ISURL("mailto://user@google.com")
True
>>> ISURL("http://user@www.google.com")
True
>>> ISURL("http://foo.com/!#$%25&'()*+,-./=?@_~")
True
>>> ISURL("http://../")
False
>>> ISURL("http://??/")
False
>>> ISURL("a.-b.cd")
False
>>> ISURL("http://foo.bar?q=Spaces should be encoded ")
False
>>> ISURL("//")
False
>>> ISURL("///a")
False
>>> ISURL("http:///a")
False
>>> ISURL("bar://www.google.com")
False
>>> ISURL("http:// shouldfail.com")
False
>>> ISURL("ftps://foo.bar/")
False
>>> ISURL("http://-error-.invalid/")
False
>>> ISURL("http://0.0.0.0")
False
>>> ISURL("http://.www.foo.bar/")
False
>>> ISURL("http://.www.foo.bar./")
False
>>> ISURL("example.com/file[/].html")
False
>>> ISURL("http://example.com/file[/].html")
False
>>> ISURL("http://mw1.google.com/kml-samples/gp/seattle/gigapxl/$[level]/r$[y]_c$[x].jpg")
False
>>> ISURL("http://foo.com/>")
False
"""
value = value.strip()
if ' ' in value: # Disallow spaces inside value.
return False
return bool(_url_regexp.match(value))
def N(value):
"""
Returns the value converted to a number. True/False are converted to 1/0. A date is converted to
Excel-style serial number of the date. Anything else is converted to 0.
>>> N(7)
7
>>> N(7.1)
7.1
>>> N("Even")
0
>>> N("7")
0
>>> N(True)
1
>>> N(datetime.datetime(2011, 4, 17))
40650.0
"""
if ISNUMBER(value):
return value
if isinstance(value, datetime.date):
return date.DATE_TO_XL(value)
return 0
def NA():
"""
Returns the "value not available" error, `#N/A`.
>>> math.isnan(NA())
True
"""
return float('nan')
def TYPE(value):
"""
Returns a number associated with the type of data passed into the function. This is not
implemented in Grist. Use `isinstance(value, type)` or `type(value)`.
"""
raise NotImplementedError()
def CELL(info_type, reference):
"""
Returns the requested information about the specified cell. This is not implemented in Grist
"""
raise NotImplementedError()
# Unique sentinel value to represent that a lazy value evaluates with an exception.
_error_sentinel = object()
def lazy_value_or_error(value):
"""
Evaluates a value like lazy_value(), but returns _error_sentinel on exception.
"""
try:
return value() if callable(value) else value
except Exception:
return _error_sentinel
def is_error(value):
"""
Checks whether a value is an invalid value or _error_sentinel.
"""
return ((value is _error_sentinel)
or isinstance(value, AltText)
or (isinstance(value, float) and math.isnan(value)))

@ -0,0 +1,165 @@
from info import lazy_value_or_error, is_error
from usertypes import AltText # pylint: disable=unused-import,import-error
def AND(logical_expression, *logical_expressions):
"""
Returns True if all of the arguments are logically true, and False if any are false.
Same as `all([value1, value2, ...])`.
>>> AND(1)
True
>>> AND(0)
False
>>> AND(1, 1)
True
>>> AND(1,2,3,4)
True
>>> AND(1,2,3,4,0)
False
"""
return all((logical_expression,) + logical_expressions)
def FALSE():
"""
Returns the logical value `False`. You may also use the value `False` directly. This
function is provided primarily for compatibility with other spreadsheet programs.
>>> FALSE()
False
"""
return False
def IF(logical_expression, value_if_true, value_if_false):
"""
Returns one value if a logical expression is `True` and another if it is `False`.
The equivalent Python expression is:
```
value_if_true if logical_expression else value_if_false
```
Since Grist supports multi-line formulas, you may also use Python blocks such as:
```
if logical_expression:
return value_if_true
else:
return value_if_false
```
NOTE: Grist follows Excel model by only evaluating one of the value expressions, by
automatically wrapping the expressions to use lazy evaluation. This allows `IF(False, 1/0, 1)`
to evaluate to `1` rather than raise an exception.
>>> IF(12, "Yes", "No")
'Yes'
>>> IF(None, "Yes", "No")
'No'
>>> IF(True, 0.85, 0.0)
0.85
>>> IF(False, 0.85, 0.0)
0.0
More tests:
>>> IF(True, lambda: (1/0), lambda: (17))
Traceback (most recent call last):
...
ZeroDivisionError: integer division or modulo by zero
>>> IF(False, lambda: (1/0), lambda: (17))
17
"""
return lazy_value(value_if_true) if logical_expression else lazy_value(value_if_false)
def IFERROR(value, value_if_error=""):
"""
Returns the first argument if it is not an error value, otherwise returns the second argument if
present, or a blank if the second argument is absent.
NOTE: Grist handles values that raise an exception by wrapping them to use lazy evaluation.
>>> IFERROR(float('nan'), "**NAN**")
'**NAN**'
>>> IFERROR(17.17, "**NAN**")
17.17
>>> IFERROR("Text")
'Text'
>>> IFERROR(AltText("hello"))
''
More tests:
>>> IFERROR(lambda: (1/0.1), "X")
10.0
>>> IFERROR(lambda: (1/0.0), "X")
'X'
>>> IFERROR(lambda: AltText("A"), "err")
'err'
>>> IFERROR(lambda: None, "err")
>>> IFERROR(lambda: foo.bar, 123)
123
>>> IFERROR(lambda: "test".bar(), 123)
123
>>> IFERROR(lambda: "test".bar())
''
>>> IFERROR(lambda: "test".upper(), 123)
'TEST'
"""
value = lazy_value_or_error(value)
return value if not is_error(value) else value_if_error
def NOT(logical_expression):
"""
Returns the opposite of a logical value: `NOT(True)` returns `False`; `NOT(False)` returns
`True`. Same as `not logical_expression`.
>>> NOT(123)
False
>>> NOT(0)
True
"""
return not logical_expression
def OR(logical_expression, *logical_expressions):
"""
Returns True if any of the arguments is logically true, and false if all of the
arguments are false.
Same as `any([value1, value2, ...])`.
>>> OR(1)
True
>>> OR(0)
False
>>> OR(1, 1)
True
>>> OR(0, 1)
True
>>> OR(0, 0)
False
>>> OR(0,False,0.0,"",None)
False
>>> OR(0,None,3,0)
True
"""
return any((logical_expression,) + logical_expressions)
def TRUE():
"""
Returns the logical value `True`. You may also use the value `True` directly. This
function is provided primarily for compatibility with other spreadsheet programs.
>>> TRUE()
True
"""
return True
def lazy_value(value):
"""
Evaluates a lazy value by calling it when it's a callable, or returns it unchanged otherwise.
"""
return value() if callable(value) else value

@ -0,0 +1,80 @@
# pylint: disable=redefined-builtin, line-too-long
def ADDRESS(row, column, absolute_relative_mode, use_a1_notation, sheet):
"""Returns a cell reference as a string."""
raise NotImplementedError()
def CHOOSE(index, choice1, choice2):
"""Returns an element from a list of choices based on index."""
raise NotImplementedError()
def COLUMN(cell_reference=None):
"""Returns the column number of a specified cell, with `A=1`."""
raise NotImplementedError()
def COLUMNS(range):
"""Returns the number of columns in a specified array or range."""
raise NotImplementedError()
def GETPIVOTDATA(value_name, any_pivot_table_cell, original_column_1, pivot_item_1=None, *args):
"""Extracts an aggregated value from a pivot table that corresponds to the specified row and column headings."""
raise NotImplementedError()
def HLOOKUP(search_key, range, index, is_sorted):
"""Horizontal lookup. Searches across the first row of a range for a key and returns the value of a specified cell in the column found."""
raise NotImplementedError()
def HYPERLINK(url, link_label):
"""Creates a hyperlink inside a cell."""
raise NotImplementedError()
def INDEX(reference, row, column):
"""Returns the content of a cell, specified by row and column offset."""
raise NotImplementedError()
def INDIRECT(cell_reference_as_string):
"""Returns a cell reference specified by a string."""
raise NotImplementedError()
def LOOKUP(search_key, search_range_or_search_result_array, result_range=None):
"""Looks through a row or column for a key and returns the value of the cell in a result range located in the same position as the search row or column."""
raise NotImplementedError()
def MATCH(search_key, range, search_type):
"""Returns the relative position of an item in a range that matches a specified value."""
raise NotImplementedError()
def OFFSET(cell_reference, offset_rows, offset_columns, height, width):
"""Returns a range reference shifted a specified number of rows and columns from a starting cell reference."""
raise NotImplementedError()
def ROW(cell_reference):
"""Returns the row number of a specified cell."""
raise NotImplementedError()
def ROWS(range):
"""Returns the number of rows in a specified array or range."""
raise NotImplementedError()
def VLOOKUP(table, **field_value_pairs):
"""
Vertical lookup. Searches the given table for a record matching the given `field=value`
arguments. If multiple records match, returns one of them. If none match, returns the special
empty record.
The returned object is a record whose fields are available using `.field` syntax. For example,
`VLOOKUP(Employees, EmployeeID=$EmpID).Salary`.
Note that `VLOOKUP` isn't commonly needed in Grist, since [Reference columns](col-refs) are the
best way to link data between tables, and allow simple efficient usage such as `$Person.Age`.
`VLOOKUP` is exactly quivalent to `table.lookupOne(**field_value_pairs)`. See
[lookupOne](#lookupone).
For example:
```
VLOOKUP(People, First_Name="Lewis", Last_Name="Carroll")
VLOOKUP(People, First_Name="Lewis", Last_Name="Carroll").Age
```
"""
return table.lookupOne(**field_value_pairs)

@ -0,0 +1,830 @@
# pylint: disable=unused-argument
from __future__ import absolute_import
import itertools
import math as _math
import operator
import random
from functions.info import ISNUMBER, ISLOGICAL
import roman
# Iterates through elements of iterable arguments, or through individual args when not iterable.
def _chain(*values_or_iterables):
for v in values_or_iterables:
try:
for x in v:
yield x
except TypeError:
yield v
# Iterates through iterable or other arguments, skipping non-numeric ones.
def _chain_numeric(*values_or_iterables):
for v in _chain(*values_or_iterables):
if ISNUMBER(v) and not ISLOGICAL(v):
yield v
# Iterates through iterable or other arguments, replacing non-numeric ones with 0 (or True with 1).
def _chain_numeric_a(*values_or_iterables):
for v in _chain(*values_or_iterables):
yield int(v) if ISLOGICAL(v) else v if ISNUMBER(v) else 0
def _round_toward_zero(value):
return _math.floor(value) if value >= 0 else _math.ceil(value)
def _round_away_from_zero(value):
return _math.ceil(value) if value >= 0 else _math.floor(value)
def ABS(value):
"""
Returns the absolute value of a number.
>>> ABS(2)
2
>>> ABS(-2)
2
>>> ABS(-4)
4
"""
return abs(value)
def ACOS(value):
"""
Returns the inverse cosine of a value, in radians.
>>> round(ACOS(-0.5), 9)
2.094395102
>>> round(ACOS(-0.5)*180/PI(), 10)
120.0
"""
return _math.acos(value)
def ACOSH(value):
"""
Returns the inverse hyperbolic cosine of a number.
>>> ACOSH(1)
0.0
>>> round(ACOSH(10), 7)
2.9932228
"""
return _math.acosh(value)
def ARABIC(roman_numeral):
"""
Computes the value of a Roman numeral.
>>> ARABIC("LVII")
57
>>> ARABIC('mcmxii')
1912
"""
return roman.fromRoman(roman_numeral.upper())
def ASIN(value):
"""
Returns the inverse sine of a value, in radians.
>>> round(ASIN(-0.5), 9)
-0.523598776
>>> round(ASIN(-0.5)*180/PI(), 10)
-30.0
>>> round(DEGREES(ASIN(-0.5)), 10)
-30.0
"""
return _math.asin(value)
def ASINH(value):
"""
Returns the inverse hyperbolic sine of a number.
>>> round(ASINH(-2.5), 9)
-1.647231146
>>> round(ASINH(10), 9)
2.99822295
"""
return _math.asinh(value)
def ATAN(value):
"""
Returns the inverse tangent of a value, in radians.
>>> round(ATAN(1), 9)
0.785398163
>>> ATAN(1)*180/PI()
45.0
>>> DEGREES(ATAN(1))
45.0
"""
return _math.atan(value)
def ATAN2(x, y):
"""
Returns the angle between the x-axis and a line segment from the origin (0,0) to specified
coordinate pair (`x`,`y`), in radians.
>>> round(ATAN2(1, 1), 9)
0.785398163
>>> round(ATAN2(-1, -1), 9)
-2.35619449
>>> ATAN2(-1, -1)*180/PI()
-135.0
>>> DEGREES(ATAN2(-1, -1))
-135.0
>>> round(ATAN2(1,2), 9)
1.107148718
"""
return _math.atan2(y, x)
def ATANH(value):
"""
Returns the inverse hyperbolic tangent of a number.
>>> round(ATANH(0.76159416), 9)
1.00000001
>>> round(ATANH(-0.1), 9)
-0.100335348
"""
return _math.atanh(value)
def CEILING(value, factor=1):
"""
Rounds a number up to the nearest multiple of factor, or the nearest integer if the factor is
omitted or 1.
>>> CEILING(2.5, 1)
3
>>> CEILING(-2.5, -2)
-4
>>> CEILING(-2.5, 2)
-2
>>> CEILING(1.5, 0.1)
1.5
>>> CEILING(0.234, 0.01)
0.24
"""
return int(_math.ceil(float(value) / factor)) * factor
def COMBIN(n, k):
"""
Returns the number of ways to choose some number of objects from a pool of a given size of
objects.
>>> COMBIN(8,2)
28
>>> COMBIN(4,2)
6
>>> COMBIN(10,7)
120
"""
# From http://stackoverflow.com/a/4941932/328565
k = min(k, n-k)
if k == 0:
return 1
numer = reduce(operator.mul, xrange(n, n-k, -1))
denom = reduce(operator.mul, xrange(1, k+1))
return numer//denom
def COS(angle):
"""
Returns the cosine of an angle provided in radians.
>>> round(COS(1.047), 7)
0.5001711
>>> round(COS(60*PI()/180), 10)
0.5
>>> round(COS(RADIANS(60)), 10)
0.5
"""
return _math.cos(angle)
def COSH(value):
"""
Returns the hyperbolic cosine of any real number.
>>> round(COSH(4), 6)
27.308233
>>> round(COSH(EXP(1)), 7)
7.6101251
"""
return _math.cosh(value)
def DEGREES(angle):
"""
Converts an angle value in radians to degrees.
>>> round(DEGREES(ACOS(-0.5)), 10)
120.0
>>> DEGREES(PI())
180.0
"""
return _math.degrees(angle)
def EVEN(value):
"""
Rounds a number up to the nearest even integer, rounding away from zero.
>>> EVEN(1.5)
2
>>> EVEN(3)
4
>>> EVEN(2)
2
>>> EVEN(-1)
-2
"""
return int(_round_away_from_zero(float(value) / 2)) * 2
def EXP(exponent):
"""
Returns Euler's number, e (~2.718) raised to a power.
>>> round(EXP(1), 8)
2.71828183
>>> round(EXP(2), 7)
7.3890561
"""
return _math.exp(exponent)
def FACT(value):
"""
Returns the factorial of a number.
>>> FACT(5)
120
>>> FACT(1.9)
1
>>> FACT(0)
1
>>> FACT(1)
1
>>> FACT(-1)
Traceback (most recent call last):
...
ValueError: factorial() not defined for negative values
"""
return _math.factorial(int(value))
def FACTDOUBLE(value):
"""
Returns the "double factorial" of a number.
>>> FACTDOUBLE(6)
48
>>> FACTDOUBLE(7)
105
>>> FACTDOUBLE(3)
3
>>> FACTDOUBLE(4)
8
"""
return reduce(operator.mul, xrange(value, 1, -2))
def FLOOR(value, factor=1):
"""
Rounds a number down to the nearest integer multiple of specified significance.
>>> FLOOR(3.7,2)
2
>>> FLOOR(-2.5,-2)
-2
>>> FLOOR(2.5,-2)
Traceback (most recent call last):
...
ValueError: factor argument invalid
>>> FLOOR(1.58,0.1)
1.5
>>> FLOOR(0.234,0.01)
0.23
"""
if (factor < 0) != (value < 0):
raise ValueError("factor argument invalid")
return int(_math.floor(float(value) / factor)) * factor
def _gcd(a, b):
while a != 0:
if a > b:
a, b = b, a
a, b = b % a, a
return b
def GCD(value1, *more_values):
"""
Returns the greatest common divisor of one or more integers.
>>> GCD(5, 2)
1
>>> GCD(24, 36)
12
>>> GCD(7, 1)
1
>>> GCD(5, 0)
5
>>> GCD(0, 5)
5
>>> GCD(5)
5
>>> GCD(14, 42, 21)
7
"""
values = [v for v in (value1,) + more_values if v]
if not values:
return 0
if any(v < 0 for v in values):
raise ValueError("gcd requires non-negative values")
return reduce(_gcd, map(int, values))
def INT(value):
"""
Rounds a number down to the nearest integer that is less than or equal to it.
>>> INT(8.9)
8
>>> INT(-8.9)
-9
>>> 19.5-INT(19.5)
0.5
"""
return int(_math.floor(value))
def _lcm(a, b):
return a * b / _gcd(a, b)
def LCM(value1, *more_values):
"""
Returns the least common multiple of one or more integers.
>>> LCM(5, 2)
10
>>> LCM(24, 36)
72
>>> LCM(0, 5)
0
>>> LCM(5)
5
>>> LCM(10, 100)
100
>>> LCM(12, 18)
36
>>> LCM(12, 18, 24)
72
"""
values = (value1,) + more_values
if any(v < 0 for v in values):
raise ValueError("gcd requires non-negative values")
if any(v == 0 for v in values):
return 0
return reduce(_lcm, map(int, values))
def LN(value):
"""
Returns the the logarithm of a number, base e (Euler's number).
>>> round(LN(86), 7)
4.4543473
>>> round(LN(2.7182818), 7)
1.0
>>> round(LN(EXP(3)), 10)
3.0
"""
return _math.log(value)
def LOG(value, base=10):
"""
Returns the the logarithm of a number given a base.
>>> LOG(10)
1.0
>>> LOG(8, 2)
3.0
>>> round(LOG(86, 2.7182818), 7)
4.4543473
"""
return _math.log(value, base)
def LOG10(value):
"""
Returns the the logarithm of a number, base 10.
>>> round(LOG10(86), 9)
1.934498451
>>> LOG10(10)
1.0
>>> LOG10(100000)
5.0
>>> LOG10(10**5)
5.0
"""
return _math.log10(value)
def MOD(dividend, divisor):
"""
Returns the result of the modulo operator, the remainder after a division operation.
>>> MOD(3, 2)
1
>>> MOD(-3, 2)
1
>>> MOD(3, -2)
-1
>>> MOD(-3, -2)
-1
"""
return dividend % divisor
def MROUND(value, factor):
"""
Rounds one number to the nearest integer multiple of another.
>>> MROUND(10, 3)
9
>>> MROUND(-10, -3)
-9
>>> round(MROUND(1.3, 0.2), 10)
1.4
>>> MROUND(5, -2)
Traceback (most recent call last):
...
ValueError: factor argument invalid
"""
if (factor < 0) != (value < 0):
raise ValueError("factor argument invalid")
return int(_round_toward_zero(float(value) / factor + 0.5)) * factor
def MULTINOMIAL(value1, *more_values):
"""
Returns the factorial of the sum of values divided by the product of the values' factorials.
>>> MULTINOMIAL(2, 3, 4)
1260
>>> MULTINOMIAL(3)
1
>>> MULTINOMIAL(1,2,3)
60
>>> MULTINOMIAL(0,2,4,6)
13860
"""
s = value1
res = 1
for v in more_values:
s += v
res *= COMBIN(s, v)
return res
def ODD(value):
"""
Rounds a number up to the nearest odd integer.
>>> ODD(1.5)
3
>>> ODD(3)
3
>>> ODD(2)
3
>>> ODD(-1)
-1
>>> ODD(-2)
-3
"""
return int(_round_away_from_zero(float(value + 1) / 2)) * 2 - 1
def PI():
"""
Returns the value of Pi to 14 decimal places.
>>> round(PI(), 9)
3.141592654
>>> round(PI()/2, 9)
1.570796327
>>> round(PI()*9, 8)
28.27433388
"""
return _math.pi
def POWER(base, exponent):
"""
Returns a number raised to a power.
>>> POWER(5,2)
25.0
>>> round(POWER(98.6,3.2), 3)
2401077.222
>>> round(POWER(4,5.0/4), 9)
5.656854249
"""
return _math.pow(base, exponent)
def PRODUCT(factor1, *more_factors):
"""
Returns the result of multiplying a series of numbers together. Each argument may be a number or
an array.
>>> PRODUCT([5,15,30])
2250
>>> PRODUCT([5,15,30], 2)
4500
>>> PRODUCT(5,15,[30],[2])
4500
More tests:
>>> PRODUCT([2, True, None, "", False, "0", 5])
10
>>> PRODUCT([2, True, None, "", False, 0, 5])
0
"""
return reduce(operator.mul, _chain_numeric(factor1, *more_factors))
def QUOTIENT(dividend, divisor):
"""
Returns one number divided by another.
>>> QUOTIENT(5, 2)
2
>>> QUOTIENT(4.5, 3.1)
1
>>> QUOTIENT(-10, 3)
-3
"""
return TRUNC(float(dividend) / divisor)
def RADIANS(angle):
"""
Converts an angle value in degrees to radians.
>>> round(RADIANS(270), 6)
4.712389
"""
return _math.radians(angle)
def RAND():
"""
Returns a random number between 0 inclusive and 1 exclusive.
"""
return random.random()
def RANDBETWEEN(low, high):
"""
Returns a uniformly random integer between two values, inclusive.
"""
return random.randrange(low, high + 1)
def ROMAN(number, form_unused=None):
"""
Formats a number in Roman numerals. The second argument is ignored in this implementation.
>>> ROMAN(499,0)
'CDXCIX'
>>> ROMAN(499.2,0)
'CDXCIX'
>>> ROMAN(57)
'LVII'
>>> ROMAN(1912)
'MCMXII'
"""
# TODO: Maybe we should support the second argument.
return roman.toRoman(int(number))
def ROUND(value, places=0):
"""
Rounds a number to a certain number of decimal places according to standard rules.
>>> ROUND(2.15, 1) # Excel actually gives the more correct 2.2
2.1
>>> ROUND(2.149, 1)
2.1
>>> ROUND(-1.475, 2)
-1.48
>>> ROUND(21.5, -1)
20.0
>>> ROUND(626.3,-3)
1000.0
>>> ROUND(1.98,-1)
0.0
>>> ROUND(-50.55,-2)
-100.0
"""
# TODO: Excel manages to round 2.15 to 2.2, but Python sees 2.149999... and rounds to 2.1
# (see Python notes in documentation of `round()`).
return round(value, places)
def ROUNDDOWN(value, places=0):
"""
Rounds a number to a certain number of decimal places, always rounding down towards zero.
>>> ROUNDDOWN(3.2, 0)
3
>>> ROUNDDOWN(76.9,0)
76
>>> ROUNDDOWN(3.14159, 3)
3.141
>>> ROUNDDOWN(-3.14159, 1)
-3.1
>>> ROUNDDOWN(31415.92654, -2)
31400
"""
factor = 10**-places
return int(_round_toward_zero(float(value) / factor)) * factor
def ROUNDUP(value, places=0):
"""
Rounds a number to a certain number of decimal places, always rounding up away from zero.
>>> ROUNDUP(3.2,0)
4
>>> ROUNDUP(76.9,0)
77
>>> ROUNDUP(3.14159, 3)
3.142
>>> ROUNDUP(-3.14159, 1)
-3.2
>>> ROUNDUP(31415.92654, -2)
31500
"""
factor = 10**-places
return int(_round_away_from_zero(float(value) / factor)) * factor
def SERIESSUM(x, n, m, a):
"""
Given parameters x, n, m, and a, returns the power series sum a_1*x^n + a_2*x^(n+m)
+ ... + a_i*x^(n+(i-1)m), where i is the number of entries in range `a`.
>>> SERIESSUM(1,0,1,1)
1
>>> SERIESSUM(2,1,0,[1,2,3])
12
>>> SERIESSUM(-3,1,1,[2,4,6])
-132
>>> round(SERIESSUM(PI()/4,0,2,[1,-1./FACT(2),1./FACT(4),-1./FACT(6)]), 6)
0.707103
"""
return sum(coef*pow(x, n+i*m) for i, coef in enumerate(_chain(a)))
def SIGN(value):
"""
Given an input number, returns `-1` if it is negative, `1` if positive, and `0` if it is zero.
>>> SIGN(10)
1
>>> SIGN(4.0-4.0)
0
>>> SIGN(-0.00001)
-1
"""
return 0 if value == 0 else int(_math.copysign(1, value))
def SIN(angle):
"""
Returns the sine of an angle provided in radians.
>>> round(SIN(PI()), 10)
0.0
>>> SIN(PI()/2)
1.0
>>> round(SIN(30*PI()/180), 10)
0.5
>>> round(SIN(RADIANS(30)), 10)
0.5
"""
return _math.sin(angle)
def SINH(value):
"""
Returns the hyperbolic sine of any real number.
>>> round(2.868*SINH(0.0342*1.03), 7)
0.1010491
"""
return _math.sinh(value)
def SQRT(value):
"""
Returns the positive square root of a positive number.
>>> SQRT(16)
4.0
>>> SQRT(-16)
Traceback (most recent call last):
...
ValueError: math domain error
>>> SQRT(ABS(-16))
4.0
"""
return _math.sqrt(value)
def SQRTPI(value):
"""
Returns the positive square root of the product of Pi and the given positive number.
>>> round(SQRTPI(1), 6)
1.772454
>>> round(SQRTPI(2), 6)
2.506628
"""
return _math.sqrt(_math.pi * value)
def SUBTOTAL(function_code, range1, range2):
"""
Returns a subtotal for a vertical range of cells using a specified aggregation function.
"""
raise NotImplementedError()
def SUM(value1, *more_values):
"""
Returns the sum of a series of numbers. Each argument may be a number or an array.
Non-numeric values are ignored.
>>> SUM([5,15,30])
50
>>> SUM([5.,15,30], 2)
52.0
>>> SUM(5,15,[30],[2])
52
More tests:
>>> SUM([10.25, None, "", False, "other", 20.5])
30.75
>>> SUM([True, "3", 4], True)
6
"""
return sum(_chain_numeric_a(value1, *more_values))
def SUMIF(records, criterion, sum_range):
"""
Returns a conditional sum across a range.
"""
raise NotImplementedError()
def SUMIFS(sum_range, criteria_range1, criterion1, *args):
"""
Returns the sum of a range depending on multiple criteria.
"""
raise NotImplementedError()
def SUMPRODUCT(array1, *more_arrays):
"""
Multiplies corresponding components in the given arrays, and returns the sum of those products.
>>> SUMPRODUCT([3,8,1,4,6,9], [2,6,5,7,7,3])
156
>>> SUMPRODUCT([], [], [])
0
>>> SUMPRODUCT([-0.25], [-2], [-3])
-1.5
>>> SUMPRODUCT([-0.25, -0.25], [-2, -2], [-3, -3])
-3.0
"""
return sum(reduce(operator.mul, values) for values in itertools.izip(array1, *more_arrays))
def SUMSQ(value1, value2):
"""
Returns the sum of the squares of a series of numbers and/or cells.
"""
raise NotImplementedError()
def TAN(angle):
"""
Returns the tangent of an angle provided in radians.
>>> round(TAN(0.785), 8)
0.99920399
>>> round(TAN(45*PI()/180), 10)
1.0
>>> round(TAN(RADIANS(45)), 10)
1.0
"""
return _math.tan(angle)
def TANH(value):
"""
Returns the hyperbolic tangent of any real number.
>>> round(TANH(-2), 6)
-0.964028
>>> TANH(0)
0.0
>>> round(TANH(0.5), 6)
0.462117
"""
return _math.tanh(value)
def TRUNC(value, places=0):
"""
Truncates a number to a certain number of significant digits by omitting less significant
digits.
>>> TRUNC(8.9)
8
>>> TRUNC(-8.9)
-8
>>> TRUNC(0.45)
0
"""
# TRUNC seems indistinguishable from ROUNDDOWN.
return ROUNDDOWN(value, places)

@ -0,0 +1,329 @@
from datetime import datetime, timedelta
import re
from date import DATEADD, NOW, DTIME
from moment_parse import MONTH_NAMES, DAY_NAMES
# Limit exports to schedule, so that upper-case constants like MONTH_NAMES, DAY_NAMES don't end up
# exposed as if Excel-style functions (or break docs generation).
__all__ = ['SCHEDULE']
def SCHEDULE(schedule, start=None, count=10, end=None):
"""
Returns the list of `datetime` objects generated according to the `schedule` string. Starts at
`start`, which defaults to NOW(). Generates at most `count` results (10 by default). If `end` is
given, stops there.
The schedule has the format "INTERVAL: SLOTS, ...". For example:
annual: Jan-15, Apr-15, Jul-15 -- Three times a year on given dates at midnight.
annual: 1/15, 4/15, 7/15 -- Same as above.
monthly: /1 2pm, /15 2pm -- The 1st and the 15th of each month, at 2pm.
3-months: /10, +1m /20 -- Every 3 months on the 10th of month 1, 20th of month 2.
weekly: Mo 9am, Tu 9am, Fr 2pm -- Three times a week at specified times.
2-weeks: Mo, +1w Tu -- Every 2 weeks on Monday of week 1, Tuesday of week 2.
daily: 07:30, 21:00 -- Twice a day at specified times.
2-day: 12am, 4pm, +1d 8am -- Three times every two days, evenly spaced.
hourly: :15, :45 -- 15 minutes before and after each hour.
4-hour: :00, 1:20, 2:40 -- Three times every 4 hours, evenly spaced.
10-minute: +0s -- Every 10 minutes on the minute.
INTERVAL must be either of the form `N-unit` where `N` is a number and `unit` is one of `year`,
`month`, `week`, `day`, `hour`; or one of the aliases: `annual`, `monthly`, `weekly`, `daily`,
`hourly`, which mean `1-year`, `1-month`, etc.
SLOTS support the following units:
`Jan-15` or `1/15` -- Month and day of the month; available when INTERVAL is year-based.
`/15` -- Day of the month, available when INTERVAL is month-based.
`Mon`, `Mo`, `Friday` -- Day of the week (or abbreviation), when INTERVAL is week-based.
10am, 1:30pm, 15:45 -- Time of day, available for day-based or longer intervals.
:45, :00 -- Minutes of the hour, available when INTERVAL is hour-based.
+1d, +15d -- How many days to add to start of INTERVAL.
+1w -- How many weeks to add to start of INTERVAL.
+1m -- How many months to add to start of INTERVAL.
The SLOTS are always relative to the INTERVAL rather than to `start`. Week-based intervals start
on Sunday. E.g. `weekly: +1d, +4d` is the same as `weekly: Mon, Thu`, and generates times on
Mondays and Thursdays regardless of `start`.
The first generated time is determined by the *unit* of the INTERVAL without regard to the
multiple. E.g. both "2-week: Mon" and "3-week: Mon" start on the first Monday after `start`, and
then generate either every second or every third Monday after that. Similarly, `24-hour: :00`
starts with the first top-of-the-hour after `start` (not with midnight), and then repeats every
24 hours. To start with the midnight after `start`, use `daily: 0:00`.
For interval units of a day or longer, if time-of-day is not specified, it defaults to midnight.
The time zone of `start` determines the time zone of the generated times.
>>> def show(dates): return [d.strftime("%Y-%m-%d %H:%M") for d in dates]
>>> start = datetime(2018, 9, 4, 14, 0); # 2pm on Tue, Sep 4 2018.
>>> show(SCHEDULE('annual: Jan-15, Apr-15, Jul-15, Oct-15', start=start, count=4))
['2018-10-15 00:00', '2019-01-15 00:00', '2019-04-15 00:00', '2019-07-15 00:00']
>>> show(SCHEDULE('annual: 1/15, 4/15, 7/15', start=start, count=4))
['2019-01-15 00:00', '2019-04-15 00:00', '2019-07-15 00:00', '2020-01-15 00:00']
>>> show(SCHEDULE('monthly: /1 2pm, /15 5pm', start=start, count=4))
['2018-09-15 17:00', '2018-10-01 14:00', '2018-10-15 17:00', '2018-11-01 14:00']
>>> show(SCHEDULE('3-months: /10, +1m /20', start=start, count=4))
['2018-09-10 00:00', '2018-10-20 00:00', '2018-12-10 00:00', '2019-01-20 00:00']
>>> show(SCHEDULE('weekly: Mo 9am, Tu 9am, Fr 2pm', start=start, count=4))
['2018-09-07 14:00', '2018-09-10 09:00', '2018-09-11 09:00', '2018-09-14 14:00']
>>> show(SCHEDULE('2-weeks: Mo, +1w Tu', start=start, count=4))
['2018-09-11 00:00', '2018-09-17 00:00', '2018-09-25 00:00', '2018-10-01 00:00']
>>> show(SCHEDULE('daily: 07:30, 21:00', start=start, count=4))
['2018-09-04 21:00', '2018-09-05 07:30', '2018-09-05 21:00', '2018-09-06 07:30']
>>> show(SCHEDULE('2-day: 12am, 4pm, +1d 8am', start=start, count=4))
['2018-09-04 16:00', '2018-09-05 08:00', '2018-09-06 00:00', '2018-09-06 16:00']
>>> show(SCHEDULE('hourly: :15, :45', start=start, count=4))
['2018-09-04 14:15', '2018-09-04 14:45', '2018-09-04 15:15', '2018-09-04 15:45']
>>> show(SCHEDULE('4-hour: :00, +1H :20, +2H :40', start=start, count=4))
['2018-09-04 14:00', '2018-09-04 15:20', '2018-09-04 16:40', '2018-09-04 18:00']
"""
return Schedule(schedule).series(start or NOW(), end, count=count)
class Delta(object):
"""
Similar to timedelta, keeps intervals by unit. Specifically, this is needed for months
and years, since those can't be represented exactly with a timedelta.
"""
def __init__(self):
self._timedelta = timedelta(0)
self._months = 0
def add_interval(self, number, unit):
if unit == 'months':
self._months += number
elif unit == 'years':
self._months += number * 12
else:
self._timedelta += timedelta(**{unit: number})
return self
def add_to(self, dtime):
return datetime.combine(DATEADD(dtime, months=self._months), dtime.timetz()) + self._timedelta
class Schedule(object):
"""
Schedule parses a schedule spec into an interval and slots in the constructor. Then the series()
method applies it to any start/end dates.
"""
def __init__(self, spec_string):
parts = spec_string.split(":", 1)
if len(parts) != 2:
raise ValueError("schedule must have the form INTERVAL: SLOTS, ...")
count, unit = _parse_interval(parts[0].strip())
self._interval_unit = unit
self._interval = Delta().add_interval(count, unit)
self._slots = [_parse_slot(t, self._interval_unit) for t in parts[1].split(",")]
def series(self, start_dtime, end_dtime, count=10):
# Start with a preceding unit boundary, then check the slots within that unit and start with
# the first one that's at start_dtime or later.
start_dtime = DTIME(start_dtime)
end_dtime = end_dtime and DTIME(end_dtime)
dtime = _round_down_to_unit(start_dtime, self._interval_unit)
while True:
for slot in self._slots:
if count <= 0:
return
out = slot.add_to(dtime)
if out < start_dtime:
continue
if end_dtime is not None and out > end_dtime:
return
yield out
count -= 1
dtime = self._interval.add_to(dtime)
def _fail(message):
raise ValueError(message)
def _round_down_to_unit(dtime, unit):
"""
Rounds datetime down to the given unit. Weeks are rounded to start of Sunday.
"""
tz = dtime.tzinfo
return ( datetime(dtime.year, 1, 1, tzinfo=tz) if unit == 'years'
else datetime(dtime.year, dtime.month, 1, tzinfo=tz) if unit == 'months'
else (dtime - timedelta(days=dtime.isoweekday() % 7))
.replace(hour=0, minute=0, second=0, microsecond=0) if unit == 'weeks'
else dtime.replace(hour=0, minute=0, second=0, microsecond=0) if unit == 'days'
else dtime.replace(minute=0, second=0, microsecond=0) if unit == 'hours'
else dtime.replace(second=0, microsecond=0) if unit == 'minutes'
else dtime.replace(microsecond=0) if unit == 'seconds'
else _fail("Invalid unit %s" % unit)
)
_UNITS = ('years', 'months', 'weeks', 'days', 'hours', 'minutes', 'seconds')
_VALID_UNITS = set(_UNITS)
_SINGULAR_UNITS = dict(zip(('year', 'month', 'week', 'day', 'hour', 'minute', 'second'), _UNITS))
_SHORT_UNITS = dict(zip(('y', 'm', 'w', 'd', 'H', 'M', 'S'), _UNITS))
_INTERVAL_ALIASES = {
'annual': (1, 'years'),
'monthly': (1, 'months'),
'weekly': (1, 'weeks'),
'daily': (1, 'days'),
'hourly': (1, 'hours'),
}
_INTERVAL_RE = re.compile(r'^(?P<num>\d+)[-\s]+(?P<unit>[a-z]+)$', re.I)
# Maps weekday names, including 2- and 3-letter abbreviations, to numbers 0 through 6.
WEEKDAY_OFFSETS = {}
for (i, name) in enumerate(DAY_NAMES):
WEEKDAY_OFFSETS[name] = i
WEEKDAY_OFFSETS[name[:3]] = i
WEEKDAY_OFFSETS[name[:2]] = i
# Maps month names, including 3-letter abbreviations, to numbers 0 through 11.
MONTH_OFFSETS = {}
for (i, name) in enumerate(MONTH_NAMES):
MONTH_OFFSETS[name] = i
MONTH_OFFSETS[name[:3]] = i
def _parse_interval(interval_str):
"""
Given a spec like "daily" or "3-week", returns (N, unit), such as (1, "days") or (3, "weeks").
"""
interval_str = interval_str.lower()
if interval_str in _INTERVAL_ALIASES:
return _INTERVAL_ALIASES[interval_str]
m = _INTERVAL_RE.match(interval_str)
if not m:
raise ValueError("Not a valid interval '%s'" % interval_str)
num = int(m.group("num"))
unit = m.group("unit")
unit = _SINGULAR_UNITS.get(unit, unit)
if unit not in _VALID_UNITS:
raise ValueError("Unknown unit '%s' in interval '%s'" % (unit, interval_str))
return (num, unit)
def _parse_slot(slot_str, parent_unit):
"""
Parses a slot in one of several recognized formats. Allowed formats depend on parent_unit, e.g.
'Jan-15' is valid when parent_unit is 'years', but not when it is 'hours'. We also disallow
using the same unit more than once, which is confusing, e.g. "+1d +2d" or "9:30am +2H".
Returns a Delta object.
"""
parts = slot_str.split()
if not parts:
raise ValueError("At least one slot must be specified")
delta = Delta()
seen_units = set()
allowed_slot_types = _ALLOWED_SLOTS_BY_UNIT.get(parent_unit) or ('delta',)
# Slot parts go through parts like "Jan-15 16pm", collecting the offsets into a single Delta.
for part in parts:
m = _SLOT_RE.match(part)
if not m:
raise ValueError("Invalid slot '%s'" % part)
for slot_type in allowed_slot_types:
if m.group(slot_type):
# If there is a group for one slot type, that's the only group. We find and use the
# corresponding parser, then move on to the next slot part.
for count, unit in _SLOT_PARSERS[slot_type](m):
delta.add_interval(count, unit)
if unit in seen_units:
raise ValueError("Duplicate unit %s in '%s'" % (unit, slot_str))
seen_units.add(unit)
break
else:
# If none of the allowed slot types was found, it must be a disallowed one.
raise ValueError("Invalid slot '%s' for unit '%s'" % (part, parent_unit))
return delta
# We parse all slot types using one big regex. The constants below define one part of the regex
# for each slot type (e.g. to match "Jan-15" or "5:30am" or "+1d"). Note that all group names
# (defined with (?P<NAME>...)) must be distinct.
_DATE_RE = r'(?:(?P<month_name>[a-z]+)-|(?P<month_num>\d+)/)(?P<month_day>\d+)'
_MDAY_RE = r'/(?P<month_day2>\d+)'
_WDAY_RE = r'(?P<weekday>[a-z]+)'
_TIME_RE = r'(?P<hours>\d+)(?:\:(?P<minutes>\d{2})(?P<ampm1>am|pm)?|(?P<ampm2>am|pm))'
_MINS_RE = r':(?P<minutes2>\d{2})'
_DELTA_RE = r'\+(?P<count>\d+)(?P<unit>[a-z]+)'
# The regex parts are combined and compiled here. Only one group will match, corresponding to one
# slot type. Different slot types depend on the unit of the overall interval.
_SLOT_RE = re.compile(
r'^(?:(?P<date>%s)|(?P<mday>%s)|(?P<wday>%s)|(?P<time>%s)|(?P<mins>%s)|(?P<delta>%s))$' %
(_DATE_RE, _MDAY_RE, _WDAY_RE, _TIME_RE, _MINS_RE, _DELTA_RE), re.IGNORECASE)
# Slot types that make sense for each unit of overall interval. If not listed (e.g. "minutes")
# then only "delta" slot type is allowed.
_ALLOWED_SLOTS_BY_UNIT = {
'years': ('date', 'time', 'delta'),
'months': ('mday', 'time', 'delta'),
'weeks': ('wday', 'time', 'delta'),
'days': ('time', 'delta'),
'hours': ('mins', 'delta'),
}
# The helper methods below parse one slot type each, given a regex match that matched that slot
# type. These are combined and used via the _SLOT_PARSERS dict below.
def _parse_slot_date(m):
mday = int(m.group("month_day"))
month_name = m.group("month_name")
month_num = m.group("month_num")
if month_name:
name = month_name.lower()
if name not in MONTH_OFFSETS:
raise ValueError("Unknown month '%s'" % month_name)
mnum = MONTH_OFFSETS[name]
else:
mnum = int(month_num) - 1
return [(mnum, 'months'), (mday - 1, 'days')]
def _parse_slot_mday(m):
mday = int(m.group("month_day2"))
return [(mday - 1, 'days')]
def _parse_slot_wday(m):
wday = m.group("weekday").lower()
if wday not in WEEKDAY_OFFSETS:
raise ValueError("Unknown day of the week '%s'" % wday)
return [(WEEKDAY_OFFSETS[wday], "days")]
def _parse_slot_time(m):
hours = int(m.group("hours"))
minutes = int(m.group("minutes") or 0)
ampm = m.group("ampm1") or m.group("ampm2")
if ampm:
hours = (hours % 12) + (12 if ampm.lower() == "pm" else 0)
return [(hours, 'hours'), (minutes, 'minutes')]
def _parse_slot_mins(m):
minutes = int(m.group("minutes2"))
return [(minutes, 'minutes')]
def _parse_slot_delta(m):
count = int(m.group("count"))
unit = m.group("unit")
if unit not in _SHORT_UNITS:
raise ValueError("Unknown unit '%s' in interval '%s'" % (unit, m.group()))
return [(count, _SHORT_UNITS[unit])]
_SLOT_PARSERS = {
'date': _parse_slot_date,
'mday': _parse_slot_mday,
'wday': _parse_slot_wday,
'time': _parse_slot_time,
'mins': _parse_slot_mins,
'delta': _parse_slot_delta,
}

@ -0,0 +1,615 @@
# pylint: disable=redefined-builtin, line-too-long, unused-argument
from math import _chain, _chain_numeric, _chain_numeric_a
from info import ISNUMBER, ISLOGICAL
from date import DATE # pylint: disable=unused-import
def _average(iterable):
total, count = 0.0, 0
for value in iterable:
total += value
count += 1
return total / count
def _default_if_empty(iterable, default):
"""
Yields all values from iterable, except when it is empty, yields just the single default value.
"""
empty = True
for value in iterable:
empty = False
yield value
if empty:
yield default
def AVEDEV(value1, value2):
"""Calculates the average of the magnitudes of deviations of data from a dataset's mean."""
raise NotImplementedError()
def AVERAGE(value, *more_values):
"""
Returns the numerical average value in a dataset, ignoring non-numerical values.
Each argument may be a value or an array. Values that are not numbers, including logical
and blank values, and text representations of numbers, are ignored.
>>> AVERAGE([2, -1.0, 11])
4.0
>>> AVERAGE([2, -1, 11, "Hello"])
4.0
>>> AVERAGE([2, -1, "Hello", DATE(2015,1,1)], True, [False, "123", "", 11])
4.0
>>> AVERAGE(False, True)
Traceback (most recent call last):
...
ZeroDivisionError: float division by zero
"""
return _average(_chain_numeric(value, *more_values))
def AVERAGEA(value, *more_values):
"""
Returns the numerical average value in a dataset, counting non-numerical values as 0.
Each argument may be a value of an array. Values that are not numbers, including dates and text
representations of numbers, are counted as 0 (zero). Logical value of True is counted as 1, and
False as 0.
>>> AVERAGEA([2, -1.0, 11])
4.0
>>> AVERAGEA([2, -1, 11, "Hello"])
3.0
>>> AVERAGEA([2, -1, "Hello", DATE(2015,1,1)], True, [False, "123", "", 11.5])
1.5
>>> AVERAGEA(False, True)
0.5
"""
return _average(_chain_numeric_a(value, *more_values))
# Note that Google Sheets offers a similar function, called AVERAGE.WEIGHTED
# (https://support.google.com/docs/answer/9084098?hl=en)
def AVERAGE_WEIGHTED(pairs):
"""
Given a list of (value, weight) pairs, finds the average of the values weighted by the
corresponding weights. Ignores any pairs with a non-numerical value or weight.
If you have two lists, of values and weights, use the Python built-in zip() function to create a
list of pairs.
>>> AVERAGE_WEIGHTED(((95, .25), (90, .1), ("X", .5), (85, .15), (88, .2), (82, .3), (70, None)))
87.7
>>> AVERAGE_WEIGHTED(zip([95, 90, "X", 85, 88, 82, 70], [25, 10, 50, 15, 20, 30, None]))
87.7
>>> AVERAGE_WEIGHTED(zip([95, 90, False, 85, 88, 82, 70], [.25, .1, .5, .15, .2, .3, True]))
87.7
"""
sum_value, sum_weight = 0.0, 0.0
for value, weight in pairs:
# The type-checking here is the same as used by _chain_numeric.
if ISNUMBER(value) and not ISLOGICAL(value) and ISNUMBER(weight) and not ISLOGICAL(weight):
sum_value += value * weight
sum_weight += weight
return sum_value / sum_weight
def AVERAGEIF(criteria_range, criterion, average_range=None):
"""Returns the average of a range depending on criteria."""
raise NotImplementedError()
def AVERAGEIFS(average_range, criteria_range1, criterion1, *args):
"""Returns the average of a range depending on multiple criteria."""
raise NotImplementedError()
def BINOMDIST(num_successes, num_trials, prob_success, cumulative):
"""
Calculates the probability of drawing a certain number of successes (or a maximum number of
successes) in a certain number of tries given a population of a certain size containing a
certain number of successes, with replacement of draws.
"""
raise NotImplementedError()
def CONFIDENCE(alpha, standard_deviation, pop_size):
"""Calculates the width of half the confidence interval for a normal distribution."""
raise NotImplementedError()
def CORREL(data_y, data_x):
"""Calculates r, the Pearson product-moment correlation coefficient of a dataset."""
raise NotImplementedError()
def COUNT(value, *more_values):
"""
Returns the count of numerical values in a dataset, ignoring non-numerical values.
Each argument may be a value or an array. Values that are not numbers, including logical
and blank values, and text representations of numbers, are ignored.
>>> COUNT([2, -1.0, 11])
3
>>> COUNT([2, -1, 11, "Hello"])
3
>>> COUNT([2, -1, "Hello", DATE(2015,1,1)], True, [False, "123", "", 11.5])
3
>>> COUNT(False, True)
0
"""
return sum(1 for v in _chain_numeric(value, *more_values))
def COUNTA(value, *more_values):
"""
Returns the count of all values in a dataset, including non-numerical values.
Each argument may be a value or an array.
>>> COUNTA([2, -1.0, 11])
3
>>> COUNTA([2, -1, 11, "Hello"])
4
>>> COUNTA([2, -1, "Hello", DATE(2015,1,1)], True, [False, "123", "", 11.5])
9
>>> COUNTA(False, True)
2
"""
return sum(1 for v in _chain(value, *more_values))
def COVAR(data_y, data_x):
"""Calculates the covariance of a dataset."""
raise NotImplementedError()
def CRITBINOM(num_trials, prob_success, target_prob):
"""Calculates the smallest value for which the cumulative binomial distribution is greater than or equal to a specified criteria."""
raise NotImplementedError()
def DEVSQ(value1, value2):
"""Calculates the sum of squares of deviations based on a sample."""
raise NotImplementedError()
def EXPONDIST(x, lambda_, cumulative):
"""Returns the value of the exponential distribution function with a specified lambda at a specified value."""
raise NotImplementedError()
def F_DIST(x, degrees_freedom1, degrees_freedom2, cumulative):
"""
Calculates the left-tailed F probability distribution (degree of diversity) for two data sets
with given input x. Alternately called Fisher-Snedecor distribution or Snedecor's F
distribution.
"""
raise NotImplementedError()
def F_DIST_RT(x, degrees_freedom1, degrees_freedom2):
"""
Calculates the right-tailed F probability distribution (degree of diversity) for two data sets
with given input x. Alternately called Fisher-Snedecor distribution or Snedecor's F
distribution.
"""
raise NotImplementedError()
def FDIST(x, degrees_freedom1, degrees_freedom2):
"""
Calculates the right-tailed F probability distribution (degree of diversity) for two data sets
with given input x. Alternately called Fisher-Snedecor distribution or Snedecor's F
distribution.
"""
raise NotImplementedError()
def FISHER(value):
"""Returns the Fisher transformation of a specified value."""
raise NotImplementedError()
def FISHERINV(value):
"""Returns the inverse Fisher transformation of a specified value."""
raise NotImplementedError()
def FORECAST(x, data_y, data_x):
"""Calculates the expected y-value for a specified x based on a linear regression of a dataset."""
raise NotImplementedError()
def GEOMEAN(value1, value2):
"""Calculates the geometric mean of a dataset."""
raise NotImplementedError()
def HARMEAN(value1, value2):
"""Calculates the harmonic mean of a dataset."""
raise NotImplementedError()
def HYPGEOMDIST(num_successes, num_draws, successes_in_pop, pop_size):
"""Calculates the probability of drawing a certain number of successes in a certain number of tries given a population of a certain size containing a certain number of successes, without replacement of draws."""
raise NotImplementedError()
def INTERCEPT(data_y, data_x):
"""Calculates the y-value at which the line resulting from linear regression of a dataset will intersect the y-axis (x=0)."""
raise NotImplementedError()
def KURT(value1, value2):
"""Calculates the kurtosis of a dataset, which describes the shape, and in particular the "peakedness" of that dataset."""
raise NotImplementedError()
def LARGE(data, n):
"""Returns the nth largest element from a data set, where n is user-defined."""
raise NotImplementedError()
def LOGINV(x, mean, standard_deviation):
"""Returns the value of the inverse log-normal cumulative distribution with given mean and standard deviation at a specified value."""
raise NotImplementedError()
def LOGNORMDIST(x, mean, standard_deviation):
"""Returns the value of the log-normal cumulative distribution with given mean and standard deviation at a specified value."""
raise NotImplementedError()
def MAX(value, *more_values):
"""
Returns the maximum value in a dataset, ignoring non-numerical values.
Each argument may be a value or an array. Values that are not numbers, including logical
and blank values, and text representations of numbers, are ignored. Returns 0 if the arguments
contain no numbers.
>>> MAX([2, -1.5, 11.5])
11.5
>>> MAX([2, -1.5, "Hello", DATE(2015, 1, 1)], True, [False, "123", "", 11.5])
11.5
>>> MAX(True, -123)
-123
>>> MAX("123", -123)
-123
>>> MAX("Hello", "123", DATE(2015, 1, 1))
0
"""
return max(_default_if_empty(_chain_numeric(value, *more_values), 0))
def MAXA(value, *more_values):
"""
Returns the maximum numeric value in a dataset.
Each argument may be a value of an array. Values that are not numbers, including dates and text
representations of numbers, are counted as 0 (zero). Logical value of True is counted as 1, and
False as 0. Returns 0 if the arguments contain no numbers.
>>> MAXA([2, -1.5, 11.5])
11.5
>>> MAXA([2, -1.5, "Hello", DATE(2015, 1, 1)], True, [False, "123", "", 11.5])
11.5
>>> MAXA(True, -123)
1
>>> MAXA("123", -123)
0
>>> MAXA("Hello", "123", DATE(2015, 1, 1))
0
"""
return max(_default_if_empty(_chain_numeric_a(value, *more_values), 0))
def MEDIAN(value, *more_values):
"""
Returns the median value in a numeric dataset, ignoring non-numerical values.
Each argument may be a value or an array. Values that are not numbers, including logical
and blank values, and text representations of numbers, are ignored.
Produces an error if the arguments contain no numbers.
The median is the middle number when all values are sorted. So half of the values in the dataset
are less than the median, and half of the values are greater. If there is an even number of
values in the dataset, returns the average of the two numbers in the middle.
>>> MEDIAN(1, 2, 3, 4, 5)
3
>>> MEDIAN(3, 5, 1, 4, 2)
3
>>> MEDIAN(xrange(10))
4.5
>>> MEDIAN("Hello", "123", DATE(2015, 1, 1), 12.3)
12.3
>>> MEDIAN("Hello", "123", DATE(2015, 1, 1))
Traceback (most recent call last):
...
ValueError: MEDIAN requires at least one number
"""
values = sorted(_chain_numeric(value, *more_values))
if not values:
raise ValueError("MEDIAN requires at least one number")
count = len(values)
if count % 2 == 0:
return (values[count / 2 - 1] + values[count / 2]) / 2.0
else:
return values[(count - 1) / 2]
def MIN(value, *more_values):
"""
Returns the minimum value in a dataset, ignoring non-numerical values.
Each argument may be a value or an array. Values that are not numbers, including logical
and blank values, and text representations of numbers, are ignored. Returns 0 if the arguments
contain no numbers.
>>> MIN([2, -1.5, 11.5])
-1.5
>>> MIN([2, -1.5, "Hello", DATE(2015, 1, 1)], True, [False, "123", "", 11.5])
-1.5
>>> MIN(True, 123)
123
>>> MIN("-123", 123)
123
>>> MIN("Hello", "123", DATE(2015, 1, 1))
0
"""
return min(_default_if_empty(_chain_numeric(value, *more_values), 0))
def MINA(value, *more_values):
"""
Returns the minimum numeric value in a dataset.
Each argument may be a value of an array. Values that are not numbers, including dates and text
representations of numbers, are counted as 0 (zero). Logical value of True is counted as 1, and
False as 0. Returns 0 if the arguments contain no numbers.
>>> MINA([2, -1.5, 11.5])
-1.5
>>> MINA([2, -1.5, "Hello", DATE(2015, 1, 1)], True, [False, "123", "", 11.5])
-1.5
>>> MINA(True, 123)
1
>>> MINA("-123", 123)
0
>>> MINA("Hello", "123", DATE(2015, 1, 1))
0
"""
return min(_default_if_empty(_chain_numeric_a(value, *more_values), 0))
def MODE(value1, value2):
"""Returns the most commonly occurring value in a dataset."""
raise NotImplementedError()
def NEGBINOMDIST(num_failures, num_successes, prob_success):
"""Calculates the probability of drawing a certain number of failures before a certain number of successes given a probability of success in independent trials."""
raise NotImplementedError()
def NORMDIST(x, mean, standard_deviation, cumulative):
"""
Returns the value of the normal distribution function (or normal cumulative distribution
function) for a specified value, mean, and standard deviation.
"""
raise NotImplementedError()
def NORMINV(x, mean, standard_deviation):
"""Returns the value of the inverse normal distribution function for a specified value, mean, and standard deviation."""
raise NotImplementedError()
def NORMSDIST(x):
"""Returns the value of the standard normal cumulative distribution function for a specified value."""
raise NotImplementedError()
def NORMSINV(x):
"""Returns the value of the inverse standard normal distribution function for a specified value."""
raise NotImplementedError()
def PEARSON(data_y, data_x):
"""Calculates r, the Pearson product-moment correlation coefficient of a dataset."""
raise NotImplementedError()
def PERCENTILE(data, percentile):
"""Returns the value at a given percentile of a dataset."""
raise NotImplementedError()
def PERCENTRANK(data, value, significant_digits=None):
"""Returns the percentage rank (percentile) of a specified value in a dataset."""
raise NotImplementedError()
def PERCENTRANK_EXC(data, value, significant_digits=None):
"""Returns the percentage rank (percentile) from 0 to 1 exclusive of a specified value in a dataset."""
raise NotImplementedError()
def PERCENTRANK_INC(data, value, significant_digits=None):
"""Returns the percentage rank (percentile) from 0 to 1 inclusive of a specified value in a dataset."""
raise NotImplementedError()
def PERMUT(n, k):
"""Returns the number of ways to choose some number of objects from a pool of a given size of objects, considering order."""
raise NotImplementedError()
def POISSON(x, mean, cumulative):
"""
Returns the value of the Poisson distribution function (or Poisson cumulative distribution
function) for a specified value and mean.
"""
raise NotImplementedError()
def PROB(data, probabilities, low_limit, high_limit=None):
"""Given a set of values and corresponding probabilities, calculates the probability that a value chosen at random falls between two limits."""
raise NotImplementedError()
def QUARTILE(data, quartile_number):
"""Returns a value nearest to a specified quartile of a dataset."""
raise NotImplementedError()
def RANK(value, data, is_ascending=None):
"""Returns the rank of a specified value in a dataset."""
raise NotImplementedError()
def RANK_AVG(value, data, is_ascending=None):
"""Returns the rank of a specified value in a dataset. If there is more than one entry of the same value in the dataset, the average rank of the entries will be returned."""
raise NotImplementedError()
def RANK_EQ(value, data, is_ascending=None):
"""Returns the rank of a specified value in a dataset. If there is more than one entry of the same value in the dataset, the top rank of the entries will be returned."""
raise NotImplementedError()
def RSQ(data_y, data_x):
"""Calculates the square of r, the Pearson product-moment correlation coefficient of a dataset."""
raise NotImplementedError()
def SKEW(value1, value2):
"""Calculates the skewness of a dataset, which describes the symmetry of that dataset about the mean."""
raise NotImplementedError()
def SLOPE(data_y, data_x):
"""Calculates the slope of the line resulting from linear regression of a dataset."""
raise NotImplementedError()
def SMALL(data, n):
"""Returns the nth smallest element from a data set, where n is user-defined."""
raise NotImplementedError()
def STANDARDIZE(value, mean, standard_deviation):
"""Calculates the normalized equivalent of a random variable given mean and standard deviation of the distribution."""
raise NotImplementedError()
# This should make us all cry a little. Because the sandbox does not do Python3 (which has
# statistics package), and because it does not do numpy (because it's native and hasn't been built
# for it), we have to implement simple stats functions by hand.
# TODO: switch to use the statistics package instead, once we upgrade to Python3.
#
# The following implementation of stdev is taken from https://stackoverflow.com/a/27758326/328565
def _mean(data):
return sum(data) / float(len(data))
def _ss(data):
"""Return sum of square deviations of sequence data."""
c = _mean(data)
return sum((x-c)**2 for x in data)
def _stddev(data, ddof=0):
"""Calculates the population standard deviation
by default; specify ddof=1 to compute the sample
standard deviation."""
n = len(data)
ss = _ss(data)
pvar = ss/(n-ddof)
return pvar**0.5
# The examples in the doctests below come from https://support.google.com/docs/answer/3094054 and
# related articles, which helps ensure correctness and compatibility.
def STDEV(value, *more_values):
"""
Calculates the standard deviation based on a sample, ignoring non-numerical values.
>>> STDEV([2, 5, 8, 13, 10])
4.277849927241488
>>> STDEV([2, 5, 8, 13, 10, True, False, "Test"])
4.277849927241488
>>> STDEV([2, 5, 8, 13, 10], 3, 12, 15)
4.810702354423639
>>> STDEV([2, 5, 8, 13, 10], [3, 12, 15])
4.810702354423639
>>> STDEV([5])
Traceback (most recent call last):
...
ZeroDivisionError: float division by zero
"""
return _stddev(list(_chain_numeric(value, *more_values)), 1)
def STDEVA(value, *more_values):
"""
Calculates the standard deviation based on a sample, setting text to the value `0`.
>>> STDEVA([2, 5, 8, 13, 10])
4.277849927241488
>>> STDEVA([2, 5, 8, 13, 10, True, False, "Test"])
4.969550137731641
>>> STDEVA([2, 5, 8, 13, 10], 1, 0, 0)
4.969550137731641
>>> STDEVA([2, 5, 8, 13, 10], [1, 0, 0])
4.969550137731641
>>> STDEVA([5])
Traceback (most recent call last):
...
ZeroDivisionError: float division by zero
"""
return _stddev(list(_chain_numeric_a(value, *more_values)), 1)
def STDEVP(value, *more_values):
"""
Calculates the standard deviation based on an entire population, ignoring non-numerical values.
>>> STDEVP([2, 5, 8, 13, 10])
3.8262252939417984
>>> STDEVP([2, 5, 8, 13, 10, True, False, "Test"])
3.8262252939417984
>>> STDEVP([2, 5, 8, 13, 10], 3, 12, 15)
4.5
>>> STDEVP([2, 5, 8, 13, 10], [3, 12, 15])
4.5
>>> STDEVP([5])
0.0
"""
return _stddev(list(_chain_numeric(value, *more_values)), 0)
def STDEVPA(value, *more_values):
"""
Calculates the standard deviation based on an entire population, setting text to the value `0`.
>>> STDEVPA([2, 5, 8, 13, 10])
3.8262252939417984
>>> STDEVPA([2, 5, 8, 13, 10, True, False, "Test"])
4.648588495446763
>>> STDEVPA([2, 5, 8, 13, 10], 1, 0, 0)
4.648588495446763
>>> STDEVPA([2, 5, 8, 13, 10], [1, 0, 0])
4.648588495446763
>>> STDEVPA([5])
0.0
"""
return _stddev(list(_chain_numeric_a(value, *more_values)), 0)
def STEYX(data_y, data_x):
"""Calculates the standard error of the predicted y-value for each x in the regression of a dataset."""
raise NotImplementedError()
def T_INV(probability, degrees_freedom):
"""Calculates the negative inverse of the one-tailed TDIST function."""
raise NotImplementedError()
def T_INV_2T(probability, degrees_freedom):
"""Calculates the inverse of the two-tailed TDIST function."""
raise NotImplementedError()
def TDIST(x, degrees_freedom, tails):
"""Calculates the probability for Student's t-distribution with a given input (x)."""
raise NotImplementedError()
def TINV(probability, degrees_freedom):
"""Calculates the inverse of the two-tailed TDIST function."""
raise NotImplementedError()
def TRIMMEAN(data, exclude_proportion):
"""Calculates the mean of a dataset excluding some proportion of data from the high and low ends of the dataset."""
raise NotImplementedError()
def TTEST(range1, range2, tails, type):
"""Returns the probability associated with t-test. Determines whether two samples are likely to have come from the same two underlying populations that have the same mean."""
raise NotImplementedError()
def VAR(value1, value2):
"""Calculates the variance based on a sample."""
raise NotImplementedError()
def VARA(value1, value2):
"""Calculates an estimate of variance based on a sample, setting text to the value `0`."""
raise NotImplementedError()
def VARP(value1, value2):
"""Calculates the variance based on an entire population."""
raise NotImplementedError()
def VARPA(value1, value2):
"""Calculates the variance based on an entire population, setting text to the value `0`."""
raise NotImplementedError()
def WEIBULL(x, shape, scale, cumulative):
"""
Returns the value of the Weibull distribution function (or Weibull cumulative distribution
function) for a specified shape and scale.
"""
raise NotImplementedError()
def ZTEST(data, value, standard_deviation):
"""Returns the two-tailed P-value of a Z-test with standard distribution."""
raise NotImplementedError()

@ -0,0 +1,270 @@
from datetime import date, datetime, timedelta
import os
import timeit
import unittest
import moment
import schedule
from functions.date import DTIME
from functions import date as _date
DT = DTIME
TICK = timedelta.resolution
_orig_global_tz_getter = None
class TestSchedule(unittest.TestCase):
def assertDate(self, date_or_dtime, expected_str):
"""Formats date_or_dtime and compares the formatted value."""
return self.assertEqual(date_or_dtime.strftime("%Y-%m-%d %H:%M:%S"), expected_str)
def assertDateIso(self, date_or_dtime, expected_str):
"""Formats date_or_dtime and compares the formatted value."""
return self.assertEqual(date_or_dtime.isoformat(' '), expected_str)
def assertDelta(self, delta, months=0, **timedelta_args):
"""Asserts that the given delta corresponds to the given number of various units."""
self.assertEqual(delta._months, months)
self.assertEqual(delta._timedelta, timedelta(**timedelta_args))
@classmethod
def setUpClass(cls):
global _orig_global_tz_getter # pylint: disable=global-statement
_orig_global_tz_getter = _date._get_global_tz
_date._get_global_tz = lambda: moment.tzinfo('America/New_York')
@classmethod
def tearDownClass(cls):
_date._get_global_tz = _orig_global_tz_getter
def test_round_down_to_unit(self):
RDU = schedule._round_down_to_unit
self.assertDate(RDU(DT("2018-09-04 14:38:11"), "years"), "2018-01-01 00:00:00")
self.assertDate(RDU(DT("2018-01-01 00:00:00"), "years"), "2018-01-01 00:00:00")
self.assertDate(RDU(DT("2018-01-01 00:00:00") - TICK, "years"), "2017-01-01 00:00:00")
self.assertDate(RDU(DT("2018-09-04 14:38:11"), "months"), "2018-09-01 00:00:00")
self.assertDate(RDU(DT("2018-09-01 00:00:00"), "months"), "2018-09-01 00:00:00")
self.assertDate(RDU(DT("2018-09-01 00:00:00") - TICK, "months"), "2018-08-01 00:00:00")
# Note that 9/4 was a Tuesday, so start of the week (Sunday) is 9/2
self.assertDate(RDU(DT("2018-09-04 14:38:11"), "weeks"), "2018-09-02 00:00:00")
self.assertDate(RDU(DT("2018-09-02 00:00:00"), "weeks"), "2018-09-02 00:00:00")
self.assertDate(RDU(DT("2018-09-02 00:00:00") - TICK, "weeks"), "2018-08-26 00:00:00")
self.assertDate(RDU(DT("2018-09-04 14:38:11"), "days"), "2018-09-04 00:00:00")
self.assertDate(RDU(DT("2018-09-04 00:00:00"), "days"), "2018-09-04 00:00:00")
self.assertDate(RDU(DT("2018-09-04 00:00:00") - TICK, "days"), "2018-09-03 00:00:00")
self.assertDate(RDU(DT("2018-09-04 14:38:11"), "hours"), "2018-09-04 14:00:00")
self.assertDate(RDU(DT("2018-09-04 14:00:00"), "hours"), "2018-09-04 14:00:00")
self.assertDate(RDU(DT("2018-09-04 14:00:00") - TICK, "hours"), "2018-09-04 13:00:00")
self.assertDate(RDU(DT("2018-09-04 14:38:11"), "minutes"), "2018-09-04 14:38:00")
self.assertDate(RDU(DT("2018-09-04 14:38:00"), "minutes"), "2018-09-04 14:38:00")
self.assertDate(RDU(DT("2018-09-04 14:38:00") - TICK, "minutes"), "2018-09-04 14:37:00")
self.assertDate(RDU(DT("2018-09-04 14:38:11"), "seconds"), "2018-09-04 14:38:11")
self.assertDate(RDU(DT("2018-09-04 14:38:11") - TICK, "seconds"), "2018-09-04 14:38:10")
with self.assertRaisesRegexp(ValueError, r"Invalid unit inches"):
RDU(DT("2018-09-04 14:38:11"), "inches")
def test_round_down_to_unit_tz(self):
RDU = schedule._round_down_to_unit
dt = datetime(2018, 1, 1, 0, 0, 0, tzinfo=moment.tzinfo("America/New_York"))
self.assertDateIso(RDU(dt, "years"), "2018-01-01 00:00:00-05:00")
self.assertDateIso(RDU(dt - TICK, "years"), "2017-01-01 00:00:00-05:00")
self.assertDateIso(RDU(dt, "months"), "2018-01-01 00:00:00-05:00")
self.assertDateIso(RDU(dt - TICK, "months"), "2017-12-01 00:00:00-05:00")
# 2018-01-01 is a Monday
self.assertDateIso(RDU(dt, "weeks"), "2017-12-31 00:00:00-05:00")
self.assertDateIso(RDU(dt - timedelta(days=1) - TICK, "weeks"), "2017-12-24 00:00:00-05:00")
self.assertDateIso(RDU(dt, "days"), "2018-01-01 00:00:00-05:00")
self.assertDateIso(RDU(dt - TICK, "days"), "2017-12-31 00:00:00-05:00")
self.assertDateIso(RDU(dt, "hours"), "2018-01-01 00:00:00-05:00")
self.assertDateIso(RDU(dt - TICK, "hours"), "2017-12-31 23:00:00-05:00")
def test_parse_interval(self):
self.assertEqual(schedule._parse_interval("annual"), (1, "years"))
self.assertEqual(schedule._parse_interval("daily"), (1, "days"))
self.assertEqual(schedule._parse_interval("1-year"), (1, "years"))
self.assertEqual(schedule._parse_interval("1 year"), (1, "years"))
self.assertEqual(schedule._parse_interval("1 Years"), (1, "years"))
self.assertEqual(schedule._parse_interval("25-months"), (25, "months"))
self.assertEqual(schedule._parse_interval("3-day"), (3, "days"))
self.assertEqual(schedule._parse_interval("2-hour"), (2, "hours"))
with self.assertRaisesRegexp(ValueError, "Not a valid interval"):
schedule._parse_interval("1Year")
with self.assertRaisesRegexp(ValueError, "Not a valid interval"):
schedule._parse_interval("1y")
with self.assertRaisesRegexp(ValueError, "Unknown unit"):
schedule._parse_interval("1-daily")
def test_parse_slot(self):
self.assertDelta(schedule._parse_slot('Jan-15', 'years'), months=0, days=14)
self.assertDelta(schedule._parse_slot('1/15', 'years'), months=0, days=14)
self.assertDelta(schedule._parse_slot('march-1', 'years'), months=2, days=0)
self.assertDelta(schedule._parse_slot('03/09', 'years'), months=2, days=8)
self.assertDelta(schedule._parse_slot('/15', 'months'), days=14)
self.assertDelta(schedule._parse_slot('/1', 'months'), days=0)
self.assertDelta(schedule._parse_slot('Mon', 'weeks'), days=1)
self.assertDelta(schedule._parse_slot('tu', 'weeks'), days=2)
self.assertDelta(schedule._parse_slot('Friday', 'weeks'), days=5)
self.assertDelta(schedule._parse_slot('10am', 'days'), hours=10)
self.assertDelta(schedule._parse_slot('1:30pm', 'days'), hours=13, minutes=30)
self.assertDelta(schedule._parse_slot('15:45', 'days'), hours=15, minutes=45)
self.assertDelta(schedule._parse_slot('Apr-1 9am', 'years'), months=3, days=0, hours=9)
self.assertDelta(schedule._parse_slot('/3 12:30', 'months'), days=2, hours=12, minutes=30)
self.assertDelta(schedule._parse_slot('Sat 6:15pm', 'weeks'), days=6, hours=18, minutes=15)
self.assertDelta(schedule._parse_slot(':45', 'hours'), minutes=45)
self.assertDelta(schedule._parse_slot(':00', 'hours'), minutes=00)
self.assertDelta(schedule._parse_slot('+1d', 'days'), days=1)
self.assertDelta(schedule._parse_slot('+15d', 'months'), days=15)
self.assertDelta(schedule._parse_slot('+3w', 'weeks'), weeks=3)
self.assertDelta(schedule._parse_slot('+2m', 'years'), months=2)
self.assertDelta(schedule._parse_slot('+1y', 'years'), months=12)
# Test a few combinations.
self.assertDelta(schedule._parse_slot('+1y 4/5 3:45pm +30S', 'years'),
months=15, days=4, hours=15, minutes=45, seconds=30)
self.assertDelta(schedule._parse_slot('+2w Wed +6H +20M +40S', 'weeks'),
weeks=2, days=3, hours=6, minutes=20, seconds=40)
self.assertDelta(schedule._parse_slot('+2m /20 11pm', 'months'), months=2, days=19, hours=23)
self.assertDelta(schedule._parse_slot('+2M +30S', 'minutes'), minutes=2, seconds=30)
def test_parse_slot_errors(self):
# Test failures with duplicate units
with self.assertRaisesRegexp(ValueError, 'Duplicate unit'):
schedule._parse_slot('+1d +2d', 'weeks')
with self.assertRaisesRegexp(ValueError, 'Duplicate unit'):
schedule._parse_slot('9:30am +2H', 'days')
with self.assertRaisesRegexp(ValueError, 'Duplicate unit'):
schedule._parse_slot('/15 +1d', 'months')
with self.assertRaisesRegexp(ValueError, 'Duplicate unit'):
schedule._parse_slot('Feb-1 12:30pm +20M', 'years')
# Test failures with improper slot types
with self.assertRaisesRegexp(ValueError, 'Invalid slot.*for unit'):
schedule._parse_slot('Feb-1', 'weeks')
with self.assertRaisesRegexp(ValueError, 'Invalid slot.*for unit'):
schedule._parse_slot('Monday', 'months')
with self.assertRaisesRegexp(ValueError, 'Invalid slot.*for unit'):
schedule._parse_slot('4/15', 'hours')
with self.assertRaisesRegexp(ValueError, 'Invalid slot.*for unit'):
schedule._parse_slot('/1', 'years')
# Test failures with outright invalid slot syntax.
with self.assertRaisesRegexp(ValueError, 'Invalid slot'):
schedule._parse_slot('Feb:1', 'weeks')
with self.assertRaisesRegexp(ValueError, 'Invalid slot'):
schedule._parse_slot('/1d', 'months')
with self.assertRaisesRegexp(ValueError, 'Invalid slot'):
schedule._parse_slot('10', 'hours')
with self.assertRaisesRegexp(ValueError, 'Invalid slot'):
schedule._parse_slot('H1', 'years')
# Test failures with unknown values
with self.assertRaisesRegexp(ValueError, 'Unknown month'):
schedule._parse_slot('februarium-1', 'years')
with self.assertRaisesRegexp(ValueError, 'Unknown day of the week'):
schedule._parse_slot('snu', 'weeks')
with self.assertRaisesRegexp(ValueError, 'Unknown unit'):
schedule._parse_slot('+1t', 'hours')
def test_schedule(self):
# A few more examples. The ones in doctest strings are those that help documentation; the rest
# are in this file to keep the size of the main file more manageable.
# Note that the start of 2018-01-01 is a Monday
self.assertEqual(list(schedule.SCHEDULE(
"1-week: +1d 9:30am, +4d 3:30pm", start=datetime(2018,1,1), end=datetime(2018,1,31))),
[
DT("2018-01-01 09:30:00"), DT("2018-01-04 15:30:00"),
DT("2018-01-08 09:30:00"), DT("2018-01-11 15:30:00"),
DT("2018-01-15 09:30:00"), DT("2018-01-18 15:30:00"),
DT("2018-01-22 09:30:00"), DT("2018-01-25 15:30:00"),
DT("2018-01-29 09:30:00"),
])
self.assertEqual(list(schedule.SCHEDULE(
"3-month: +0d 12pm", start=datetime(2018,1,1), end=datetime(2018,6,30))),
[DT('2018-01-01 12:00:00'), DT('2018-04-01 12:00:00')])
# Ensure we can use date() object for start/end too.
self.assertEqual(list(schedule.SCHEDULE(
"3-month: +0d 12pm", start=date(2018,1,1), end=date(2018,6,30))),
[DT('2018-01-01 12:00:00'), DT('2018-04-01 12:00:00')])
# We can even use strings.
self.assertEqual(list(schedule.SCHEDULE(
"3-month: +0d 12pm", start="2018-01-01", end="2018-06-30")),
[DT('2018-01-01 12:00:00'), DT('2018-04-01 12:00:00')])
def test_timezone(self):
# Verify that the time zone of `start` determines the time zone of generated times.
tz_ny = moment.tzinfo("America/New_York")
self.assertEqual([d.isoformat(' ') for d in schedule.SCHEDULE(
"daily: 9am", count=4, start=datetime(2018, 2, 14, tzinfo=tz_ny))],
[ '2018-02-14 09:00:00-05:00', '2018-02-15 09:00:00-05:00',
'2018-02-16 09:00:00-05:00', '2018-02-17 09:00:00-05:00' ])
tz_la = moment.tzinfo("America/Los_Angeles")
self.assertEqual([d.isoformat(' ') for d in schedule.SCHEDULE(
"daily: 9am, 4:30pm", count=4, start=datetime(2018, 2, 14, 9, 0, tzinfo=tz_la))],
[ '2018-02-14 09:00:00-08:00', '2018-02-14 16:30:00-08:00',
'2018-02-15 09:00:00-08:00', '2018-02-15 16:30:00-08:00' ])
tz_utc = moment.tzinfo("UTC")
self.assertEqual([d.isoformat(' ') for d in schedule.SCHEDULE(
"daily: 9am, 4:30pm", count=4, start=datetime(2018, 2, 14, 17, 0, tzinfo=tz_utc))],
[ '2018-02-15 09:00:00+00:00', '2018-02-15 16:30:00+00:00',
'2018-02-16 09:00:00+00:00', '2018-02-16 16:30:00+00:00' ])
# This is not really a test but just a way to see some timing information about Schedule
# implementation. Run with env PY_TIMING_TESTS=1 in the environment, and the console output will
# include the measured times.
@unittest.skipUnless(os.getenv("PY_TIMING_TESTS") == "1", "Set PY_TIMING_TESTS=1 for timing")
def test_timing(self):
N = 1000
sched = "weekly: Mo 10:30am, We 10:30am"
setup = """
from functions import schedule
from datetime import datetime
"""
setup = "from functions import test_schedule as t"
expected_result = [
datetime(2018, 9, 24, 10, 30), datetime(2018, 9, 26, 22, 30),
datetime(2018, 10, 1, 10, 30), datetime(2018, 10, 3, 22, 30),
]
self.assertEqual(timing_schedule_full(), expected_result)
t = min(timeit.repeat(stmt="t.timing_schedule_full()", setup=setup, number=N, repeat=3))
print "\n*** SCHEDULE call with 4 points: %.2f us" % (t * 1000000 / N)
t = min(timeit.repeat(stmt="t.timing_schedule_init()", setup=setup, number=N, repeat=3))
print "*** Schedule constructor: %.2f us" % (t * 1000000 / N)
self.assertEqual(timing_schedule_series(), expected_result)
t = min(timeit.repeat(stmt="t.timing_schedule_series()", setup=setup, number=N, repeat=3))
print "*** Schedule series with 4 points: %.2f us" % (t * 1000000 / N)
def timing_schedule_full():
return list(schedule.SCHEDULE("weekly: Mo 10:30am, We 10:30pm",
start=datetime(2018, 9, 23), count=4))
def timing_schedule_init():
return schedule.Schedule("weekly: Mo 10:30am, We 10:30pm")
def timing_schedule_series(sched=schedule.Schedule("weekly: Mo 10:30am, We 10:30pm")):
return list(sched.series(datetime(2018, 9, 23), None, count=4))

@ -0,0 +1,590 @@
# -*- coding: UTF-8 -*-
import datetime
import dateutil.parser
import numbers
import re
from usertypes import AltText # pylint: disable=import-error
def CHAR(table_number):
"""
Convert a number into a character according to the current Unicode table.
Same as `unichr(number)`.
>>> CHAR(65)
u'A'
>>> CHAR(33)
u'!'
"""
return unichr(table_number)
# See http://stackoverflow.com/a/93029/328565
_control_chars = ''.join(map(unichr, range(0,32) + range(127,160)))
_control_char_re = re.compile('[%s]' % re.escape(_control_chars))
def CLEAN(text):
"""
Returns the text with the non-printable characters removed.
This removes both characters with values 0 through 31, and other Unicode characters in the
"control characters" category.
>>> CLEAN(CHAR(9) + "Monthly report" + CHAR(10))
u'Monthly report'
"""
return _control_char_re.sub('', text)
def CODE(string):
"""
Returns the numeric Unicode map value of the first character in the string provided.
Same as `ord(string[0])`.
>>> CODE("A")
65
>>> CODE("!")
33
>>> CODE("!A")
33
"""
return ord(string[0])
def CONCATENATE(string, *more_strings):
"""
Joins together any number of text strings into one string. Also available under the name
`CONCAT`. Same as the Python expression `"".join(array_of_strings)`.
>>> CONCATENATE("Stream population for ", "trout", " ", "species", " is ", 32, "/mile.")
u'Stream population for trout species is 32/mile.'
>>> CONCATENATE("In ", 4, " days it is ", datetime.date(2016,1,1))
u'In 4 days it is 2016-01-01'
>>> CONCATENATE("abc")
u'abc'
>>> CONCAT(0, "abc")
u'0abc'
"""
return u''.join(unicode(val) for val in (string,) + more_strings)
CONCAT = CONCATENATE
def DOLLAR(number, decimals=2):
"""
Formats a number into a formatted dollar amount, with decimals rounded to the specified place (.
If decimals value is omitted, it defaults to 2.
>>> DOLLAR(1234.567)
'$1,234.57'
>>> DOLLAR(1234.567, -2)
'$1,200'
>>> DOLLAR(-1234.567, -2)
'($1,200)'
>>> DOLLAR(-0.123, 4)
'($0.1230)'
>>> DOLLAR(99.888)
'$99.89'
>>> DOLLAR(0)
'$0.00'
>>> DOLLAR(10, 0)
'$10'
"""
formatted = "${:,.{}f}".format(round(abs(number), decimals), max(0, decimals))
return formatted if number >= 0 else "(" + formatted + ")"
def EXACT(string1, string2):
"""
Tests whether two strings are identical. Same as `string2 == string2`.
>>> EXACT("word", "word")
True
>>> EXACT("Word", "word")
False
>>> EXACT("w ord", "word")
False
"""
return string1 == string2
def FIND(find_text, within_text, start_num=1):
"""
Returns the position at which a string is first found within text.
Find is case-sensitive. The returned position is 1 if within_text starts with find_text.
Start_num specifies the character at which to start the search, defaulting to 1 (the first
character of within_text).
If find_text is not found, or start_num is invalid, raises ValueError.
>>> FIND("M", "Miriam McGovern")
1
>>> FIND("m", "Miriam McGovern")
6
>>> FIND("M", "Miriam McGovern", 3)
8
>>> FIND(" #", "Hello world # Test")
12
>>> FIND("gle", "Google", 1)
4
>>> FIND("GLE", "Google", 1)
Traceback (most recent call last):
...
ValueError: substring not found
>>> FIND("page", "homepage")
5
>>> FIND("page", "homepage", 6)
Traceback (most recent call last):
...
ValueError: substring not found
"""
return within_text.index(find_text, start_num - 1) + 1
def FIXED(number, decimals=2, no_commas=False):
"""
Formats a number with a fixed number of decimal places (2 by default), and commas.
If no_commas is True, then omits the commas.
>>> FIXED(1234.567, 1)
'1,234.6'
>>> FIXED(1234.567, -1)
'1,230'
>>> FIXED(-1234.567, -1, True)
'-1230'
>>> FIXED(44.332)
'44.33'
>>> FIXED(3521.478, 2, False)
'3,521.48'
>>> FIXED(-3521.478, 1, True)
'-3521.5'
>>> FIXED(3521.478, 0, True)
'3521'
>>> FIXED(3521.478, -2, True)
'3500'
"""
comma_flag = '' if no_commas else ','
return "{:{}.{}f}".format(round(number, decimals), comma_flag, max(0, decimals))
def LEFT(string, num_chars=1):
"""
Returns a substring of length num_chars from the beginning of the given string. If num_chars is
omitted, it is assumed to be 1. Same as `string[:num_chars]`.
>>> LEFT("Sale Price", 4)
'Sale'
>>> LEFT('Swededn')
'S'
>>> LEFT('Text', -1)
Traceback (most recent call last):
...
ValueError: num_chars invalid
"""
if num_chars < 0:
raise ValueError("num_chars invalid")
return string[:num_chars]
def LEN(text):
"""
Returns the number of characters in a text string. Same as `len(text)`.
>>> LEN("Phoenix, AZ")
11
>>> LEN("")
0
>>> LEN(" One ")
11
"""
return len(text)
def LOWER(text):
"""
Converts a specified string to lowercase. Same as `text.lower()`.
>>> LOWER("E. E. Cummings")
'e. e. cummings'
>>> LOWER("Apt. 2B")
'apt. 2b'
"""
return text.lower()
def MID(text, start_num, num_chars):
"""
Returns a segment of a string, starting at start_num. The first character in text has
start_num 1.
>>> MID("Fluid Flow", 1, 5)
'Fluid'
>>> MID("Fluid Flow", 7, 20)
'Flow'
>>> MID("Fluid Flow", 20, 5)
''
>>> MID("Fluid Flow", 0, 5)
Traceback (most recent call last):
...
ValueError: start_num invalid
"""
if start_num < 1:
raise ValueError("start_num invalid")
return text[start_num - 1 : start_num - 1 + num_chars]
def PROPER(text):
"""
Capitalizes each word in a specified string. It converts the first letter of each word to
uppercase, and all other letters to lowercase. Same as `text.title()`.
>>> PROPER('this is a TITLE')
'This Is A Title'
>>> PROPER('2-way street')
'2-Way Street'
>>> PROPER('76BudGet')
'76Budget'
"""
return text.title()
def REGEXEXTRACT(text, regular_expression):
"""
Extracts the first part of text that matches regular_expression.
>>> REGEXEXTRACT("Google Doc 101", "[0-9]+")
'101'
>>> REGEXEXTRACT("The price today is $826.25", "[0-9]*\\.[0-9]+[0-9]+")
'826.25'
If there is a parenthesized expression, it is returned instead of the whole match.
>>> REGEXEXTRACT("(Content) between brackets", "\\(([A-Za-z]+)\\)")
'Content'
>>> REGEXEXTRACT("Foo", "Bar")
Traceback (most recent call last):
...
ValueError: REGEXEXTRACT text does not match
"""
m = re.search(regular_expression, text)
if not m:
raise ValueError("REGEXEXTRACT text does not match")
return m.group(1) if m.lastindex else m.group(0)
def REGEXMATCH(text, regular_expression):
"""
Returns whether a piece of text matches a regular expression.
>>> REGEXMATCH("Google Doc 101", "[0-9]+")
True
>>> REGEXMATCH("Google Doc", "[0-9]+")
False
>>> REGEXMATCH("The price today is $826.25", "[0-9]*\\.[0-9]+[0-9]+")
True
>>> REGEXMATCH("(Content) between brackets", "\\(([A-Za-z]+)\\)")
True
>>> REGEXMATCH("Foo", "Bar")
False
"""
return bool(re.search(regular_expression, text))
def REGEXREPLACE(text, regular_expression, replacement):
"""
Replaces all parts of text matching the given regular expression with replacement text.
>>> REGEXREPLACE("Google Doc 101", "[0-9]+", "777")
'Google Doc 777'
>>> REGEXREPLACE("Google Doc", "[0-9]+", "777")
'Google Doc'
>>> REGEXREPLACE("The price is $826.25", "[0-9]*\\.[0-9]+[0-9]+", "315.75")
'The price is $315.75'
>>> REGEXREPLACE("(Content) between brackets", "\\(([A-Za-z]+)\\)", "Word")
'Word between brackets'
>>> REGEXREPLACE("Foo", "Bar", "Baz")
'Foo'
"""
return re.sub(regular_expression, replacement, text)
def REPLACE(old_text, start_num, num_chars, new_text):
"""
Replaces part of a text string with a different text string. Start_num is counted from 1.
>>> REPLACE("abcdefghijk", 6, 5, "*")
'abcde*k'
>>> REPLACE("2009", 3, 2, "10")
'2010'
>>> REPLACE('123456', 1, 3, '@')
'@456'
>>> REPLACE('foo', 1, 0, 'bar')
'barfoo'
>>> REPLACE('foo', 0, 1, 'bar')
Traceback (most recent call last):
...
ValueError: start_num invalid
"""
if start_num < 1:
raise ValueError("start_num invalid")
return old_text[:start_num - 1] + new_text + old_text[start_num - 1 + num_chars:]
def REPT(text, number_times):
"""
Returns specified text repeated a number of times. Same as `text * number_times`.
The result of the REPT function cannot be longer than 32767 characters, or it raises a
ValueError.
>>> REPT("*-", 3)
'*-*-*-'
>>> REPT('-', 10)
'----------'
>>> REPT('-', 0)
''
>>> len(REPT('---', 10000))
30000
>>> REPT('---', 11000)
Traceback (most recent call last):
...
ValueError: number_times invalid
>>> REPT('-', -1)
Traceback (most recent call last):
...
ValueError: number_times invalid
"""
if number_times < 0 or len(text) * number_times > 32767:
raise ValueError("number_times invalid")
return text * int(number_times)
def RIGHT(string, num_chars=1):
"""
Returns a substring of length num_chars from the end of a specified string. If num_chars is
omitted, it is assumed to be 1. Same as `string[-num_chars:]`.
>>> RIGHT("Sale Price", 5)
'Price'
>>> RIGHT('Stock Number')
'r'
>>> RIGHT('Text', 100)
'Text'
>>> RIGHT('Text', -1)
Traceback (most recent call last):
...
ValueError: num_chars invalid
"""
if num_chars < 0:
raise ValueError("num_chars invalid")
return string[-num_chars:]
def SEARCH(find_text, within_text, start_num=1):
"""
Returns the position at which a string is first found within text, ignoring case.
Find is case-sensitive. The returned position is 1 if within_text starts with find_text.
Start_num specifies the character at which to start the search, defaulting to 1 (the first
character of within_text).
If find_text is not found, or start_num is invalid, raises ValueError.
>>> SEARCH("e", "Statements", 6)
7
>>> SEARCH("margin", "Profit Margin")
8
>>> SEARCH(" ", "Profit Margin")
7
>>> SEARCH('"', 'The "boss" is here.')
5
>>> SEARCH("gle", "Google")
4
>>> SEARCH("GLE", "Google")
4
"""
# .lower() isn't always correct for unicode. See http://stackoverflow.com/a/29247821/328565
return within_text.lower().index(find_text.lower(), start_num - 1) + 1
def SUBSTITUTE(text, old_text, new_text, instance_num=None):
u"""
Replaces existing text with new text in a string. It is useful when you know the substring of
text to replace. Use REPLACE when you know the position of text to replace.
If instance_num is given, it specifies which occurrence of old_text to replace. If omitted, all
occurrences are replaced.
Same as `text.replace(old_text, new_text)` when instance_num is omitted.
>>> SUBSTITUTE("Sales Data", "Sales", "Cost")
'Cost Data'
>>> SUBSTITUTE("Quarter 1, 2008", "1", "2", 1)
'Quarter 2, 2008'
>>> SUBSTITUTE("Quarter 1, 2011", "1", "2", 3)
'Quarter 1, 2012'
More tests:
>>> SUBSTITUTE("Hello world", "", "-")
'Hello world'
>>> SUBSTITUTE("Hello world", " ", "-")
'Hello-world'
>>> SUBSTITUTE("Hello world", " ", 12.1)
'Hello12.1world'
>>> SUBSTITUTE(u"Hello world", u" ", 12.1)
u'Hello12.1world'
>>> SUBSTITUTE("Hello world", "world", "")
'Hello '
>>> SUBSTITUTE("Hello", "world", "")
'Hello'
Overlapping matches are all counted when looking for instance_num.
>>> SUBSTITUTE('abababab', 'abab', 'xxxx')
'xxxxxxxx'
>>> SUBSTITUTE('abababab', 'abab', 'xxxx', 1)
'xxxxabab'
>>> SUBSTITUTE('abababab', 'abab', 'xxxx', 2)
'abxxxxab'
>>> SUBSTITUTE('abababab', 'abab', 'xxxx', 3)
'ababxxxx'
>>> SUBSTITUTE('abababab', 'abab', 'xxxx', 4)
'abababab'
>>> SUBSTITUTE('abababab', 'abab', 'xxxx', 0)
Traceback (most recent call last):
...
ValueError: instance_num invalid
"""
if not old_text:
return text
if not isinstance(new_text, basestring):
new_text = str(new_text)
if instance_num is None:
return text.replace(old_text, new_text)
if instance_num <= 0:
raise ValueError("instance_num invalid")
# No trivial way to replace nth occurrence.
i = -1
for c in xrange(instance_num):
i = text.find(old_text, i + 1)
if i < 0:
return text
return text[:i] + new_text + text[i + len(old_text):]
def T(value):
"""
Returns value if value is text, or the empty string when value is not text.
>>> T('Text')
'Text'
>>> T(826)
''
>>> T('826')
'826'
>>> T(False)
''
>>> T('100 points')
'100 points'
>>> T(AltText('Text'))
'Text'
>>> T(float('nan'))
''
"""
return (value if isinstance(value, basestring) else
str(value) if isinstance(value, AltText) else "")
def TEXT(number, format_type):
"""
Converts a number into text according to a specified format. It is not yet implemented in
Grist.
"""
raise NotImplementedError()
_trim_re = re.compile(r' +')
def TRIM(text):
"""
Removes all spaces from text except for single spaces between words. Note that TRIM does not
remove other whitespace such as tab or newline characters.
>>> TRIM(" First Quarter\\n Earnings ")
'First Quarter\\n Earnings'
>>> TRIM("")
''
"""
return _trim_re.sub(' ', text.strip())
def UPPER(text):
"""
Converts a specified string to uppercase. Same as `text.lower()`.
>>> UPPER("e. e. cummings")
'E. E. CUMMINGS'
>>> UPPER("Apt. 2B")
'APT. 2B'
"""
return text.upper()
def VALUE(text):
"""
Converts a string in accepted date, time or number formats into a number or date.
>>> VALUE("$1,000")
1000
>>> VALUE("16:48:00") - VALUE("12:00:00")
datetime.timedelta(0, 17280)
>>> VALUE("01/01/2012")
datetime.datetime(2012, 1, 1, 0, 0)
>>> VALUE("")
0
>>> VALUE(0)
0
>>> VALUE("826")
826
>>> VALUE("-826.123123123")
-826.123123123
>>> VALUE(float('nan'))
nan
>>> VALUE("Invalid")
Traceback (most recent call last):
...
ValueError: text cannot be parsed to a number
>>> VALUE("13/13/13")
Traceback (most recent call last):
...
ValueError: text cannot be parsed to a number
"""
# This is not particularly robust, but makes an attempt to handle a number of cases: numbers,
# including optional comma separators, dates/times, leading dollar-sign.
if isinstance(text, (numbers.Number, datetime.date)):
return text
text = text.strip().lstrip('$')
nocommas = text.replace(',', '')
if nocommas == "":
return 0
try:
return int(nocommas)
except ValueError:
pass
try:
return float(nocommas)
except ValueError:
pass
try:
return dateutil.parser.parse(text)
except ValueError:
pass
raise ValueError('text cannot be parsed to a number')

@ -0,0 +1,189 @@
"""
gencode.py is the module that generates a python module based on the schema in a grist document.
An example of the module it generates is available in usercode.py.
The schema for grist data is:
<schema> = [ <table_info> ]
<table_info> = {
"tableId": <string>,
"columns": [ <column_info> ],
}
<column_info> = {
"id": <string>,
"type": <string>
"isFormula": <boolean>,
"formula": <opt_string>,
}
"""
import re
import imp
from collections import OrderedDict
import codebuilder
import summary
import table
import textbuilder
from usertypes import get_type_default
import logger
log = logger.Logger(__name__, logger.INFO)
indent_str = " "
# Matches newlines that are followed by a non-empty line.
indent_line_re = re.compile(r'^(?=.*\S)', re.M)
def indent(body, levels=1):
"""Indents all lines in body (which should be a textbuilder.Builder), except empty ones."""
patches = textbuilder.make_regexp_patches(body.get_text(), indent_line_re, indent_str * levels)
return textbuilder.Replacer(body, patches)
#----------------------------------------------------------------------
def get_grist_type(col_type):
"""Returns code for a grist usertype object given a column type string."""
col_type_split = col_type.split(':', 1)
typename = col_type_split[0]
if typename == 'Ref':
typename = 'Reference'
elif typename == 'RefList':
typename = 'ReferenceList'
arg = col_type_split[1] if len(col_type_split) > 1 else ''
arg = arg.strip().replace("'", "\\'")
return "grist.%s(%s)" % (typename, ("'%s'" % arg) if arg else '')
class GenCode(object):
"""
GenCode generates the Python code for a Grist document, including converting formulas to Python
functions and producing a Python specification of all the tables with data and formula fields.
To save the costly work of generating formula code, it maintains a formula cache. It is a
dictionary mapping (table_id, col_id, formula) to a textbuilder.Builder. On each run of
make_module(), it will use the previously cached values for lookups, and replace the contents
of the cache with current values. If ever we need to generate code for unrelated schemas, to
benefit from the cache, a separate GenCode object should be used for each schema.
"""
def __init__(self):
self._formula_cache = {}
self._new_formula_cache = {}
self._full_builder = None
self._user_builder = None
self._usercode = None
def _make_formula_field(self, col_info, table_id, name=None, include_type=True):
"""Returns the code for a formula field."""
# If the caller didn't specify a special name, use the colId
name = name or col_info.colId
decl = "def %s(rec, table):\n" % name
# This is where we get to use the formula cache, and save the work of rebuilding formulas.
key = (table_id, col_info.colId, col_info.formula)
body = self._formula_cache.get(key)
if body is None:
default = get_type_default(col_info.type)
body = codebuilder.make_formula_body(col_info.formula, default, (table_id, col_info.colId))
self._new_formula_cache[key] = body
decorator = ''
if include_type and col_info.type != 'Any':
decorator = '@grist.formulaType(%s)\n' % get_grist_type(col_info.type)
return textbuilder.Combiner(['\n' + decorator + decl, indent(body), '\n'])
def _make_data_field(self, col_info, table_id):
"""Returns the code for a data field."""
parts = []
if col_info.formula:
parts.append(self._make_formula_field(col_info, table_id,
name=table.get_default_func_name(col_info.colId),
include_type=False))
parts.append("%s = %s\n" % (col_info.colId, get_grist_type(col_info.type)))
return textbuilder.Combiner(parts)
def _make_field(self, col_info, table_id):
"""Returns the code for a field."""
assert not col_info.colId.startswith("_")
if col_info.isFormula:
return self._make_formula_field(col_info, table_id)
else:
return self._make_data_field(col_info, table_id)
def _make_table_model(self, table_info, summary_tables):
"""Returns the code for a table model."""
table_id = table_info.tableId
source_table_id = summary.decode_summary_table_name(table_id)
# Sort columns by "isFormula" to output all data columns before all formula columns.
columns = sorted(table_info.columns.itervalues(), key=lambda c: c.isFormula)
parts = ["@grist.UserTable\nclass %s:\n" % table_id]
if source_table_id:
parts.append(indent(textbuilder.Text("_summarySourceTable = %r\n" % source_table_id)))
for col_info in columns:
parts.append(indent(self._make_field(col_info, table_id)))
if summary_tables:
# Include summary formulas, for the user's information.
formulas = OrderedDict((c.colId, c) for s in summary_tables
for c in s.columns.itervalues() if c.isFormula)
parts.append(indent(textbuilder.Text("\nclass _Summary:\n")))
for col_info in formulas.itervalues():
parts.append(indent(self._make_field(col_info, table_id), levels=2))
return textbuilder.Combiner(parts)
def make_module(self, schema):
"""Regenerates the code text and usercode module from upated document schema."""
# Collect summary tables to group them by source table.
summary_tables = {}
for table_info in schema.itervalues():
source_table_id = summary.decode_summary_table_name(table_info.tableId)
if source_table_id:
summary_tables.setdefault(source_table_id, []).append(table_info)
fullparts = ["import grist\n" +
"from functions import * # global uppercase functions\n" +
"import datetime, math, re # modules commonly needed in formulas\n"]
userparts = fullparts[:]
for table_info in schema.itervalues():
table_model = self._make_table_model(table_info, summary_tables.get(table_info.tableId))
fullparts.append("\n\n")
fullparts.append(table_model)
if not _is_special_table(table_info.tableId):
userparts.append("\n\n")
userparts.append(table_model)
# Once all formulas are generated, replace the formula cache with the newly-populated version.
self._formula_cache = self._new_formula_cache
self._new_formula_cache = {}
self._full_builder = textbuilder.Combiner(fullparts)
self._user_builder = textbuilder.Combiner(userparts)
self._usercode = exec_module_text(self._full_builder.get_text())
def get_user_text(self):
"""Returns the text of the user-facing part of the generated code."""
return self._user_builder.get_text()
@property
def usercode(self):
"""Returns the generated usercode module."""
return self._usercode
def grist_names(self):
return codebuilder.parse_grist_names(self._full_builder)
def _is_special_table(table_id):
return table_id.startswith("_grist_") or bool(summary.decode_summary_table_name(table_id))
def exec_module_text(module_text):
# pylint: disable=exec-used
mod = imp.new_module("usercode")
exec module_text in mod.__dict__
return mod

@ -0,0 +1,141 @@
def _is_array(obj):
return isinstance(obj, list)
def get(obj, path):
"""
Looks up and returns a path in the object. Returns None if the path isn't there.
"""
for part in path:
try:
obj = obj[part]
except(KeyError, IndexError):
return None
return obj
def glob(obj, path, func, extra_arg):
"""
Resolves wildcards in `path`, calling func for all matching paths. Returns the number of
times that func was called.
obj - An object to scan.
path - Path to an item in an object or an array in obj. May contain the special key '*', which
-- for arrays only -- means "for all indices".
func - Will be called as func(subobj, key, fullPath, extraArg).
extra_arg - An arbitrary value to pass along to func, for convenience.
Returns count of matching paths, for which func got called.
"""
return _globHelper(obj, path, path, func, extra_arg)
def _globHelper(obj, path, full_path, func, extra_arg):
for i, part in enumerate(path[:-1]):
if part == "*" and _is_array(obj):
# We got an array wildcard
subpath = path[i + 1:]
count = 0
for subobj in obj:
count += _globHelper(subobj, subpath, full_path, func, extra_arg)
return count
try:
obj = obj[part]
except:
raise Exception("gpath.glob: non-existent object at " +
describe(full_path[:len(full_path) - len(path) + i + 1]))
return func(obj, path[-1], full_path, extra_arg) or 1
def place(obj, path, value):
"""
Sets or deletes an object property in DocObj.
gpath - Path to an Object in obj.
value - Any value. Setting None will remove the selected object key.
"""
return glob(obj, path, _placeHelper, value)
def _placeHelper(subobj, key, full_path, value):
if not isinstance(subobj, dict):
raise Exception("gpath.place: not a plain object at " + describe(dirname(full_path)))
if value is not None:
subobj[key] = value
elif key in subobj:
del subobj[key]
def _checkIsArray(subobj, errPrefix, index, itemPath, isInsert):
"""
This is a helper for checking operations on arrays, and throwing descriptive errors.
"""
if subobj is None:
raise Exception(errPrefix + ": non-existent object at " + describe(dirname(itemPath)))
elif not _is_array(subobj):
raise Exception(errPrefix + ": not an array at " + describe(dirname(itemPath)))
else:
length = len(subobj)
validIndex = (isinstance(index, int) and index >= 0 and index < length)
validInsertIndex = (index is None or index == length)
if not (validIndex or (isInsert and validInsertIndex)):
raise Exception(errPrefix + ": invalid array index: " + describe(itemPath))
def insert(obj, path, value):
"""
Inserts an element into an array in DocObj.
gpath - Path to an item in an array in obj.
The new value will be inserted before the item pointed to by gpath.
The last component of gpath may be null, in which case the value is appended at the end.
value - Any value.
"""
return glob(obj, path, _insertHelper, value)
def _insertHelper(subobj, index, fullPath, value):
_checkIsArray(subobj, "gpath.insert", index, fullPath, True)
if index is None:
subobj.append(value)
else:
subobj.insert(index, value)
def update(obj, path, value):
"""
Updates an element in an array in DocObj.
gpath - Path to an item in an array in obj.
value - Any value.
"""
return glob(obj, path, _updateHelper, value)
def _updateHelper(subobj, index, fullPath, value):
if index == '*':
_checkIsArray(subobj, "gpath.update", None, fullPath, True)
for i in xrange(len(subobj)):
subobj[i] = value
return len(subobj)
else:
_checkIsArray(subobj, "gpath.update", index, fullPath, False)
subobj[index] = value
def remove(obj, path):
"""
Removes an element from an array in DocObj.
gpath - Path to an item in an array in obj.
"""
return glob(obj, path, _removeHelper, None)
def _removeHelper(subobj, index, fullPath, _):
_checkIsArray(subobj, "gpath.remove", index, fullPath, False)
del subobj[index]
def dirname(path):
"""
Returns path without the last component, like a directory name in a filesystem path.
"""
return path[:-1]
def basename(path):
"""
Returns the last component of path, like base name of a filesystem path.
"""
return path[-1] if path else None
def describe(path):
"""
Returns a human-readable representation of path.
"""
return "/" + "/".join(str(p) for p in path)

@ -0,0 +1,15 @@
"""
This file packages together other modules needed for usercode in order to create
a consistent API accessible with only "import grist".
"""
# pylint: disable=unused-import
# These imports are used in processing generated usercode.
from usertypes import Any, Text, Blob, Int, Bool, Date, DateTime, \
Numeric, Choice, Id, Attachments, AltText, ifError
from usertypes import PositionNumber, ManualSortPos, Reference, ReferenceList, formulaType
from table import UserTable
from records import Record, RecordSet
DOCS = [(__name__, (Record, RecordSet, UserTable)),
('lookup', (UserTable.lookupOne, UserTable.lookupRecords))]

@ -0,0 +1,117 @@
"""
A module for creating and sanitizing identifiers
"""
import re
from string import ascii_uppercase
import itertools
from keyword import iskeyword
import logger
log = logger.Logger(__name__, logger.INFO)
_invalid_ident_char_re = re.compile(r'[^a-zA-Z0-9_]+')
_invalid_ident_start_re = re.compile(r'^(?=[0-9_])')
def _sanitize_ident(ident, prefix="c", capitalize=False):
"""
Helper for pick_ident, which given a suggested identifier, massages it to ensure it's valid for
python (and sqlite). In particular, leaves only alphanumeric characters, and prepends `prefix`
if it doesn't start with a letter.
Returns empty string if there are no valid identifier characters, so consider using as
(_sanitize_ident(...) or "your_default").
"""
ident = "" if ident is None else str(ident)
ident = _invalid_ident_char_re.sub('_', ident).lstrip('_')
ident = _invalid_ident_start_re.sub(prefix, ident)
if not ident:
return ident
if capitalize:
# Just capitalize the first letter (do NOT lowercase other letters like str.title() does).
ident = ident[0].capitalize() + ident[1:]
# Prevent names that are illegal to assign to
# iskeyword doesn't catch None/True/False in Python 2.x, but does in 3.x
# (None is actually an error, Python 2.x doesn't make assigning to True or False an error,
# but I think we don't want to allow users to do that)
while iskeyword(ident) or ident in ['None', 'True', 'False']:
ident = prefix + ident
return ident
_ends_in_digit_re = re.compile(r'\d$')
def _add_suffix(ident_base, avoid=set(), next_suffix=1):
"""
Helper which appends a numerical suffix to ident_base, incrementing it until the result doesn't
conflict with anything in the `avoid` set.
"""
if _ends_in_digit_re.search(ident_base):
ident_base += "_"
while True:
ident = "%s%d" % (ident_base, next_suffix)
if ident.upper() not in avoid:
return ident
next_suffix += 1
def _maybe_add_suffix(ident, avoid):
"""
Returns the first of ident, ident2, ident3 etc. that's not in the `avoid` set.
"""
return ident if (ident.upper() not in avoid) else _add_suffix(ident, avoid, 2)
def _uppercase(avoid):
return {name.upper() for name in avoid}
def pick_table_ident(ident, avoid=set()):
"""
Given a suggested identifier (which may be None), creates a sanitized table identifier,
possibly with a numerical suffix that doesn't conflict with anything in the `avoid` set.
"""
avoid = _uppercase(avoid)
ident = _sanitize_ident(ident, prefix="T", capitalize=True)
return _maybe_add_suffix(ident, avoid) if ident else _add_suffix("Table", avoid, 1)
def pick_col_ident(ident, avoid=set()):
"""
Given a suggested identifier (which may be None), creates a sanitized column identifier,
possibly with a numerical suffix that doesn't conflict with anything in the `avoid` set.
"""
avoid = _uppercase(avoid)
ident = _sanitize_ident(ident, prefix="c")
return _maybe_add_suffix(ident, avoid) if ident else _gen_ident(avoid)
def pick_col_ident_list(ident_list, avoid=set()):
"""
Given a list of suggested identifiers (which may be invalid), returns a list of valid sanitized
unique identifiers, that don't conflict with anything in the `avoid` set or with each other.
"""
avoid = _uppercase(avoid)
result = []
for ident in ident_list:
ident = pick_col_ident(ident, avoid=avoid)
avoid.add(ident.upper())
result.append(ident)
return result
def _gen_ident(avoid):
"""
Helper for pick_ident, which generates a valid identifier
when pick_ident is called without a suggested identifier or default.
It returns the first identifier that does not conflict with any elements of the avoid set.
The identifier is a letter or combination of letters that follows a
similar pattern to what excel uses for naming columns.
i.e. A, B, ... Z, AA, AB, ... AZ, BA, etc
"""
avoid = _uppercase(avoid)
for letter in _make_letters():
if letter not in avoid:
return letter
def _make_letters():
length = 1
while True:
for letters in itertools.product(ascii_uppercase, repeat=length):
yield ''.join(letters)
length +=1

@ -0,0 +1,309 @@
from collections import namedtuple
import column
import identifiers
import logger
log = logger.Logger(__name__, logger.INFO)
# Prefix for transform columns created during imports.
_import_transform_col_prefix = 'gristHelper_Import_'
def _gen_colids(transform_rule):
"""
For a transform_rule with colIds = None,
fills in colIds generated from labels.
"""
dest_cols = transform_rule["destCols"]
if any(dc["colId"] for dc in dest_cols):
raise ValueError("transform_rule already has colIds in _gen_colids")
col_labels = [dest_col["label"] for dest_col in dest_cols]
col_ids = identifiers.pick_col_ident_list(col_labels, avoid={'id'})
for dest_col, col_id in zip(dest_cols, col_ids):
dest_col["colId"] = col_id
def _strip_prefixes(transform_rule):
"If transform_rule has prefixed _col_ids, strips prefix"
dest_cols = transform_rule["destCols"]
for dest_col in dest_cols:
colId = dest_col["colId"]
if colId and colId.startswith(_import_transform_col_prefix):
dest_col["colId"] = colId[len(_import_transform_col_prefix):]
class ImportActions(object):
def __init__(self, useractions, docmodel, engine):
self._useractions = useractions
self._docmodel = docmodel
self._engine = engine
########################
## NOTES
# transform_rule is an object like this: {
# destCols: [ { colId, label, type, formula }, ... ],
# ..., # other params unused in sandbox
# }
#
# colId is defined if into_new_table, otherwise is None
# GenImporterView gets a hidden table with a preview of the import data (~100 rows)
# It adds formula cols and viewsections to the hidden table for the user to
# preview and edit import options. GenImporterView can start with a default transform_rule
# from table columns, or use one that's passed in (for reimporting).
# client/components/Importer.ts then puts together transform_rule, which
# specifies destination column formulas, types, labels, and colIds. It only contains colIds
# if importing into an existing table, and they are sometimes prefixed with
# _import_transform_col_prefix (if transform_rule comes from client)
# TransformAndFinishImport gets the full hidden_table (reparsed) and a transform_rule,
# (or can use a default one if it's not provided). It fills in colIds if necessary and
# strips colId prefixes. It also skips creating some formula columns
# (ones with trivial copy formulas) as an optimization.
def _MakeDefaultTransformRule(self, hidden_table_id, dest_table_id):
"""
Makes a basic transform_rule.dest_cols copying all the source cols
hidden_table_id: table with src data
dest_table_id: table data is going to
If dst_table is null, copy all src columns
If dst_table exists, copy all dst columns, and make copy formulas if any names match
returns transform_rule with only destCols filled in
"""
tables = self._docmodel.tables
hidden_table_rec = tables.lookupOne(tableId=hidden_table_id)
# will use these to set default formulas (if column names match in src and dest table)
src_cols = {c.colId for c in hidden_table_rec.columns}
target_table = tables.lookupOne(tableId=dest_table_id) if dest_table_id else hidden_table_rec
target_cols = target_table.columns
# makes dest_cols for each column in target_cols (defaults to same columns as hidden_table)
#loop through visible, non-formula target columns
dest_cols = []
for c in target_cols:
if column.is_visible_column(c.colId) and (not c.isFormula or c.formula == ""):
dest_cols.append( {
"label": c.label,
"colId": c.colId if dest_table_id else None, #should be None if into new table
"type": c.type,
"formula": ("$" + c.colId) if (c.colId in src_cols) else ''
})
return {"destCols": dest_cols}
# doesnt generate other fields of transform_rule, but sandbox only used destCols
# Returns
def _MakeImportTransformColumns(self, hidden_table_id, transform_rule, gen_all):
"""
Makes prefixed columns in the grist hidden import table (hidden_table_id)
hidden_table_id: id of temporary hidden table in which columns are made
transform_rule: defines columns to make (colids must be filled in!)
gen_all: If true, all columns will be generated
If false, formulas that just copy will be skipped, and blank formulas will be skipped
returns list of newly created colrefs (rowids into _grist_Tables_column)
"""
tables = self._docmodel.tables
hidden_table_rec = tables.lookupOne(tableId=hidden_table_id)
src_cols = {c.colId for c in hidden_table_rec.columns}
log.debug("destCols:" + repr(transform_rule['destCols']))
#wrap dest_cols as namedtuples, to allow access like 'dest_col.param'
dest_cols = [namedtuple('col', c.keys())(*c.values()) for c in transform_rule['destCols']]
log.debug("_MakeImportTransformColumns: {}".format("gen_all" if gen_all else "optimize"))
#create prefixed formula column for each of dest_cols
#take formula from transform_rule
new_cols = []
for c in dest_cols:
# skip copy and blank columns (unless gen_all)
formula = c.formula.strip()
isCopyFormula = (formula.startswith("$") and formula[1:] in src_cols)
isBlankFormula = not formula
if gen_all or (not isCopyFormula and not isBlankFormula):
#if colId specified, use that. Else label is fine
new_col_id = _import_transform_col_prefix + (c.colId or c.label)
new_col_spec = {
"label": c.label,
"type": c.type,
"isFormula": True,
"formula": c.formula}
result = self._useractions.doAddColumn(hidden_table_id, new_col_id, new_col_spec)
new_cols.append(result["colRef"])
return new_cols
def DoGenImporterView(self, source_table_id, dest_table_id, transform_rule = None):
"""
Generates viewsections/formula columns for importer
source_table_id: id of temporary hidden table, data parsed from data source
dest_table_id: id of table to import to, or None for new table
transform_rule: transform_rule to reuse (if it still applies), if None will generate new one
Removes old transform viewSection and columns for source_table_id, and creates new ones that
match the destination table.
Returns the rowId of the newly added section or 0 if no source table (source_table_id
can be None in case of importing empty file).
Creates formula columns for transforms (match columns in dest table)
"""
tables = self._docmodel.tables
src_table_rec = tables.lookupOne(tableId=source_table_id)
# for new table, dest_table_id is None
dst_table_rec = tables.lookupOne(tableId=dest_table_id) if dest_table_id else src_table_rec
# ======== Cleanup old sections/columns
# Transform sections are created without a parent View, so we delete all such sections here.
old_sections = [s for s in src_table_rec.viewSections if not s.parentId]
self._docmodel.remove(old_sections)
# Transform columns are those that start with a special prefix.
old_cols = [c for c in src_table_rec.columns
if c.colId.startswith(_import_transform_col_prefix)]
self._docmodel.remove(old_cols)
#======== Prepare/normalize transform_rule, Create new formula columns
# Defaults to duplicating dest_table columns (or src_table columns for a new table)
# If transform_rule provided, use that
if transform_rule is None:
transform_rule = self._MakeDefaultTransformRule(source_table_id, dest_table_id)
else: #ensure prefixes, colIds are correct
_strip_prefixes(transform_rule)
if not dest_table_id: # into new table: 'colId's are undefined
_gen_colids(transform_rule)
else:
if None in (dc["colId"] for dc in transform_rule["destCols"]):
errstr = "colIds must be defined in transform_rule for importing into existing table: "
raise ValueError(errstr + repr(transform_rule))
new_cols = self._MakeImportTransformColumns(source_table_id, transform_rule, gen_all=True)
# we want to generate all columns so user can see them and edit
#========= Create new transform view section.
new_section = self._docmodel.add(self._docmodel.view_sections,
tableRef=src_table_rec.id,
parentKey='record',
borderWidth=1, defaultWidth=100,
sortColRefs='[]')[0]
self._docmodel.add(new_section.fields, colRef=new_cols)
return new_section.id
def DoTransformAndFinishImport(self, hidden_table_id, dest_table_id,
into_new_table, transform_rule):
"""
Finishes import into new or existing table depending on flag 'into_new_table'
Returns destination table id. (new or existing)
"""
hidden_table = self._engine.tables[hidden_table_id]
hidden_table_rec = self._docmodel.tables.lookupOne(tableId=hidden_table_id)
src_cols = {c.colId for c in hidden_table_rec.columns}
log.debug("Starting TransformAndFinishImport, dest_cols:\n "
+ str(transform_rule["destCols"] if transform_rule else "None"))
log.debug("hidden_table_id:" + hidden_table_id)
log.debug("hidden table columns: "
+ str([(a.colId, a.label, a.type) for a in hidden_table_rec.columns]))
log.debug("dest_table_id: "
+ str(dest_table_id) + ('(NEW)' if into_new_table else '(Existing)'))
# === fill in blank transform rule
if not transform_rule:
transform_dest = None if into_new_table else dest_table_id
transform_rule = self._MakeDefaultTransformRule(hidden_table_id, transform_dest)
dest_cols = transform_rule["destCols"]
# === Normalize transform rule (gen colids)
_strip_prefixes(transform_rule) #when transform_rule from client, colIds will be prefixed
if into_new_table: # 'colId's are undefined if making new table
_gen_colids(transform_rule)
else:
if None in (dc["colId"] for dc in dest_cols):
errstr = "colIds must be defined in transform_rule for importing into existing table: "
raise ValueError(errstr + repr(transform_rule))
log.debug("Normalized dest_cols:\n " + str(dest_cols))
# ======== Make and update formula columns
# Make columns from transform_rule (now with filled-in colIds colIds),
# gen_all false skips copy columns (faster)
new_cols = self._MakeImportTransformColumns(hidden_table_id, transform_rule, gen_all=False)
self._engine._bring_all_up_to_date()
# ========= Fetch Data for each col
# (either copying, blank, or from formula column)
row_ids = list(hidden_table.row_ids) #fetch row_ids now, before we remove hidden_table
log.debug("num rows: " + str(len(row_ids)))
column_data = {} # { col:[values...], ... }
for curr_col in dest_cols:
formula = curr_col["formula"].strip()
if formula:
if (formula.startswith("$") and formula[1:] in src_cols): #copy formula
src_col_id = formula[1:]
else: #normal formula, fetch from prefix column
src_col_id = _import_transform_col_prefix + curr_col["colId"]
log.debug("Copying from: " + src_col_id)
column_data[curr_col["colId"]] = map(hidden_table.get_column(src_col_id).raw_get, row_ids)
# ========= Cleanup, Prepare new table (if needed), insert data
self._useractions.RemoveTable(hidden_table_id)
if into_new_table:
col_specs = [ {'type': curr_col['type'], 'id': curr_col['colId'], 'label': curr_col['label']}
for curr_col in dest_cols]
log.debug("Making new table. Columns:\n " + str(col_specs))
new_table = self._useractions.AddTable(dest_table_id, col_specs)
dest_table_id = new_table['table_id']
self._useractions.BulkAddRecord(dest_table_id, [None] * len(row_ids), column_data)
log.debug("Finishing TransformAndFinishImport")
return dest_table_id

@ -0,0 +1,60 @@
"""This module loads a file_importer that implements the Grist import
API, and calls its selected method passing argument received from
PluginManager.sandboxImporter(). It returns an object formatted so
that it can be used by Grist.
"""
import sys
import argparse
import logging
import imp
import json
import marshal
log = logging.getLogger(__name__)
# Include /thirdparty into module search paths, in particular for messytables.
# pylint: disable=wrong-import-position
sys.path.append('/thirdparty')
def marshal_data(export_list):
return marshal.dumps(export_list, 2)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--debug', action='store_true',
help="Print debug instead of producing normal binary output")
parser.add_argument('-t', '--table',
help="Suggested table name to use with CSV imports")
parser.add_argument('-n', '--plugin-name', required=True,
help="Name of a python module implementing the import API.")
parser.add_argument('-p', '--plugin-path',
help="Location of the module.")
parser.add_argument('--action-options',
help="Options to pass to the action. See API documentation.")
parser.add_argument('action', help='Action to call',
choices=['can_parse', 'parse_file'])
parser.add_argument('input', help='File to convert')
args = parser.parse_args()
s = logging.StreamHandler()
s.setFormatter(logging.Formatter(fmt='%(asctime)s.%(msecs)03d %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'))
rootLogger = logging.getLogger()
rootLogger.addHandler(s)
rootLogger.setLevel(logging.DEBUG if args.debug else logging.INFO)
import_plugin = imp.load_compiled(
args.plugin_name,
args.plugin_path)
options = {}
if args.action_options:
options = json.loads(args.action_options)
parsed_data = getattr(import_plugin, args.action)(args.input, **options)
marshalled_data = marshal_data(parsed_data)
log.info("Marshalled data has %d bytes", len(marshalled_data))
if not args.debug:
sys.stdout.write(marshalled_data)
if __name__ == "__main__":
main()

@ -0,0 +1,22 @@
import unittest
import messytables
import os
class TestMessyTables(unittest.TestCase):
# Just a skeleton test
def test_any_tableset(self):
path = os.path.join(os.path.dirname(__file__),
"fixtures", "nyc_schools_progress_report_ec_2013.xlsx")
with open(path, "r") as f:
table_set = messytables.any.any_tableset(f, extension=os.path.splitext(path)[1])
self.assertIsInstance(table_set, messytables.XLSTableSet)
self.assertEqual([t.name for t in table_set.tables],
['Summary', 'Student Progress', 'Student Performance', 'School Environment',
'Closing the Achievement Gap', 'Middle School Course Metrics',
'All Information', 'Peer Groups'])
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,74 @@
"""
Logging for code running in the sandbox. The output simply goes to stderr (which gets to the
console of the Node process), but the levels allow some configuration.
We don't use the `logging` module because it assumes more about the `time` module than we have.
Usage:
import logger
log = logger.Logger(__name__, logger.DEBUG) # Or logger.WARN; default is logger.INFO.
log.info("Hello world")
-> produces "[I] [foo.bar] Hello world"
"""
import sys
# Level definitions
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
CRITICAL = 50
# Level strings
level_strings = {
DEBUG: 'DEBUG',
INFO: 'INFO',
WARN: 'WARN',
ERROR: 'ERROR',
CRITICAL: 'CRITICAL',
}
def log_stderr(level, name, msg):
sys.stderr.write("[%s] [%s] %s\n" % (level_strings.get(level, '?'), name, msg))
sys.stderr.flush()
_log_handler = log_stderr
def set_handler(log_handler):
"""
Allows overriding the handler for all log messages. The handler should be a function, called as
log_handler(level, name, message).
Returns the handler which was set previously.
"""
global _log_handler # pylint: disable=global-statement
prev = _log_handler
_log_handler = log_handler
return prev
class Logger(object):
"""
The object that actually provides the logging interface, specifically the methods debug, info,
warn, error, and critical. The constructor takes an argument for a name that gets included in
each message, and a minimum level, below which messages get ignored.
"""
def __init__(self, name, min_level=INFO):
self._name = name
self._min_level = min_level
def _log(self, level, msg):
if level >= self._min_level:
_log_handler(level, self._name, msg)
def debug(self, msg):
self._log(DEBUG, msg)
def info(self, msg):
self._log(INFO, msg)
def warn(self, msg):
self._log(WARN, msg)
def error(self, msg):
self._log(ERROR, msg)
def critical(self, msg):
self._log(CRITICAL, msg)

@ -0,0 +1,216 @@
import column
import depend
import records
import relation
import twowaymap
import usertypes
import logger
log = logger.Logger(__name__, logger.INFO)
def _extract(cell_value):
"""
When cell_value is a Record, returns its rowId. Otherwise returns the value unchanged.
This is to allow lookups to work with reference columns.
"""
if isinstance(cell_value, records.Record):
return cell_value._row_id
return cell_value
class LookupMapColumn(column.BaseColumn):
"""
Conceptually a LookupMapColumn is associated with a table ("target table") and maintains for
each row a key (which is a tuple of values from the named columns), which is fast to look up.
The lookup is generally performed in a formula in a different table ("referring table").
LookupMapColumn is similar to a FormulaColumn in that it needs to do some computation whenever
one of its dependencies changes: namely, it needs to update the index.
Although it acts as a column, a LookupMapColumn isn't included among its table's columns, and
doesn't have a column id.
Compared to relational database, LookupMapColumn is analogous to a database index.
"""
def __init__(self, table, col_id, col_ids_tuple):
# Note that self._recalc_rec_method is passed in as the formula's "method".
col_info = column.ColInfo(usertypes.Any(), is_formula=True, method=self._recalc_rec_method)
super(LookupMapColumn, self).__init__(table, col_id, col_info)
self._col_ids_tuple = col_ids_tuple
self._engine = table._engine
# Two-way map between rowIds of the target table (on the left) and key tuples (on the right).
# Multiple rows can map to the same key. The map is populated by engine's _recompute when this
# node is brought up-to-date.
self._row_key_map = twowaymap.TwoWayMap(left=set, right="single")
self._engine.invalidate_column(self)
# Map of referring Node to _LookupRelation. Different tables may do lookups using this
# LookupMapColumn, and that creates a dependency from other Nodes to us, with a relation
# between referring rows and the lookup keys. This map stores these relations.
self._lookup_relations = {}
def _recalc_rec_method(self, rec, table):
"""
LookupMapColumn acts as a formula column, and this method is the "formula" called whenever
a dependency changes. If LookupMapColumn indexes columns (A,B), then a change to A or B would
cause the LookupMapColumn to be invalidated for the corresponding rows, and brought up to date
during formula recomputation by calling this method. It shold take O(1) time per affected row.
"""
old_key = self._row_key_map.lookup_left(rec._row_id)
# Note that getattr(rec, col_id) is what creates the correct dependency, as well as ensures
# that the columns used to index by are brought up-to-date (in case they are formula columns).
new_key = tuple(_extract(rec._get_col(_col_id)) for _col_id in self._col_ids_tuple)
try:
self._row_key_map.insert(rec._row_id, new_key)
except TypeError:
# If key is not hashable, ignore it, just remove the old_key then.
self._row_key_map.remove(rec._row_id, old_key)
new_key = None
# It's OK if None is one of the values, since None will just never be found as a key.
self._invalidate_affected({old_key, new_key})
def unset(self, row_id):
# This is called on record removal, and is necessary to deal with removed records.
old_key = self._row_key_map.lookup_left(row_id)
self._row_key_map.remove(row_id, old_key)
self._invalidate_affected({old_key})
def _invalidate_affected(self, affected_keys):
# For each known relation, figure out which referring rows are affected, and invalidate them.
# The engine will notice that there have been more invalidations, and recompute things again.
for node, rel in self._lookup_relations.iteritems():
affected_rows = rel.get_affected_rows_by_keys(affected_keys)
self._engine.invalidate_records(node.table_id, affected_rows, col_ids=(node.col_id,))
def _get_relation(self, referring_node):
"""
Helper which returns an existing or new _LookupRelation object for the given referring Node.
"""
rel = self._lookup_relations.get(referring_node)
if not rel:
rel = _LookupRelation(self, referring_node)
self._lookup_relations[referring_node] = rel
return rel
def _delete_relation(self, referring_node):
self._lookup_relations.pop(referring_node, None)
if not self._lookup_relations:
self._engine.mark_lookupmap_for_cleanup(self)
def do_lookup(self, key):
"""
Looks up key in the lookup map and returns a tuple with two elements: the set of matching
records (as a set object, not ordered), and the Relation object for those records, relating
the current frame to the returned records. Returns an empty set if no records match.
"""
key = tuple(_extract(val) for val in key)
current_frame = self._engine.get_current_frame()
if current_frame:
rel = self._get_relation(current_frame.node)
rel._add_lookup(current_frame.current_row_id, key)
else:
rel = None
# The _use_node call both brings LookupMapColumn up-to-date, and creates a dependency on it.
# Relation of None isn't valid, but it happens to be unused when there is no current_frame.
row_ids = self._row_key_map.lookup_right(key, set())
self._engine._use_node(self.node, rel, row_ids)
if not row_ids:
row_ids = self._row_key_map.lookup_right(key, set())
return row_ids, rel
def _get_key(self, target_row_id):
"""
Helper used by _LookupRelation to get the key associated with the given target row id.
"""
return self._row_key_map.lookup_left(target_row_id)
# Override various column methods, since LookupMapColumn doesn't care to store any values. To
# outside code, it looks like a column of None's.
def raw_get(self, value):
return None
def convert(self, value):
return None
def get_cell_value(self, row_id):
return None
def set(self, row_id, value):
pass
#----------------------------------------------------------------------
class _LookupRelation(relation.Relation):
"""
_LookupRelation maintains a mapping between rows of a table doing a lookup to the rows getting
returned from the lookup. Lookups are implemented using a LookupMapColumn, and a _LookupRelation
with in conjunction with its LookupMapColumn.
_LookupRelation are created and owned by LookupMapColumn, and should not be created directly by
other code.
"""
def __init__(self, lookup_map, referring_node):
super(_LookupRelation, self).__init__(referring_node.table_id, lookup_map.table_id)
self._lookup_map = lookup_map
self._referring_node = referring_node
# Maps referring rows to keys, where multiple rows may map to the same key AND one row may
# map to multiple keys (if a formula does multiple lookup calls).
self._row_key_map = twowaymap.TwoWayMap(left=set, right=set)
def __str__(self):
return "_LookupRelation(%s->%s)" % (self._referring_node, self.target_table)
def get_affected_rows(self, target_row_ids):
# Each target row (result of a lookup by key) is associated with a key, and all rows that
# looked up an affected key are affected by a change to any associated row. We remember which
# rows looked up which key in self._row_key_map, so that when some target row changes to a new
# key, we can know which referring rows need to be recomputed.
return self.get_affected_rows_by_keys({ self._lookup_map._get_key(r) for r in target_row_ids })
def get_affected_rows_by_keys(self, keys):
"""
This is used by LookupMapColumn to know which rows got affected when a target row changed to
have a different key. Keys can be any iterable. A key of None is allowed and affects nothing.
"""
affected_rows = set()
for key in keys:
if key is not None:
affected_rows.update(self._row_key_map.lookup_right(key, default=()))
return affected_rows
def _add_lookup(self, referring_row_id, key):
"""
Helper used by LookupMapColumn to store the fact that the given key was looked up in the
process of computing the given referring_row_id.
"""
self._row_key_map.insert(referring_row_id, key)
def reset_rows(self, referring_rows):
"""
Called when starting to compute a formula, so that mappings for the given referring_rows can
be cleared as they are about to be rebuilt.
"""
# Clear out references from referring_rows.
if referring_rows == depend.ALL_ROWS:
self._row_key_map.clear()
else:
for row_id in referring_rows:
self._row_key_map.remove_left(row_id)
def reset_all(self):
"""
Called when the dependency using this relation is reset, and this relation is no longer used.
"""
# In this case also, remove it from the LookupMapColumn. Once all relations are gone, the
# lookup map can get cleaned up.
self._row_key_map.clear()
self._lookup_map._delete_relation(self._referring_node)

@ -0,0 +1,119 @@
"""
This module defines what sandbox functions are made available to the Node controller,
and starts the grist sandbox. See engine.py for the API documentation.
"""
import sys
sys.path.append('thirdparty')
# pylint: disable=wrong-import-position
import marshal
import functools
import actions
import sandbox
import engine
import migrations
import schema
import useractions
import objtypes
import logger
log = logger.Logger(__name__, logger.INFO)
def export(method):
# Wrap each method so that it logs a message that it's being called.
@functools.wraps(method)
def wrapper(*args, **kwargs):
log.debug("calling %s" % method.__name__)
return method(*args, **kwargs)
sandbox.register(method.__name__, wrapper)
def table_data_from_db(table_name, table_data_repr):
if table_data_repr is None:
return actions.TableData(table_name, [], {})
table_data_parsed = marshal.loads(table_data_repr)
id_col = table_data_parsed.pop("id")
return actions.TableData(table_name, id_col,
actions.decode_bulk_values(table_data_parsed, _decode_db_value))
def _decode_db_value(value):
# Decode database values received from SQLite's allMarshal() call. These are encoded by
# marshalling certain types and storing as BLOBs (received in Python as binary strings, as
# opposed to text which is received as unicode). See also encodeValue() in DocStorage.js
# TODO For the moment, the sandbox uses binary strings throughout (with text in utf8 encoding).
# We should switch to representing text with unicode instead. This requires care, at least in
# fixing various occurrences of str() in our code, which may fail and which return wrong type.
t = type(value)
if t == unicode:
return value.encode('utf8')
elif t == str:
return objtypes.decode_object(marshal.loads(value))
else:
return value
def main():
eng = engine.Engine()
@export
def apply_user_actions(action_reprs):
action_group = eng.apply_user_actions([useractions.from_repr(u) for u in action_reprs])
return eng.acl_split(action_group).to_json_obj()
@export
def fetch_table(table_id, formulas=True, query=None):
return actions.get_action_repr(eng.fetch_table(table_id, formulas=formulas, query=query))
@export
def fetch_table_schema():
return eng.fetch_table_schema()
@export
def fetch_snapshot():
action_group = eng.fetch_snapshot()
return eng.acl_split(action_group).to_json_obj()
@export
def autocomplete(txt, table_id):
return eng.autocomplete(txt, table_id)
@export
def find_col_from_values(values, n, opt_table_id):
return eng.find_col_from_values(values, n, opt_table_id)
@export
def fetch_meta_tables(formulas=True):
return {table_id: actions.get_action_repr(table_data)
for (table_id, table_data) in eng.fetch_meta_tables(formulas).iteritems()}
@export
def load_meta_tables(meta_tables, meta_columns):
return eng.load_meta_tables(table_data_from_db("_grist_Tables", meta_tables),
table_data_from_db("_grist_Tables_column", meta_columns))
@export
def load_table(table_name, table_data):
return eng.load_table(table_data_from_db(table_name, table_data))
@export
def create_migrations(all_tables):
doc_actions = migrations.create_migrations(
{t: table_data_from_db(t, data) for t, data in all_tables.iteritems()})
return map(actions.get_action_repr, doc_actions)
@export
def get_version():
return schema.SCHEMA_VERSION
@export
def get_formula_error(table_id, col_id, row_id):
return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id))
export(eng.load_empty)
export(eng.load_done)
sandbox.run()
if __name__ == "__main__":
main()

@ -0,0 +1,31 @@
"""
Simple class which, given a sample, can quickly count the size of overlap with an iterable.
All elements of sample must be hashable.
This is mainly in its own file in order to be able to test and time possible alternative
implementations.
"""
class MatchCounter(object):
def __init__(self, sample):
self.sample = set(sample)
def count_unique(self, iterable):
"""
Returns the count of unique elements of iterable that are present in sample. The sample may
only contain hashable elements, so non-hashable elements of iterable are never counted.
"""
# The simplest implementation is 5 times faster:
# len(self.sample.intersection(iterable))
# but fails if iterable can ever contain non-hashable values (e.g. list). This is the next
# best alternative. Attempting to skip non-hashable values with `isinstance(v, Hashable)` is
# another order of magnitude slower.
seen = set()
for v in iterable:
try:
if v in self.sample:
seen.add(v)
except TypeError:
# Non-hashable values can't possibly be in self.sample, so just don't count those.
pass
return len(seen)

@ -0,0 +1,750 @@
import json
import re
import actions
import identifiers
import schema
import summary
import table_data_set
import logger
log = logger.Logger(__name__, logger.INFO)
# PHILOSOPHY OF MIGRATIONS.
#
# We should probably never remove, modify, or rename metadata tables or columns.
# Instead, we should only add.
#
# We can mark old columns/tables as deprecated, which should be ignored except to prevent us from
# adding same-named entities in the future.
#
# If we change the meaning of a column, we have to create a new column with a new name.
#
# This should make it at least barely possible to share documents by people who are not all on the
# same Grist version (even so, it will require more work). It should also make it somewhat safe to
# upgrade and then open the document with a previous version.
all_migrations = {}
def noop_migration(_all_tables):
return []
def create_migrations(all_tables):
"""
Creates and returns a list of DocActions needed to bring this document to
schema.SCHEMA_VERSION. It requires as input all of the documents tables.
all_tables: all tables as a dictionary mapping table name to TableData.
"""
try:
doc_version = all_tables['_grist_DocInfo'].columns["schemaVersion"][0]
except Exception:
doc_version = 0
# We create a TableDataSet, and populate it with the subset of the current schema that matches
# all_tables. For missing items, we make up tables and incomplete columns, which should be OK
# since we would not be adding new records to deprecated columns.
# Note that this approach makes it NOT OK to change column types.
tdset = table_data_set.TableDataSet()
# For each table in the provided metadata tables, create an AddTable action.
user_schema = schema.build_schema(all_tables['_grist_Tables'],
all_tables['_grist_Tables_column'],
include_builtin=False)
for t in user_schema.itervalues():
tdset.apply_doc_action(actions.AddTable(t.tableId, schema.cols_to_dict_list(t.columns)))
# For each old table/column, construct an AddTable action using the current schema.
new_schema = {a.table_id: a for a in schema.schema_create_actions()}
for table_id, data in sorted(all_tables.iteritems()):
# User tables should already be in tdset; the rest must be metadata tables.
if table_id not in tdset.all_tables:
new_col_info = {}
if table_id in new_schema:
new_col_info = {c['id']: c for c in new_schema[table_id].columns}
# Use an incomplete default for unknown (i.e. deprecated) columns; some uses of the column
# would be invalid, such as adding a new record with missing values.
col_info = sorted([new_col_info.get(col_id, {'id': col_id}) for col_id in data.columns])
tdset.apply_doc_action(actions.AddTable(table_id, col_info))
# And load in the original data, interpreting the TableData object as BulkAddRecord action.
tdset.apply_doc_action(actions.BulkAddRecord(*data))
migration_actions = []
for version in xrange(doc_version + 1, schema.SCHEMA_VERSION + 1):
migration_actions.extend(all_migrations.get(version, noop_migration)(tdset))
# Note that if we are downgrading versions (i.e. doc_version is higher), then the following is
# the only action we include into the migration.
migration_actions.append(actions.UpdateRecord('_grist_DocInfo', 1, {
'schemaVersion': schema.SCHEMA_VERSION
}))
return migration_actions
def get_last_migration_version():
"""
Returns the last schema version number for which we have a migration defined.
"""
return max(all_migrations.iterkeys())
def migration(schema_version):
"""
Decorator for migrations that associates the decorated migration function with the given
schema_version. This decorate function will be run to migrate forward to schema_version.
"""
def add_migration(migration_func):
all_migrations[schema_version] = migration_func
return migration_func
return add_migration
# A little shorthand to make AddColumn actions more concise.
def add_column(table_id, col_id, col_type, *args, **kwargs):
return actions.AddColumn(table_id, col_id,
schema.make_column(col_id, col_type, *args, **kwargs))
# Another shorthand to only add a column if it isn't already there.
def maybe_add_column(tdset, table_id, col_id, col_type, *args, **kwargs):
if col_id not in tdset.all_tables[table_id].columns:
return add_column(table_id, col_id, col_type, *args, **kwargs)
return None
# Returns the next unused row id for the records of the table given by table_id.
def next_id(tdset, table_id):
row_ids = tdset.all_tables[table_id].row_ids
return max(row_ids) + 1 if row_ids else 1
# Parses a json string, but returns an empty object for invalid json.
def safe_parse(json_str):
try:
return json.loads(json_str)
except ValueError:
return {}
@migration(schema_version=1)
def migration1(tdset):
"""
Add TabItems table, and populate based on existing sections.
"""
doc_actions = []
# The very first migration is extra-lax, and creates some tables that are missing in some test
# docs. That's only because we did not distinguish schema version before migrations were
# implemented. Other migrations should not need such conditionals.
if '_grist_Attachments' not in tdset.all_tables:
doc_actions.append(actions.AddTable("_grist_Attachments", [
schema.make_column("fileIdent", "Text"),
schema.make_column("fileName", "Text"),
schema.make_column("fileType", "Text"),
schema.make_column("fileSize", "Int"),
schema.make_column("timeUploaded", "DateTime")
]))
if '_grist_TabItems' not in tdset.all_tables:
doc_actions.append(actions.AddTable("_grist_TabItems", [
schema.make_column("tableRef", "Ref:_grist_Tables"),
schema.make_column("viewRef", "Ref:_grist_Views"),
]))
if 'schemaVersion' not in tdset.all_tables['_grist_DocInfo'].columns:
doc_actions.append(add_column('_grist_DocInfo', 'schemaVersion', 'Int'))
doc_actions.extend([
add_column('_grist_Attachments', 'imageHeight', 'Int'),
add_column('_grist_Attachments', 'imageWidth', 'Int'),
])
view_sections = actions.transpose_bulk_action(tdset.all_tables['_grist_Views_section'])
rows = sorted({(s.tableRef, s.parentId) for s in view_sections})
if rows:
values = {'tableRef': [r[0] for r in rows],
'viewRef': [r[1] for r in rows]}
row_ids = range(1, len(rows) + 1)
doc_actions.append(actions.ReplaceTableData('_grist_TabItems', row_ids, values))
return tdset.apply_doc_actions(doc_actions)
@migration(schema_version=2)
def migration2(tdset):
"""
Add TableViews table, and populate based on existing sections.
Add TabBar table, and populate based on existing views.
Add PrimaryViewId to Tables and populated using relatedViews
"""
# Maps tableRef to viewRef
primary_views = {}
# Associate each view with a single table; this dict includes primary views.
views_to_table = {}
# For each table, find a view to serve as the primary view.
view_sections = actions.transpose_bulk_action(tdset.all_tables['_grist_Views_section'])
for s in view_sections:
if s.tableRef not in primary_views and s.parentKey == "record":
# The view containing this section is a good candidate for primary view.
primary_views[s.tableRef] = s.parentId
if s.parentId not in views_to_table:
# The first time we see a (view, table) combination, associate the view with that table.
views_to_table[s.parentId] = s.tableRef
def create_primary_views_action(primary_views):
row_ids = sorted(primary_views.keys())
values = {'primaryViewId': [primary_views[r] for r in row_ids]}
return actions.BulkUpdateRecord('_grist_Tables', row_ids, values)
def create_tab_bar_action(views_to_table):
row_ids = range(1, len(views_to_table) + 1)
return actions.ReplaceTableData('_grist_TabBar', row_ids, {
'viewRef': sorted(views_to_table.keys())
})
def create_table_views_action(views_to_table, primary_views):
related_views = sorted(set(views_to_table.keys()) - set(primary_views.values()))
row_ids = range(1, len(related_views) + 1)
return actions.ReplaceTableData('_grist_TableViews', row_ids, {
'tableRef': [views_to_table[v] for v in related_views],
'viewRef': related_views,
})
return tdset.apply_doc_actions([
actions.AddTable('_grist_TabBar', [
schema.make_column('viewRef', 'Ref:_grist_Views'),
]),
actions.AddTable('_grist_TableViews', [
schema.make_column('tableRef', 'Ref:_grist_Tables'),
schema.make_column('viewRef', 'Ref:_grist_Views'),
]),
add_column('_grist_Tables', 'primaryViewId', 'Ref:_grist_Views'),
create_primary_views_action(primary_views),
create_tab_bar_action(views_to_table),
create_table_views_action(views_to_table, primary_views)
])
@migration(schema_version=3)
def migration3(tdset):
"""
There is no longer a "Derived" type for columns, and summary tables use the type suitable for
the column being summarized. For old documents, convert "Derived" type to "Any", and adjust the
usage of "lookupOrAddDerived()" function.
"""
# Note that this is a complicated migration, and mainly acceptable because it is before our very
# first release. For a released product, a change like this should be done in a backwards
# compatible way: keep but deprecate 'Derived'; introduce a lookupOrAddDerived2() to use for new
# summary tables, but keep the old interface as well for existing ones. The reason is that such
# migrations are error-prone and may mess up customers' data.
doc_actions = []
tables = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables']))
tables_map = {t.id: t for t in tables}
columns = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables_column']))
# Convert columns from type 'Derived' to type 'Any'
affected_cols = [c for c in columns if c.type == 'Derived']
if affected_cols:
doc_actions.extend(
actions.ModifyColumn(tables_map[c.parentId].tableId, c.colId, {'type': 'Any'})
for c in affected_cols
)
doc_actions.append(actions.BulkUpdateRecord(
'_grist_Tables_column',
[c.id for c in affected_cols],
{'type': ['Any' for c in affected_cols]}
))
# Convert formulas of the form '.lookupOrAddDerived($x,$y)' to '.lookupOrAddDerived(x=$x,y=$y)'
formula_re = re.compile(r'(\w+).lookupOrAddDerived\((.*?)\)')
arg_re = re.compile(r'^\$(\w+)$')
def replace(match):
args = ", ".join(arg_re.sub(r'\1=$\1', arg.strip()) for arg in match.group(2).split(","))
return '%s.lookupOrAddDerived(%s)' % (match.group(1), args)
formula_updates = []
for c in columns:
new_formula = c.formula and formula_re.sub(replace, c.formula)
if new_formula != c.formula:
formula_updates.append((c, new_formula))
if formula_updates:
doc_actions.extend(
actions.ModifyColumn(tables_map[c.parentId].tableId, c.colId, {'formula': f})
for c, f in formula_updates
)
doc_actions.append(actions.BulkUpdateRecord(
'_grist_Tables_column',
[c.id for c, f in formula_updates],
{'formula': [f for c, f in formula_updates]}
))
return tdset.apply_doc_actions(doc_actions)
@migration(schema_version=4)
def migration4(tdset):
"""
Add TabPos column to TabBar table
"""
doc_actions = []
row_ids = tdset.all_tables['_grist_TabBar'].row_ids
doc_actions.append(add_column('_grist_TabBar', 'tabPos', 'PositionNumber'))
doc_actions.append(actions.BulkUpdateRecord('_grist_TabBar', row_ids, {'tabPos': row_ids}))
return tdset.apply_doc_actions(doc_actions)
@migration(schema_version=5)
def migration5(tdset):
return tdset.apply_doc_actions([
add_column('_grist_Views', 'primaryViewTable', 'Ref:_grist_Tables',
formula='_grist_Tables.lookupOne(primaryViewId=$id)', isFormula=True),
])
@migration(schema_version=6)
def migration6(tdset):
# This undoes the previous migration, since primaryViewTable is now a formula private to the
# sandbox rather than part of the document schema.
return tdset.apply_doc_actions([
actions.RemoveColumn('_grist_Views', 'primaryViewTable'),
])
@migration(schema_version=7)
def migration7(tdset):
"""
Add summarySourceTable/summarySourceCol fields to metadata, and adjust existing summary tables
to correspond to the new style.
"""
# Note: this migration has some faults.
# - It doesn't delete viewSectionFields for columns it removes (if a user added some special
# columns manually.
# - It doesn't fix types of Reference columns that refer to old-style summary tables
# (if the user created some such columns manually).
doc_actions = filter(None, [
maybe_add_column(tdset, '_grist_Tables', 'summarySourceTable', 'Ref:_grist_Tables'),
maybe_add_column(tdset, '_grist_Tables_column', 'summarySourceCol', 'Ref:_grist_Tables_column')
])
# Maps tableRef to Table object.
tables_map = {t.id: t for t in actions.transpose_bulk_action(tdset.all_tables['_grist_Tables'])}
# Maps tableName to tableRef
table_name_to_ref = {t.tableId: t.id for t in tables_map.itervalues()}
# List of Column objects
columns = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables_column']))
# Maps columnRef to Column object.
columns_map_by_ref = {c.id: c for c in columns}
# Maps (tableRef, colName) to Column object.
columns_map_by_table_colid = {(c.parentId, c.colId): c for c in columns}
# Set of all tableNames.
table_name_set = set(table_name_to_ref.keys())
remove_cols = [] # List of columns to remove
formula_updates = [] # List of (column, new_table_name, new_formula) pairs
table_renames = [] # List of (table, new_name) pairs
source_tables = [] # List of (table, summarySourceTable) pairs
source_cols = [] # List of (column, summarySourceColumn) pairs
# Summary tables used to be named as "Summary_<SourceName>_<ColRef1>_<ColRef2>". This regular
# expression parses that.
summary_re = re.compile(r'^Summary_(\w+?)((?:_\d+)*)$')
for t in tables_map.itervalues():
m = summary_re.match(t.tableId)
if not m or m.group(1) not in table_name_to_ref:
continue
# We have a valid summary table.
source_table_name = m.group(1)
source_table_ref = table_name_to_ref[source_table_name]
groupby_colrefs = map(int, m.group(2).strip("_").split("_"))
# Prepare a new-style name for the summary table. Be sure not to conflict with existing tables
# or with each other (i.e. don't rename multiple tables to the same name).
new_name = summary.encode_summary_table_name(source_table_name)
new_name = identifiers.pick_table_ident(new_name, avoid=table_name_set)
table_name_set.add(new_name)
log.warn("Upgrading summary table %s for %s(%s) to %s" % (
t.tableId, source_table_name, groupby_colrefs, new_name))
# Remove the "lookupOrAddDerived" column from the source table (which is named using the
# summary table name for its colId).
remove_cols.extend(c for c in columns
if c.parentId == source_table_ref and c.colId == t.tableId)
# Upgrade the "group" formula in the summary table.
expected_group_formula = "%s.lookupRecords(%s=$id)" % (source_table_name, t.tableId)
new_formula = "table.getSummarySourceGroup(rec)"
formula_updates.extend((c, new_name, new_formula) for c in columns
if (c.parentId == t.id and c.colId == "group" and
c.formula == expected_group_formula))
# Schedule a rename of the summary table.
table_renames.append((t, new_name))
# Set summarySourceTable fields on the metadata.
source_tables.append((t, source_table_ref))
# Set summarySourceCol fields in the metadata. We need to find the right summary column.
groupby_cols = set()
for col_ref in groupby_colrefs:
src_col = columns_map_by_ref.get(col_ref)
sum_col = columns_map_by_table_colid.get((t.id, src_col.colId)) if src_col else None
if sum_col:
groupby_cols.add(sum_col)
source_cols.append((sum_col, src_col.id))
else:
log.warn("Upgrading summary table %s: couldn't find column %s" % (t.tableId, col_ref))
# Finally, we have to remove all non-formula columns that are not groupby-columns (e.g.
# 'manualSort'), because the new approach assumes ALL non-formula columns are for groupby.
remove_cols.extend(c for c in columns
if c.parentId == t.id and c not in groupby_cols and not c.isFormula)
# Create all the doc actions from the arrays we prepared.
# Process remove_cols
doc_actions.extend(
actions.RemoveColumn(tables_map[c.parentId].tableId, c.colId) for c in remove_cols)
doc_actions.append(actions.BulkRemoveRecord(
'_grist_Tables_column', [c.id for c in remove_cols]))
# Process table_renames
doc_actions.extend(
actions.RenameTable(t.tableId, new) for (t, new) in table_renames)
doc_actions.append(actions.BulkUpdateRecord(
'_grist_Tables', [t.id for t, new in table_renames],
{'tableId': [new for t, new in table_renames]}
))
# Process source_tables and source_cols
doc_actions.append(actions.BulkUpdateRecord(
'_grist_Tables', [t.id for t, ref in source_tables],
{'summarySourceTable': [ref for t, ref in source_tables]}
))
doc_actions.append(actions.BulkUpdateRecord(
'_grist_Tables_column', [t.id for t, ref in source_cols],
{'summarySourceCol': [ref for t, ref in source_cols]}
))
# Process formula_updates. Do this last since recalculation of these may cause new records added
# to summary tables, so we should have all the tables correctly set up by this time.
doc_actions.extend(
actions.ModifyColumn(table_id, c.colId, {'formula': f})
for c, table_id, f in formula_updates)
doc_actions.append(actions.BulkUpdateRecord(
'_grist_Tables_column', [c.id for c, t, f in formula_updates],
{'formula': [f for c, t, f in formula_updates]}
))
return tdset.apply_doc_actions(doc_actions)
@migration(schema_version=8)
def migration8(tdset):
return tdset.apply_doc_actions([
add_column('_grist_Tables_column', 'untieColIdFromLabel', 'Bool'),
])
@migration(schema_version=9)
def migration9(tdset):
return tdset.apply_doc_actions([
add_column('_grist_Tables_column', 'displayCol', 'Ref:_grist_Tables_column'),
add_column('_grist_Views_section_field', 'displayCol', 'Ref:_grist_Tables_column'),
])
@migration(schema_version=10)
def migration10(tdset):
"""
Add displayCol to all reference cols, with formula $<ref_col_id>.<visible_col_id>
(Note that displayCol field was added in the previous migration.)
"""
doc_actions = []
tables = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables']))
columns = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables_column']))
# Maps tableRef to tableId.
tables_map = {t.id: t.tableId for t in tables}
# Maps tableRef to sets of colIds in the tables. Used to prevent repeated colIds.
table_col_ids = {t.id: set(tdset.all_tables[t.tableId].columns.keys()) for t in tables}
# Get the next sequential column row id.
row_id = next_id(tdset, '_grist_Tables_column')
for c in columns:
# If a column is a reference with an unset display column, add a display column.
if c.type.startswith('Ref:') and not c.displayCol:
# Get visible_col_id. If not found, row id is used and no display col is necessary.
visible_col_id = ""
try:
visible_col_id = json.loads(c.widgetOptions).get('visibleCol')
if not visible_col_id:
continue
except Exception:
continue # If invalid widgetOptions, skip this column.
# Set formula to use the current visibleCol in widgetOptions.
formula = ("$%s.%s" % (c.colId, visible_col_id))
# Get a unique colId for the display column, and add it to the set of used ids.
used_col_ids = table_col_ids[c.parentId]
display_col_id = identifiers.pick_col_ident('gristHelper_Display', avoid=used_col_ids)
used_col_ids.add(display_col_id)
# Add all actions to the list.
doc_actions.append(add_column(tables_map[c.parentId], 'gristHelper_Display', 'Any',
formula=formula, isFormula=True))
doc_actions.append(actions.AddRecord('_grist_Tables_column', row_id, {
'parentPos': 1.0,
'label': 'gristHelper_Display',
'isFormula': True,
'parentId': c.parentId,
'colId': 'gristHelper_Display',
'formula': formula,
'widgetOptions': '',
'type': 'Any'
}))
doc_actions.append(actions.UpdateRecord('_grist_Tables_column', c.id, {'displayCol': row_id}))
# Increment row id to the next unused.
row_id += 1
return tdset.apply_doc_actions(doc_actions)
@migration(schema_version=11)
def migration11(tdset):
return tdset.apply_doc_actions([
add_column('_grist_Views_section', 'embedId', 'Text'),
])
@migration(schema_version=12)
def migration12(tdset):
return tdset.apply_doc_actions([
add_column('_grist_Views_section', 'options', 'Text')
])
@migration(schema_version=13)
def migration13(tdset):
# Adds a basketId to the entire document to take advantage of basket functionality.
# From this version on, embedId is deprecated.
return tdset.apply_doc_actions([
add_column('_grist_DocInfo', 'basketId', 'Text')
])
@migration(schema_version=14)
def migration14(tdset):
# Create the ACL table AND also the default ACL groups, default resource, and the default rule.
# These match the actions applied to new document by 'InitNewDoc' useraction (as of v14).
return tdset.apply_doc_actions([
actions.AddTable('_grist_ACLMemberships', [
schema.make_column('parent', 'Ref:_grist_ACLPrincipals'),
schema.make_column('child', 'Ref:_grist_ACLPrincipals'),
]),
actions.AddTable('_grist_ACLPrincipals', [
schema.make_column('userName', 'Text'),
schema.make_column('groupName', 'Text'),
schema.make_column('userEmail', 'Text'),
schema.make_column('instanceId', 'Text'),
schema.make_column('type', 'Text'),
]),
actions.AddTable('_grist_ACLResources', [
schema.make_column('colIds', 'Text'),
schema.make_column('tableId', 'Text'),
]),
actions.AddTable('_grist_ACLRules', [
schema.make_column('aclFormula', 'Text'),
schema.make_column('principals', 'Text'),
schema.make_column('resource', 'Ref:_grist_ACLResources'),
schema.make_column('aclColumn', 'Ref:_grist_Tables_column'),
schema.make_column('permissions', 'Int'),
]),
# Set up initial ACL data.
actions.BulkAddRecord('_grist_ACLPrincipals', [1,2,3,4], {
'type': ['group', 'group', 'group', 'group'],
'groupName': ['Owners', 'Admins', 'Editors', 'Viewers'],
}),
actions.AddRecord('_grist_ACLResources', 1, {
'tableId': '', 'colIds': ''
}),
actions.AddRecord('_grist_ACLRules', 1, {
'resource': 1, 'permissions': 0x3F, 'principals': '[1]'
}),
])
@migration(schema_version=15)
def migration15(tdset):
# Adds a filter JSON property to each field.
# From this version on, filterSpec in _grist_Views_section is deprecated.
doc_actions = [
add_column('_grist_Views_section_field', 'filter', 'Text')
]
# Get all section and field data to move section filter data to the fields
sections = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Views_section']))
fields = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Views_section_field']))
specs = {s.id: safe_parse(s.filterSpec) for s in sections}
# Move filter data from sections to fields
for f in fields:
# If the field belongs to the section and the field's colRef is in its filterSpec,
# pull the filter setting from the section.
filter_spec = specs.get(f.parentId)
if filter_spec and str(f.colRef) in filter_spec:
doc_actions.append(actions.UpdateRecord('_grist_Views_section_field', f.id, {
'filter': json.dumps(filter_spec[str(f.colRef)])
}))
return tdset.apply_doc_actions(doc_actions)
@migration(schema_version=16)
def migration16(tdset):
# Add visibleCol to columns and view fields, and set it from columns' and fields' widgetOptions.
doc_actions = [
add_column('_grist_Tables_column', 'visibleCol', 'Ref:_grist_Tables_column'),
add_column('_grist_Views_section_field', 'visibleCol', 'Ref:_grist_Tables_column'),
]
# Maps tableId to table, for looking up target table as listed in "Ref:*" types.
tables = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables']))
tables_by_id = {t.tableId: t for t in tables}
# Allow looking up columns by ref or by (tableRef, colId)
columns = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables_column']))
columns_by_ref = {c.id: c for c in columns}
columns_by_id = {(c.parentId, c.colId): c.id for c in columns}
# Helper which returns the {'visibleCol', 'widgetOptions'} update visibleCol should be set.
def convert_visible_col(col, widget_options):
if not col.type.startswith('Ref:'):
return None
# To set visibleCol, we need to know the target table. Skip if we can't find it.
target_table = tables_by_id.get(col.type[len('Ref:'):])
if not target_table:
return None
try:
parsed_options = json.loads(widget_options)
except Exception:
return None # If invalid widgetOptions, skip this column.
visible_col_id = parsed_options.pop('visibleCol', None)
if not visible_col_id:
return None
# Find visible_col_id as the column name in the appropriate table.
target_col_ref = (0 if visible_col_id == 'id' else
columns_by_id.get((target_table.id, visible_col_id), None))
if target_col_ref is None:
return None
# Use compact separators without whitespace, to match how JS encodes JSON.
return {'visibleCol': target_col_ref,
'widgetOptions': json.dumps(parsed_options, separators=(',', ':')) }
for c in columns:
new_values = convert_visible_col(c, c.widgetOptions)
if new_values:
doc_actions.append(actions.UpdateRecord('_grist_Tables_column', c.id, new_values))
fields = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Views_section_field']))
for f in fields:
c = columns_by_ref.get(f.colRef)
if c:
new_values = convert_visible_col(c, f.widgetOptions)
if new_values:
doc_actions.append(actions.UpdateRecord('_grist_Views_section_field', f.id, new_values))
return tdset.apply_doc_actions(doc_actions)
@migration(schema_version=17)
def migration17(tdset):
"""
There is no longer an "Image" type for columns, as "Attachments" now serves as a
display type for arbitrary files including images. Convert "Image" columns to "Attachments"
columns.
"""
doc_actions = []
tables = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables']))
tables_map = {t.id: t for t in tables}
columns = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables_column']))
# Convert columns from type 'Image' to type 'Attachments'
affected_cols = [c for c in columns if c.type == 'Image']
conv = lambda val: [val] if isinstance(val, int) and val > 0 else []
if affected_cols:
# Update the types in the data tables
doc_actions.extend(
actions.ModifyColumn(tables_map[c.parentId].tableId, c.colId, {'type': 'Attachments'})
for c in affected_cols
)
# Update the values to lists
for c in affected_cols:
if c.isFormula:
# Formula columns don't have data stored in DB, should not have data changes.
continue
table_id = tables_map[c.parentId].tableId
table = tdset.all_tables[table_id]
doc_actions.append(
actions.BulkUpdateRecord(table_id, table.row_ids,
{c.colId: [conv(val) for val in table.columns[c.colId]]})
)
# Update the types in the metadata tables
doc_actions.append(actions.BulkUpdateRecord(
'_grist_Tables_column',
[c.id for c in affected_cols],
{'type': ['Attachments' for c in affected_cols]}
))
return tdset.apply_doc_actions(doc_actions)
@migration(schema_version=18)
def migration18(tdset):
return tdset.apply_doc_actions([
add_column('_grist_DocInfo', 'timezone', 'Text'),
# all documents prior to this migration have been created in New York
actions.UpdateRecord('_grist_DocInfo', 1, {'timezone': 'America/New_York'})
])
@migration(schema_version=19)
def migration19(tdset):
return tdset.apply_doc_actions([
add_column('_grist_Tables', 'onDemand', 'Bool'),
])
@migration(schema_version=20)
def migration20(tdset):
"""
Add _grist_Pages table and populate based on existing TableViews entries, ie: tables are sorted
alphabetically by their `tableId` and views are gathered within their corresponding table and
sorted by their id.
"""
tables = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Tables']))
table_map = {t.id: t for t in tables}
table_views = list(actions.transpose_bulk_action(tdset.all_tables['_grist_TableViews']))
# Old docs may include "Other views", not associated with any table. Don't include those in
# table_views_map: they'll get included but not sorted or grouped by tableId.
table_views_map = {tv.viewRef: table_map[tv.tableRef].tableId
for tv in table_views if tv.tableRef in table_map}
views = list(actions.transpose_bulk_action(tdset.all_tables['_grist_Views']))
def view_key(view):
"""
Returns ("Table1", 2) where "Table1" is the view's tableId and 2 the view id. For
primary view (ie: not referenced in _grist_TableViews) returns ("Table1", -1). Useful
to get the list of views sorted in the same way as in the Table side pane. We use -1
for primary view to make sure they come first among all the views of the same table.
"""
if view.id in table_views_map:
return (table_views_map[view.id], view.id)
# the name of primary view's is the same as the tableId
return (view.name, -1)
views.sort(key=view_key)
row_ids = range(1, len(views) + 1)
return tdset.apply_doc_actions([
actions.AddTable('_grist_Pages', [
schema.make_column('viewRef', 'Ref:_grist_Views'),
schema.make_column('pagePos', 'PositionNumber'),
schema.make_column('indentation', 'Int'),
]),
actions.ReplaceTableData('_grist_Pages', row_ids, {
'viewRef': [v.id for v in views],
'pagePos': row_ids,
'indentation': [1 if v.id in table_views_map else 0 for v in views]
})
])

@ -0,0 +1,258 @@
from datetime import datetime, timedelta, tzinfo as _tzinfo
from collections import namedtuple
import marshal
from time import time
import bisect
import itertools
import os
import moment_parse
import iso8601
# This is prepared by sandbox/install_tz.py
ZoneRecord = namedtuple("ZoneRecord", ("name", "abbrs", "offsets", "untils"))
# moment.py mirrors core functionality of moment-timezone.js
# moment.py includes function parse, located and documented in moment_parse.py
# Documentation: http://momentjs.com/timezone/docs/
EPOCH = datetime(1970, 1, 1)
DATE_EPOCH = EPOCH.date()
CURRENT_DATE = DATE_EPOCH + timedelta(seconds=time())
_TZDATA = None
# Returns a dictionary mapping timezone name to ZoneRecord object. It reads the data on first
# call, caches it, and returns cached data on all future calls.
def get_tz_data():
global _TZDATA # pylint: disable=global-statement
if _TZDATA is None:
all_zones = read_tz_raw_data()
# The marshalled data is an array of tuples (name, abbrs, offsets, untils)
_TZDATA = {x[0]: ZoneRecord._make(x) for x in all_zones}
return _TZDATA
# Reads and returns the marshalled tzdata file (produced by sandbox/install_tz.py).
# The return value is a list of tuples (name, abbrs, offsets, untils).
def read_tz_raw_data():
tzfile = os.path.join(os.path.dirname(__file__), "tzdata.data")
with open(tzfile, "rb") as tzdata:
return marshal.load(tzdata)
# Converts a UTC datetime to timestamp in milliseconds.
def utc_to_ts_ms(dt):
return (dt.replace(tzinfo=None) - EPOCH).total_seconds() * 1000
# Converts timestamp in seconds to datetime in the given timezone. If tzinfo is given, then zone
# is ignored and may be None.
def ts_to_dt(timestamp, zone, tzinfo=None):
return (EPOCH_UTC + timedelta(seconds=timestamp)).astimezone(tzinfo or zone.get_tzinfo(None))
# Converts datetime to timestamp in seconds. Optional timezone may be given to serve as the
# default if dt is unaware (has no associated timezone).
def dt_to_ts(dt, timezone=None):
offset = dt.utcoffset()
if offset is None:
offset = timezone.dt_offset(dt) if timezone else timedelta(0)
return (dt.replace(tzinfo=None) - offset - EPOCH).total_seconds()
# Converts timestamp in seconds to date.
def ts_to_date(timestamp):
return DATE_EPOCH + timedelta(seconds=timestamp)
# Converts date to timestamp of the midnight in seconds, in the given timezone, or UTC by default.
def date_to_ts(date, timezone=None):
ts = (date - DATE_EPOCH).total_seconds()
return ts if not timezone else ts - timezone.offset(ts * 1000).total_seconds()
# Calls parse from moment_parse.py
def parse(date_string, parse_format, zonelabel='UTC'):
return moment_parse.parse(date_string, parse_format, zonelabel)
# Parses a datetime in the ISO format, YYYY-MM-DDTHH:MM:SS.mmmmmm+HH:MM. Most parts are optional;
# see https://pypi.org/project/iso8601/ for details. Returns a timestamp in seconds.
def parse_iso(date_string, timezone=None):
dt = iso8601.parse_date(date_string, default_timezone=None)
return dt_to_ts(dt, timezone)
# Parses a date in ISO format, ignoring all time components. Returns timestamp of UTC midnight.
def parse_iso_date(date_string):
dt = iso8601.parse_date(date_string, default_timezone=None)
return date_to_ts(dt.date())
class tz(object):
"""Implements basics of moment.js and moment-timezone.js"""
# dt (datetime / number) - Either a local datetime in the time of the
# provided timezone or a timestamp since epoch in milliseconds.
# zonelabel (string) - The name of the timezone; should correspond to
# one of the names in the moment-timezone json data.
def __init__(self, dt, zonelabel="UTC"):
self._tzinfo = tzinfo(zonelabel)
if isinstance(dt, datetime):
timestamp = dt_to_ts(dt.replace(tzinfo=self._tzinfo)) * 1000
elif isinstance(dt, (float, int, long)):
timestamp = dt
else:
raise TypeError("'dt' should be a datetime object or a numeric type")
self.timestamp = timestamp
# Returns the timestamp in seconds
def timestamp_s(self):
return self.timestamp / 1000
# Changes the timezone to the one corresponding to 'zonelabel' without
# changing the underlying time since epoch.
def tz(self, zonelabel):
self._tzinfo = tzinfo(zonelabel)
return self
# Returns a datetime object with the moment-timezone object's local time and the timezone
# at the current timestamp.
def datetime(self):
return ts_to_dt(self.timestamp / 1000.0, None, self._tzinfo)
def zoneName(self):
return self._tzinfo.zone.name
def zoneAbbr(self):
return self._tzinfo.zone.abbr(self.timestamp)
def zoneOffset(self):
return self._tzinfo.zone.offset(self.timestamp)
class TzInfo(_tzinfo):
"""
Implements datetime.tzinfo interface using moment-timezone data. If favor_offset is used, it
tells which offset to favor when a datetime is ambiguous. If None, the offset that's in effect
earlier is favored.
"""
def __init__(self, zone, favor_offset):
super(TzInfo, self).__init__()
self.zone = zone
self._favor_offset = favor_offset
def utcoffset(self, dt):
"""Implementation of tzinfo.utcoffset interface."""
return self.zone.dt_offset(dt, self._favor_offset)
def tzname(self, dt):
"""Implementation of tzinfo.tzname interface."""
abbr = self.zone.dt_tzname(dt, self._favor_offset)
# tzname must return a string, not unicode.
return abbr.encode('utf8') if isinstance(abbr, unicode) else abbr
def dst(self, dt):
"""Implementation of tzinfo.dst interface."""
return self.utcoffset(dt) - self.zone.standard_offset
def fromutc(self, dt):
# This produces a datetime with a specific offset, and sets tzinfo that favors that offset.
offset = self.zone.offset(utc_to_ts_ms(dt))
return (dt + offset).replace(tzinfo=self.zone.get_tzinfo(offset))
def __repr__(self):
"""
Produces a friendly representation
>>> moment.tzinfo('America/New_York')
moment.tzinfo('America/New_York')
"""
return 'moment.tzinfo({!r})'.format(self.zone.name)
class Zone(object):
"""
Implements the zone object of moment-timezone.js, and contains the logic needed by TzInfo.
This is the class that interfaces directly with moment-timezone data.
"""
def __init__(self, zonelabel):
"""
Creates a Zone object for the given zonelabel, which must be a string key into the
moment-timezone json data.
"""
zone_data = get_tz_data()[zonelabel]
self.name = zonelabel
self.untils = zone_data.untils[:-1] # In ms. We omit the trailing None value.
self.abbrs = zone_data.abbrs
self.offsets = zone_data.offsets # Offsets in minutes.
self.standard_offset = timedelta(minutes=-self.offsets[0])
# "Until" times adjusted by the corresponding offsets. These are used in translating from
# datetime to absolute timestamp.
self.offset_untils = [until - offset * 60000 for (until, offset) in
itertools.izip(self.untils, self.offsets)]
# Cache of TzInfo objects for this Zone, used by get_tzinfo(). There could be multiple TzInfo
# objects, one for each possible offset, but their behavior only differs for ambiguous time.
self._tzinfo = {}
def dt_offset(self, dt, favor_offset=None):
"""Returns the timedelta for timezone offset east of UTC at the given datetime."""
i = self._index_dt(dt, favor_offset)
return timedelta(minutes = -self.offsets[i])
def dt_tzname(self, dt, favor_offset=None):
"""Returns the timezone abbreviation (e.g. EST or EDT) at the given datetime."""
i = self._index_dt(dt, favor_offset)
return self.abbrs[i]
def offset(self, timestamp_ms):
"""Returns the timedelta for timezone offset east of UTC at the given ms timestamp."""
i = self._index(timestamp_ms)
return timedelta(minutes = -self.offsets[i])
def abbr(self, timestamp_ms):
"""Returns the timezone abbreviation (e.g. EST or EDT) at the given ms timestamp."""
i = self._index(timestamp_ms)
return self.abbrs[i]
def _index(self, timestamp):
"""Helper to return the index into the offsets data corresponding to the given timestamp."""
return bisect.bisect_right(self.untils, timestamp)
def _index_dt(self, dt, favor_offset):
"""
Helper to return the index into the offsets data corresponding to the given datetime.
In case of ambiguous dates, will favor the given favor_offset. If it is None or doesn't match
the later of the two offsets, will use the offset that's was in effect earlier.
"""
timestamp = utc_to_ts_ms(dt)
i = bisect.bisect_right(self.offset_untils, timestamp)
if i < len(self.offset_untils) and timestamp >= self.untils[i] - self.offsets[i + 1] * 60000:
# We have an ambiguous time and can use self.offsets[i] or self.offsets[i + 1]. If
# favor_offset matches the later offset, use that. Otherwise, prefer the earlier one.
if timedelta(minutes=-self.offsets[i + 1]) == favor_offset:
return i + 1
return i
def get_tzinfo(self, favor_offset):
"""
Returns a TzInfo object for this Zone that favors the given offset in case of ambiguity.
If favor_offset is none, ambiguous times are resolved to the offset that comes into effect
earlier. This is used with a particular offset by TzInfo.fromutc() method, which is part of
implementation of TzInfo.astimezone(). We distinguish ambiguous times by using TzInfo variants
that favor one offset or another for different meanings of the ambiguous times.
"""
return (self._tzinfo.get(favor_offset) or
self._tzinfo.setdefault(favor_offset, TzInfo(self, favor_offset)))
_zone_cache = {}
def get_zone(zonelabel):
"""Returns Zone(zonelabel), with caching."""
return (_zone_cache.get(zonelabel) or
_zone_cache.setdefault(zonelabel, Zone(zonelabel)))
def tzinfo(zonelabel, favor_offset=None):
"""
Returns TzInfo instance for zonelabel, with the optional favor_offset (mainly for internal use
by astimezone via fromutc).
"""
return get_zone(zonelabel).get_tzinfo(favor_offset)
# Some more globals that rely on the machinery above.
TZ_UTC = tzinfo('UTC')
EPOCH_UTC = EPOCH.replace(tzinfo=TZ_UTC) # Same as EPOCH, but an "aware" instance.

@ -0,0 +1,159 @@
import re
from collections import OrderedDict
from datetime import datetime
import moment
# Regex list of lowercase months with characters after the first three made optional
MONTH_NAMES = ['january', 'february', 'march', 'april', 'may', 'june', 'july', 'august',
'september', 'october', 'november', 'december']
MONTHS = [m[:3]+"(?:"+m[3:]+")?" if len(m) > 3 else m[:3] for m in MONTH_NAMES]
# Regex list of lowercase weekdays with characters after the first three made optional
DAY_NAMES = ['sunday', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday']
WEEKDAYS = [d[:3]+"(?:"+d[3:]+")?" for d in DAY_NAMES]
# Acceptable format tokens mapped to what they should match in the date string
# Ordered so that larger configurations are matched first
DATE_TOKENS = OrderedDict([
("HH", r"(?P<H>\d{1,2})"), # 24 hr
("H", r"(?P<H>\d{1,2})"),
("hh", r"(?P<h>\d{1,2})"), # 12 hr
("h", r"(?P<h>\d{1,2})"),
("mm", r"(?P<m>\d{1,2})"), # min
("m", r"(?P<m>\d{1,2})"),
("A", r"(?P<A>[ap]m?)"), # am/pm
("a", r"(?P<A>[ap]m?)"),
("ss", r"(?P<s>\d{1,2})"), # sec
("s", r"(?P<s>\d{1,2})"),
("SSSSSS", r"(?P<S>\d{1,6})"), # fractional second
("SSSSS", r"(?P<S>\d{1,6})"),
("SSSS", r"(?P<S>\d{1,6})"),
("SSS", r"(?P<S>\d{1,6})"),
("SS", r"(?P<S>\d{1,6})"),
("S", r"(?P<S>\d{1,6})"),
("YYYY", r"(?P<YY>\d{4}|\d{2})"), # 4 or 2 digit year
("YY", r"(?P<YY>\d{2})"), # 2 digit year
("MMMM", r"(?P<MMM>" + ("|".join(MONTHS)) + ")"), # month name, abbr or not
("MMM", r"(?P<MMM>" + ("|".join(MONTHS)) + ")"),
("MM", r"(?P<M>\d{1,2})"), # month num
("M", r"(?P<M>\d{1,2})"),
("DD", r"(?P<D>\d{1,2})"), # day num
("Do", r"(?P<D>\d{1,2})(st|nd|rd|th)"),
("D", r"(?P<D>\d{1,2})"),
("dddd", r"(" + ("|".join(WEEKDAYS)) + ")"), # day name, abbr or not (ignored)
("ddd", r"(" + ("|".join(WEEKDAYS)) + ")")
])
DATE_TOKENS_REGEX = re.compile("("+("|".join(DATE_TOKENS))+")")
# List of separators to replace and match any standard date/time separators
SEP = r"[\s/.\-:,]*"
SEP_REGEX = re.compile(SEP)
# Maps date parse format to compile regex
FORMAT_CACHE = {}
# Parses date_string using parse_format in the style of moment.js
# See: http://momentjs.com/docs/#/parsing
# Supports the following tokens:
# H HH 0..23 24 hour time
# h hh 1..12 12 hour time used with a A.
# a A am pm Post or ante meridiem
# m mm 0..59 Minutes
# s ss 0..59 Seconds
# S SS SSS 0..999 Fractional seconds
# YYYY 2014 4 or 2 digit year
# YY 14 2 digit year
# M MM 1..12 Month number
# MMM MMMM Jan..December Month name in locale set by moment.locale()
# D DD 1..31 Day of month
# Do 1st..31st Day of month with ordinal
def parse(date_string, parse_format, zonelabel='UTC', override_current_date=None):
"""Parse a date string via a moment.js style parse format and a timezone string.
Supported tokens are documented above. Returns seconds since epoch"""
if parse_format in FORMAT_CACHE:
# Check if parse_format has been cache, and retrieve if so
parser = FORMAT_CACHE[parse_format]
else:
# e.g. "MM-YY" -> "(?P<mm>\d{1,2})-(?P<yy>\d{2})"
# Note that DATE_TOKENS is ordered so that the longer letter chains are recognized first
tokens = DATE_TOKENS_REGEX.split(parse_format)
tokens = [DATE_TOKENS[t] if t in DATE_TOKENS else SEP_REGEX.sub(SEP, t) for t in tokens]
# Compile new token string ignoring case (for month names)
parser = re.compile(''.join(tokens), re.I)
FORMAT_CACHE[parse_format] = parser
match = parser.match(date_string)
# Throw error if matching failed
if match is None:
raise Exception("Failed to parse %s with %s" % (date_string, parse_format))
# Create datetime from the results of parsing
current_date = override_current_date or moment.CURRENT_DATE
m = match.groupdict()
dt = datetime(
year=getYear(m, current_date.year),
month=getMonth(m, current_date.month),
day=int(m['D']) if ('D' in m) else current_date.day,
hour=getHour(m),
minute=int(m['m']) if ('m' in m) else 0,
second=int(m['s']) if ('s' in m) else 0,
microsecond=getMicrosecond(m)
)
# Parses the datetime with the given timezone to return the seconds since EPOCH
return moment.tz(dt, zonelabel).timestamp_s()
def getYear(match_dict, current_year):
if 'YYYY' in match_dict:
return int(match_dict['YYYY'])
elif 'YY' in match_dict:
match = match_dict['YY']
if len(match) == 2:
# Must guess on the century, choose so the result is closest to the current year
# The first year that could be meant by YY is the current year - 50.
first = current_year - 50
# We are seeking k such that 100k + YY is between first and first + 100.
# first <= 100k + YY < first + 100
# 0 <= 100k + YY - first < 100
# The value inside the comparison operators is precisely (YY - first) % 100.
# So we can calculate the century 100k as (YY - first) % 100 - (YY - first).
return first + (int(match) - first) % 100
else:
return int(match)
else:
return current_year
def getMonth(match_dict, current_month):
if 'M' in match_dict:
return int(match_dict['M'])
elif 'MMM' in match_dict:
return lazy_index(MONTHS, match_dict['MMM'][:3].lower()) + 1
else:
return current_month
def getHour(match_dict):
if 'H' in match_dict:
return int(match_dict['H'])
elif 'h' in match_dict:
hr = int(match_dict['h']) % 12
merid = 12 if 'A' in match_dict and match_dict['A'][0] == "p" else 0
return hr + merid
else:
return 0
def getMicrosecond(match_dict):
if 'S' in match_dict:
match = match_dict['S']
return int(match + ("0"*(6-len(match))) if len(match) < 6 else match[:6])
else:
return 0
# Gets the index of the first string from iter that starts with startswith
def lazy_index(l, startswith, missing=None):
for i, token in enumerate(l):
if token[:len(startswith)] == startswith:
return i
return missing

@ -0,0 +1,375 @@
"""
This module implements handling of non-primitive objects as values in Grist data cells. It is
currently only used to handle errors thrown from formulas.
Non-primitive values are represented in actions as [type_name, args...].
objtypes.register_converter() - registers a new supported object type.
objtypes.encode_object(obj) - returns a marshallable list representation.
objtypes.decode_object(val) - returns an object represented by the [name, args...] argument.
If an object cannot be encoded or decoded, a RaisedError exception is encoded or returned instead.
In a formula, this would cause an exception to be raised.
"""
import marshal
import exceptions
import traceback
from datetime import date, datetime
import moment
import records
class UnmarshallableError(ValueError):
"""
Error raised when an object cannot be represented in an action by Grist. It happens if the
object is of a type for which there is no registered converter, or if encoding it involves
values that cannot be marshalled.
"""
pass
class ConversionError(ValueError):
"""
Indicates a failure to convert a value between Grist types. We don't usually expose it to the
user, since such a failure normally results in silent alttext.
"""
pass
class InvalidTypedValue(ValueError):
"""
Indicates that AltText was in place of a typed value and produced an error. The value of AltText
is included into the exception, both to be more informative, and to sort displayCols properly.
"""
def __init__(self, typename, value):
super(InvalidTypedValue, self).__init__(typename)
self.typename = typename
self.value = value
def __str__(self):
return "Invalid %s: %s" % (self.typename, self.value)
_max_js_int = 1<<31
def is_int_short(value):
return -_max_js_int <= value < _max_js_int
def check_marshallable(value):
"""
Raises UnmarshallableError if value cannot be marshalled.
"""
if isinstance(value, (str, unicode, float, bool)) or value is None:
# We don't need to marshal these to know they are marshallable.
return
if isinstance(value, (long, int)):
# Ints are also marshallable, except that we only support 32-bit ints on JS side.
if not is_int_short(value):
raise UnmarshallableError("Integer too large")
return
# Other things we need to try to know.
try:
marshal.dumps(value)
except Exception as e:
raise UnmarshallableError(str(e))
def is_marshallable(value):
"""
Returns a boolean for whether the value can be marshalled.
"""
try:
check_marshallable(value)
return True
except Exception:
return False
# Maps of type or name to (type, name, converter) tuple.
_registered_converters_by_name = {}
_registered_converters_by_type = {}
def register_converter_by_type(type_, converter_func):
assert type_ not in _registered_converters_by_type
_registered_converters_by_type[type_] = converter_func
def register_converter_by_name(converter, type_, name):
assert name not in _registered_converters_by_name
_registered_converters_by_name[name] = (type_, name, converter)
def register_converter(converter, type_, name=None):
"""
Register a new converter for the given type, with the given name (defaulting to type.__name__).
The converter must implement methods:
converter.encode_args(obj) - should return [args...] as a python list of
marshallable arguments.
converter.decode_args(type, arglist) - should return obj of type `type`.
It's up to the converter to ensure that converter.decode_args(type(obj),
converter.encode_args(obj)) returns a value equivalent to the original obj.
"""
if name is None:
name = type_.__name__
register_converter_by_name(converter, type_, name)
register_converter_by_type(type_, _encode_obj_impl(converter, name))
def deregister_converter(name):
"""
De-register a named converter if previously registered.
"""
prev = _registered_converters_by_name.pop(name, None)
if prev:
del _registered_converters_by_type[prev[0]]
def encode_object(obj):
"""
Given an object, returns [typename, args...] array of marshallable values, which should be
sufficient to reconstruct `obj`. Given a primitive object, returns it unchanged.
If obj failed to encode, yields an encoding for RaisedException(UnmarshallableError, message).
I.e. on reading this back, and using the value, we'll get UnmarshallableError exception.
"""
try:
t = type(obj)
try:
converter = (
_registered_converters_by_type.get(t) or
_registered_converters_by_type[getattr(t, '_objtypes_converter_type', t)])
except KeyError:
raise UnmarshallableError("No converter for type %s" % type(obj))
return converter(obj)
except Exception as e:
# Don't risk calling encode_object recursively; instead encode a RaisedException error
# manually with arguments that ought not fail.
return ["E", "UnmarshallableError", str(e), repr(obj)]
def decode_object(value):
"""
Given a value of the form [typename, args...], returns an object represented by it. If typename
is unknown, or construction fails for any reason, returns (not raises!) RaisedException with
original exception in its .error property.
"""
if not isinstance(value, (tuple, list)):
return value
try:
name = value[0]
args = value[1:]
try:
type_, _, converter = _registered_converters_by_name[name]
except KeyError:
raise KeyError("Unknown object type %r" % name)
return converter.decode_args(type_, args)
except Exception as e:
return RaisedException(e)
class SelfConverter(object):
"""
Converter for objects that implement the converter interface:
self.encode_args() - should return a list of marshallable arguments.
cls.decode_args(args...) - should return an instance given the arguments from encode_args.
"""
@classmethod
def encode_args(cls, obj):
return obj.encode_args()
@classmethod
def decode_args(cls, type_, args):
return type_.decode_args(*args)
#----------------------------------------------------------------------
# Implementations of encoding objects. For basic types, there is nothing to encode, but for
# integers, we check that they are in JS range.
def _encode_obj_impl(converter, name):
def inner(obj):
try:
args = converter.encode_args(obj)
except Exception:
raise UnmarshallableError("Encoding of %s failed" % name)
for arg in args:
check_marshallable(arg)
return [name] + args
return inner
def _encode_identity(value):
return value
def _encode_integer(value):
if not is_int_short(value):
raise UnmarshallableError("Integer too large")
return value
register_converter_by_type(str, _encode_identity)
register_converter_by_type(unicode, _encode_identity)
register_converter_by_type(float, _encode_identity)
register_converter_by_type(bool, _encode_identity)
register_converter_by_type(type(None), _encode_identity)
register_converter_by_type(long, _encode_integer)
register_converter_by_type(int, _encode_integer)
#----------------------------------------------------------------------
class RaisedException(object):
"""
RaisedException is a special type of object which indicates that a value in a cell isn't a plain
value but an exception to be raised. All caught exceptions are wrapped in RaisedException. The
original exception is saved in the .error attribute. The traceback is saved in .details
attribute only when needed (flag include_details is set).
RaisedException is registered under a special short name ("E") to save bytes since it's such a
widely-used wrapper. To encode_args, it simply returns the entire encoded stored error, e.g.
RaisedException(ValueError("foo")) is encoded as ["E", "ValueError", "foo"].
"""
def __init__(self, error, include_details=False):
self.error = error
self.details = traceback.format_exc() if include_details else None
def encode_args(self):
# TODO: We should probably return all args, to communicate the error details to the browser
# and to DB (for when we store formula results). There are two concerns: one is that it's
# potentially quite verbose; the other is that it's makes the tests more annoying (again b/c
# verbose).
if self.details:
return [type(self.error).__name__, str(self.error), self.details]
if isinstance(self.error, InvalidTypedValue):
return [type(self.error).__name__, self.error.typename, self.error.value]
return [type(self.error).__name__]
@classmethod
def decode_args(cls, *args):
return cls(decode_object(args))
def __eq__(self, other):
return isinstance(other, type(self)) and self.encode_args() == other.encode_args()
def __ne__(self, other):
return not self.__eq__(other)
class ExceptionConverter(object):
"""
Converter for any type derived from BaseException. On encoding it returns the exception object's
.args attribute, and uses them on decoding as constructor arguments to instantiate the error.
"""
@classmethod
def encode_args(cls, obj):
return list(getattr(obj, 'args', ()))
@classmethod
def decode_args(cls, type_, args):
return type_(*args)
# Register all Exceptions as valid types that can be handled by Grist.
for _, my_type in exceptions.__dict__.iteritems():
if isinstance(my_type, type) and issubclass(my_type, BaseException):
register_converter(ExceptionConverter, my_type)
# Register the special exceptions we defined.
register_converter(ExceptionConverter, UnmarshallableError)
register_converter(ExceptionConverter, ConversionError)
# Register the special wrapper class for raised exceptions with a custom short name.
register_converter(SelfConverter, RaisedException, "E")
class RecordList(list):
"""
Just like list but allows setting custom attributes, which we use for remembering _group_by and
_sort_by attributes when storing RecordSet as usertypes.ReferenceList type.
"""
def __init__(self, row_ids, group_by=None, sort_by=None):
list.__init__(self, row_ids)
self._group_by = group_by
self._sort_by = sort_by
def __repr__(self):
return "RecordList(%r, group_by=%r, sort_by=%r)" % (
list.__repr__(self), self._group_by, self._sort_by)
class ListConverter(object):
"""
Converter for the 'list' type.
"""
@classmethod
def encode_args(cls, obj):
return obj
@classmethod
def decode_args(cls, type_, args):
return type_(args)
# Register a converter for lists, also with a custom short name. It is used, in particular, for
# ReferenceLists. The first line ensures RecordLists are encoded as just lists; the second line
# overrides the decoding of 'L', so that it always decodes to a plain list, since for now, at
# least, there is no need to accept incoming RecordLists.
register_converter_by_type(RecordList, _encode_obj_impl(ListConverter, "L"))
register_converter(ListConverter, list, "L")
class DateTimeConverter(object):
"""
Converter for the 'datetime.datetime' type.
"""
@classmethod
def encode_args(cls, obj):
return [moment.dt_to_ts(obj), obj.tzinfo.zone.name]
@classmethod
def decode_args(cls, _type, args):
return moment.ts_to_dt(args[0], moment.Zone(args[1]))
# Register a converter for dates, also with a custom short name.
register_converter(DateTimeConverter, datetime, "D")
class DateConverter(object):
"""
Converter for the 'datetime.date' type.
"""
@classmethod
def encode_args(cls, obj):
return [moment.date_to_ts(obj)]
@classmethod
def decode_args(cls, _type, args):
return moment.ts_to_date(args[0])
register_converter(DateConverter, date, "d")
# We don't currently have a good way to convert an incoming marshalled record to a proper Record
# object for an appropriate table. We don't expect incoming marshalled records at all, but if such
# a thing happens, we'll construct this RecordStub.
class RecordStub(object):
def __init__(self, table_id, row_id):
self.table_id = table_id
self.row_id = row_id
class RecordConverter(object):
"""
Converter for 'record.Record' objects.
"""
@classmethod
def encode_args(cls, obj):
return [obj._table.table_id, obj._row_id]
@classmethod
def decode_args(cls, _type, args):
return RecordStub(args[0], args[1])
# When marshalling any subclass of Record in objtypes.py, we'll use the base Record as the type.
records.Record._objtypes_converter_type = records.Record
register_converter(RecordConverter, records.Record, "R")

@ -0,0 +1,167 @@
"""
Implements the base classes for Record and RecordSet objects used to represent records in Grist
tables. Individual tables use derived versions of these, which add per-column properties.
"""
import functools
@functools.total_ordering
class Record(object):
"""
Name: Record, rec
A Record represents a record of data. It is the primary means of accessing values in formulas. A
Record for a particular table has a property for each data and formula column in the table.
In a formula, `$field` is translated to `rec.field`, where `rec` is the Record for which the
formula is being evaluated.
For example:
```
def Full_Name(rec, table):
return rec.First_Name + ' ' + rec.LastName
def Name_Length(rec, table):
return len(rec.Full_Name)
```
"""
# Some documentation for method-like parts of Record, which aren't actually methods.
_DOC_EXTRA = (
"""
Name: $Field, rec.Field
Usage: __$__*Field* or __rec__*.Field*
Access the field named "Field" of the current record. E.g. `$First_Name` or `rec.First_Name`.
""",
"""
Name: $group, rec.group
Usage: __$group__
In a summary view, `$group` is a special field containing the list of Records that are
summarized by the current summary line. E.g. `len($group)` is the count of those records.
See [RecordSet](#recordset) for useful properties offered by the returned object.
Examples:
```
sum($group.Amount) # Sum of the Amount field in the matching records
sum(r.Amount for r in $group) # Same as sum($group.Amount)
sum(r.Amount for r in $group if r > 0) # Sum of only the positive amounts
sum(r.Shares * r.Price for r in $group) # Sum of shares * price products
```
"""
)
# Record is always a thin class, containing essentially a reference to a row in the table. The
# properties to access individual fields of a row are provided in per-table derived classes.
def __init__(self, table, row_id, relation=None):
"""
Creates a Record object.
table - Table object, in which this record lives.
row_id - The ID of the record within table.
relation - Relation object for how this record was obtained; used in dependency tracking.
"""
self._table = table
self._row_id = row_id
self._source_relation = relation or table._identity_relation
def _get_col(self, col_id):
return self._table._get_col_value(col_id, self._row_id, self._source_relation)
# Look up a property of the record. Internal properties are simple.
# For columns, we explicitly check that we have them before attempting to access.
# Otherwise AttributeError is ambiguous - it could be because we don't have the
# column, or because the column threw an AttributeError when evaluated.
def __getattr__(self, name):
if name in self._table.all_columns:
return self._get_col(name)
return self._table._attribute_error(name, self._source_relation)
def __hash__(self):
return hash((self._table, self._row_id))
def __eq__(self, other):
return (isinstance(other, Record) and
(self._table, self._row_id) == (other._table, other._row_id))
def __ne__(self, other):
return not self.__eq__(other)
def __lt__(self, other):
return (self._table.table_id, self._row_id) < (other._table.table_id, other._row_id)
def __int__(self):
return self._row_id
def __nonzero__(self):
return bool(self._row_id)
def __repr__(self):
return "%s[%s]" % (self._table.table_id, self._row_id)
def _clone_with_relation(self, src_relation):
return self.__class__(self._table, self._row_id,
relation=src_relation.compose(self._source_relation))
class RecordSet(object):
"""
A RecordSet represents a collection of records, as returned by `Table.lookupRecords()` or
`$group` property in summary views.
A RecordSet allows iterating through the records:
```
sum(r.Amount for r in Students.lookupRecords(First_Name="John", Last_Name="Doe"))
min(r.DueDate for r in Tasks.lookupRecords(Owner="Bob"))
```
RecordSets also provide a convenient way to access the list of values for a particular field for
all the records, as `record_set.Field`. For example, the examples above are equivalent to:
```
sum(Students.lookupRecords(First_Name="John", Last_Name="Doe").Amount)
min(Tasks.lookupRecords(Owner="Bob").DueDate)
```
You can get the number of records in a RecordSet using `len`, e.g. `len($group)`.
"""
def __init__(self, table, row_ids, relation=None, group_by=None, sort_by=None):
"""
group_by may be a dictionary mapping column names to values that are all the same for the given
RecordSet. sort_by may be the column name used for sorting this record set. Both are set by
lookupRecords, and used when using RecordSet to insert new records.
"""
self._table = table
self._row_ids = row_ids
self._source_relation = relation or table._identity_relation
# If row_ids is itself a RecordSet, default to its _group_by and _sort_by properties.
self._group_by = group_by or getattr(row_ids, '_group_by', None)
self._sort_by = sort_by or getattr(row_ids, '_sort_by', None)
def __len__(self):
return len(self._row_ids)
def __nonzero__(self):
return bool(self._row_ids)
def __iter__(self):
for row_id in self._row_ids:
yield self.Record(self._table, row_id, self._source_relation)
def get_one(self):
row_id = min(self._row_ids) if self._row_ids else 0
return self.Record(self._table, row_id, self._source_relation)
def _get_col(self, col_id):
return self._table._get_col_subset(col_id, self._row_ids, self._source_relation)
def __getattr__(self, name):
if name in self._table.all_columns:
return self._get_col(name)
return self._table._attribute_error(name, self._source_relation)
def _clone_with_relation(self, src_relation):
return self.__class__(self._table, self._row_ids,
relation=src_relation.compose(self._source_relation),
group_by=self._group_by,
sort_by=self._sort_by)

@ -0,0 +1,331 @@
"""
This module is used in the implementation of ordering of records in Grist. Order is maintained
using floating-point "positions". E.g. inserting a record will normally add a record with position
being the average of its neighbor's positions.
The difficulty is that it's possible (and sometimes easy) to get floats closer and closer
together, until they are too close (and average of neighbors is equal to one of them). This
requires adjusting existing positions.
This problem is known in computer science as the List-Labeling Problem. There are known algorithms
which maintain ordered labels using fixed number of bits. We use an approach that requires
amortized log(N) relabelings per insert.
For references:
[Wikipedia] https://en.wikipedia.org/wiki/Order-maintenance_problem
The Wikipedia article describes in particular an approach using Scapegoat Trees.
[Bender] http://erikdemaine.org/papers/DietzSleator_ESA2002/paper.pdf
This paper by Bender et al is the best I found that describes the theory and a reasonably
simple solution that doesn't require explicit trees. This is what we rely on here.
What complicates our approach is that inserts never modify positions directly; instead, when we
have items to insert, we need to prepare adjustments (both to new and existing positions), which
are then turned into DocActions to be communicated and applied (both in memory and in storage).
The interface offered by this class is a single `prepare_inserts()` function, which takes a sorted
list and a list of keys, and returns the adjustments to existing records and to the new keys.
Note that we rely heavily here on availability of a sorted container, for which we use the
sortedcontainers module from here:
http://www.grantjenks.com/docs/sortedcontainers/sortedlist.html
https://github.com/grantjenks/sorted_containers
Note also that unlike the original paper we deal with floats rather than integers. This is to
maximize the number of usable bits, since other parts of the system (namely Javascript) don't
support 64-bits integers. We also avoid renumbering everything when we double the number of
elements. The changes aren't vetted theoretically, and may break some conclusions from the paper.
Throughout this file, "key" refers to the floating point value that's called a "label" in
list-labeling papers, "position" elsewhere in Grist code, and "key" in sortedcontainers docs.
"""
import bisect
import itertools
import math
import struct
from sortedcontainers import SortedList, SortedListWithKey
def prepare_inserts_dumb(sortedlist, keys):
"""
This is the dumb implementation of repositioning: whenever we don't have enough space to insert
keys, just renumber everything 1 through N.
"""
# It's still a bit tricky to do this because we need to return adjustments to existing and new
# keys, without actually inserting and renumbering.
ins_groups, ungroup_func = _group_insertions(sortedlist, keys)
insertions = []
adjustments = []
def get_endpoints(index, count):
before = sortedlist._key(sortedlist[index - 1]) if index > 0 else 0.0
after = (sortedlist._key(sortedlist[index])
if index < len(sortedlist) else before + count + 1)
return (before, after)
def is_valid_insert(index, count):
before, after = get_endpoints(index, count)
return is_valid_range(before, get_range(before, after, count), after)
if all(is_valid_insert(index, ins_count) for index, ins_count in ins_groups):
for index, ins_count in ins_groups:
before, after = get_endpoints(index, ins_count)
insertions.extend(get_range(before, after, ins_count))
else:
next_key = 1.0
prev_index = 0
# Complete the renumbering by forcing an extra empty group at the end.
ins_groups.append((len(sortedlist), 0))
for index, ins_count in ins_groups:
adj_count = index - prev_index
adjustments.extend(itertools.izip(xrange(prev_index, index),
frange_from(next_key, adj_count)))
next_key += adj_count
insertions.extend(frange_from(next_key, ins_count))
next_key += ins_count
prev_index = index
return adjustments, ungroup_func(insertions)
def prepare_inserts(sortedlist, keys):
"""
Takes a SortedListWithKey and a list of keys to insert. The keys should be floats.
Returns two lists: [(index, new_key), ...], [new_keys...]
The first list contains pairs for existing items in sortedlist that need to be adjusted to have
new keys (these will not change the ordering). The second is a list of new keys to use in place
of keys. To avoid reorderings, adjustments should be applied before insertions.
"""
worklist = ListWithAdjustments(sortedlist)
ins_groups, ungroup_func = _group_insertions(sortedlist, keys)
for index, ins_count in ins_groups:
worklist.prep_inserts_at_index(index, ins_count)
return worklist.get_adjustments(), ungroup_func(worklist.get_insertions())
def _group_insertions(sortedlist, keys):
"""
Given a list of keys to insert into sortedlist, returns the pair:
[(index, count), ...] pairs for how many items to insert immediately before each index.
ungroup(new_keys): a function that rearranges new keys to match the original keys.
"""
# We'll go through keys to insert in increasing order, to process consecutive keys together.
ins_keys = sorted((key, i) for i, key in enumerate(keys))
# We group by the index at which a new key is to be inserted.
ins_groups = [(index, len(list(ins_iter))) for index, ins_iter in
itertools.groupby(ins_keys, key=lambda pair: sortedlist.bisect_key_left(pair[0]))]
indices = [i for key, i in ins_keys]
def ungroup(new_keys):
return [key for _, key in sorted(zip(indices, new_keys))]
return ins_groups, ungroup
def frange_from(start, count):
return [start + i for i in xrange(count)]
def nextfloat(x):
"""
Returns the next representable float after the float x. This is useful to indicate insertions
AFTER ane existing element.
(See http://stackoverflow.com/a/10426033/328565 for implementation info).
"""
n = struct.unpack('<q', struct.pack('<d', x or 0.0))[0]
n += (1 if n >= 0 else -1)
return struct.unpack('<d', struct.pack('<q', n))[0]
def prevfloat(x):
n = struct.unpack('<q', struct.pack('<d', x or 0.0))[0]
n -= (1 if n >= 0 else -1)
return struct.unpack('<d', struct.pack('<q', n))[0]
class ListWithAdjustments(object):
"""
To prepare inserts, we adjust elements to be inserted and elements in the underlying list. We
don't want to actually touch the underlying list, but we need to remember the adjustments,
because later adjustments may depend on and readjust earlier ones.
"""
def __init__(self, orig_list):
"""
Orig_list must be a a SortedListWithKey.
"""
self._orig_list = orig_list
self._key = orig_list._key
# Stores pairs (i, new_key) where i is an index into orig_list.
# Note that adjustments don't affect the order in the original list, so the list is sorted
# both on keys an on indices; and a missing index i means that (i, orig_key) fits into the
# adjustments list both by key and by index.
self._adjustments = SortedListWithKey(key=lambda pair: pair[1])
# Stores keys for new insertions.
self._insertions = SortedList()
def get_insertions(self):
return self._insertions
def get_adjustments(self):
return self._adjustments
def _adj_bisect_key_left(self, key):
"""
Works as bisect_key_left(key) on the orig_list as if all adjustments have been applied.
"""
adj_index = self._adjustments.bisect_key_left(key)
adj_next = (self._adjustments[adj_index][0] if adj_index < len(self._adjustments)
else len(self._orig_list))
adj_prev = self._adjustments[adj_index - 1][0] if adj_index > 0 else -1
orig_index = self._orig_list.bisect_key_left(key)
if adj_prev < orig_index and orig_index < adj_next:
return orig_index
return adj_next
def _adj_get_key(self, index):
"""
Returns the key corresponding to the given index into orig_list as if all adjustments have
been applied.
"""
i = bisect.bisect_left(self._adjustments, (index, float('-inf')))
if i < len(self._adjustments) and self._adjustments[i][0] == index:
return self._adjustments[i][1]
return self._key(self._orig_list[index])
def count_range(self, begin, end):
"""
Returns the number of elements with keys in the half-open interval [begin, end).
"""
adj_begin = self._adj_bisect_key_left(begin)
adj_end = self._adj_bisect_key_left(end)
ins_begin = self._insertions.bisect_left(begin)
ins_end = self._insertions.bisect_left(end)
return (adj_end - adj_begin) + (ins_end - ins_begin)
def _adjust_range(self, begin, end):
"""
Make changes to stored adjustments and insertions to distribute them equally in the half-open
interval of keys [begin, end).
"""
adj_begin = self._adj_bisect_key_left(begin)
adj_end = self._adj_bisect_key_left(end)
ins_begin = self._insertions.bisect_left(begin)
ins_end = self._insertions.bisect_left(end)
self._do_adjust_range(adj_begin, adj_end, ins_begin, ins_end, begin, end)
def _adjust_all(self):
"""
Renumber everything to be equally distributed in the open interval (new_begin, new_end).
"""
orig_len = len(self._orig_list)
ins_len = len(self._insertions)
self._do_adjust_range(0, orig_len, 0, ins_len, 0.0, orig_len + ins_len + 1.0)
def _do_adjust_range(self, adj_begin, adj_end, ins_begin, ins_end, new_begin_key, new_end_key):
"""
Implements renumbering as used by _adjust_range() and _adjust_all().
"""
count = (adj_end - adj_begin) + (ins_end - ins_begin)
prev_keys = ([(self._adj_get_key(i), False, i) for i in xrange(adj_begin, adj_end)] +
[(self._insertions[i], True, i) for i in xrange(ins_begin, ins_end)])
prev_keys.sort()
new_keys = get_range(new_begin_key, new_end_key, count)
for (old_key, is_insert, i), new_key in itertools.izip(prev_keys, new_keys):
if is_insert:
self._insertions.remove(old_key)
self._insertions.add(new_key)
else:
# (i, old_key) pair may not be among _adjustments, so we discard() rather than remove().
self._adjustments.discard((i, old_key))
self._adjustments.add((i, new_key))
def prep_inserts_at_index(self, index, count):
# This is the crux of the algorithm, inspired by the [Bender] paper (cited above).
# Here's a brief summary of the algorithm, and of our departures from it.
# - The algorithm inserts keys while it is able. When there isn't enough space, it walks
# enclosing intervals around the key it wants to insert, doubling the interval each time,
# until it finds an interval that doesn't overflow. The overflow threshold is calculated in
# such a way that the bigger the interval, the smaller the density it seeks.
# - The algorithm uses integers, picking the number of bits to work for list length between
# n/2 and 2n, and rebuilding from scratch any time length moves out of this range. We don't
# rebuild anything, don't change number of bits, and use floats. This breaks some of the
# theoretical results, and thinking about floats is much harder than about integers. So we
# are not on particularly solid ground with these changes (but it seems to work).
# - We try different thresholds, which seems to perform better. This is mentioned in "Variable
# T" section of [Bender] paper, but our approach isn't quite the same. So it's also on shaky
# theoretical ground.
assert count > 0
begin = self._adj_get_key(index - 1) if index > 0 else 0.0
end = self._adj_get_key(index) if index < len(self._orig_list) else begin + count + 1
if begin < 0 or end <= 0 or math.isinf(max(begin, end)):
# This should only happen if we have some invalid positions (e.g. from before we started
# using this logic). In this case, just renumber everything 1 through n (leaving space so
# that the count insertions take the first count integers).
self._insertions.update([begin if index > 0 else float('-inf')] * count)
self._adjust_all()
return
self._insertions.update(get_range(begin, end, count))
if not is_valid_range(begin, self._insertions.irange(begin, end), end):
assert self.count_range(begin, end) > 0
min_key, max_key = self._find_sparse_enough_range(begin, end)
self._adjust_range(min_key, max_key)
assert is_valid_range(begin, self._insertions.irange(begin, end), end)
def _find_sparse_enough_range(self, begin, end):
# frac is a parameter used for relabeling, corresponding to 2/T in [Bender]. Its
# interpretation is that frac^i is the overflow limit for intervals of size 2^i.
for frac in (1.14, 1.3):
thresh = 1
for i in xrange(64):
rbegin, rend = range_around_float(begin, i)
assert self.count_range(rbegin, rend) > 0
if end <= rend and self.count_range(rbegin, rend) < thresh:
return (rbegin, rend)
thresh *= frac
raise ValueError("This isn't expected")
def is_valid_range(begin, iterable, end):
"""
Return true if all inserted keys in the range [begin, end] are distinct, and different from
the endpoints.
"""
return all_distinct(itertools.chain((begin,), iterable, (end,)))
def all_distinct(iterable):
"""
Returns true if none of the consecutive items in the iterable are the same.
"""
a, b = itertools.tee(iterable)
next(b, None)
return all(x != y for x, y in itertools.izip(a, b))
def range_around_float(x, i):
"""
Returns a pair (min, max) of floats such that the half-open interval [min,max) contains 2^i
representable floats, with x among them.
"""
# This is hard to explain (so easy for this to be wrong). m is in [0.5, 1), with 52 bits of
# precision (for 64-bit double-precision floats, as Python uses). We are trying to zero-out the
# last i bits of the precision. So we shift the mantissa left by (52-i) bits, round down
# (zeroing out remaining i bits), then shift back.
m, e = math.frexp(x)
mf = math.floor(math.ldexp(m, 53 - i))
exp = e + i - 53
return (math.ldexp(mf, exp), math.ldexp(mf + 1, exp))
def get_range(start, end, count):
"""
Returns an equally-distributed list of floats greater than start and less than end.
"""
step = float(end - start) / (count + 1)
# Ensure all resulting values are strictly less than end.
limit = prevfloat(end)
return [min(start + step * k, limit) for k in xrange(1, count + 1)]

@ -0,0 +1,122 @@
"""
A Relation represent mapping between rows, and used in determining which rows need to be
recomputed when something changes.
Relations can be determined by a foreign key or another form of lookup, and they may be composed.
For example, if Person.zip is the formula 'rec.school.address.zip', it involves three Relations:
ReferenceRelation between Person and School tables, another ReferenceRelation between School and
Address tables. Together, they form ComposedRelation relation between Person and Address tables.
"""
import depend
class Relation(object):
"""
Represents a row mapping between two tables. The arguments are table IDs (not actual tables).
"""
def __init__(self, referring_table, target_table):
self.referring_table = referring_table
self.target_table = target_table
# Maps the relation objects that we wrap to the resulting composed relations.
self._target_relations = {}
def get_affected_rows(self, input_rows):
"""
Given an iterable over input (dependency) rows, returns a `set` of output (dependent) rows.
"""
raise NotImplementedError()
def reset_all(self):
"""
Called when the dependency using this relation is reset, and this relation is no longer used.
"""
self.reset_rows(depend.ALL_ROWS)
def reset_rows(self, referring_rows):
"""
Call when starting to compute a formula to tell a Relation that it can start with a clean
slate for all row_ids in the passed-in iterable.
"""
pass
def compose(self, other_relation):
r = self._target_relations.get(other_relation)
if r is None:
r = self._target_relations[other_relation] = ComposedRelation(self, other_relation)
return r
class IdentityRelation(Relation):
"""
The trivial mapping, used to represent the relation between fields of the same record.
"""
def __init__(self, table_id):
super(IdentityRelation, self).__init__(table_id, table_id)
def get_affected_rows(self, input_rows):
return input_rows
def __str__(self):
return "Identity(%s)" % self.referring_table
# Important: we intentionally do not optimize compose() for an IdentityRelation, since
# (Identity + Rel) is not the same Rel when it comes to reset_rows() calls. [See test_lookups.py
# test_dependencies_relations_bug for a detailed description of a bug this can cause.]
class ComposedRelation(Relation):
"""
Represents a composition of two Relations. E.g. if referring side maps Students to Schools, and
target_side maps Schools to Addresses, then the composition maps Students to Addresses (so a
Student records depend on Address records, and changes to Address records affect Students).
"""
def __init__(self, referring_side, target_side):
assert referring_side.target_table == target_side.referring_table
super(ComposedRelation, self).__init__(referring_side.referring_table,
target_side.target_table)
self.source_relation = referring_side
self.target_relation = target_side
def get_affected_rows(self, input_rows):
return self.source_relation.get_affected_rows(
self.target_relation.get_affected_rows(input_rows))
def reset_rows(self, referring_rows):
# In the example from the doc-string, this says that certain Students are being recomputed, so
# no longer refer to any Schools. It doesn't say anything about Schools' dependence on
# Addresses, so there is nothing to reset in self.target_relation.
self.source_relation.reset_rows(referring_rows)
def __str__(self):
return "%s + %s" % (self.source_relation, self.target_relation)
class ReferenceRelation(Relation):
"""
Base class for Relations between records in two tables.
"""
def __init__(self, referring_table, target_table, ref_col_id):
super(ReferenceRelation, self).__init__(referring_table, target_table)
self.inverse_map = {} # maps target rows to sets of referring rows
self._ref_col_id = ref_col_id
def __str__(self):
return "ReferenceRelation(%s.%s)" % (self.referring_table, self._ref_col_id)
def get_affected_rows(self, input_rows):
# Each input row (target of the reference link) may be pointed to by multiple references,
# so we need to take the union of all of those sets.
affected_rows = set()
for target_row_id in input_rows:
affected_rows.update(self.inverse_map.get(target_row_id, ()))
return affected_rows
def add_reference(self, referring_row_id, target_row_id):
self.inverse_map.setdefault(target_row_id, set()).add(referring_row_id)
def remove_reference(self, referring_row_id, target_row_id):
self.inverse_map[target_row_id].remove(referring_row_id)
def clear(self):
self.inverse_map.clear()

@ -0,0 +1,87 @@
"""
This module implements an interpreter for a REPL. It subclasses Python's
code.InteractiveInterpreter class, implementing most of its methods, but with
slight changes in order to be convenient for Grist's purposes
"""
import code
import sys
from StringIO import StringIO
from collections import namedtuple
SUCCESS = 0
INCOMPLETE = 1
ERROR = 2
EvalTuple = namedtuple("EvalTuple", ("output", "error", "status"))
#pylint: disable=exec-used, bare-except
class REPLInterpreter(code.InteractiveInterpreter):
def __init__(self):
code.InteractiveInterpreter.__init__(self)
self.error_text = ""
def runsource(self, source, filename="<input>", symbol="single"):
"""
Compiles and executes source. Returns an EvalTuple with a status
INCOMPLETE if the code is incomplete,
ERROR if it encountered a compilation or run-time error,
SUCCESS otherwise.
an output, which gives all of the output of the user's program
(with stderr piped to stdout, essentially, though mock-file objects are used)
an error, which reports a syntax error at compilation or a runtime error with a
Traceback.
"""
old_stdout = sys.stdout
old_stderr = sys.stderr
user_output = StringIO()
self.error_text = ""
try:
code = self.compile(source, filename, symbol)
except (OverflowError, SyntaxError, ValueError):
self.showsyntaxerror(filename)
status = ERROR
else:
status = INCOMPLETE if code is None else SUCCESS
if status == SUCCESS:
try:
# We use temproray variables to access stdio/stdout
# to make sure the client can't do funky things
# like get/set attr and have that hurt us
sys.stdout = user_output
sys.stderr = user_output
exec code in self.locals
except:
# bare except to catch absolutely all things the user can throw
self.showtraceback()
status = ERROR
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
program_output = user_output.getvalue()
user_output.close()
return EvalTuple(program_output, self.error_text, status)
def write(self, txt):
"""
Used by showsyntaxerror and showtraceback
"""
self.error_text += txt
def runcode(self, code):
"""
This would normally do the part of runsource after compiling the code, but doesn't quite
make sense as its own function for our purposes because it couldn't support an INCOMPLETE
return value, etc. We explicitly hide it here to make sure the base class's version isn't
called by accident.
"""
raise NotImplementedError("REPLInterpreter.runcode not implemented, use runsource instead")

@ -0,0 +1,33 @@
"""
Helper to run Python unittests in the sandbox. They can be run directly as follows:
./sandbox/nacl/bin/run -E PYTHONPATH=/thirdparty python -m unittest discover -v -s /grist
This modules makes this a bit easier, and adds support for --xunit option, needed for running
tests under 'arc unit' and under Jenkins.
./sandbox/nacl/bin/run python /grist/runtests.py [--xunit]
"""
import os
import sys
import unittest
sys.path.append('/thirdparty')
def main():
# Change to the directory of this file (/grist in sandbox), to discover everything under it.
os.chdir(os.path.dirname(__file__))
argv = sys.argv[:]
test_runner = None
if "--xunit" in argv:
import xmlrunner
argv.remove("--xunit")
test_runner = xmlrunner.XMLTestRunner(stream=sys.stdout)
if all(arg.startswith("-") for arg in argv[1:]):
argv.insert(1, "discover")
unittest.main(module=None, argv=argv, testRunner=test_runner)
if __name__ == '__main__':
main()

@ -0,0 +1,100 @@
"""
Implements the python side of the data engine sandbox, which allows us to register functions on
the python side and call them from Node.js.
Usage:
import sandbox
sandbox.register(func_name, func)
sandbox.call_external("hello", 1, 2, 3)
sandbox.run()
"""
import os
import marshal
import signal
import sys
import traceback
def log(msg):
sys.stderr.write(str(msg) + "\n")
sys.stderr.flush()
class Sandbox(object):
"""
This class works in conjunction with Sandbox.js to allow function calls
between the Node process and this sandbox.
The sandbox provides two pipes (on fds 3 and 4) to send data to and from the sandboxed
process. Data on these is serialized using `marshal` module. All messages are comprised of a
msgCode followed immediatedly by msgBody, with the following msgCodes:
CALL = call to the other side. The data must be an array of [func_name, arguments...]
DATA = data must be a value to return to a call from the other side
EXC = data must be an exception to return to a call from the other side
"""
CALL = None
DATA = True
EXC = False
def __init__(self):
self._functions = {}
self._external_input = os.fdopen(3, "r", 64*1024)
self._external_output = os.fdopen(4, "w", 64*1024)
def _send_to_js(self, msgCode, msgBody):
# (Note that marshal version 2 is the default; we specify it explicitly for clarity. The
# difference with version 0 is that version 2 uses a faster binary format for floats.)
# For large data, JS's Unmarshaller is very inefficient parsing it if it gets it piecewise.
# It's much better to ensure the whole blob is sent as one write. We marshal the resulting
# buffer again so that the reader can quickly tell how many bytes to expect.
buf = marshal.dumps((msgCode, msgBody), 2)
marshal.dump(buf, self._external_output, 2)
self._external_output.flush()
def call_external(self, name, *args):
self._send_to_js(Sandbox.CALL, (name,) + args)
(msgCode, data) = self.run(break_on_response=True)
if msgCode == Sandbox.EXC:
raise Exception(data)
return data
def register(self, func_name, func):
self._functions[func_name] = func
def run(self, break_on_response=False):
while True:
try:
msgCode = marshal.load(self._external_input)
data = marshal.load(self._external_input)
except EOFError:
break
if msgCode != Sandbox.CALL:
if break_on_response:
return (msgCode, data)
continue
if not isinstance(data, list) or len(data) < 1:
raise ValueError("Bad call " + data)
try:
fname = data[0]
args = data[1:]
ret = self._functions[fname](*args)
self._send_to_js(Sandbox.DATA, ret)
except Exception as e:
traceback.print_exc()
self._send_to_js(Sandbox.EXC, "%s %s" % (type(e).__name__, e))
if break_on_response:
raise Exception("Sandbox disconnected unexpectedly")
sandbox = Sandbox()
def call_external(name, *args):
return sandbox.call_external(name, *args)
def register(func_name, func):
sandbox.register(func_name, func)
def run():
sandbox.run()

@ -0,0 +1,354 @@
"""
schema.py defines the schema of the tables describing Grist's own data structures. While users can
create tables, add and remove columns, etc, Grist stores various document metadata (about the
users' tables, views, etc.) also in tables.
Before changing this file, please review:
https://phab.getgrist.com/w/migrations/
"""
import itertools
from collections import OrderedDict, namedtuple
import actions
SCHEMA_VERSION = 20
def make_column(col_id, col_type, formula='', isFormula=False):
return {
"id": col_id,
"type": col_type,
"isFormula": isFormula,
"formula": formula
}
def schema_create_actions():
return [
# The document-wide metadata. It's all contained in a single record with id=1.
actions.AddTable("_grist_DocInfo", [
make_column("docId", "Text"), # DEPRECATED: docId is now stored in _gristsys_FileInfo
make_column("peers", "Text"), # DEPRECATED: now _grist_ACLPrincipals is used for this
# Basket id of the document for online storage, if a Basket has been created for it.
make_column("basketId", "Text"),
# Version number of the document. It tells us how to migrate it to reach SCHEMA_VERSION.
make_column("schemaVersion", "Int"),
# Document timezone.
make_column("timezone", "Text"),
]),
# The names of the user tables. This does NOT include built-in tables.
actions.AddTable("_grist_Tables", [
make_column("tableId", "Text"),
make_column("primaryViewId","Ref:_grist_Views"),
# For a summary table, this points to the corresponding source table.
make_column("summarySourceTable", "Ref:_grist_Tables"),
# A table may be marked as "onDemand", which will keep its data out of the data engine, and
# only available to the frontend when requested.
make_column("onDemand", "Bool")
]),
# All columns in all user tables.
actions.AddTable("_grist_Tables_column", [
make_column("parentId", "Ref:_grist_Tables"),
make_column("parentPos", "PositionNumber"),
make_column("colId", "Text"),
make_column("type", "Text"),
make_column("widgetOptions","Text"), # JSON extending column's widgetOptions
make_column("isFormula", "Bool"),
make_column("formula", "Text"),
make_column("label", "Text"),
# Normally a change to label changes colId as well, unless untieColIdFromLabel is True.
# (We intentionally pick a variable whose default value is false.)
make_column("untieColIdFromLabel", "Bool"),
# For a group-by column in a summary table, this points to the corresponding source column.
make_column("summarySourceCol", "Ref:_grist_Tables_column"),
# Points to a display column, if it exists, for this column.
make_column("displayCol", "Ref:_grist_Tables_column"),
# For Ref cols only, points to the column in the pointed-to table, which is to be displayed.
# E.g. Foo.person may have a visibleCol pointing to People.Name, with the displayCol
# pointing to Foo._gristHelper_DisplayX column with the formula "$person.Name".
make_column("visibleCol", "Ref:_grist_Tables_column"),
]),
# DEPRECATED: Previously used to keep import options, and allow the user to change them.
actions.AddTable("_grist_Imports", [
make_column("tableRef", "Ref:_grist_Tables"),
make_column("origFileName", "Text"),
make_column("parseFormula", "Text", isFormula=True,
formula="grist.parseImport(rec, table._engine)"),
# The following translate directly to csv module options. We can use csv.Sniffer to guess
# them based on a sample of the data (it also guesses hasHeaders option).
make_column("delimiter", "Text", formula="','"),
make_column("doublequote", "Bool", formula="True"),
make_column("escapechar", "Text"),
make_column("quotechar", "Text", formula="'\"'"),
make_column("skipinitialspace", "Bool"),
# Other parameters Grist understands.
make_column("encoding", "Text", formula="'utf8'"),
make_column("hasHeaders", "Bool"),
]),
# DEPRECATED: Previously - All external database credentials attached to the document
actions.AddTable("_grist_External_database", [
make_column("host", "Text"),
make_column("port", "Int"),
make_column("username", "Text"),
make_column("dialect", "Text"),
make_column("database", "Text"),
make_column("storage", "Text"),
]),
# DEPRECATED: Previously - Reference to a table from an external database
actions.AddTable("_grist_External_table", [
make_column("tableRef", "Ref:_grist_Tables"),
make_column("databaseRef", "Ref:_grist_External_database"),
make_column("tableName", "Text"),
]),
# Document tabs that represent a cross-reference between Tables and Views
actions.AddTable("_grist_TableViews", [
make_column("tableRef", "Ref:_grist_Tables"),
make_column("viewRef", "Ref:_grist_Views"),
]),
# DEPRECATED: Previously used to cross-reference between Tables and Views
actions.AddTable("_grist_TabItems", [
make_column("tableRef", "Ref:_grist_Tables"),
make_column("viewRef", "Ref:_grist_Views"),
]),
actions.AddTable("_grist_TabBar", [
make_column("viewRef", "Ref:_grist_Views"),
make_column("tabPos", "PositionNumber"),
]),
# Table for storing the tree of pages. 'pagePos' and 'indentation' columns gives how a page is
# shown in the panel: 'pagePos' determines the page overall position when no pages are collapsed
# (ie: all pages are visible) and 'indentation' gives the level of nesting (depth). Note that
# the parent-child relationships between pages have to be inferred from the variation of
# `indentation` between consecutive pages. For instance a difference of +1 between two
# consecutive pages means that the second page is the child of the first page. A difference of 0
# means that both are siblings and a difference of -1 means that the second page is a sibling to
# the first page parent.
actions.AddTable("_grist_Pages", [
make_column("viewRef", "Ref:_grist_Views"),
make_column("indentation", "Int"),
make_column("pagePos", "PositionNumber"),
]),
# All user views.
actions.AddTable("_grist_Views", [
make_column("name", "Text"),
make_column("type", "Text"), # TODO: Should this be removed?
make_column("layoutSpec", "Text"), # JSON string describing the view layout
]),
# The sections of user views (e.g. a view may contain a list section and a detail section).
# Different sections may need different parameters, so this table includes columns for all
# possible parameters, and any given section will use some subset, depending on its type.
actions.AddTable("_grist_Views_section", [
make_column("tableRef", "Ref:_grist_Tables"),
make_column("parentId", "Ref:_grist_Views"),
# parentKey is the type of view section, such as 'list', 'detail', or 'single'.
# TODO: rename this (e.g. to "sectionType").
make_column("parentKey", "Text"),
make_column("title", "Text"),
make_column("defaultWidth", "Int", formula="100"),
make_column("borderWidth", "Int", formula="1"),
make_column("theme", "Text"),
make_column("options", "Text"),
make_column("chartType", "Text"),
make_column("layoutSpec", "Text"), # JSON string describing the record layout
# filterSpec is deprecated as of version 15. Do not remove or reuse.
make_column("filterSpec", "Text"),
make_column("sortColRefs", "Text"),
make_column("linkSrcSectionRef", "Ref:_grist_Views_section"),
make_column("linkSrcColRef", "Ref:_grist_Tables_column"),
make_column("linkTargetColRef", "Ref:_grist_Tables_column"),
# embedId is deprecated as of version 12. Do not remove or reuse.
make_column("embedId", "Text"),
]),
# The fields of a view section.
actions.AddTable("_grist_Views_section_field", [
make_column("parentId", "Ref:_grist_Views_section"),
make_column("parentPos", "PositionNumber"),
make_column("colRef", "Ref:_grist_Tables_column"),
make_column("width", "Int"),
make_column("widgetOptions","Text"), # JSON extending field's widgetOptions
# Points to a display column, if it exists, for this field.
make_column("displayCol", "Ref:_grist_Tables_column"),
# For Ref cols only, may override the column to be displayed fromin the pointed-to table.
make_column("visibleCol", "Ref:_grist_Tables_column"),
# JSON string describing the default filter as map from either an `included` or an
# `excluded` string to an array of column values:
# Ex1: { included: ['foo', 'bar'] }
# Ex2: { excluded: ['apple', 'orange'] }
make_column("filter", "Text")
]),
# The code for all of the validation rules available to a Grist document
actions.AddTable("_grist_Validations", [
make_column("formula", "Text"),
make_column("name", "Text"),
make_column("tableRef", "Int")
]),
# The input code and output text and compilation/runtime errors for usercode
actions.AddTable("_grist_REPL_Hist", [
make_column("code", "Text"),
make_column("outputText", "Text"),
make_column("errorText", "Text")
]),
# All of the attachments attached to this document.
actions.AddTable("_grist_Attachments", [
make_column("fileIdent", "Text"), # Checksum of the file contents. It identifies the file
# data in the _gristsys_Files table.
make_column("fileName", "Text"), # User defined file name
make_column("fileType", "Text"), # A string indicating the MIME type of the data
make_column("fileSize", "Int"), # The size in bytes
make_column("imageHeight", "Int"), # height in pixels
make_column("imageWidth", "Int"), # width in pixels
make_column("timeUploaded", "DateTime")
]),
# All of the ACL rules.
actions.AddTable('_grist_ACLRules', [
make_column('resource', 'Ref:_grist_ACLResources'),
make_column('permissions', 'Int'), # Bit-map of permission types. See acl.py.
make_column('principals', 'Text'), # JSON array of _grist_ACLPrincipals refs.
make_column('aclFormula', 'Text'), # Formula to apply to tableId, which should return
# additional principals for each row.
make_column('aclColumn', 'Ref:_grist_Tables_column')
]),
actions.AddTable('_grist_ACLResources', [
make_column('tableId', 'Text'), # Name of the table this rule applies to, or ''
make_column('colIds', 'Text'), # Comma-separated list of colIds, or ''
]),
# All of the principals used by ACL rules, including users, groups, and instances.
actions.AddTable('_grist_ACLPrincipals', [
make_column('type', 'Text'), # 'user', 'group', or 'instance'
make_column('userEmail', 'Text'), # For 'user' principals
make_column('userName', 'Text'), # For 'user' principals
make_column('groupName', 'Text'), # For 'group' principals
make_column('instanceId', 'Text'), # For 'instance' principals
# docmodel.py defines further `name` and `allInstances`, and members intended as helpers
# only: `memberships`, `children`, and `descendants`.
]),
# Table for containment relationships between Principals, e.g. user contains multiple
# instances, group contains multiple users, and groups may contain other groups.
actions.AddTable('_grist_ACLMemberships', [
make_column('parent', 'Ref:_grist_ACLPrincipals'),
make_column('child', 'Ref:_grist_ACLPrincipals'),
]),
# TODO:
# The Data Engine should not load up the action log or be able to modify it, or know anything
# about it. It's bad if users could hack up data engine logic to mess with history. (E.g. if
# share a doc for editing, and peer tries to hack it, want to know that can revert; i.e. peer
# shouldn't be able to destroy history.) Also, the action log could be big. It's nice to keep
# it in sqlite and not take up memory.
#
# For this reason, JS code perhaps should be the one creating action tables for a new
# document. It should also ignore any actions that attempt to change such tables. I.e. it
# should have some protected tables, perhaps with a different prefix (_gristsys_), which can't
# be changed by actions generated from the data engine.
#
# TODO
# Conversion of schema actions to metadata-change actions perhaps should also be done by JS,
# and metadata tables should be protected (i.e. can't be changed by user). Hmm....
# # The actions that fully determine the history of this database.
# actions.AddTable("_grist_Action", [
# make_column("num", "Int"), # Action-group number
# make_column("time", "Int"), # Milliseconds since Epoch
# make_column("user", "Text"), # User performing this action
# make_column("desc", "Text"), # Action description
# make_column("otherId", "Int"), # For Undo and Redo, id of the other action
# make_column("linkId", "Int"), # Id of the prev action in the same bundle
# make_column("json", "Text"), # JSON representation of the action
# ]),
# # A logical action is comprised potentially of multiple steps.
# actions.AddTable("_grist_Action_step", [
# make_column("parentId", "Ref:_grist_Action"),
# make_column("type", "Text"), # E.g. "undo", "stored"
# make_column("name", "Text"), # E.g. "AddRecord" or "RenameTable"
# make_column("tableId", "Text"), # Name of the table
# make_column("colIds", "Text"), # Comma-separated names of affected columns
# make_column("rowIds", "Text"), # Comma-separated IDs of affected rows
# make_column("values", "Text"), # All values for the affected rows and columns,
# # bundled together, column-wise, as a JSON array.
# ]),
]
# These are little structs to represent the document schema that's used in code generation.
# Schema itself (as stored by Engine) is an OrderedDict(tableId -> SchemaTable), with
# SchemaTable.columns being an OrderedDict(colId -> SchemaColumn).
SchemaTable = namedtuple('SchemaTable', ('tableId', 'columns'))
SchemaColumn = namedtuple('SchemaColumn', ('colId', 'type', 'isFormula', 'formula'))
# Helpers to convert between schema structures and dicts used in schema actions.
def dict_to_col(col, col_id=None):
"""Convert dict as used in AddColumn/AddTable actions to a SchemaColumn object."""
return SchemaColumn(col_id or col["id"], col["type"], bool(col["isFormula"]), col["formula"])
def col_to_dict(col, include_id=True):
"""Convert SchemaColumn to dict to use in AddColumn/AddTable actions."""
ret = {"type": col.type, "isFormula": col.isFormula, "formula": col.formula}
if include_id:
ret["id"] = col.colId
return ret
def dict_list_to_cols(dict_list):
"""Convert list of column dicts to an OrderedDict of SchemaColumns."""
return OrderedDict((c["id"], dict_to_col(c)) for c in dict_list)
def cols_to_dict_list(cols):
"""Convert OrderedDict of SchemaColumns to an array of column dicts."""
return [col_to_dict(c) for c in cols.values()]
def clone_schema(schema):
return OrderedDict((t, SchemaTable(s.tableId, s.columns.copy()))
for (t, s) in schema.iteritems())
def build_schema(meta_tables, meta_columns, include_builtin=True):
"""
Arguments are TableData objects for the _grist_Tables and _grist_Tables_column tables.
Returns the schema object for engine.py, used in particular in gencode.py.
"""
assert meta_tables.table_id == '_grist_Tables'
assert meta_columns.table_id == '_grist_Tables_column'
# Schema is an OrderedDict.
schema = OrderedDict()
if include_builtin:
for t in schema_create_actions():
schema[t.table_id] = SchemaTable(t.table_id, dict_list_to_cols(t.columns))
# Construct a list of columns sorted by table and position.
collist = sorted(actions.transpose_bulk_action(meta_columns),
key=lambda c: (c.parentId, c.parentPos))
coldict = {t: list(cols) for t, cols in itertools.groupby(collist, lambda r: r.parentId)}
for t in actions.transpose_bulk_action(meta_tables):
columns = OrderedDict((c.colId, SchemaColumn(c.colId, c.type, c.isFormula, c.formula))
for c in coldict[t.id])
schema[t.tableId] = SchemaTable(t.tableId, columns)
return schema

@ -0,0 +1,319 @@
from collections import namedtuple
import json
import re
import logger
log = logger.Logger(__name__, logger.INFO)
ColInfo = namedtuple('ColInfo', ('colId', 'type', 'isFormula', 'formula',
'widgetOptions', 'label'))
def _make_col_info(col=None, **values):
"""Return a ColInfo() with the given fields, optionally copying values from the given column."""
for key in ColInfo._fields:
values.setdefault(key, getattr(col, key) if col else None)
return ColInfo(**values)
def _get_colinfo_dict(col_info, with_id=False):
"""Return a dict suitable to use with AddColumn or AddTable (when with_id=True) actions."""
col_values = {k: v for k, v in col_info._asdict().iteritems() if v is not None and k != 'colId'}
if with_id:
col_values['id'] = col_info.colId
return col_values
# To generate code, we need to know for each summary table, what its source table is. It would be
# easy if we had access to metadata records, but (at least for now) we generate all code based on
# schema only. So we encode the source table name inside of the summary table name.
#
# The encoding includes the length of the source table name, to avoid the possibility of ambiguity
# between the second summary table for "Foo", and the first summary table for "Foo2".
#
# Note that it means we need to rename summary tables when the source table is renamed.
def encode_summary_table_name(source_table_name):
"""
Create a summary table name that reliably encodes the source table name. It can be decoded even
if a suffix is added to the returned name.
"""
return "GristSummary_%d_%s" % (len(source_table_name), source_table_name)
_summary_re = re.compile(r'GristSummary_(\d+)_')
def decode_summary_table_name(summary_table_name):
"""
Extract the name of the source table from the summary table name.
"""
m = _summary_re.match(summary_table_name)
if m:
start = m.end(0)
length = int(m.group(1))
source_name = summary_table_name[start : start + length]
if len(source_name) == length:
return source_name
return None
def _group_colinfo(source_table):
"""Returns ColInfo() for the 'group' column that must be present in every summary table."""
return _make_col_info(colId='group', type='RefList:%s' % source_table.tableId,
isFormula=True, formula='table.getSummarySourceGroup(rec)')
def _update_sort_spec(sort_spec, old_table, new_table):
"""
Replace column references in the sort spec (which is a JSON string encoding a list of column
refs, negated for descending) with references to the new table. Returns the new JSON string,
or empty string in case of a problem.
"""
old_cols_map = {c.id: c.colId for c in old_table.columns}
new_cols_map = {c.colId: c.id for c in new_table.columns}
# When adjusting, we take a possibly negated old colRef, and produce a new colRef.
# If anything is gone, we return 0, which will be excluded from the new sort spec.
def adjust(col_spec):
sign = 1 if col_spec >= 0 else -1
return sign * new_cols_map.get(old_cols_map.get(abs(col_spec)), 0)
try:
old_sort_spec = json.loads(sort_spec)
new_sort_spec = filter(None, [adjust(col_spec) for col_spec in old_sort_spec])
return json.dumps(new_sort_spec, separators=(',', ':'))
except Exception, e:
log.warn("update_summary_section: can't parse sortColRefs JSON; clearing sortColRefs")
return ''
class SummaryActions(object):
def __init__(self, useractions, docmodel):
self.useractions = useractions
self.docmodel = docmodel
def _get_or_add_columns(self, table, all_colinfo):
"""
Given a table record and a list of ColInfo objects, generates a list of corresponding column
records in the table, creating appropriate columns if they don't yet exist.
"""
prior = {c.colId: c for c in table.columns}
for ci in all_colinfo:
col = prior.get(ci.colId)
if col and col.type == ci.type and col.formula == ci.formula:
yield col
else:
result = self.useractions.doAddColumn(table.tableId, ci.colId,
_get_colinfo_dict(ci, with_id=False))
yield self.docmodel.columns.table.get_record(result['colRef'])
def _get_or_create_summary(self, source_table, source_groupby_columns, formula_colinfo):
"""
Finds a summary table or creates a new one, based on source_table, grouped by the columns
in groupby_colinfo, and containing formulas in formula_colinfo. Source_table should be a
Record from _grist_Tables, and other arguments should be lists of ColInfo objects.
Returns the tuple (summary_table, groupby_columns, formula_columns).
"""
key = tuple(sorted(int(c) for c in source_groupby_columns))
groupby_colinfo = [_make_col_info(col=c, isFormula=False, formula='')
for c in source_groupby_columns]
summary_table = next((t for t in source_table.summaryTables if t.summaryKey == key), None)
created = False
if not summary_table:
result = self.useractions.doAddTable(
encode_summary_table_name(source_table.tableId),
[_get_colinfo_dict(ci, with_id=True) for ci in groupby_colinfo + formula_colinfo],
summarySourceTableRef=source_table.id)
summary_table = self.docmodel.tables.table.get_record(result['id'])
created = True
# Note that in this case, _get_or_add_columns() below should not add any new columns,
# but only return existing ones. (The table may contain extra columns, e.g. 'manualSort',
# at least in theory.)
groupby_columns = list(self._get_or_add_columns(summary_table, groupby_colinfo))
formula_columns = list(self._get_or_add_columns(summary_table, formula_colinfo))
if created:
# Set the summarySourceCol field for all the group-by columns in the table.
self.docmodel.update(groupby_columns,
summarySourceCol=[c.id for c in source_groupby_columns])
assert summary_table.summaryKey == key
return (summary_table, groupby_columns, formula_columns)
def update_summary_section(self, view_section, source_table, source_groupby_columns):
source_groupby_colset = set(source_groupby_columns)
groupby_colids = {c.colId for c in source_groupby_columns}
prev_fields = list(view_section.fields)
# Go through fields figuring out which ones we'll keep.
prev_group_fields, formula_fields, delete_fields = [], [], []
for field in prev_fields:
# Records implement __hash__, so we can look them up in sets.
if field.colRef.summarySourceCol in source_groupby_colset:
prev_group_fields.append(field)
elif field.colRef.isFormula and field.colRef.colId not in groupby_colids:
formula_fields.append(field)
else:
delete_fields.append(field)
# Prepare ColInfo for all columns we want to keep.
formula_colinfo = [_make_col_info(f.colRef) for f in formula_fields]
have_group_col = any(f.colRef.colId == 'group' for f in formula_fields)
if not have_group_col:
formula_colinfo.append(_group_colinfo(source_table))
# Get column records for all the columns we should have in our section.
summary_table, groupby_columns, formula_columns = self._get_or_create_summary(
source_table, source_groupby_columns, formula_colinfo)
if not have_group_col:
# We've added the "group" column; now restore the lists to match what we want in fields.
formula_colinfo.pop()
formula_columns.pop()
# Remember the original table, which we need later to adjust the sort spec (sortColRefs).
orig_table = view_section.tableRef
# This line is a bit hard to explain: we unset viewSection.tableRef before updating all the
# fields, and then set it to the correct value. Note how undo will reverse the operations, and
# produce the same sequence (unset, update fields, set). Client-side code relies on this to
# avoid having to deal with inconsistent view sections while fields are being updated.
self.docmodel.update([view_section], tableRef=0)
# Delete fields no longer relevant.
self.docmodel.remove(delete_fields)
# Update fields for all formula fields and reused group-by fields to point to new columns.
source_col_map = dict(zip(source_groupby_columns, groupby_columns))
prev_group_columns = [source_col_map[f.colRef.summarySourceCol] for f in prev_group_fields]
self.docmodel.update(formula_fields + prev_group_fields,
colRef=[c.id for c in formula_columns + prev_group_columns])
# Finally, we need to create fields for newly-added group-by columns. If there were missing
# fields for any group-by columns before, they'll be created now.
new_group_columns = [c for c in groupby_columns if c not in prev_group_columns]
# Insert these after the last existing group-by field.
insert_pos = prev_group_fields[-1].parentPos if prev_group_fields else None
new_group_fields = self.docmodel.insert_after(view_section.fields, insert_pos,
colRef=[c.id for c in new_group_columns])
# Reorder the group-by fields if needed, to match the order requested.
group_col_to_field = {f.colRef: f for f in prev_group_fields + new_group_fields}
group_fields = [group_col_to_field[c] for c in groupby_columns]
group_positions = [field.parentPos for field in group_fields]
sorted_positions = sorted(group_positions)
if sorted_positions != group_positions:
self.docmodel.update(group_fields, parentPos=sorted_positions)
update_args = {}
if view_section.sortColRefs:
# Fix the sortSpec to refer to the new columns.
update_args['sortColRefs'] = _update_sort_spec(
view_section.sortColRefs, orig_table, summary_table)
# Finally update the section to point to the new table.
self.docmodel.update([view_section], tableRef=summary_table.id, **update_args)
def _find_sister_column(self, source_table, col_id):
"""Returns a summary formula column for source_table with the given col_id, or None."""
for t in source_table.summaryTables:
c = self.docmodel.columns.lookupOne(parentId=t.id, colId=col_id, isFormula=True)
if c:
return c
return None
def _create_summary_colinfo(self, source_table, source_groupby_columns):
"""Come up automatically with a list of columns to include into a summary table."""
# Column 'group' defines the group of records that map to this summary line.
all_colinfo = [_group_colinfo(source_table)]
# For every column in the source data, if there is a same-named formula column in another
# summary table, use it here; otherwise if it's a numerical column, automatically add a
# same-named column with the sum of the values in the group.
groupby_col_ids = {c.colId for c in source_groupby_columns}
for col in source_table.columns:
if col.colId in groupby_col_ids or col.colId == 'group':
continue
c = self._find_sister_column(source_table, col.colId)
if c:
all_colinfo.append(_make_col_info(col=c))
elif col.type in ('Int', 'Numeric'):
all_colinfo.append(_make_col_info(col=col, isFormula=True,
formula='SUM($group.%s)' % col.colId))
# Add a default 'count' column for the number of records in the group, unless a different
# 'count' was already added (which we would then prefer as presumably more useful). We add the
# default 'count' right after 'group', to make it the first of the visible formula columns.
if not any(c.colId == 'count' for c in all_colinfo):
all_colinfo.insert(1, _make_col_info(colId='count', type='Int',
isFormula=True, formula='len($group)'))
return all_colinfo
def create_new_summary_section(self, source_table, source_groupby_columns, view, section_type):
formula_colinfo = list(self._create_summary_colinfo(source_table, source_groupby_columns))
summary_table, groupby_columns, formula_columns = self._get_or_create_summary(
source_table, source_groupby_columns, formula_colinfo)
section = self.docmodel.add(view.viewSections, tableRef=summary_table.id,
parentKey=section_type)[0]
self.docmodel.add(section.fields,
colRef=[c.id for c in groupby_columns + formula_columns
if c.colId != "group"])
return section
def detach_summary_section(self, view_section):
"""
Create a real table equivalent to the given summary section, and update the section to show
the new table instead of the summary.
"""
source_table_id = view_section.tableRef.summarySourceTable.tableId
# Get a list of columns that we need for the new table.
fields = view_section.fields
field_col_recs = [f.colRef for f in fields]
# Prepare the column info for each column.
col_info = [_make_col_info(col=c) for c in field_col_recs if c.colId != 'group']
# Prepare the 'group' column, which is that one column that's different from the original.
group_args = ', '.join('%s=$%s' % (c.summarySourceCol.colId, c.colId)
for c in field_col_recs if c.summarySourceCol)
col_info.append(_make_col_info(colId='group', type='RefList:%s' % source_table_id,
isFormula=True,
formula='%s.lookupRecords(%s)' % (source_table_id, group_args)))
# Create the new table.
res = self.useractions.AddTable(None, [_get_colinfo_dict(ci, with_id=True) for ci in col_info])
new_table = self.docmodel.tables.table.get_record(res["id"])
# Remember the original table, which we need later e.g. to adjust the sort spec (sortColRefs).
orig_table = view_section.tableRef
# Populate the new table.
old_data = self.useractions._engine.fetch_table(orig_table.tableId, formulas=False)
self.useractions.ReplaceTableData(new_table.tableId, old_data.row_ids, old_data.columns)
# Unset viewSection.tableRef before updating the fields, to avoid having inconsistencies. (See
# longer explanation in update_summary_section().)
self.docmodel.update([view_section], tableRef=0)
# Update all fields to point to new columns.
new_col_dict = {c.colId: c.id for c in new_table.columns}
self.docmodel.update(fields, colRef=[new_col_dict[c.colId] for c in field_col_recs])
# If the section is sorted, fix the sortSpec to refer to the new columns.
update_args = {}
if view_section.sortColRefs:
update_args['sortColRefs'] = _update_sort_spec(
view_section.sortColRefs, orig_table, new_table)
# Update the section to point to the new table.
self.docmodel.update([view_section], tableRef=new_table.id, **update_args)

@ -0,0 +1,482 @@
import collections
import types
import column
import depend
import docmodel
import lookup
import records
import relation as relation_module # "relation" is used too much as a variable name below.
import usertypes
import logger
log = logger.Logger(__name__, logger.INFO)
class ColumnView(object):
"""
ColumnView is an iterable that represents one column of a RecordSet. You may iterate through
its values and see its size, but it provides no other interface.
"""
def __init__(self, column_obj, row_ids, relation):
self._column = column_obj
self._row_ids = row_ids
self._source_relation = relation
def __len__(self):
return len(self._row_ids)
def __iter__(self):
for row_id in self._row_ids:
yield _adjust_record(self._source_relation, self._column.get_cell_value(row_id))
def _adjust_record(relation, value):
"""
Helper to adjust a Record's source relation to be the composition with the given relation. This
is used to wrap values like `foo.bar`: if `bar` is a Record, then its source relation should be
the composition of the source relation of `foo` and the relation associated with `bar`.
"""
if isinstance(value, (records.Record, records.RecordSet)):
return value._clone_with_relation(relation)
return value
def _make_sample_record(table_id, col_objs):
"""
Helper to create a sample record for a table, used for auto-completions.
"""
# This type gets created with a property for each column. We use property-methods rather than
# plain properties because this sample record is created before all tables have initialized, so
# reference values (using .sample_record for other tables) are not yet available.
RecType = type(table_id, (), {
# Note col=col to bind col at lambda-creation time; see
# https://stackoverflow.com/questions/10452770/python-lambdas-binding-to-local-values
col.col_id: property(lambda self, col=col: col.sample_value())
for col in col_objs
if column.is_user_column(col.col_id) or col.col_id == 'id'
})
return RecType()
def get_default_func_name(col_id):
return "_default_" + col_id
def get_validation_func_name(index):
return "validation___%d" % index
class UserTable(object):
"""
Each data table in the document is represented in the code by an instance of `UserTable` class.
These names are always capitalized. A UserTable provides access to all the records in the table,
as well as methods to look up particular records.
Every table in the document is available to all formulas.
"""
# UserTables are only created in auto-generated code by using UserTable as decorator for a table
# model class. I.e.
#
# @grist.UserTable
# class Students:
# ...
#
# makes the "Students" identifier an actual UserTable instance, so that Students.lookupRecords
# and so on can be used.
def __init__(self, model_class):
docmodel.enhance_model(model_class)
self.Model = model_class
column_ids = {col for col in model_class.__dict__ if not col.startswith("_")}
column_ids.add('id')
self.Record = type('Record', (records.Record,), {})
self.RecordSet = type('RecordSet', (records.RecordSet,), {})
self.RecordSet.Record = self.Record
self.table = None
def _set_table_impl(self, table_impl):
self.table = table_impl
# Note these methods are named camelCase since they are a public interface exposed to formulas,
# and we decided camelCase was a more user-friendly choice for user-facing functions.
def lookupRecords(self, **field_value_pairs):
"""
Returns the Records from this table that match the given field=value arguments. If
`sort_by=field` is given, sort the results by that field.
For example:
```
People.lookupRecords(Last_Name="Johnson", sort_by="First_Name")
People.lookupRecords(First_Name="George", Last_Name="Washington")
```
See [RecordSet](#recordset) for useful properties offered by the returned object.
"""
return self.table.lookup_records(**field_value_pairs)
def lookupOne(self, **field_value_pairs):
"""
Returns a Record matching the given field=value arguments. If multiple records match, returns
one of them. If none match, returns the special empty record.
For example:
```
People.lookupOne(First_Name="Lewis", Last_Name="Carroll")
```
"""
return self.table.lookup_one_record(**field_value_pairs)
def lookupOrAddDerived(self, **kwargs):
return self.table.lookupOrAddDerived(**kwargs)
def getSummarySourceGroup(self, rec):
return self.table.getSummarySourceGroup(rec)
@property
def all(self):
"""
Name: all
Usage: UserTable.__all__
The list of all the records in this table.
For example, this evaluates to the number of records in the table `Students`.
```
len(Students.all)
```
This evaluates to the sum of the `Population` field for every record in the table `Countries`.
```
sum(r.Population for r in Countries.all)
```
"""
return self.lookupRecords()
def __dir__(self):
# Suppress member properties when listing dir(TableClass). This affects rlcompleter, with the
# result that auto-complete will only return class properties, not member properties added in
# the constructor.
return []
class Table(object):
"""
Table represents a table with all its columns and data.
"""
class RowIDs(object):
"""
Helper container that represents the set of valid row IDs in this table.
"""
def __init__(self, id_column):
self._id_column = id_column
def __contains__(self, row_id):
return row_id < self._id_column.size() and self._id_column.raw_get(row_id) > 0
def __iter__(self):
for row_id in xrange(self._id_column.size()):
if self._id_column.raw_get(row_id) > 0:
yield row_id
def max(self):
last = self._id_column.size() - 1
while last > 0 and last not in self:
last -= 1
return last
def __init__(self, table_id, engine):
# The id of the table is the name of its class.
self.table_id = table_id
# Each table maintains a reference to the engine that owns it.
self._engine = engine
# The UserTable object for this table, set in _rebuild_model
self.user_table = None
# Store the identity Relation for this table.
self._identity_relation = relation_module.IdentityRelation(table_id)
# Set of ReferenceColumn objects that refer to this table
self._back_references = set()
# Store the constant Node for "new columns". Accessing invalid columns creates a dependency
# on this node, and triggers recomputation when columns are added or renamed.
self._new_columns_node = depend.Node(self.table_id, None)
# Collection of special columns that this table maintains, which include LookupMapColumns
# and formula columns for maintaining summary tables. These persist across table rebuilds, and
# get cleaned up with delete_column().
self._special_cols = {}
# Maintain Column objects both as a mapping from col_id and as an ordered list.
self.all_columns = collections.OrderedDict()
# This column is always present.
self._id_column = column.create_column(self, 'id', column.get_col_info(usertypes.Id()))
# The `row_ids` member offers some useful interfaces:
# * if row_id in table.row_ids
# * for row_id in table.row_ids
self.row_ids = self.RowIDs(self._id_column)
# For a summary table, this is a reference to the Table object for the source table.
self._summary_source_table = None
# For a summary table, the name of the special helper column auto-added to the source table.
self._summary_helper_col_id = None
def _rebuild_model(self, user_table):
"""
Sets class-wide properties from a new Model class for the table (inner class within the table
class), and rebuilds self.all_columns from the new Model, reusing columns with existing names.
"""
self.user_table = user_table
self.Model = user_table.Model
self.Record = user_table.Record
self.RecordSet = user_table.RecordSet
new_cols = collections.OrderedDict()
new_cols['id'] = self._id_column
# List of Columns in the same order as they appear in the generated Model definition.
col_items = [c for c in self.Model.__dict__.iteritems() if not c[0].startswith("_")]
col_items.sort(key=lambda c: self._get_sort_order(c[1]))
for col_id, col_model in col_items:
default_func = self.Model.__dict__.get(get_default_func_name(col_id))
new_cols[col_id] = self._create_or_update_col(col_id, col_model, default_func)
# Used for auto-completion as a record with correct properties of correct types.
self.sample_record = _make_sample_record(self.table_id, new_cols.itervalues())
# Note that we reuse previous special columns like lookup maps, since those not affected by
# column changes should stay the same. These get removed when unneeded using other means.
new_cols.update(sorted(self._special_cols.iteritems()))
# Set the new columns.
self.all_columns = new_cols
# Make sure any new columns get resized to the full table size.
self.grow_to_max()
# If this is a summary table, auto-create a necessary helper formula in the source table.
summary_src = getattr(self.Model, '_summarySourceTable', None)
if summary_src not in self._engine.tables:
self._summary_source_table = None
self._summary_helper_col_id = None
else:
self._summary_source_table = self._engine.tables[summary_src]
self._summary_helper_col_id = "#summary#%s" % self.table_id
# Figure out the group-by columns: these are all the non-formula columns.
groupby_cols = tuple(sorted(col_id for (col_id, col_model) in col_items
if not isinstance(col_model, types.FunctionType)))
# Add the special helper column to the source table.
self._summary_source_table._add_update_summary_col(self, groupby_cols)
def _add_update_summary_col(self, summary_table, groupby_cols):
# TODO: things need to be removed also from summary_cols when a summary table is deleted.
@usertypes.formulaType(usertypes.Reference(summary_table.table_id))
def _updateSummary(rec, table): # pylint: disable=unused-argument
return summary_table.lookupOrAddDerived(**{c: getattr(rec, c) for c in groupby_cols})
_updateSummary.is_private = True
col_id = summary_table._summary_helper_col_id
col_obj = self._create_or_update_col(col_id, _updateSummary)
self._special_cols[col_id] = col_obj
self.all_columns[col_id] = col_obj
def get_helper_columns(self):
"""
Returns a list of columns from other tables that are only needed for the sake of this table.
"""
if self._summary_source_table and self._summary_helper_col_id:
helper_col = self._summary_source_table.get_column(self._summary_helper_col_id)
return [helper_col]
return []
def _create_or_update_col(self, col_id, col_model, default_func=None):
"""
Helper to update an existing column with a new model, or create a new column object.
"""
col_info = column.get_col_info(col_model, default_func)
col_obj = self.all_columns.get(col_id)
if col_obj:
# This is important for when a column has NOT changed, since although the formula method is
# unchanged, it's important to use the new instance of it from the newly built module.
col_obj.update_method(col_info.method)
else:
col_obj = column.create_column(self, col_id, col_info)
self._engine.invalidate_column(col_obj)
return col_obj
@staticmethod
def _get_sort_order(col_model):
"""
We sort columns according to the order in which they appear in the model definition. To
detect this order, we sort data columns by _creation_order, and formula columns by the
function's source-code line number.
"""
return ((0, col_model._creation_order)
if not isinstance(col_model, types.FunctionType) else
(1, col_model.func_code.co_firstlineno))
def next_row_id(self):
"""
Returns the ID of the next row that can be added to this table.
"""
return self.row_ids.max() + 1
def grow_to_max(self):
"""
Resizes all columns as needed so that all valid row_ids are valid indices into all columns.
"""
size = self.row_ids.max() + 1
for col_obj in self.all_columns.itervalues():
col_obj.growto(size)
def get_column(self, col_id):
"""
Returns the column with the given column ID.
"""
return self.all_columns[col_id]
def has_column(self, col_id):
"""
Returns whether col_id represents a valid column in the table.
"""
return col_id in self.all_columns
def lookup_records(self, **kwargs):
"""
Returns a Record matching the given column=value arguments. It creates the necessary
dependencies, so that the formula will get re-evaluated if needed. It also creates and starts
maintaining a lookup index to make such lookups fast.
"""
# The tuple of keys used determines the LookupMap we need.
sort_by = kwargs.pop('sort_by', None)
col_ids = tuple(sorted(kwargs.iterkeys()))
key = tuple(kwargs[c] for c in col_ids)
lookup_map = self._get_lookup_map(col_ids)
row_id_set, rel = lookup_map.do_lookup(key)
if sort_by:
row_ids = sorted(row_id_set, key=lambda r: self._get_col_value(sort_by, r, rel))
else:
row_ids = sorted(row_id_set)
return self.RecordSet(self, row_ids, rel, group_by=kwargs, sort_by=sort_by)
def lookup_one_record(self, **kwargs):
return self.lookup_records(**kwargs).get_one()
def _get_lookup_map(self, col_ids_tuple):
"""
Helper which returns the LookupMapColumn for the given combination of lookup columns. A
LookupMap behaves a bit like a formula column in that it depends on the passed-in columns and
gets updated whenever any of them change.
"""
# LookupMapColumn is a Node, so identified by (table_id, col_id) pair, so we make up a col_id
# to identify this lookup object uniquely in this Table.
lookup_col_id = "#lookup#" + ":".join(col_ids_tuple)
lmap = self._special_cols.get(lookup_col_id)
if not lmap:
# Check that the table actually has all the columns we looking up.
for c in col_ids_tuple:
if not self.has_column(c):
raise KeyError("Table %s has no column %s" % (self.table_id, c))
lmap = lookup.LookupMapColumn(self, lookup_col_id, col_ids_tuple)
self._special_cols[lookup_col_id] = lmap
self.all_columns[lookup_col_id] = lmap
return lmap
def delete_column(self, col_obj):
assert col_obj.table_id == self.table_id
self._special_cols.pop(col_obj.col_id, None)
self.all_columns.pop(col_obj.col_id, None)
def lookupOrAddDerived(self, **kwargs):
record = self.lookup_one_record(**kwargs)
if not record._row_id and not self._engine.is_triggered_by_table_action(self.table_id):
record._row_id = self._engine.user_actions.AddRecord(self.table_id, None, kwargs)
return record
def getSummarySourceGroup(self, rec):
return (self._summary_source_table.lookup_records(**{self._summary_helper_col_id: int(rec)})
if self._summary_source_table else None)
def get(self, **kwargs):
"""
Returns the first row_id matching the given column=value arguments. This is intended for grist
internal code rather than for user formulas, because it doesn't create the necessary
dependencies.
"""
# TODO: It should use indices, to avoid linear searching
# TODO: It should create dependencies as needed when used from formulas.
# TODO: It should return Record instead, for convenience of user formulas
col_values = [(self.all_columns[col_id], value) for (col_id, value) in kwargs.iteritems()]
for row_id in self.row_ids:
if all(col.raw_get(row_id) == value for col, value in col_values):
return row_id
raise KeyError("'get' found no matching record")
def filter(self, **kwargs):
"""
Generates all row_ids matching the given column=value arguments. This is intended for grist
internal code rather than for user formulas, because it doesn't create the necessary
dependencies. Use filter_records() to generate Record objects instead.
"""
# TODO: It should use indices, to avoid linear searching
# TODO: It should create dependencies as needed when used from formulas.
# TODO: It should return Record instead, for convenience of user formulas
col_values = [(self.all_columns[col_id], value) for (col_id, value) in kwargs.iteritems()]
for row_id in self.row_ids:
if all(col.raw_get(row_id) == value for col, value in col_values):
yield row_id
def get_record(self, row_id):
"""
Returns a Record object corresponding to the given row_id. This is intended for grist internal
code rather than user formulas.
"""
# We don't set up any dependencies, so it would be incorrect to use this from formulas.
# We no longer assert, however, since such calls may still happen e.g. while applying
# user-actions caused by formula side-effects (e.g. as trigged by lookupOrAddDerived())
if row_id not in self.row_ids:
raise KeyError("'get_record' found no matching record")
return self.Record(self, row_id, None)
def filter_records(self, **kwargs):
"""
Generator for Record objects for all the rows matching the given column=value arguments.
This is intended for grist internal code rather than user formula. You may call this with no
arguments to generate all Records in the table.
"""
# See note in get_record() about using this call from formulas.
for row_id in self.filter(**kwargs):
yield self.Record(self, row_id, None)
# TODO: document everything here.
def _use_column(self, col_id, relation, row_ids):
"""
relation - Describes how the record was obtained, on which col_id is being accessed.
"""
col = self.all_columns[col_id]
# The _use_node call both creates a dependency and brings formula columns up-to-date.
self._engine._use_node(col.node, relation, row_ids)
return col
# Called when record.foo is accessed
def _get_col_value(self, col_id, row_id, relation):
return _adjust_record(relation,
self._use_column(col_id, relation, [row_id]).get_cell_value(row_id))
def _attribute_error(self, col_id, relation):
self._engine._use_node(self._new_columns_node, relation)
raise AttributeError("Table '%s' has no column '%s'" % (self.table_id, col_id))
# Called when record_set.foo is accessed. Should return something like a ColumnView.
def _get_col_subset(self, col_id, row_ids, relation):
# TODO: when column is a reference, we ought to return RecordSet. Otherwise ColumnView
# looks like a RecordSet (returns Records), but doesn't support property access.
return ColumnView(self._use_column(col_id, relation, row_ids), row_ids, relation)

@ -0,0 +1,133 @@
from itertools import izip
import actions
from usertypes import get_type_default
import logger
log = logger.Logger(__name__, logger.INFO)
class TableDataSet(object):
"""
TableDataSet represents the full data of a Grist document as a dictionary mapping tableId to
actions.TableData. It then allows applying arbitrary doc-actions, and updates its representation
of the document accordingly. The dictionary is available as the object's `all_tables` member.
This is used, in particular, for migrations, which need to access data with minimal assumptions
about its interpretation.
Note that to initialize a TableDataSet, the schema is needed, so it should be done by applying
AddTable actions, followed by BulkAddRecord or ReplaceTableData actions.
"""
def __init__(self):
# Dictionary of { tableId: actions.TableData object }
self.all_tables = {}
# Dictionary of { tableId: { colId: values }} where values come from AddTable, as modified by
# Add/ModifyColumn actions.
self._schema = {}
def apply_doc_action(self, action):
try:
getattr(self, action.__class__.__name__)(*action)
except Exception, e:
log.warn("ERROR applying action %s: %s" % (action, e))
raise
def apply_doc_actions(self, doc_actions):
for a in doc_actions:
self.apply_doc_action(a)
return doc_actions
def get_col_info(self, table_id, col_id):
return self._schema[table_id][col_id]
def get_schema(self):
return self._schema
#----------------------------------------
# Actions on records.
#----------------------------------------
def AddRecord(self, table_id, row_id, columns):
self.BulkAddRecord(table_id, [row_id], {key: [val] for key, val in columns.iteritems()})
def BulkAddRecord(self, table_id, row_ids, columns):
table_data = self.all_tables[table_id]
table_data.row_ids.extend(row_ids)
for col, values in table_data.columns.iteritems():
if col in columns:
values.extend(columns[col])
else:
col_info = self._schema[table_id][col]
default = get_type_default(col_info['type'])
values.extend([default] * len(row_ids))
def RemoveRecord(self, table_id, row_id):
return self.BulkRemoveRecord(table_id, [row_id])
def BulkRemoveRecord(self, table_id, row_ids):
table_data = self.all_tables[table_id]
remove_set = set(row_ids)
for col, values in table_data.columns.iteritems():
values[:] = [v for r, v in izip(table_data.row_ids, values) if r not in remove_set]
table_data.row_ids[:] = [r for r in table_data.row_ids if r not in remove_set]
def UpdateRecord(self, table_id, row_id, columns):
self.BulkUpdateRecord(
table_id, [row_id], {key: [val] for key, val in columns.iteritems()})
def BulkUpdateRecord(self, table_id, row_ids, columns):
table_data = self.all_tables[table_id]
rowid_map = {r:i for i, r in enumerate(table_data.row_ids)}
table_indices = [rowid_map[r] for r in row_ids]
for col, values in columns.iteritems():
if col in table_data.columns:
col_values = table_data.columns[col]
for i, v in izip(table_indices, values):
col_values[i] = v
def ReplaceTableData(self, table_id, row_ids, columns):
table_data = self.all_tables[table_id]
del table_data.row_ids[:]
for col, values in table_data.columns.iteritems():
del values[:]
self.BulkAddRecord(table_id, row_ids, columns)
#----------------------------------------
# Actions on columns.
#----------------------------------------
def AddColumn(self, table_id, col_id, col_info):
self._schema[table_id][col_id] = col_info
default = get_type_default(col_info['type'])
table_data = self.all_tables[table_id]
table_data.columns[col_id] = [default] * len(table_data.row_ids)
def RemoveColumn(self, table_id, col_id):
self._schema[table_id].pop(col_id, None)
table_data = self.all_tables[table_id]
table_data.columns.pop(col_id, None)
def RenameColumn(self, table_id, old_col_id, new_col_id):
self._schema[table_id][new_col_id] = self._schema[table_id].pop(old_col_id)
table_data = self.all_tables[table_id]
table_data.columns[new_col_id] = table_data.columns.pop(old_col_id)
def ModifyColumn(self, table_id, col_id, col_info):
self._schema[table_id][col_id].update(col_info)
#----------------------------------------
# Actions on tables.
#----------------------------------------
def AddTable(self, table_id, columns):
self.all_tables[table_id] = actions.TableData(table_id, [], {c['id']: [] for c in columns})
self._schema[table_id] = {c['id']: c.copy() for c in columns}
def RemoveTable(self, table_id):
del self.all_tables[table_id]
del self._schema[table_id]
def RenameTable(self, old_table_id, new_table_id):
table_data = self.all_tables.pop(old_table_id)
self.all_tables[new_table_id] = actions.TableData(new_table_id, table_data.row_ids,
table_data.columns)
self._schema[new_table_id] = self._schema.pop(old_table_id)

@ -0,0 +1,512 @@
"""
Test of ACL rules.
"""
import acl
import actions
import logger
import schema
import testutil
import test_engine
import useractions
log = logger.Logger(__name__, logger.INFO)
class TestACL(test_engine.EngineTestCase):
maxDiff = None # Allow self.assertEqual to display big diffs
starting_table_data = [
["id", "city", "state", "amount" ],
[ 21, "New York", "NY" , 1. ],
[ 22, "Albany", "NY" , 2. ],
[ 23, "Seattle", "WA" , 3. ],
[ 24, "Chicago", "IL" , 4. ],
[ 25, "Bedford", "MA" , 5. ],
[ 26, "New York", "NY" , 6. ],
[ 27, "Buffalo", "NY" , 7. ],
[ 28, "Bedford", "NY" , 8. ],
[ 29, "Boston", "MA" , 9. ],
[ 30, "Yonkers", "NY" , 10. ],
[ 31, "New York", "NY" , 11. ],
]
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Address", [
[11, "city", "Text", False, "", "City", ""],
[12, "state", "Text", False, "", "State", "WidgetOptions1"],
[13, "amount", "Numeric", False, "", "Amount", "WidgetOptions2"],
]]
],
"DATA": {
"Address": starting_table_data,
"_grist_ACLRules": [
["id", "resource", "permissions", "principals", "aclFormula", "aclColumn"],
],
"_grist_ACLResources": [
["id", "tableId", "colIds"],
],
"_grist_ACLPrincipals": [
["id", "type", "userEmail", "userName", "groupName", "instanceId"],
],
"_grist_ACLMemberships": [
["id", "parent", "child"],
]
}
})
def _apply_ua(self, *useraction_reprs):
"""Returns an ActionBundle."""
user_actions = [useractions.from_repr(ua) for ua in useraction_reprs]
return self.engine.acl_split(self.engine.apply_user_actions(user_actions))
def test_trivial_action_bundle(self):
# In this test case, we just check that an ActionGroup is packaged unchanged into an
# ActionBundle when there are no ACL rules at all.
self.load_sample(self.sample)
# Verify the starting table; there should be no views yet.
self.assertTableData("Address", self.starting_table_data)
# Check that the raw action group created by an action is as expected.
out_action = self.update_record("Address", 22, amount=20.)
self.assertPartialOutActions(out_action, {
'stored': [['UpdateRecord', 'Address', 22, {'amount': 20.}]],
'undo': [['UpdateRecord', 'Address', 22, {'amount': 2.}]],
'calc': [],
'retValues': [None],
})
# In this case, we have no rules, and the action is packaged unchanged into an ActionBundle.
out_bundle = self.engine.acl_split(out_action)
self.assertEqual(out_bundle.to_json_obj(), {
'envelopes': [{"recipients": []}],
'stored': [(0, ['UpdateRecord', 'Address', 22, {'amount': 20.}])],
'undo': [(0, ['UpdateRecord', 'Address', 22, {'amount': 2.}])],
'calc': [],
'retValues': [None],
'rules': [],
})
# Another similar action.
out_bundle = self._apply_ua(
['UpdateRecord', 'Address', 21, {'amount': 10., 'city': 'NYC'}])
self.assertEqual(out_bundle.to_json_obj(), {
'envelopes': [{"recipients": []}],
'stored': [(0, ['UpdateRecord', 'Address', 21, {'amount': 10., 'city': 'NYC'}])],
'undo': [(0, ['UpdateRecord', 'Address', 21, {'amount': 1., 'city': 'New York'}])],
'calc': [],
'retValues': [None],
'rules': [],
})
def test_bundle_default_rules(self):
# Check that a newly-created document (which should have default rules) produces the same
# bundle as the trivial document without rules.
self._apply_ua(['InitNewDoc', 'UTC'])
# Create a schema for a table, and fill with some data.
self.apply_user_action(["AddTable", "Address", [
{"id": "city", "type": "Text"},
{"id": "state", "type": "Text"},
{"id": "amount", "type": "Numeric"},
]])
self.add_records("Address", self.starting_table_data[0], self.starting_table_data[1:])
self.assertTableData("Address", cols="subset", data=self.starting_table_data)
# Check that an action creates the same bundle as in the trivial case.
out_bundle = self._apply_ua(
['UpdateRecord', 'Address', 21, {'amount': 10., 'city': 'NYC'}])
self.assertEqual(out_bundle.to_json_obj(), {
'envelopes': [{"recipients": []}],
'stored': [(0, ['UpdateRecord', 'Address', 21, {'amount': 10., 'city': 'NYC'}])],
'undo': [(0, ['UpdateRecord', 'Address', 21, {'amount': 1., 'city': 'New York'}])],
'calc': [],
'retValues': [None],
'rules': [1],
})
# Once we add principals to Owners group, they should show up in the recipient list.
self.add_records('_grist_ACLPrincipals', ['id', 'type', 'userName', 'instanceId'], [
[20, 'user', 'foo@grist', ''],
[21, 'instance', '', '12345'],
[22, 'instance', '', '0abcd'],
])
self.add_records('_grist_ACLMemberships', ['parent', 'child'], [
[1, 20], # group 'Owners' contains user 'foo@grist'
[20, 21], # user 'foo@grist', contains instance '12345' and '67890'
[20, 22],
])
# Similar action to before, which is bundled as a single envelope, but includes recipients.
out_bundle = self._apply_ua(
['UpdateRecord', 'Address', 21, {'amount': 11., 'city': 'NYC2'}])
self.assertEqual(out_bundle.to_json_obj(), {
'envelopes': [{"recipients": ['0abcd', '12345']}],
'stored': [(0, ['UpdateRecord', 'Address', 21, {'amount': 11., 'city': 'NYC2'}])],
'undo': [(0, ['UpdateRecord', 'Address', 21, {'amount': 10., 'city': 'NYC'}])],
'calc': [],
'retValues': [None],
'rules': [1],
})
def init_employees_doc(self):
# Create a document with non-trivial rules, and check that actions are split correctly,
# using col/table/default rules, and including undo and calc actions.
#
# This is the structure we create:
# Columns Name, Position
# VIEW permission to group Employees
# EDITOR permission to groups Managers, Owners
# Default for columns
# EDITOR permission to groups Managers, Owners
self._apply_ua(['InitNewDoc', 'UTC'])
self.apply_user_action(["AddTable", "Employees", [
{"id": "name", "type": "Text"},
{"id": "position", "type": "Text"},
{"id": "ssn", "type": "Text"},
{"id": "salary", "type": "Numeric", "isFormula": True,
"formula": "100000 if $position.startswith('Senior') else 60000"},
]])
# Set up some groups and instances (skip Users for simplicity). See the assert below for
# better view of the created structure.
self.add_records('_grist_ACLPrincipals', ['id', 'type', 'groupName', 'instanceId'], [
[21, 'group', 'Managers', ''],
[22, 'group', 'Employees', ''],
[23, 'instance', '', 'alice'],
[24, 'instance', '', 'bob'],
[25, 'instance', '', 'chuck'],
[26, 'instance', '', 'eve'],
[27, 'instance', '', 'zack'],
])
# Set up Alice and Bob as Managers; Alice, Chuck, Eve as Employees; and Zack as an Owner.
self.add_records('_grist_ACLMemberships', ['parent', 'child'], [
[21, 23], [21, 24],
[22, 23], [22, 25], [22, 26],
[1, 27]
])
self.assertTableData('_grist_ACLPrincipals', cols="subset", data=[
['id', 'name', 'allInstances' ],
[1, 'Group:Owners', [27] ],
[2, 'Group:Admins', [] ],
[3, 'Group:Editors', [] ],
[4, 'Group:Viewers', [] ],
[21, 'Group:Managers', [23,24] ],
[22, 'Group:Employees', [23,25,26] ],
[23, 'Inst:alice', [23] ],
[24, 'Inst:bob', [24] ],
[25, 'Inst:chuck', [25] ],
[26, 'Inst:eve', [26] ],
[27, 'Inst:zack', [27] ],
])
# Set up some ACL resources and rules: for columns "name,position", give VIEW permission to
# Employees, EDITOR to Managers+Owners; for the rest, just Editor to Managers+Owners.
self.add_records('_grist_ACLResources', ['id', 'tableId', 'colIds'], [
[2, 'Employees', 'name,position'],
[3, 'Employees', ''],
])
self.add_records('_grist_ACLRules', ['id', 'resource', 'permissions', 'principals'], [
[12, 2, acl.Permissions.VIEW, ['L', 22]],
[13, 2, acl.Permissions.EDITOR, ['L', 21,1]],
[14, 3, acl.Permissions.EDITOR, ['L', 21,1]],
])
# OK, now to some actions. The table starts out empty.
self.assertTableData('Employees', [['id', 'manualSort', 'name', 'position', 'salary', 'ssn']])
def test_rules_order(self):
# Test that shows the problem with the ordering of actions in Envelopes.
self.init_employees_doc()
self._apply_ua(self.add_records_action('Employees', [
['name', 'position', 'ssn'],
['John', 'Scientist', '000-00-0000'],
['Ellen', 'Senior Scientist', '111-11-1111'],
['Susie', 'Manager', '222-22-2222'],
['Frank', 'Senior Manager', '222-22-2222'],
]))
out_bundle = self._apply_ua(['ApplyDocActions', [
['UpdateRecord', 'Employees', 1, {'ssn': 'xxx-xx-0000'}],
['UpdateRecord', 'Employees', 1, {'position': 'Senior Jester'}],
['UpdateRecord', 'Employees', 1, {'ssn': 'yyy-yy-0000'}],
]])
self.assertTableData('Employees', cols="subset", data=[
['id', 'name', 'position', 'salary', 'ssn'],
[1, 'John', 'Senior Jester', 100000.0, 'yyy-yy-0000'],
[2, 'Ellen', 'Senior Scientist', 100000.0, '111-11-1111'],
[3, 'Susie', 'Manager', 60000.0, '222-22-2222'],
[4, 'Frank', 'Senior Manager', 100000.0, '222-22-2222'],
])
# Check the main aspects of the created bundles.
env = out_bundle.envelopes
# We expect two envelopes: one for Managers+Owners, one for all including Employees,
# because 'ssn' and 'position' columns are resources with different permissions.
# Note how non-consecutive actions may belong to the same envelope. This is needed to allow
# users (e.g. alice in this example) to process DocActions in the same order as how they were
# created, even when alice is present in different sets of recipients.
self.assertEqual(env[0].recipients, {"alice", "bob", "zack"})
self.assertEqual(env[1].recipients, {"alice", "bob", "zack", "chuck", "eve"})
self.assertEqual(out_bundle.stored, [
(0, actions.UpdateRecord('Employees', 1, {'ssn': 'xxx-xx-0000'})),
(1, actions.UpdateRecord('Employees', 1, {'position': 'Senior Jester'})),
(0, actions.UpdateRecord('Employees', 1, {'ssn': 'yyy-yy-0000'})),
])
self.assertEqual(out_bundle.calc, [
(0, actions.UpdateRecord('Employees', 1, {'salary': 100000.00}))
])
def test_with_rules(self):
self.init_employees_doc()
out_bundle = self._apply_ua(self.add_records_action('Employees', [
['name', 'position', 'ssn'],
['John', 'Scientist', '000-00-0000'],
['Ellen', 'Senior Scientist', '111-11-1111'],
['Susie', 'Manager', '222-22-2222'],
['Frank', 'Senior Manager', '222-22-2222'],
]))
# Check the main aspects of the output.
env = out_bundle.envelopes
# We expect two envelopes: one for Managers+Owners, one for all including Employees.
self.assertEqual([e.recipients for e in env], [
{"alice","chuck","eve","bob","zack"},
{"alice", "bob", "zack"}
])
# Only "name" and "position" are sent to Employees; the rest only to Managers+Owners.
self.assertEqual([(env, set(a.columns)) for (env, a) in out_bundle.stored], [
(0, {"name", "position"}),
(1, {"ssn", "manualSort"}),
])
self.assertEqual([(env, set(a.columns)) for (env, a) in out_bundle.calc], [
(1, {"salary"})
])
# Full bundle requires careful reading. See the checks above for the essential parts.
self.assertEqual(out_bundle.to_json_obj(), {
"envelopes": [
{"recipients": [ "alice", "bob", "chuck", "eve", "zack" ]},
{"recipients": [ "alice", "bob", "zack" ]},
],
"stored": [
# TODO Yikes, there is a problem here! We have two envelopes, each with BulkAddRecord
# actions, but some recipients receive BOTH envelopes. What is "alice" to do with two
# separate BulkAddRecord actions that both include rowIds 1, 2, 3, 4?
(0, [ "BulkAddRecord", "Employees", [ 1, 2, 3, 4 ], {
"position": [ "Scientist", "Senior Scientist", "Manager", "Senior Manager" ],
"name": [ "John", "Ellen", "Susie", "Frank" ]
}]),
(1, [ "BulkAddRecord", "Employees", [ 1, 2, 3, 4 ], {
"manualSort": [ 1, 2, 3, 4 ],
"ssn": [ "000-00-0000", "111-11-1111", "222-22-2222", "222-22-2222" ]
}]),
],
"undo": [
# TODO All recipients now get BulkRemoveRecord (which is correct), but some get it twice,
# which is a simpler manifestation of the problem with BulkAddRecord.
(0, [ "BulkRemoveRecord", "Employees", [ 1, 2, 3, 4 ] ]),
(1, [ "BulkRemoveRecord", "Employees", [ 1, 2, 3, 4 ] ]),
],
"calc": [
(1, [ "BulkUpdateRecord", "Employees", [ 1, 2, 3, 4 ], {
"salary": [ 60000, 100000, 60000, 100000 ]
}])
],
"retValues": [[1, 2, 3, 4]],
"rules": [12,13,14],
})
def test_empty_add_record(self):
self.init_employees_doc()
out_bundle = self._apply_ua(['AddRecord', 'Employees', None, {}])
self.assertEqual(out_bundle.to_json_obj(), {
"envelopes": [{"recipients": [ "alice", "bob", "chuck", "eve", "zack" ]},
{"recipients": [ "alice", "bob", "zack" ]} ],
# TODO Note the same issues as in previous test case: some recipients receive duplicate or
# near-duplicate AddRecord and RemoveRecord actions, governed by different rules.
"stored": [
(0, [ "AddRecord", "Employees", 1, {}]),
(1, [ "AddRecord", "Employees", 1, {"manualSort": 1.0}]),
],
"undo": [
(0, [ "RemoveRecord", "Employees", 1 ]),
(1, [ "RemoveRecord", "Employees", 1 ]),
],
"calc": [
(1, [ "UpdateRecord", "Employees", 1, { "salary": 60000.0 }])
],
"retValues": [1],
"rules": [12,13,14],
})
out_bundle = self._apply_ua(['UpdateRecord', 'Employees', 1, {"position": "Senior Citizen"}])
self.assertEqual(out_bundle.to_json_obj(), {
"envelopes": [{"recipients": [ "alice", "bob", "chuck", "eve", "zack" ]},
{"recipients": [ "alice", "bob", "zack" ]} ],
"stored": [
(0, [ "UpdateRecord", "Employees", 1, {"position": "Senior Citizen"}]),
],
"undo": [
(0, [ "UpdateRecord", "Employees", 1, {"position": ""}]),
],
"calc": [
(1, [ "UpdateRecord", "Employees", 1, { "salary": 100000.0 }])
],
"retValues": [None],
"rules": [12,13,14],
})
def test_add_user(self):
self.init_employees_doc()
out_bundle = self._apply_ua(['AddUser', 'f@g.c', 'Fred', ['XXX', 'YYY']])
self.assertEqual(out_bundle.to_json_obj(), {
# TODO: Only Owners are getting these metadata changes, but all users should get them.
"envelopes": [{"recipients": [ "XXX", "YYY", "zack" ]}],
"stored": [
(0, [ "AddRecord", "_grist_ACLPrincipals", 28, {
'type': 'user', 'userEmail': 'f@g.c', 'userName': 'Fred'}]),
(0, [ "BulkAddRecord", "_grist_ACLPrincipals", [29, 30], {
'type': ['instance', 'instance'],
'instanceId': ['XXX', 'YYY']
}]),
(0, [ "BulkAddRecord", "_grist_ACLMemberships", [7, 8, 9], {
# Adds instances (29, 30) to user (28), and user (28) to group owners (1)
'parent': [28, 28, 1],
'child': [29, 30, 28],
}]),
],
"undo": [
(0, [ "RemoveRecord", "_grist_ACLPrincipals", 28]),
(0, [ "BulkRemoveRecord", "_grist_ACLPrincipals", [29, 30]]),
(0, [ "BulkRemoveRecord", "_grist_ACLMemberships", [7, 8, 9]]),
],
"calc": [
],
"retValues": [None],
"rules": [1],
})
def test_doc_snapshot(self):
self.init_employees_doc()
# Apply an action to the initial employees doc to make the test case more complex
self.add_records('Employees', ['name', 'position', 'ssn'], [
['John', 'Scientist', '000-00-0000'],
['Ellen', 'Senior Scientist', '111-11-1111'],
['Susie', 'Manager', '222-22-2222'],
['Frank', 'Senior Manager', '222-22-2222']
])
# Retrieve the doc snapshot and split it
snapshot_action_group = self.engine.fetch_snapshot()
snapshot_bundle = self.engine.acl_split(snapshot_action_group)
init_schema_actions = [actions.get_action_repr(a) for a in schema.schema_create_actions()]
# We check that the unsplit doc snapshot bundle includes all the necessary actions
# to rebuild the doc
snapshot = snapshot_action_group.get_repr()
self.assertEqual(snapshot['calc'], [])
self.assertEqual(snapshot['retValues'], [])
self.assertEqual(snapshot['undo'], [])
stored_subset = [
['AddTable', 'Employees',
[{'formula': '','id': 'manualSort','isFormula': False,'type': 'ManualSortPos'},
{'formula': '','id': 'name','isFormula': False,'type': 'Text'},
{'formula': '','id': 'position','isFormula': False,'type': 'Text'},
{'formula': '','id': 'ssn','isFormula': False,'type': 'Text'},
{'formula': "100000 if $position.startswith('Senior') else 60000",
'id': 'salary',
'isFormula': True,
'type': 'Numeric'}]],
['BulkAddRecord', '_grist_Tables', [1],
{'primaryViewId': [1],
'summarySourceTable': [0],
'tableId': ['Employees'],
'onDemand': [False]}],
['BulkAddRecord', 'Employees', [1, 2, 3, 4], {
'manualSort': [1.0, 2.0, 3.0, 4.0],
'name': ['John', 'Ellen', 'Susie', 'Frank'],
'position': ['Scientist', 'Senior Scientist', 'Manager', 'Senior Manager'],
'ssn': ['000-00-0000', '111-11-1111', '222-22-2222', '222-22-2222']
}],
['BulkAddRecord','_grist_Tables_column',[1, 2, 3, 4, 5],
{'colId': ['manualSort', 'name', 'position', 'ssn', 'salary'],
'displayCol': [0, 0, 0, 0, 0],
'formula': ['','','','',"100000 if $position.startswith('Senior') else 60000"],
'isFormula': [False, False, False, False, True],
'label': ['manualSort', 'name', 'position', 'ssn', 'salary'],
'parentId': [1, 1, 1, 1, 1],
'parentPos': [1.0, 2.0, 3.0, 4.0, 5.0],
'summarySourceCol': [0, 0, 0, 0, 0],
'type': ['ManualSortPos', 'Text', 'Text', 'Text', 'Numeric'],
'untieColIdFromLabel': [False, False, False, False, False],
'widgetOptions': ['', '', '', '', ''],
'visibleCol': [0, 0, 0, 0, 0]}]
]
for action in stored_subset:
self.assertIn(action, snapshot['stored'])
# We check that the full doc snapshot bundle is split as expected
snapshot_bundle_json = snapshot_bundle.to_json_obj()
self.assertEqual(snapshot_bundle_json['envelopes'], [
{'recipients': ['#ALL']},
{'recipients': ['zack']},
{'recipients': ['alice', 'bob', 'chuck', 'eve', 'zack']},
{'recipients': ['alice', 'bob', 'zack']}
])
self.assertEqual(snapshot_bundle_json['calc'], [])
self.assertEqual(snapshot_bundle_json['retValues'], [])
self.assertEqual(snapshot_bundle_json['undo'], [])
self.assertEqual(snapshot_bundle_json['rules'], [1, 12, 13, 14])
stored_subset = ([(0, action_repr) for action_repr in init_schema_actions] + [
(0, ['AddTable', 'Employees',
[{'formula': '','id': 'manualSort','isFormula': False,'type': 'ManualSortPos'},
{'formula': '','id': 'name','isFormula': False,'type': 'Text'},
{'formula': '','id': 'position','isFormula': False,'type': 'Text'},
{'formula': '','id': 'ssn','isFormula': False,'type': 'Text'},
{'formula': "100000 if $position.startswith('Senior') else 60000",
'id': 'salary',
'isFormula': True,
'type': 'Numeric'}]]),
# TODO (High-priority): The following action only received by 'zack' when it should be
# received by everyone.
(1, ['BulkAddRecord', '_grist_Tables', [1],
{'primaryViewId': [1],
'summarySourceTable': [0],
'tableId': ['Employees'],
'onDemand': [False]}]),
(2, ['BulkAddRecord', 'Employees', [1, 2, 3, 4],
{'name': ['John', 'Ellen', 'Susie', 'Frank'],
'position': ['Scientist', 'Senior Scientist', 'Manager', 'Senior Manager']}]),
(3, ['BulkAddRecord', 'Employees', [1, 2, 3, 4],
{'manualSort': [1.0, 2.0, 3.0, 4.0],
'ssn': ['000-00-0000', '111-11-1111', '222-22-2222', '222-22-2222']}]),
(1, ['BulkAddRecord', '_grist_Tables_column', [1, 2, 3, 4, 5],
{'colId': ['manualSort','name','position','ssn','salary'],
'displayCol': [0, 0, 0, 0, 0],
'formula': ['','','','',"100000 if $position.startswith('Senior') else 60000"],
'isFormula': [False, False, False, False, True],
'label': ['manualSort','name','position','ssn','salary'],
'parentId': [1, 1, 1, 1, 1],
'parentPos': [1.0, 2.0, 3.0, 4.0, 5.0],
'summarySourceCol': [0, 0, 0, 0, 0],
'type': ['ManualSortPos','Text','Text','Text','Numeric'],
'untieColIdFromLabel': [False, False, False, False, False],
'widgetOptions': ['', '', '', '', ''],
'visibleCol': [0, 0, 0, 0, 0]}])
])
for action in stored_subset:
self.assertIn(action, snapshot_bundle_json['stored'])

@ -0,0 +1,79 @@
import unittest
import actions
class TestActions(unittest.TestCase):
action_obj1 = actions.UpdateRecord("foo", 17, {"bar": "baz"})
doc_action1 = ["UpdateRecord", "foo", 17, {"bar": "baz"}]
def test_convert(self):
self.assertEqual(actions.get_action_repr(self.action_obj1), self.doc_action1)
self.assertEqual(actions.action_from_repr(self.doc_action1), self.action_obj1)
with self.assertRaises(ValueError) as err:
actions.action_from_repr(["Foo", "bar"])
self.assertTrue("Foo" in str(err.exception))
def test_prune_actions(self):
# prune_actions is in-place, so we make a new list every time.
def alist():
return [
actions.BulkUpdateRecord("Table1", [1,2,3], {'Foo': [10,20,30]}),
actions.BulkUpdateRecord("Table2", [1,2,3], {'Foo': [10,20,30], 'Bar': ['a','b','c']}),
actions.UpdateRecord("Table1", 17, {'Foo': 10}),
actions.UpdateRecord("Table2", 18, {'Foo': 10, 'Bar': 'a'}),
actions.AddRecord("Table1", 17, {'Foo': 10}),
actions.BulkAddRecord("Table2", 18, {'Foo': 10, 'Bar': 'a'}),
actions.ReplaceTableData("Table2", 18, {'Foo': 10, 'Bar': 'a'}),
actions.RemoveRecord("Table1", 17),
actions.BulkRemoveRecord("Table2", [17,18]),
actions.AddColumn("Table1", "Foo", {"type": "Text"}),
actions.RenameColumn("Table1", "Foo", "Bar"),
actions.ModifyColumn("Table1", "Foo", {"type": "Text"}),
actions.RemoveColumn("Table1", "Foo"),
actions.AddTable("THello", [{"id": "Foo"}, {"id": "Bar"}]),
actions.RemoveTable("THello"),
actions.RenameTable("THello", "TWorld"),
]
def prune(table_id, col_id):
a = alist()
actions.prune_actions(a, table_id, col_id)
return a
self.assertEqual(prune('Table1', 'Foo'), [
actions.BulkUpdateRecord("Table2", [1,2,3], {'Foo': [10,20,30], 'Bar': ['a','b','c']}),
actions.UpdateRecord("Table2", 18, {'Foo': 10, 'Bar': 'a'}),
actions.BulkAddRecord("Table2", 18, {'Foo': 10, 'Bar': 'a'}),
actions.ReplaceTableData("Table2", 18, {'Foo': 10, 'Bar': 'a'}),
actions.RemoveRecord("Table1", 17),
actions.BulkRemoveRecord("Table2", [17,18]),
# It doesn't do anything with column renames; it can be addressed if needed.
actions.RenameColumn("Table1", "Foo", "Bar"),
# It doesn't do anything with AddTable, which is expected.
actions.AddTable("THello", [{"id": "Foo"}, {"id": "Bar"}]),
actions.RemoveTable("THello"),
actions.RenameTable("THello", "TWorld"),
])
self.assertEqual(prune('Table2', 'Foo'), [
actions.BulkUpdateRecord("Table1", [1,2,3], {'Foo': [10,20,30]}),
actions.BulkUpdateRecord("Table2", [1,2,3], {'Bar': ['a','b','c']}),
actions.UpdateRecord("Table1", 17, {'Foo': 10}),
actions.UpdateRecord("Table2", 18, {'Bar': 'a'}),
actions.AddRecord("Table1", 17, {'Foo': 10}),
actions.BulkAddRecord("Table2", 18, {'Bar': 'a'}),
actions.ReplaceTableData("Table2", 18, {'Bar': 'a'}),
actions.RemoveRecord("Table1", 17),
actions.BulkRemoveRecord("Table2", [17,18]),
actions.AddColumn("Table1", "Foo", {"type": "Text"}),
actions.RenameColumn("Table1", "Foo", "Bar"),
actions.ModifyColumn("Table1", "Foo", {"type": "Text"}),
actions.RemoveColumn("Table1", "Foo"),
actions.AddTable("THello", [{"id": "Foo"}, {"id": "Bar"}]),
actions.RemoveTable("THello"),
actions.RenameTable("THello", "TWorld"),
])
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
import unittest
import codebuilder
def make_body(formula, default=None):
return codebuilder.make_formula_body(formula, default).get_text()
class TestCodeBuilder(unittest.TestCase):
def test_make_formula_body(self):
# Test simple usage.
self.assertEqual(make_body(""), "return None")
self.assertEqual(make_body("", 0.0), "return 0.0")
self.assertEqual(make_body("", ""), "return ''")
self.assertEqual(make_body(" "), "return None")
self.assertEqual(make_body(" ", "-"), "return '-'")
self.assertEqual(make_body("\n\t"), "return None")
self.assertEqual(make_body("$foo"), "return rec.foo")
self.assertEqual(make_body("rec.foo"), "return rec.foo")
self.assertEqual(make_body("return $foo"), "return rec.foo")
self.assertEqual(make_body("return $f123"), "return rec.f123")
self.assertEqual(make_body("return rec.foo"), "return rec.foo")
self.assertEqual(make_body("$foo if $bar else max($foo.bar.baz)"),
"return rec.foo if rec.bar else max(rec.foo.bar.baz)")
# Check that we don't mistake our temporary representation of "$" for the real thing.
self.assertEqual(make_body("return DOLLARfoo"), "return DOLLARfoo")
# Test that we don't translate $foo inside string literals or comments.
self.assertEqual(make_body("$foo or '$foo'"), "return rec.foo or '$foo'")
self.assertEqual(make_body("$foo * 2 # $foo"), "return rec.foo * 2 # $foo")
self.assertEqual(make_body("$foo * 2 # $foo\n$bar"), "rec.foo * 2 # $foo\nreturn rec.bar")
self.assertEqual(make_body("$foo or '\\'$foo\\''"), "return rec.foo or '\\'$foo\\''")
self.assertEqual(make_body('$foo or """$foo"""'), 'return rec.foo or """$foo"""')
self.assertEqual(make_body('$foo or """Some "$foos" stay"""'),
'return rec.foo or """Some "$foos" stay"""')
# Check that we only insert a return appropriately.
self.assertEqual(make_body('if $foo:\n return 1\nelse:\n return 2\n'),
'if rec.foo:\n return 1\nelse:\n return 2\n')
self.assertEqual(make_body('a = $foo\nmax(a, a*2)'), 'a = rec.foo\nreturn max(a, a*2)')
# Check that return gets inserted correctly when there is a multi-line expression.
self.assertEqual(make_body('($foo or\n $bar)'), 'return (rec.foo or\n rec.bar)')
self.assertEqual(make_body('return ($foo or\n $bar)'), 'return (rec.foo or\n rec.bar)')
self.assertEqual(make_body('if $foo: return 17'), 'if rec.foo: return 17')
self.assertEqual(make_body('$foo\n# return $bar'), 'return rec.foo\n# return $bar')
# Test that formulas with a single string literal work, including multi-line string literals.
self.assertEqual(make_body('"test"'), 'return "test"')
self.assertEqual(make_body('("""test1\ntest2\ntest3""")'), 'return ("""test1\ntest2\ntest3""")')
self.assertEqual(make_body('"""test1\ntest2\ntest3"""'), 'return """test1\ntest2\ntest3"""')
self.assertEqual(make_body('"""test1\\ntest2\\ntest3"""'), 'return """test1\\ntest2\\ntest3"""')
# Same, with single quotes.
self.assertEqual(make_body("'test'"), "return 'test'")
self.assertEqual(make_body("('''test1\ntest2\ntest3''')"), "return ('''test1\ntest2\ntest3''')")
self.assertEqual(make_body("'''test1\ntest2\ntest3'''"), "return '''test1\ntest2\ntest3'''")
self.assertEqual(make_body("'''test1\\ntest2\\ntest3'''"), "return '''test1\\ntest2\\ntest3'''")
# And with mixing quotes
self.assertEqual(make_body("'''test1\"\"\" +\\\n \"\"\"test2'''"),
"return '''test1\"\"\" +\\\n \"\"\"test2'''")
self.assertEqual(make_body("'''test1''' +\\\n \"\"\"test2\"\"\""),
"return '''test1''' +\\\n \"\"\"test2\"\"\"")
self.assertEqual(make_body("'''test1\"\"\"\n\"\"\"test2'''"),
"return '''test1\"\"\"\n\"\"\"test2'''")
self.assertEqual(make_body("'''test1'''\n\"\"\"test2\"\"\""),
"'''test1'''\nreturn \"\"\"test2\"\"\"")
# Test that we produce valid code when "$foo" occurs in invalid places.
self.assertEqual(make_body('foo($bar=1)'),
"# foo($bar=1)\nraise SyntaxError('invalid syntax on line 1 col 5')")
self.assertEqual(make_body('def $bar(): pass'),
"# def $bar(): pass\nraise SyntaxError('invalid syntax on line 1 col 5')")
# If $ is a syntax error, we don't want to turn it into a different syntax error.
self.assertEqual(make_body('$foo + ("$%.2f" $ ($17.5))'),
'# $foo + ("$%.2f" $ ($17.5))\n'
"raise SyntaxError('invalid syntax on line 1 col 17')")
self.assertEqual(make_body('if $foo:\n' +
' return $foo\n' +
'else:\n' +
' return $ bar\n'),
'# if $foo:\n' +
'# return $foo\n' +
'# else:\n' +
'# return $ bar\n' +
"raise SyntaxError('invalid syntax on line 4 col 10')")
# Check for reasonable behaviour with non-empty text and no statements.
self.assertEqual(make_body('# comment'), '# comment\npass')
self.assertEqual(make_body('\\'), '\\\npass')
self.assertEqual(make_body('rec = 1'), "# rec = 1\n" +
"raise SyntaxError('Grist disallows assignment " +
"to the special variable \"rec\" on line 1 col 1')")
self.assertEqual(make_body('for rec in []: pass'), "# for rec in []: pass\n" +
"raise SyntaxError('Grist disallows assignment " +
"to the special variable \"rec\" on line 1 col 4')")
# some legitimates use of rec
body = ("""
foo = rec
rec.foo = 1
[rec for x in rec]
for a in rec:
t = a
[rec for x in rec]
return rec
""")
self.assertEqual(make_body(body), body)
# mostly legitimate use of rec but one failing
body = ("""
foo = rec
rec.foo = 1
[1 for rec in []]
for a in rec:
t = a
[rec for x in rec]
return rec
""")
self.assertRegexpMatches(make_body(body),
r"raise SyntaxError\('Grist disallows assignment" +
r" to the special variable \"rec\" on line 4 col 7'\)")
def test_make_formula_body_unicode(self):
# Test that we don't fail when strings include unicode characters
self.assertEqual(make_body("'résumé' + $foo"), u"return 'résumé' + rec.foo")
# Or when a unicode object is passed in, rather than a byte string
self.assertEqual(make_body(u"'résumé' + $foo"), u"return 'résumé' + rec.foo")
# Check the return type of make_body()
self.assertEqual(type(make_body("foo")), unicode)
self.assertEqual(type(make_body(u"foo")), unicode)
def test_wrap_logical(self):
self.assertEqual(make_body("IF($foo, $bar, $baz)"),
"return IF(rec.foo, lambda: (rec.bar), lambda: (rec.baz))")
self.assertEqual(make_body("return IF(FOO(x,y), BAR(x,y) * 2, BAZ(x,y) + 5)"),
"return IF(FOO(x,y), lambda: (BAR(x,y) * 2), lambda: (BAZ(x,y) + 5))")
self.assertEqual(make_body("""
y = $Test
x = IF( FOO(x,y) or 6,
BAR($x,y).blahh ,
Foo.lookupRecords(foo=$foo.bar,
bar=True
).baz
)
return x or y
"""), """
y = rec.Test
x = IF( FOO(x,y) or 6,
lambda: (BAR(rec.x,y).blahh) ,
lambda: (Foo.lookupRecords(foo=rec.foo.bar,
bar=True
).baz)
)
return x or y
""")
self.assertEqual(make_body("IF($A == 0, IF($B > 5, 'Test1'), IF($C < 10, 'Test2', 'Test3'))"),
"return IF(rec.A == 0, " +
"lambda: (IF(rec.B > 5, lambda: ('Test1'))), " +
"lambda: (IF(rec.C < 10, lambda: ('Test2'), lambda: ('Test3'))))"
)
def test_wrap_error(self):
self.assertEqual(make_body("ISERR($foo.bar)"), "return ISERR(lambda: (rec.foo.bar))")
self.assertEqual(make_body("ISERROR(1 / 0)"), "return ISERROR(lambda: (1 / 0))")
self.assertEqual(make_body("IFERROR($foo + #\n 1 / 0, 'XX')"),
"return IFERROR(lambda: (rec.foo + #\n 1 / 0), 'XX')")
# Check that extra parentheses are OK.
self.assertEqual(make_body("IFERROR((($foo + 1) / 0))"),
"return IFERROR((lambda: ((rec.foo + 1) / 0)))")
# Check that missing arguments is OK
self.assertEqual(make_body("ISERR()"), "return ISERR()")

@ -0,0 +1,453 @@
import logger
import testutil
import test_engine
from test_engine import Table, Column, View, Section, Field
log = logger.Logger(__name__, logger.INFO)
class TestColumnActions(test_engine.EngineTestCase):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Address", [
[21, "city", "Text", False, "", "", ""],
]]
],
"DATA": {
"Address": [
["id", "city" ],
[11, "New York" ],
[12, "Colombia" ],
[13, "New Haven" ],
[14, "West Haven" ]],
}
})
@test_engine.test_undo
def test_column_updates(self):
# Verify various automatic adjustments for column updates
# (1) that label gets synced to colId unless untieColIdFromLabel is set.
# (2) that unsetting untieColId syncs the label to colId.
# (3) that a complex BulkUpdateRecord for _grist_Tables_column is processed correctly.
self.load_sample(self.sample)
self.apply_user_action(["AddColumn", "Address", "foo", {"type": "Numeric"}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "city", "", "Text", False ],
[ 22, 1, "foo", "foo", "Numeric", False ],
])
# Check that label is synced to colId, via either ModifyColumn or UpdateRecord useraction.
self.apply_user_action(["ModifyColumn", "Address", "city", {"label": "Hello"}])
self.apply_user_action(["UpdateRecord", "_grist_Tables_column", 22, {"label": "World"}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Hello", "Hello", "Text", False ],
[ 22, 1, "World", "World", "Numeric", False ],
])
# But check that a rename or an update that includes colId is not affected by label.
self.apply_user_action(["RenameColumn", "Address", "Hello", "Hola"])
self.apply_user_action(["UpdateRecord", "_grist_Tables_column", 22,
{"label": "Foo", "colId": "Bar"}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Hola", "Hello", "Text", False ],
[ 22, 1, "Bar", "Foo", "Numeric", False ],
])
# Check that setting untieColIdFromLabel doesn't change anything immediately.
self.apply_user_action(["BulkUpdateRecord", "_grist_Tables_column", [21,22],
{"untieColIdFromLabel": [True, True]}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Hola", "Hello", "Text", True ],
[ 22, 1, "Bar", "Foo", "Numeric", True ],
])
# Check that ModifyColumn and UpdateRecord useractions no longer copy label to colId.
self.apply_user_action(["ModifyColumn", "Address", "Hola", {"label": "Hello"}])
self.apply_user_action(["UpdateRecord", "_grist_Tables_column", 22, {"label": "World"}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Hola", "Hello", "Text", True ],
[ 22, 1, "Bar", "World", "Numeric", True ],
])
# Check that unsetting untieColIdFromLabel syncs label, whether label is provided or not.
self.apply_user_action(["UpdateRecord", "_grist_Tables_column", 21,
{"untieColIdFromLabel": False, "label": "Alice"}])
self.apply_user_action(["UpdateRecord", "_grist_Tables_column", 22,
{"untieColIdFromLabel": False}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Alice", "Alice", "Text", False ],
[ 22, 1, "World", "World", "Numeric", False ],
])
# Check that column names still get sanitized and disambiguated.
self.apply_user_action(["UpdateRecord", "_grist_Tables_column", 21, {"label": "Alice M"}])
self.apply_user_action(["UpdateRecord", "_grist_Tables_column", 22, {"label": "Alice-M"}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Alice_M", "Alice M", "Text", False ],
[ 22, 1, "Alice_M2", "Alice-M", "Numeric", False ],
])
# Check that a column rename doesn't avoid its own name.
self.apply_user_action(["UpdateRecord", "_grist_Tables_column", 21, {"label": "Alice*M"}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Alice_M", "Alice*M", "Text", False ],
[ 22, 1, "Alice_M2", "Alice-M", "Numeric", False ],
])
# Untie colIds and tie them again, and make sure it doesn't cause unneeded renames.
self.apply_user_action(["BulkUpdateRecord", "_grist_Tables_column", [21,22],
{ "untieColIdFromLabel": [True, True] }])
self.apply_user_action(["BulkUpdateRecord", "_grist_Tables_column", [21,22],
{ "untieColIdFromLabel": [False, False] }])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Alice_M", "Alice*M", "Text", False ],
[ 22, 1, "Alice_M2", "Alice-M", "Numeric", False ],
])
# Check that disambiguating also works correctly for bulk updates.
self.apply_user_action(["BulkUpdateRecord", "_grist_Tables_column", [21,22],
{"label": ["Bob Z", "Bob-Z"]}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Bob_Z", "Bob Z", "Text", False ],
[ 22, 1, "Bob_Z2", "Bob-Z", "Numeric", False ],
])
# Same for changing colIds directly.
self.apply_user_action(["BulkUpdateRecord", "_grist_Tables_column", [21,22],
{"colId": ["Carol X", "Carol-X"]}])
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Carol_X", "Bob Z", "Text", False ],
[ 22, 1, "Carol_X2", "Bob-Z", "Numeric", False ],
])
# Check confusing bulk updates with different keys changing for different records.
out_actions = self.apply_user_action(["BulkUpdateRecord", "_grist_Tables_column", [21,22], {
"label": ["Bob Z", "Bob-Z"], # Unchanged from before.
"untieColIdFromLabel": [True, False]
}])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Address", "Carol_X2", "Bob_Z"],
["BulkUpdateRecord", "_grist_Tables_column", [21, 22],
{"colId": ["Carol_X", "Bob_Z"], # Note that only one column is changing.
"untieColIdFromLabel": [True, False]
# No update to label, they get trimmed as unchanged.
}
],
]})
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId", "label", "type", "untieColIdFromLabel" ],
[ 21, 1, "Carol_X", "Bob Z", "Text", True ],
[ 22, 1, "Bob_Z", "Bob-Z", "Numeric", False ],
])
#----------------------------------------------------------------------
address_table_data = [
["id", "city", "state", "amount" ],
[ 21, "New York", "NY" , 1. ],
[ 22, "Albany", "NY" , 2. ],
[ 23, "Seattle", "WA" , 3. ],
[ 24, "Chicago", "IL" , 4. ],
[ 25, "Bedford", "MA" , 5. ],
[ 26, "New York", "NY" , 6. ],
[ 27, "Buffalo", "NY" , 7. ],
[ 28, "Bedford", "NY" , 8. ],
[ 29, "Boston", "MA" , 9. ],
[ 30, "Yonkers", "NY" , 10. ],
[ 31, "New York", "NY" , 11. ],
]
sample2 = testutil.parse_test_sample({
"SCHEMA": [
[1, "Address", [
[11, "city", "Text", False, "", "", ""],
[12, "state", "Text", False, "", "", ""],
[13, "amount", "Numeric", False, "", "", ""],
]]
],
"DATA": {
"Address": address_table_data
}
})
def init_sample_data(self):
# Add a new view with a section, and a new table to that view, and a summary table.
self.load_sample(self.sample2)
self.apply_user_action(["CreateViewSection", 1, 0, "record", None])
self.apply_user_action(["CreateViewSection", 0, 1, "record", None])
self.apply_user_action(["CreateViewSection", 1, 1, "record", [12]])
self.apply_user_action(["BulkAddRecord", "Table1", [None]*3, {
"A": ["a", "b", "c"],
"B": ["d", "e", "f"],
"C": ["", "", ""]
}])
# Verify the new structure of tables and views.
self.assertTables([
Table(1, "Address", primaryViewId=0, summarySourceTable=0, columns=[
Column(11, "city", "Text", False, "", 0),
Column(12, "state", "Text", False, "", 0),
Column(13, "amount", "Numeric", False, "", 0),
]),
Table(2, "Table1", 2, 0, columns=[
Column(14, "manualSort", "ManualSortPos", False, "", 0),
Column(15, "A", "Text", False, "", 0),
Column(16, "B", "Text", False, "", 0),
Column(17, "C", "Text", False, "", 0),
]),
Table(3, "GristSummary_7_Address", 0, 1, columns=[
Column(18, "state", "Text", False, "", summarySourceCol=12),
Column(19, "group", "RefList:Address", True, summarySourceCol=0,
formula="table.getSummarySourceGroup(rec)"),
Column(20, "count", "Int", True, summarySourceCol=0, formula="len($group)"),
Column(21, "amount", "Numeric", True, summarySourceCol=0, formula="SUM($group.amount)"),
]),
])
self.assertViews([
View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(1, colRef=11),
Field(2, colRef=12),
Field(3, colRef=13),
]),
Section(3, parentKey="record", tableRef=2, fields=[
Field(7, colRef=15),
Field(8, colRef=16),
Field(9, colRef=17),
]),
Section(4, parentKey="record", tableRef=3, fields=[
Field(10, colRef=18),
Field(11, colRef=20),
Field(12, colRef=21),
]),
]),
View(2, sections=[
Section(2, parentKey="record", tableRef=2, fields=[
Field(4, colRef=15),
Field(5, colRef=16),
Field(6, colRef=17),
]),
])
])
self.assertTableData('Address', data=self.address_table_data)
self.assertTableData('Table1', data=[
["id", "A", "B", "C", "manualSort"],
[ 1, "a", "d", "", 1.0],
[ 2, "b", "e", "", 2.0],
[ 3, "c", "f", "", 3.0],
])
self.assertTableData("GristSummary_7_Address", cols="subset", data=[
[ "id", "state", "count", "amount" ],
[ 1, "NY", 7, 1.+2+6+7+8+10+11 ],
[ 2, "WA", 1, 3. ],
[ 3, "IL", 1, 4. ],
[ 4, "MA", 2, 5.+9 ],
])
#----------------------------------------------------------------------
@test_engine.test_undo
def test_column_removals(self):
# Verify removal of fields when columns are removed.
self.init_sample_data()
# Add link{Src,Target}ColRef to ViewSections. These aren't actually meaningful links, but they
# should still get cleared automatically when columns get removed.
self.apply_user_action(['UpdateRecord', '_grist_Views_section', 2, {
'linkSrcSectionRef': 1,
'linkSrcColRef': 11,
'linkTargetColRef': 16
}])
self.assertTableData('_grist_Views_section', cols="subset", rows="subset", data=[
["id", "linkSrcSectionRef", "linkSrcColRef", "linkTargetColRef"],
[2, 1, 11, 16 ],
])
# Test that we can remove multiple columns using BulkUpdateRecord.
self.apply_user_action(["BulkRemoveRecord", '_grist_Tables_column', [11, 16]])
# Test that link{Src,Target}colRef back-references get unset.
self.assertTableData('_grist_Views_section', cols="subset", rows="subset", data=[
["id", "linkSrcSectionRef", "linkSrcColRef", "linkTargetColRef"],
[2, 1, 0, 0 ],
])
# Test that columns and section fields got removed.
self.assertTables([
Table(1, "Address", primaryViewId=0, summarySourceTable=0, columns=[
Column(12, "state", "Text", False, "", 0),
Column(13, "amount", "Numeric", False, "", 0),
]),
Table(2, "Table1", 2, 0, columns=[
Column(14, "manualSort", "ManualSortPos", False, "", 0),
Column(15, "A", "Text", False, "", 0),
Column(17, "C", "Text", False, "", 0),
]),
Table(3, "GristSummary_7_Address", 0, 1, columns=[
Column(18, "state", "Text", False, "", summarySourceCol=12),
Column(19, "group", "RefList:Address", True, summarySourceCol=0,
formula="table.getSummarySourceGroup(rec)"),
Column(20, "count", "Int", True, summarySourceCol=0, formula="len($group)"),
Column(21, "amount", "Numeric", True, summarySourceCol=0, formula="SUM($group.amount)"),
]),
])
self.assertViews([
View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(2, colRef=12),
Field(3, colRef=13),
]),
Section(3, parentKey="record", tableRef=2, fields=[
Field(7, colRef=15),
Field(9, colRef=17),
]),
Section(4, parentKey="record", tableRef=3, fields=[
Field(10, colRef=18),
Field(11, colRef=20),
Field(12, colRef=21),
]),
]),
View(2, sections=[
Section(2, parentKey="record", tableRef=2, fields=[
Field(4, colRef=15),
Field(6, colRef=17),
]),
])
])
#----------------------------------------------------------------------
@test_engine.test_undo
def test_summary_column_removals(self):
# Verify that when we remove a column used for summary-table group-by, it updates summary
# tables appropriately.
self.init_sample_data()
# Test that we cannot remove group-by columns from summary tables directly.
with self.assertRaisesRegexp(ValueError, "cannot remove .* group-by"):
self.apply_user_action(["BulkRemoveRecord", '_grist_Tables_column', [20,18]])
# Test that group-by columns in summary tables get removed.
self.apply_user_action(["BulkRemoveRecord", '_grist_Tables_column', [11,12,16]])
# Verify the new structure of tables and views.
self.assertTables([
Table(1, "Address", primaryViewId=0, summarySourceTable=0, columns=[
Column(13, "amount", "Numeric", False, "", 0),
]),
Table(2, "Table1", 2, 0, columns=[
Column(14, "manualSort", "ManualSortPos", False, "", 0),
Column(15, "A", "Text", False, "", 0),
Column(17, "C", "Text", False, "", 0),
]),
# Note that the summary table here switches to a new one, without the deleted group-by.
Table(4, "GristSummary_7_Address2", 0, 1, columns=[
Column(22, "count", "Int", True, summarySourceCol=0, formula="len($group)"),
Column(23, "amount", "Numeric", True, summarySourceCol=0, formula="SUM($group.amount)"),
Column(24, "group", "RefList:Address", True, summarySourceCol=0,
formula="table.getSummarySourceGroup(rec)"),
]),
])
self.assertViews([
View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(3, colRef=13),
]),
Section(3, parentKey="record", tableRef=2, fields=[
Field(7, colRef=15),
Field(9, colRef=17),
]),
Section(4, parentKey="record", tableRef=4, fields=[
Field(11, colRef=22),
Field(12, colRef=23),
]),
]),
View(2, sections=[
Section(2, parentKey="record", tableRef=2, fields=[
Field(4, colRef=15),
Field(6, colRef=17),
]),
])
])
# Verify the data itself.
self.assertTableData('Address', data=[
["id", "amount" ],
[ 21, 1. ],
[ 22, 2. ],
[ 23, 3. ],
[ 24, 4. ],
[ 25, 5. ],
[ 26, 6. ],
[ 27, 7. ],
[ 28, 8. ],
[ 29, 9. ],
[ 30, 10. ],
[ 31, 11. ],
])
self.assertTableData('Table1', data=[
["id", "A", "C", "manualSort"],
[ 1, "a", "", 1.0],
[ 2, "b", "", 2.0],
[ 3, "c", "", 3.0],
])
self.assertTableData("GristSummary_7_Address2", cols="subset", data=[
[ "id", "count", "amount" ],
[ 1, 7+1+1+2, 1.+2+6+7+8+10+11+3+4+5+9 ],
])
#----------------------------------------------------------------------
@test_engine.test_undo
def test_column_sort_removals(self):
# Verify removal of sort spec entries when columns are removed.
self.init_sample_data()
# Add sortSpecs to ViewSections.
self.apply_user_action(['BulkUpdateRecord', '_grist_Views_section', [2, 3, 4],
{'sortColRefs': ['[15, -16]', '[-15, 16, 17]', '[19]']}
])
self.assertTableData('_grist_Views_section', cols="subset", rows="subset", data=[
["id", "sortColRefs" ],
[2, '[15, -16]' ],
[3, '[-15, 16, 17]'],
[4, '[19]' ],
])
# Remove column, and check that the correct sortColRefs items are removed.
self.apply_user_action(["RemoveRecord", '_grist_Tables_column', 16])
self.assertTableData('_grist_Views_section', cols="subset", rows="subset", data=[
["id", "sortColRefs"],
[2, '[15]' ],
[3, '[-15, 17]' ],
[4, '[19]' ],
])
# Update sortColRefs for next test.
self.apply_user_action(['UpdateRecord', '_grist_Views_section', 3,
{'sortColRefs': '[-15, -16, 17]'}
])
# Remove multiple columns using BulkUpdateRecord, and check that the sortSpecs are updated.
self.apply_user_action(["BulkRemoveRecord", '_grist_Tables_column', [15, 17, 19]])
self.assertTableData('_grist_Views_section', cols="subset", rows="subset", data=[
["id", "sortColRefs"],
[2, '[]' ],
[3, '[-16]' ],
[4, '[]' ],
])

@ -0,0 +1,98 @@
import testsamples
import testutil
import test_engine
class TestCompletion(test_engine.EngineTestCase):
def setUp(self):
super(TestCompletion, self).setUp()
self.load_sample(testsamples.sample_students)
# To test different column types, we add some differently-typed columns to the sample.
self.add_column('Students', 'school', type='Ref:Schools')
self.add_column('Students', 'birthDate', type='Date')
self.add_column('Students', 'lastVisit', type='DateTime:America/New_York')
self.add_column('Schools', 'yearFounded', type='Int')
self.add_column('Schools', 'budget', type='Numeric')
def test_keyword(self):
self.assertEqual(self.engine.autocomplete("for", "Address"),
["for", "format("])
def test_grist(self):
self.assertEqual(self.engine.autocomplete("gri", "Address"),
["grist"])
def test_function(self):
self.assertEqual(self.engine.autocomplete("MEDI", "Address"),
["MEDIAN("])
def test_member(self):
self.assertEqual(self.engine.autocomplete("datetime.tz", "Address"),
["datetime.tzinfo("])
def test_suggest_globals_and_tables(self):
# Should suggest globals and table names.
self.assertEqual(self.engine.autocomplete("ME", "Address"), ['MEDIAN('])
self.assertEqual(self.engine.autocomplete("Ad", "Address"), ['Address'])
self.assertGreaterEqual(set(self.engine.autocomplete("S", "Address")),
{'Schools', 'Students', 'SUM(', 'STDEV('})
self.assertEqual(self.engine.autocomplete("Addr", "Schools"), ['Address'])
def test_suggest_columns(self):
self.assertEqual(self.engine.autocomplete("$ci", "Address"),
["$city"])
self.assertEqual(self.engine.autocomplete("rec.i", "Address"),
["rec.id"])
self.assertEqual(len(self.engine.autocomplete("$", "Address")),
2)
# A few more detailed examples.
self.assertEqual(self.engine.autocomplete("$", "Students"),
['$birthDate', '$firstName', '$id', '$lastName', '$lastVisit',
'$school', '$schoolCities', '$schoolIds', '$schoolName'])
self.assertEqual(self.engine.autocomplete("$fi", "Students"), ['$firstName'])
self.assertEqual(self.engine.autocomplete("$school", "Students"),
['$school', '$schoolCities', '$schoolIds', '$schoolName'])
def test_suggest_lookup_methods(self):
# Should suggest lookup formulas for tables.
self.assertEqual(self.engine.autocomplete("Address.", "Students"),
['Address.all', 'Address.lookupOne(', 'Address.lookupRecords('])
self.assertEqual(self.engine.autocomplete("Address.lookup", "Students"),
['Address.lookupOne(', 'Address.lookupRecords('])
def test_suggest_column_type_methods(self):
# Should treat columns as correct types.
self.assertGreaterEqual(set(self.engine.autocomplete("$firstName.", "Students")),
{'$firstName.startswith(', '$firstName.replace(', '$firstName.title('})
self.assertGreaterEqual(set(self.engine.autocomplete("$birthDate.", "Students")),
{'$birthDate.month', '$birthDate.strftime(', '$birthDate.replace('})
self.assertGreaterEqual(set(self.engine.autocomplete("$lastVisit.m", "Students")),
{'$lastVisit.month', '$lastVisit.minute'})
self.assertGreaterEqual(set(self.engine.autocomplete("$school.", "Students")),
{'$school.address', '$school.name',
'$school.yearFounded', '$school.budget'})
self.assertEqual(self.engine.autocomplete("$school.year", "Students"),
['$school.yearFounded'])
self.assertGreaterEqual(set(self.engine.autocomplete("$yearFounded.", "Schools")),
{'$yearFounded.denominator', # Only integers have this
'$yearFounded.bit_length(', # and this
'$yearFounded.real'})
self.assertGreaterEqual(set(self.engine.autocomplete("$budget.", "Schools")),
{'$budget.is_integer(', # Only floats have this
'$budget.real'})
def test_suggest_follows_references(self):
# Should follow references and autocomplete those types.
self.assertEqual(self.engine.autocomplete("$school.name.st", "Students"),
['$school.name.startswith(', '$school.name.strip('])
self.assertGreaterEqual(set(self.engine.autocomplete("$school.yearFounded.", "Students")),
{'$school.yearFounded.denominator',
'$school.yearFounded.bit_length(',
'$school.yearFounded.real'})
self.assertEqual(self.engine.autocomplete("$school.address.", "Students"),
['$school.address.city', '$school.address.id'])
self.assertEqual(self.engine.autocomplete("$school.address.city.st", "Students"),
['$school.address.city.startswith(', '$school.address.city.strip('])

@ -0,0 +1,294 @@
import actions
import logger
import testutil
import test_engine
log = logger.Logger(__name__, logger.INFO)
def _bulk_update(table_name, col_names, row_data):
return actions.BulkUpdateRecord(
*testutil.table_data_from_rows(table_name, col_names, row_data))
class TestDerived(test_engine.EngineTestCase):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Customers", [
[1, "firstName", "Text", False, "", "", ""],
[2, "lastName", "Text", False, "", "", ""],
[3, "state", "Text", False, "", "", ""],
]],
[2, "Orders", [
[10, "year", "Int", False, "", "", ""],
[11, "customer", "Ref:Customers", False, "", "", ""],
[12, "product", "Text", False, "", "", ""],
[13, "amount", "Numeric", False, "", "", ""],
]],
],
"DATA": {
"Customers": [
["id", "firstName", "lastName", "state"],
[1, "Lois", "Long", "NY"],
[2, "Felix", "Myers", "NY"],
[3, "Grace", "Hawkins", "CT"],
[4, "Bessie", "Green", "NJ"],
[5, "Jerome", "Daniel", "CT"],
],
"Orders": [
["id", "year", "customer", "product", "amount" ],
[1, 2012, 3, "A", 15 ],
[2, 2013, 2, "A", 15 ],
[3, 2013, 3, "A", 15 ],
[4, 2014, 1, "B", 35 ],
[5, 2014, 5, "B", 35 ],
[6, 2014, 3, "A", 16 ],
[7, 2015, 1, "A", 17 ],
[8, 2015, 2, "B", 36 ],
[9, 2015, 3, "B", 36 ],
[10, 2015, 5, "A", 17 ],
]
}
})
def test_group_by_one(self):
"""
Test basic summary table operation, for a table grouped by one columns.
"""
self.load_sample(self.sample)
# Create a derived table summarizing count and total of orders by year.
self.apply_user_action(["CreateViewSection", 2, 0, 'record', [10]])
# Check the results.
self.assertPartialData("GristSummary_6_Orders", ["id", "year", "count", "amount", "group" ], [
[1, 2012, 1, 15, [1]],
[2, 2013, 2, 30, [2,3]],
[3, 2014, 3, 86, [4,5,6]],
[4, 2015, 4, 106, [7,8,9,10]],
])
# Updating amounts should cause totals to be updated in the summary.
out_actions = self.update_records("Orders", ["id", "amount"], [
[1, 14],
[2, 14]
])
self.assertPartialOutActions(out_actions, {
"calc": [actions.BulkUpdateRecord("GristSummary_6_Orders", [1,2], {'amount': [14, 29]})],
"calls": {"GristSummary_6_Orders": {"amount": 2}}
})
# Changing a record from one product to another should cause the two affected lines to change.
out_actions = self.update_record("Orders", 10, year=2012)
self.assertPartialOutActions(out_actions, {
"stored": [actions.UpdateRecord("Orders", 10, {"year": 2012})],
"calc": [
actions.BulkUpdateRecord("GristSummary_6_Orders", [1,4], {"group": [[1,10], [7,8,9]]}),
actions.BulkUpdateRecord("GristSummary_6_Orders", [1,4], {"amount": [31.0, 89.0]}),
actions.BulkUpdateRecord("GristSummary_6_Orders", [1,4], {"count": [2,3]}),
],
"calls": {"GristSummary_6_Orders": {"group": 2, "amount": 2, "count": 2},
"Orders": {"#lookup##summary#GristSummary_6_Orders": 1,
"#summary#GristSummary_6_Orders": 1}}
})
self.assertPartialData("GristSummary_6_Orders", ["id", "year", "count", "amount", "group" ], [
[1, 2012, 2, 31.0, [1,10]],
[2, 2013, 2, 29.0, [2,3]],
[3, 2014, 3, 86.0, [4,5,6]],
[4, 2015, 3, 89.0, [7,8,9]],
])
# Changing a record to a new year that wasn't in the summary should cause an add-record.
out_actions = self.update_record("Orders", 10, year=1999)
self.assertPartialOutActions(out_actions, {
"stored": [
actions.UpdateRecord("Orders", 10, {"year": 1999}),
actions.AddRecord("GristSummary_6_Orders", 5, {'year': 1999}),
],
"calc": [
actions.BulkUpdateRecord("GristSummary_6_Orders", [1,5], {"group": [[1], [10]]}),
actions.BulkUpdateRecord("GristSummary_6_Orders", [1,5], {"amount": [14.0, 17.0]}),
actions.BulkUpdateRecord("GristSummary_6_Orders", [1,5], {"count": [1,1]}),
],
"calls": {
"GristSummary_6_Orders": {'#lookup#year': 1, "group": 2, "amount": 2, "count": 2},
"Orders": {"#lookup##summary#GristSummary_6_Orders": 2,
"#summary#GristSummary_6_Orders": 2}}
})
self.assertPartialData("GristSummary_6_Orders", ["id", "year", "count", "amount", "group" ], [
[1, 2012, 1, 14.0, [1]],
[2, 2013, 2, 29.0, [2,3]],
[3, 2014, 3, 86.0, [4,5,6]],
[4, 2015, 3, 89.0, [7,8,9]],
[5, 1999, 1, 17.0, [10]],
])
def test_group_by_two(self):
"""
Test a summary table created by grouping on two columns.
"""
self.load_sample(self.sample)
self.apply_user_action(["CreateViewSection", 2, 0, 'record', [10, 12]])
self.assertPartialData("GristSummary_6_Orders", [
"id", "year", "product", "count", "amount", "group"
], [
[1, 2012, "A", 1, 15.0, [1]],
[2, 2013, "A", 2, 30.0, [2,3]],
[3, 2014, "B", 2, 70.0, [4,5]],
[4, 2014, "A", 1, 16.0, [6]],
[5, 2015, "A", 2, 34.0, [7,10]],
[6, 2015, "B", 2, 72.0, [8,9]],
])
# Changing a record from one product to another should cause the two affected lines to change,
# or new lines to be created as needed.
out_actions = self.update_records("Orders", ["id", "product"], [
[2, "B"],
[6, "B"],
[7, "C"],
])
self.assertPartialOutActions(out_actions, {
"stored": [
actions.BulkUpdateRecord("Orders", [2, 6, 7], {"product": ["B", "B", "C"]}),
actions.AddRecord("GristSummary_6_Orders", 7, {'year': 2013, 'product': 'B'}),
actions.AddRecord("GristSummary_6_Orders", 8, {'year': 2015, 'product': 'C'}),
],
"calc": [
actions.BulkUpdateRecord("GristSummary_6_Orders", [2,3,4,5,7,8], {
"group": [[3], [4,5,6], [], [10], [2], [7]]
}),
actions.BulkUpdateRecord("GristSummary_6_Orders", [2,3,4,5,7,8], {
"amount": [15.0, 86.0, 0, 17.0, 15.0, 17.0]
}),
actions.BulkUpdateRecord("GristSummary_6_Orders", [2,3,4,5,7,8], {
"count": [1, 3, 0, 1, 1, 1]
}),
],
})
# Verify the results.
self.assertPartialData("GristSummary_6_Orders", [
"id", "year", "product", "count", "amount", "group"
], [
[1, 2012, "A", 1, 15.0, [1]],
[2, 2013, "A", 1, 15.0, [3]],
[3, 2014, "B", 3, 86.0, [4,5,6]],
[4, 2014, "A", 0, 0.0, []],
[5, 2015, "A", 1, 17.0, [10]],
[6, 2015, "B", 2, 72.0, [8,9]],
[7, 2013, "B", 1, 15.0, [2]],
[8, 2015, "C", 1, 17.0, [7]],
])
def test_group_with_references(self):
"""
Test summary tables grouped on indirect values. In this example we want for each
customer.state, the number of customers and the total of their orders, which we can do either
as a summary on the Customers table, or a summary on the Orders table.
"""
self.load_sample(self.sample)
# Create a summary on the Customers table. Adding orders involves a lookup for each customer.
self.apply_user_action(["CreateViewSection", 1, 0, 'record', [3]])
self.add_column("GristSummary_9_Customers", "totalAmount",
formula="sum(sum(Orders.lookupRecords(customer=c).amount) for c in $group)")
self.assertPartialData("GristSummary_9_Customers", ["id", "state", "count", "totalAmount"], [
[1, "NY", 2, 103.0 ],
[2, "CT", 2, 134.0 ],
[3, "NJ", 1, 0.0 ],
])
# # Create the same summary on the Orders table, looking up 'state' via the Customer reference.
# self.apply_user_action(["AddDerivedTableSource", "Summary4", "Orders",
# {"state": "$customer.state"}])
# self.add_column("Summary4", "numCustomers", formula="len(set($source_Orders.customer))")
# self.add_column("Summary4", "totalAmount", formula="sum($source_Orders.amount)")
# self.assertPartialData("Summary4", ["id", "state", "numCustomers", "totalAmount"], [
# [1, "CT", 2, 134.0 ],
# [2, "NY", 2, 103.0 ],
# ])
# In either case, changing an amount (from 36->37 for a CT customer) should update summaries.
out_actions = self.update_record('Orders', 9, amount=37)
self.assertPartialOutActions(out_actions, {
"calc": [actions.UpdateRecord("GristSummary_9_Customers", 2, {"totalAmount": 135.0})]
})
# In either case, changing a customer's state should trigger recomputation too.
# We are changing a NY customer with $51 in orders to MA.
self.update_record('Customers', 2, state="MA")
self.assertPartialData("GristSummary_9_Customers", ["id", "state", "count", "totalAmount"], [
[1, "NY", 1, 52.0 ],
[2, "CT", 2, 135.0 ],
[3, "NJ", 1, 0.0 ],
[4, "MA", 1, 51.0 ],
])
# self.assertPartialData("Summary4", ["id", "state", "numCustomers", "totalAmount"], [
# [1, "CT", 2, 135.0 ],
# [2, "NY", 1, 52.0 ],
# [3, "MA", 1, 51.0 ],
# ])
# Similarly, changing an Order to refer to a different customer should update both tables.
# Here we are changing a $17 order (#7) for a NY customer (#1) to a NJ customer (#4).
out_actions = self.update_record("Orders", 7, customer=4)
# self.assertPartialOutActions(out_actions, {
# "stored": [actions.UpdateRecord("Orders", 7, {"customer": 4}),
# actions.AddRecord("Summary4", 4, {"state": "NJ"}),
# actions.UpdateRecord("Summary4", 4, {"manualSort": 4.0})]
# })
self.assertPartialData("GristSummary_9_Customers", ["id", "state", "count", "totalAmount"], [
[1, "NY", 1, 35.0 ],
[2, "CT", 2, 135.0 ],
[3, "NJ", 1, 17.0 ],
[4, "MA", 1, 51.0 ],
])
# self.assertPartialData("Summary4", ["id", "state", "numCustomers", "totalAmount"], [
# [1, "CT", 2, 135.0 ],
# [2, "NY", 1, 35.0 ],
# [3, "MA", 1, 51.0 ],
# [4, "NJ", 1, 17.0 ],
# ])
def test_deletions(self):
self.load_sample(self.sample)
# Create a summary table summarizing count and total of orders by year.
self.apply_user_action(["CreateViewSection", 2, 0, 'record', [10]])
self.assertPartialData("GristSummary_6_Orders", ["id", "year", "count", "amount", "group" ], [
[1, 2012, 1, 15.0, [1]],
[2, 2013, 2, 30.0, [2,3]],
[3, 2014, 3, 86.0, [4,5,6]],
[4, 2015, 4, 106.0, [7,8,9,10]],
])
# Update a record so that a new line appears in the summary table.
out_actions_update = self.update_record("Orders", 1, year=2007)
self.assertPartialData("GristSummary_6_Orders", ["id", "year", "count", "amount", "group" ], [
[1, 2012, 0, 0.0, []],
[2, 2013, 2, 30.0, [2,3]],
[3, 2014, 3, 86.0, [4,5,6]],
[4, 2015, 4, 106.0, [7,8,9,10]],
[5, 2007, 1, 15.0, [1]],
])
# Undo and ensure that the new line is gone from the summary table.
out_actions_undo = self.apply_undo_actions(out_actions_update.undo)
self.assertPartialData("GristSummary_6_Orders", ["id", "year", "count", "amount", "group" ], [
[1, 2012, 1, 15.0, [1]],
[2, 2013, 2, 30.0, [2,3]],
[3, 2014, 3, 86.0, [4,5,6]],
[4, 2015, 4, 106.0, [7,8,9,10]],
])
self.assertPartialOutActions(out_actions_undo, {
"stored": [actions.RemoveRecord("GristSummary_6_Orders", 5),
actions.UpdateRecord("Orders", 1, {"year": 2012})],
"calls": {"GristSummary_6_Orders": {"group": 1, "amount": 1, "count": 1},
"Orders": {"#lookup##summary#GristSummary_6_Orders": 2,
"#summary#GristSummary_6_Orders": 2}}
})

@ -0,0 +1,492 @@
import logger
import testutil
import test_engine
from test_engine import Table, Column
log = logger.Logger(__name__, logger.INFO)
class TestUserActions(test_engine.EngineTestCase):
ref_sample = testutil.parse_test_sample({
# pylint: disable=line-too-long
"SCHEMA": [
[1, "Television", [
[21, "show", "Text", False, "", "", ""],
[22, "network", "Text", False, "", "", ""],
[23, "viewers", "Int", False, "", "", ""]
]]
],
"DATA": {
"Television": [
["id", "show" , "network", "viewers"],
[11, "Game of Thrones", "HBO" , 100],
[12, "Narcos" , "Netflix", 500],
[13, "Today" , "NBC" , 200],
[14, "Empire" , "Fox" , 300]],
}
})
def test_display_cols(self):
# Test the implementation of display columns which adds a column modified by
# a formula as a display version of the original column.
self.load_sample(self.ref_sample)
# Add a new table for People so that we get the associated views and fields.
self.apply_user_action(['AddTable', 'Favorites', [{'id': 'favorite', 'type':
'Ref:Television'}]])
self.apply_user_action(['BulkAddRecord', 'Favorites', [1,2,3,4,5], {
'favorite': [2, 4, 1, 4, 3]
}])
self.assertTables([
Table(1, "Television", 0, 0, columns=[
Column(21, "show", "Text", False, "", 0),
Column(22, "network", "Text", False, "", 0),
Column(23, "viewers", "Int", False, "", 0),
]),
Table(2, "Favorites", 1, 0, columns=[
Column(24, "manualSort", "ManualSortPos", False, "", 0),
Column(25, "favorite", "Ref:Television", False, "", 0),
]),
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 0],
])
self.assertTableData("Favorites", cols="subset", data=[
["id", "favorite"],
[1, 2],
[2, 4],
[3, 1],
[4, 4],
[5, 3]
])
# Add an extra view for the new table to test multiple fields at once
self.apply_user_action(['AddView', 'Favorites', 'raw_data', 'Extra View'])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 0],
[2, 25, 0]
])
# Set display formula for 'favorite' column.
# A "gristHelper_Display" column with the requested formula should be added and set as the
# displayCol of the favorite column.
self.apply_user_action(['SetDisplayFormula', 'Favorites', None, 25, '$favorite.show'])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[25, "favorite", 2, 26, ""],
[26, "gristHelper_Display", 2, 0, "$favorite.show"]
])
# Set display formula for 'favorite' column fields.
# A single "gristHelper_Display2" column should be added with the requested formula, since both
# require the same formula. The fields' colRefs should be set to the new column.
self.apply_user_action(['SetDisplayFormula', 'Favorites', 1, None, '$favorite.network'])
self.apply_user_action(['SetDisplayFormula', 'Favorites', 2, None, '$favorite.network'])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[25, "favorite", 2, 26, ""],
[26, "gristHelper_Display", 2, 0, "$favorite.show"],
[27, "gristHelper_Display2", 2, 0, "$favorite.network"],
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 27],
[2, 25, 27]
])
# Change display formula for a field.
# Since the field is changing to use a formula not yet held by a display column,
# a new display column should be added with the desired formula.
self.apply_user_action(['SetDisplayFormula', 'Favorites', 2, None, '$favorite.viewers'])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[25, "favorite", 2, 26, ""],
[26, "gristHelper_Display", 2, 0, "$favorite.show"],
[27, "gristHelper_Display2", 2, 0, "$favorite.network"],
[28, "gristHelper_Display3", 2, 0, "$favorite.viewers"]
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 27],
[2, 25, 28]
])
# Remove a field.
# This should also remove the display column used by that field, since it is not used
# by any other fields.
self.apply_user_action(['RemoveRecord', '_grist_Views_section_field', 2])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[25, "favorite", 2, 26, ""],
[26, "gristHelper_Display", 2, 0, "$favorite.show"],
[27, "gristHelper_Display2", 2, 0, "$favorite.network"],
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 27]
])
# Add a new column with a formula.
self.apply_user_action(['AddColumn', 'Favorites', 'fav_viewers', {
'formula': '$favorite.viewers'
}])
# Add a field back for the favorites table and set its display formula to the
# same formula that the new column has. Make sure that the new column is NOT used as
# the display column.
self.apply_user_action(['AddRecord', '_grist_Views_section_field', None, {
'parentId': 2,
'colRef': 25
}])
self.apply_user_action(['SetDisplayFormula', 'Favorites', 4, None, '$favorite.viewers'])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[25, "favorite", 2, 26, ""],
[26, "gristHelper_Display", 2, 0, "$favorite.show"],
[27, "gristHelper_Display2", 2, 0, "$favorite.network"],
[28, "fav_viewers", 2, 0, "$favorite.viewers"],
[29, "gristHelper_Display3", 2, 0, "$favorite.viewers"]
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 27],
[2, 28, 0], # fav_viewers field
[3, 28, 0], # fav_viewers field
[4, 25, 29] # re-added field w/ display col
])
# Change the display formula for a field to be the same as the other field, then remove
# the field.
# The display column should not be removed since it is still in use.
self.apply_user_action(['SetDisplayFormula', 'Favorites', 4, None, '$favorite.network'])
self.apply_user_action(['RemoveRecord', '_grist_Views_section_field', 4])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[25, "favorite", 2, 26, ""],
[26, "gristHelper_Display", 2, 0, "$favorite.show"],
[27, "gristHelper_Display2",2, 0, "$favorite.network"],
[28, "fav_viewers", 2, 0, "$favorite.viewers"],
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 27],
[2, 28, 0],
[3, 28, 0],
])
# Clear field display formula, then set it again.
# Clearing the display formula should remove the display column, since it is no longer
# used by any column or field.
self.apply_user_action(['SetDisplayFormula', 'Favorites', 1, None, ''])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[25, "favorite", 2, 26, ""],
[26, "gristHelper_Display", 2, 0, "$favorite.show"],
[28, "fav_viewers", 2, 0, "$favorite.viewers"],
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 0],
[2, 28, 0],
[3, 28, 0]
])
# Setting the display formula should add another display column.
self.apply_user_action(['SetDisplayFormula', 'Favorites', 1, None, '$favorite.viewers'])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[25, "favorite", 2, 26, ""],
[26, "gristHelper_Display", 2, 0, "$favorite.show"],
[28, "fav_viewers", 2, 0, "$favorite.viewers"],
[29, "gristHelper_Display2",2, 0, "$favorite.viewers"],
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 29],
[2, 28, 0],
[3, 28, 0]
])
# Change column display formula.
# This should re-use the current display column since it is only used by the column.
self.apply_user_action(['SetDisplayFormula', 'Favorites', None, 25, '$favorite.network'])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[25, "favorite", 2, 26, ""],
[26, "gristHelper_Display",2, 0, "$favorite.network"],
[28, "fav_viewers", 2, 0, "$favorite.viewers"],
[29, "gristHelper_Display2",2, 0, "$favorite.viewers"],
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 29],
[2, 28, 0],
[3, 28, 0]
])
# Remove column.
# This should remove the display column used by the column.
self.apply_user_action(['RemoveColumn', "Favorites", "favorite"])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id >= 25), data=[
["id", "colId", "parentId", "displayCol", "formula"],
[28, "fav_viewers", 2, 0, "$favorite.viewers"]
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[2, 28, 0],
[3, 28, 0]
])
def test_display_col_removal(self):
# Test that when removing a column, we don't produce unnecessary calc actions for a display
# column that may also get auto-removed.
self.load_sample(self.ref_sample)
# Create a display column.
self.apply_user_action(['SetDisplayFormula', 'Television', None, 21, '$show.upper()'])
# Verify the state of columns and display columns.
self.assertTableData("_grist_Tables_column", cols="subset", data=[
["id", "colId", "type", "displayCol", "formula" ],
[21, "show", "Text", 24 , "" ],
[22, "network", "Text", 0 , "" ],
[23, "viewers", "Int", 0 , "" ],
[24, "gristHelper_Display", "Any", 0 , "$show.upper()"]
])
self.assertTableData("Television", cols="all", data=[
["id", "show" , "network", "viewers", "gristHelper_Display"],
[11, "Game of Thrones", "HBO" , 100, "GAME OF THRONES"],
[12, "Narcos" , "Netflix", 500, "NARCOS"],
[13, "Today" , "NBC" , 200, "TODAY"],
[14, "Empire" , "Fox" , 300, "EMPIRE"],
])
# Remove the column that has a displayCol referring to it.
out_actions = self.apply_user_action(['RemoveColumn', 'Television', 'show'])
# Verify that the resulting actions don't include any calc actions.
self.assertPartialOutActions(out_actions, {
"stored": [
["BulkRemoveRecord", "_grist_Tables_column", [21, 24]],
["RemoveColumn", "Television", "show"],
["RemoveColumn", "Television", "gristHelper_Display"],
],
"calc": []
})
# Verify the state of columns and display columns afterwards.
self.assertTableData("_grist_Tables_column", cols="subset", data=[
["id", "colId", "type", "displayCol", "formula" ],
[22, "network", "Text", 0 , "" ],
[23, "viewers", "Int", 0 , "" ],
])
self.assertTableData("Television", cols="all", data=[
["id", "network", "viewers" ],
[11, "HBO" , 100 ],
[12, "Netflix", 500 ],
[13, "NBC" , 200 ],
[14, "Fox" , 300 ],
])
def test_display_col_copying(self):
# Test that when switching types and using CopyFromColumn, displayCol is set/unset correctly.
self.load_sample(self.ref_sample)
# Add a new table for People so that we get the associated views and fields.
self.apply_user_action(['AddTable', 'Favorites', [
{'id': 'favorite', 'type': 'Ref:Television'},
{'id': 'favorite2', 'type': 'Text'}]])
self.apply_user_action(['BulkAddRecord', 'Favorites', [1,2,3,4,5], {
'favorite': [2, 4, 1, 4, 3]
}])
# Set a displayCol.
self.apply_user_action(['SetDisplayFormula', 'Favorites', None, 25, '$favorite.show'])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id > 24), data=[
["id" , "colId" , "parentId", "displayCol", "type", "formula"],
[25 , "favorite" , 2 , 27 , "Ref:Television", ""],
[26 , "favorite2" , 2 , 0 , "Text", ""],
[27 , "gristHelper_Display" , 2 , 0 , "Any", "$favorite.show"],
])
# Copy 'favorite' to 'favorite2': displayCol should be set on the latter.
self.apply_user_action(['CopyFromColumn', 'Favorites', 'favorite', 'favorite2', None])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id > 24), data=[
["id" , "colId" , "parentId", "displayCol", "type", "formula"],
[25 , "favorite" , 2 , 27 , "Ref:Television", ""],
[26 , "favorite2" , 2 , 28 , "Ref:Television", ""],
[27 , "gristHelper_Display" , 2 , 0 , "Any", "$favorite.show"],
[28 , "gristHelper_Display2", 2 , 0 , "Any", "$favorite2.show"],
])
# SetDisplyFormula to a different formula: displayCol should get reused.
self.apply_user_action(['SetDisplayFormula', 'Favorites', None, 25, '$favorite.network'])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id > 24), data=[
["id" , "colId" , "parentId", "displayCol", "type", "formula"],
[25 , "favorite" , 2 , 27 , "Ref:Television", ""],
[26 , "favorite2" , 2 , 28 , "Ref:Television", ""],
[27 , "gristHelper_Display" , 2 , 0 , "Any", "$favorite.network"],
[28 , "gristHelper_Display2", 2 , 0 , "Any", "$favorite2.show"],
])
# Copy again; the destination displayCol should get adjusted but reused.
self.apply_user_action(['CopyFromColumn', 'Favorites', 'favorite', 'favorite2', None])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id > 24), data=[
["id" , "colId" , "parentId", "displayCol", "type", "formula"],
[25 , "favorite" , 2 , 27 , "Ref:Television", ""],
[26 , "favorite2" , 2 , 28 , "Ref:Television", ""],
[27 , "gristHelper_Display" , 2 , 0 , "Any", "$favorite.network"],
[28 , "gristHelper_Display2", 2 , 0 , "Any", "$favorite2.network"],
])
# If we change column type, the displayCol should get unset and deleted.
out_actions = self.apply_user_action(['ModifyColumn', 'Favorites', 'favorite',
{'type': 'Numeric'}])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id > 24), data=[
["id" , "colId" , "parentId", "displayCol", "type", "formula"],
[25 , "favorite" , 2 , 0 , "Numeric", ""],
[26 , "favorite2" , 2 , 28 , "Ref:Television", ""],
[28 , "gristHelper_Display2", 2 , 0 , "Any", "$favorite2.network"],
])
# Copy again; the destination displayCol should now get deleted too.
self.apply_user_action(['CopyFromColumn', 'Favorites', 'favorite', 'favorite2', None])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.id > 24), data=[
["id" , "colId" , "parentId", "displayCol", "type", "formula"],
[25 , "favorite" , 2 , 0 , "Numeric", ""],
[26 , "favorite2" , 2 , 0 , "Numeric", ""],
])
def test_display_col_table_rename(self):
self.load_sample(self.ref_sample)
# Add a table for people to get an associated view.
self.apply_user_action(['AddTable', 'People', [
{'id': 'name', 'type': 'Text'},
{'id': 'favorite', 'type': 'Ref:Television',
'widgetOptions': '\"{\"alignment\":\"center\",\"visibleCol\":\"show\"}\"'},
{'id': 'network', 'type': 'Any', 'isFormula': True,
'formula': 'Television.lookupOne(show=rec.favorite.show).network'}]])
self.apply_user_action(['BulkAddRecord', 'People', [1,2,3], {
'name': ['Bob', 'Jim', 'Don'],
'favorite': [12, 11, 13]
}])
# Add a display formula for the 'favorite' column.
# A "gristHelper_Display" column with the requested formula should be added and set as the
# displayCol of the favorite column.
self.apply_user_action(['SetDisplayFormula', 'People', None, 26, '$favorite.show'])
# Set display formula for 'favorite' column field.
# A single "gristHelper_Display2" column should be added with the requested formula.
self.apply_user_action(['SetDisplayFormula', 'People', 1, None, '$favorite.network'])
# Check that the tables are set up as expected.
self.assertTables([
Table(1, "Television", 0, 0, columns=[
Column(21, "show", "Text", False, "", 0),
Column(22, "network", "Text", False, "", 0),
Column(23, "viewers", "Int", False, "", 0),
]),
Table(2, "People", 1, 0, columns=[
Column(24, "manualSort", "ManualSortPos", False, "", 0),
Column(25, "name", "Text", False, "", 0),
Column(26, "favorite", "Ref:Television", False, "", 0),
Column(27, "network", "Any", True,
"Television.lookupOne(show=rec.favorite.show).network", 0),
Column(28, "gristHelper_Display", "Any", True, "$favorite.show", 0),
Column(29, "gristHelper_Display2", "Any", True, "$favorite.network", 0)
]),
])
self.assertTableData("People", cols="subset", data=[
["id", "name", "favorite", "network"],
[1, "Bob", 12, "Netflix"],
[2, "Jim", 11, "HBO"],
[3, "Don", 13, "NBC"]
])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.parentId.id == 2),
data=[
["id", "colId", "parentId", "displayCol", "formula"],
[24, "manualSort", 2, 0, ""],
[25, "name", 2, 0, ""],
[26, "favorite", 2, 28, ""],
[27, "network", 2, 0,
"Television.lookupOne(show=rec.favorite.show).network"],
[28, "gristHelper_Display", 2, 0, "$favorite.show"],
[29, "gristHelper_Display2", 2, 0, "$favorite.network"]
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 29],
[2, 26, 0],
[3, 27, 0]
])
# Rename the referenced table.
out_actions = self.apply_user_action(['RenameTable', 'Television', 'Television2'])
# Verify the resulting actions.
# This tests a bug fix where table renames would cause widgetOptions and displayCols
# of columns referencing the renamed table to be unset. See https://phab.getgrist.com/T206.
# Ensure that no actions are generated to unset the widgetOptions and the displayCols of the
# field or column.
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "People", "favorite", {"type": "Int"}],
["RenameTable", "Television", "Television2"],
["UpdateRecord", "_grist_Tables", 1, {"tableId": "Television2"}],
["ModifyColumn", "People", "favorite", {"type": "Ref:Television2"}],
["ModifyColumn", "People", "network",
{"formula": "Television2.lookupOne(show=rec.favorite.show).network"}],
["BulkUpdateRecord", "_grist_Tables_column", [26, 27], {
"formula": ["", "Television2.lookupOne(show=rec.favorite.show).network"],
"type": ["Ref:Television2", "Any"]
}]
],
"calc": []
})
# Verify that the tables have responded as expected to the change.
self.assertTables([
Table(1, "Television2", 0, 0, columns=[
Column(21, "show", "Text", False, "", 0),
Column(22, "network", "Text", False, "", 0),
Column(23, "viewers", "Int", False, "", 0),
]),
Table(2, "People", 1, 0, columns=[
Column(24, "manualSort", "ManualSortPos", False, "", 0),
Column(25, "name", "Text", False, "", 0),
Column(26, "favorite", "Ref:Television2", False, "", 0),
Column(27, "network", "Any", True,
"Television2.lookupOne(show=rec.favorite.show).network", 0),
Column(28, "gristHelper_Display", "Any", True, "$favorite.show", 0),
Column(29, "gristHelper_Display2", "Any", True, "$favorite.network", 0)
]),
])
self.assertTableData("People", cols="subset", data=[
["id", "name", "favorite", "network"],
[1, "Bob", 12, "Netflix"],
[2, "Jim", 11, "HBO"],
[3, "Don", 13, "NBC"]
])
self.assertTableData("_grist_Tables_column", cols="subset", rows=(lambda r: r.parentId.id == 2),
data=[
["id", "colId", "parentId", "displayCol", "formula"],
[24, "manualSort", 2, 0, ""],
[25, "name", 2, 0, ""],
[26, "favorite", 2, 28, ""],
[27, "network", 2, 0,
"Television2.lookupOne(show=rec.favorite.show).network"],
[28, "gristHelper_Display", 2, 0, "$favorite.show"],
[29, "gristHelper_Display2", 2, 0, "$favorite.network"]
])
self.assertTableData("_grist_Views_section_field", cols="subset", data=[
["id", "colRef", "displayCol"],
[1, 25, 29],
[2, 26, 0],
[3, 27, 0]
])

@ -0,0 +1,249 @@
import actions
import logger
import testsamples
import test_engine
from test_engine import Table, Column
log = logger.Logger(__name__, logger.INFO)
class TestDocModel(test_engine.EngineTestCase):
def test_meta_tables(self):
"""
Test changes to records accessed via lookup.
"""
self.load_sample(testsamples.sample_students)
self.assertPartialData("_grist_Tables", ["id", "columns"], [
[1, [1,2,4,5,6]],
[2, [10,12]],
[3, [21]],
])
# Test that adding a column produces a change to 'columns' without emitting an action.
out_actions = self.add_column('Students', 'test', type='Text', isFormula=False)
self.assertPartialData("_grist_Tables", ["id", "columns"], [
[1, [1,2,4,5,6,22]],
[2, [10,12]],
[3, [21]],
])
self.assertPartialOutActions(out_actions, {
"calc": [],
"stored": [
["AddColumn", "Students", "test",
{"formula": "", "isFormula": False, "type": "Text"}
],
["AddRecord", "_grist_Tables_column", 22,
{"colId": "test", "formula": "", "isFormula": False, "label": "test",
"parentId": 1, "parentPos": 6.0, "type": "Text", "widgetOptions": ""}
],
],
"undo": [
["RemoveColumn", "Students", "test"],
["RemoveRecord", "_grist_Tables_column", 22],
]
})
# Undo the AddColumn action. Check that actions are in correct order, and still produce undos.
out_actions = self.apply_user_action(
['ApplyUndoActions', [actions.get_action_repr(a) for a in out_actions.undo]])
self.assertPartialOutActions(out_actions, {
"calc": [],
"stored": [
["RemoveRecord", "_grist_Tables_column", 22],
["RemoveColumn", "Students", "test"],
],
"undo": [
["AddRecord", "_grist_Tables_column", 22, {"colId": "test", "label": "test",
"parentId": 1, "parentPos": 6.0, "type": "Text"}],
["AddColumn", "Students", "test", {"formula": "", "isFormula": False, "type": "Text"}],
]
})
# Test that when we add a table, .column is set correctly.
out_actions = self.apply_user_action(['AddTable', 'Test2', [
{'id': 'A', 'type': 'Text'},
{'id': 'B', 'type': 'Numeric'},
{'id': 'C', 'type': 'Numeric', 'formula': 'len($A)', 'isFormula': True}
]])
self.assertPartialData("_grist_Tables", ["id", "columns"], [
[1, [1,2,4,5,6]],
[2, [10,12]],
[3, [21]],
[4, [22,23,24,25]],
])
self.assertPartialData("_grist_Tables_column", ["id", "colId", "parentId"], [
[1, "firstName", 1],
[2, "lastName", 1],
[4, "schoolName", 1],
[5, "schoolIds", 1],
[6, "schoolCities", 1],
[10, "name", 2],
[12, "address", 2],
[21, "city", 3],
# Newly added columns:
[22, 'manualSort', 4],
[23, 'A', 4],
[24, 'B', 4],
[25, 'C', 4],
])
def test_add_column_position(self):
self.load_sample(testsamples.sample_students)
# Client may send AddColumn actions with fractional positions. Test that it works.
# TODO: this should probably use parentPos in the future and be done via metadata AddRecord.
out_actions = self.add_column('Students', 'test', type='Text', _position=2.75)
self.assertPartialData("_grist_Tables", ["id", "columns"], [
[1, [1,2,22,4,5,6]],
[2, [10,12]],
[3, [21]],
])
out_actions = self.add_column('Students', None, type='Text', _position=6)
self.assertPartialData("_grist_Tables", ["id", "columns"], [
[1, [1,2,22,4,5,6,23]],
[2, [10,12]],
[3, [21]],
])
self.assertPartialData("_grist_Tables_column", ["id", "colId", "parentId"], [
[1, "firstName", 1],
[2, "lastName", 1],
[4, "schoolName", 1],
[5, "schoolIds", 1],
[6, "schoolCities", 1],
[10, "name", 2],
[12, "address", 2],
[21, "city", 3],
[22, "test", 1],
[23, "A", 1],
])
def assertRecordSet(self, record_set, expected_row_ids):
self.assertEqual(list(record_set.id), expected_row_ids)
def test_lookup_recompute(self):
self.load_sample(testsamples.sample_students)
self.apply_user_action(['AddTable', 'Test2', [
{'id': 'A', 'type': 'Text'},
{'id': 'B', 'type': 'Numeric'},
]])
self.apply_user_action(['AddTable', 'Test3', [
{'id': 'A', 'type': 'Text'},
{'id': 'B', 'type': 'Numeric'},
]])
self.apply_user_action(['AddViewSection', 'Section2', 'record', 1, 'Test2'])
self.apply_user_action(['AddViewSection', 'Section3', 'record', 1, 'Test3'])
self.assertPartialData('_grist_Views', ["id"], [
[1],
[2],
])
self.assertPartialData('_grist_Views_section', ["id", "parentId", "tableRef"], [
[1, 1, 4],
[2, 2, 5],
[3, 1, 4],
[4, 1, 5],
])
self.assertPartialData('_grist_Views_section_field', ["id", "parentId", "parentPos"], [
[1, 1, 1.0],
[2, 1, 2.0],
[3, 2, 3.0],
[4, 2, 4.0],
[5, 3, 5.0],
[6, 3, 6.0],
[7, 4, 7.0],
[8, 4, 8.0],
])
table = self.engine.docmodel.tables.lookupOne(tableId='Test2')
self.assertRecordSet(table.viewSections, [1,3])
self.assertRecordSet(list(table.viewSections)[0].fields, [1,2])
self.assertRecordSet(list(table.viewSections)[1].fields, [5,6])
view = self.engine.docmodel.views.lookupOne(id=1)
self.assertRecordSet(view.viewSections, [1,3,4])
self.engine.docmodel.remove(f for vs in table.viewSections for f in vs.fields)
self.engine.docmodel.remove(table.viewSections)
self.assertRecordSet(view.viewSections, [4])
def test_modifications(self):
# Test the add/remove/update methods of DocModel.
self.load_sample(testsamples.sample_students)
table = self.engine.docmodel.get_table('Students')
records = table.lookupRecords(lastName='Bush')
self.assertEqual([r.id for r in records], [2, 4])
self.assertEqual([r.schoolName for r in records], ["Yale", "Yale"])
self.assertEqual([r.firstName for r in records], ["George W", "George H"])
# Test the update() method.
self.engine.docmodel.update(records, schoolName="Test", firstName=["george w", "george h"])
self.assertEqual([r.schoolName for r in records], ["Test", "Test"])
self.assertEqual([r.firstName for r in records], ["george w", "george h"])
# Test the remove() method.
self.engine.docmodel.remove(records)
records = table.lookupRecords(lastName='Bush')
self.assertEqual(list(records), [])
self.assertTableData("Students", cols="subset", data=[
["id","firstName","lastName", "schoolName" ],
[1, "Barack", "Obama", "Columbia" ],
[3, "Bill", "Clinton", "Columbia" ],
[5, "Ronald", "Reagan", "Eureka" ],
[6, "Gerald", "Ford", "Yale" ]])
# Test the add() method.
self.engine.docmodel.add(table, schoolName="Foo", firstName=["X", "Y"])
self.assertTableData("Students", cols="subset", data=[
["id","firstName","lastName", "schoolName" ],
[1, "Barack", "Obama", "Columbia" ],
[3, "Bill", "Clinton", "Columbia" ],
[5, "Ronald", "Reagan", "Eureka" ],
[6, "Gerald", "Ford", "Yale" ],
[7, "X", "", "Foo" ],
[8, "Y", "", "Foo" ],
])
def test_inserts(self):
# Test the insert() method. We do this on the columns metadata table, so that we can sort by
# a PositionNumber column.
self.load_sample(testsamples.sample_students)
student_columns = self.engine.docmodel.tables.lookupOne(tableId='Students').columns
school_columns = self.engine.docmodel.tables.lookupOne(tableId='Schools').columns
# Should go at the end of the Students table.
cols = self.engine.docmodel.insert(student_columns, None, colId=["a", "b"], type="Text")
# Should go at the start of the Schools table.
self.engine.docmodel.insert_after(school_columns, None, colId="foo", type="Int")
# Should go before the new "a", "b" columns of the Students table.
self.engine.docmodel.insert(student_columns, cols[0].parentPos, colId="bar", type="Date")
# Verify that the right columns were added to the right tables. This doesn't check positions.
self.assertTables([
Table(1, "Students", 0, 0, columns=[
Column(1, "firstName", "Text", False, "", 0),
Column(2, "lastName", "Text", False, "", 0),
Column(4, "schoolName", "Text", False, "", 0),
Column(5, "schoolIds", "Text", True,
"':'.join(str(id) for id in Schools.lookupRecords(name=$schoolName).id)", 0),
Column(6, "schoolCities", "Text", True,
"':'.join(r.address.city for r in Schools.lookupRecords(name=$schoolName))", 0),
Column(22, "a", "Text", False, "", 0),
Column(23, "b", "Text", False, "", 0),
Column(25, "bar", "Date", False, "", 0),
]),
Table(2, "Schools", 0, 0, columns=[
Column(10, "name", "Text", False, "", 0),
Column(12, "address", "Ref:Address",False, "", 0),
Column(24, "foo", "Int", False, "", 0),
]),
Table(3, "Address", 0, 0, columns=[
Column(21, "city", "Text", False, "", 0),
])
])
# Verify that positions are set such that the order is what we asked for.
student_columns = self.engine.docmodel.tables.lookupOne(tableId='Students').columns
self.assertEqual(map(int, student_columns), [1,2,4,5,6,25,22,23])
school_columns = self.engine.docmodel.tables.lookupOne(tableId='Schools').columns
self.assertEqual(map(int, school_columns), [24,10,12])

@ -0,0 +1,559 @@
import difflib
import functools
import json
import unittest
from collections import namedtuple
import actions
import column
import engine
import logger
import useractions
import testutil
log = logger.Logger(__name__, logger.DEBUG)
# These are for use in verifying metadata using assertTables/assertViews methods. E.g.
# self.assertViews([View(1, sections=[Section(1, parentKey="record", tableRef=1, fields=[
# Field(1, colRef=11) ]) ]) ])
Table = namedtuple('Table', ('id tableId primaryViewId summarySourceTable columns'))
Column = namedtuple('Column', ('id colId type isFormula formula summarySourceCol'))
View = namedtuple('View', 'id sections')
Section = namedtuple('Section', 'id parentKey tableRef fields')
Field = namedtuple('Field', 'id colRef')
class EngineTestCase(unittest.TestCase):
"""
Provides functionality for verifying engine actions and data, which is general enough to be
useful for other tests. It is also used by TestEngine below.
"""
# Place to keep the original log handler (which we modify for the duration of the test).
# We can't use cls._orig_log_handler directly because then Python it's an actual class method.
_orig_log_handler = []
@classmethod
def setUpClass(cls):
cls._orig_log_handler.append(logger.set_handler(testutil.limit_log_stderr(logger.WARN)))
@classmethod
def tearDownClass(cls):
logger.set_handler(cls._orig_log_handler.pop())
def setUp(self):
"""
Initial setup for each test case.
"""
self.engine = engine.Engine()
self.engine.load_empty()
# Set up call tracing to count calls (formula evaluations) for each column for each table.
self.call_counts = {}
def trace_call(col_obj, _rec):
# Ignore formulas in metadata tables for simplicity. Such formulas are mostly private, and
# it would be annoying to fix tests every time we change them.
if not col_obj.table_id.startswith("_grist_"):
tmap = self.call_counts.setdefault(col_obj.table_id, {})
tmap[col_obj.col_id] = tmap.get(col_obj.col_id, 0) + 1
self.engine.formula_tracer = trace_call
# This is set when a test case is wrapped by `test_engine.test_undo`.
self._undo_state_tracker = None
@classmethod
def _getEngineDataLines(cls, engine_data, col_names=[]):
"""
Helper for assertEqualEngineData, which returns engine data represented as lines of text
suitable for diffing. If col_names is given, it determines the order of columns (columns not
found in this list are included in the end and sorted by name).
"""
sort_keys = {c: i for i, c in enumerate(col_names)}
ret = []
engine_data = actions.encode_objects(engine_data)
for table_id, table_data in sorted(engine_data.items()):
ret.append("TABLE %s\n" % table_id)
col_items = sorted(table_data.columns.items(),
key=lambda c: (sort_keys.get(c[0], float('inf')), c))
col_items.insert(0, ('id', table_data.row_ids))
table_rows = zip(*[[col_id] + values for (col_id, values) in col_items])
ret.extend(json.dumps(row) + "\n" for row in table_rows)
return ret
def assertEqualDocData(self, observed, expected, col_names=[]):
"""
Compare full engine data, as a mapping of table_ids to TableData objects, and reporting
differences with a customized diff (similar to the JSON representation in the test script).
"""
if observed != expected:
o_lines = self._getEngineDataLines(observed, col_names)
e_lines = self._getEngineDataLines(expected, col_names)
self.fail("Observed data not as expected:\n" +
"".join(difflib.unified_diff(e_lines, o_lines,
fromfile="expected", tofile="observed")))
def assertCorrectEngineData(self, expected_data):
"""
Verifies that the data engine contains the same data as the given expected data,
which should be a dictionary mapping table names to TableData objects.
"""
expected_output = actions.decode_objects(expected_data)
meta_tables = self.engine.fetch_table("_grist_Tables")
output = {t: self.engine.fetch_table(t) for t in meta_tables.columns["tableId"]}
output = testutil.replace_nans(output)
self.assertEqualDocData(output, expected_output)
def getFullEngineData(self):
return testutil.replace_nans({t: self.engine.fetch_table(t) for t in self.engine.tables})
def assertPartialData(self, table_name, col_names, row_data):
"""
Verifies that the data engine contains the right data for the given col_names (ignoring any
other columns).
"""
expected = testutil.table_data_from_rows(table_name, col_names, row_data)
observed = self.engine.fetch_table(table_name, private=True)
ignore = set(observed.columns) - set(expected.columns)
for col_id in ignore:
del observed.columns[col_id]
self.assertEqualDocData({table_name: observed}, {table_name: expected})
action_group_action_fields = ("stored", "undo", "calc")
@classmethod
def _formatActionGroup(cls, action_group, use_repr=False):
"""
Helper for assertEqualActionGroups below.
"""
lines = ["{"]
for (k, action_list) in sorted(action_group.items()):
if k in cls.action_group_action_fields:
for a in action_list:
rep = repr(a) if use_repr else json.dumps(actions.get_action_repr(a), sort_keys=True)
lines.append("%s: %s," % (k, rep))
else:
lines.append("%s: %s," % (k, json.dumps(action_list)))
lines.append("}")
return lines
def assertEqualActionGroups(self, observed, expected):
"""
Compare grouped doc actions, reporting differences with a customized diff
(a bit more readable than unittest's usual diff).
"""
# Do some clean up on the observed data.
observed = testutil.replace_nans(observed)
# Convert expected actions into a comparable form.
for k in self.action_group_action_fields:
if k in expected:
expected[k] = [actions.action_from_repr(a) if isinstance(a, list) else a
for a in expected[k]]
if observed != expected:
o_lines = self._formatActionGroup(observed)
e_lines = self._formatActionGroup(expected)
extra = ""
if o_lines == e_lines:
o_lines = self._formatActionGroup(observed, use_repr=True)
e_lines = self._formatActionGroup(expected, use_repr=True)
extra = " (BUT HAVE SAME REPR!)"
self.fail(("Observed out actions not as expected%s:\n" % extra) +
"\n".join(difflib.unified_diff(e_lines, o_lines, n=3, lineterm="",
fromfile="expected", tofile="observed")))
def assertOutActions(self, out_action_group, expected_group):
"""
Compares action group returned from engine.apply_user_actions() to expected actions as listed
in testscript. The array of retValues is only checked if present in expected_group.
"""
observed = {k: getattr(out_action_group, k) for k in self.action_group_action_fields }
if "retValue" in expected_group:
observed["retValue"] = out_action_group.retValues
self.assertEqualActionGroups(observed, expected_group)
def assertPartialOutActions(self, out_action_group, expected_group):
"""
Compares a single action group as returned from engine.apply_user_actions() to expected
actions, checking only those fields that are included in the expected_group dict.
"""
observed = {k: getattr(out_action_group, k) for k in expected_group}
self.assertEqualActionGroups(observed, expected_group)
def dump_data(self):
"""
Prints a dump of all engine data, for help in writing / debugging tests.
"""
output = {t: self.engine.fetch_table(t) for t in self.engine.schema}
output = testutil.replace_nans(output)
print ''.join(self._getEngineDataLines(output))
def dump_actions(self, out_actions):
"""
Prints out_actions in human-readable format, for help in writing / debugging tets.
"""
print "\n".join(self._formatActionGroup(out_actions.__dict__))
def assertTableData(self, table_name, data=[], cols="all", rows="all", sort=None):
"""
Verify some or all of the data in the table named `table_name`.
- data: an array of rows, with first row containing column names starting with "id", and
other rows also all starting with row_id.
- cols: may be "all" (default) to match all columns, or "subset" to match only those listed.
- rows: may be "all" (default) to match all rows, or "subset" to match only those listed,
or a function called with a Record to return whether to include it.
- sort: optionally a key function called with a Record, for sorting observed rows.
"""
assert data[0][0] == 'id', "assertRecords requires 'id' as the first column"
col_names = data[0]
row_data = data[1:]
expected = testutil.table_data_from_rows(table_name, col_names, row_data)
table = self.engine.tables[table_name]
columns = [c for c in table.all_columns.values()
if c.col_id != "id" and not column.is_virtual_column(c.col_id)]
if cols == "all":
pass
elif cols == "subset":
columns = [c for c in columns if c.col_id in col_names]
else:
raise ValueError("assertRecords: invalid value for cols: %s" % (cols,))
if rows == "all":
row_ids = list(table.row_ids)
elif rows == "subset":
row_ids = [row[0] for row in row_data]
elif callable(rows):
row_ids = [r.id for r in table.user_table.all if rows(r)]
else:
raise ValueError("assertRecords: invalid value for rows: %s" % (rows,))
if sort:
row_ids.sort(key=lambda r: sort(table.get_record(r)))
observed_col_data = {c.col_id: map(c.raw_get, row_ids) for c in columns if c.col_id != "id"}
observed = actions.TableData(table_name, row_ids, observed_col_data)
self.assertEqualDocData({table_name: observed}, {table_name: expected},
col_names=col_names)
def assertTables(self, list_of_tables):
"""
Verifies that the given Table test-records correspond to the metadata for tables/columns.
"""
self.assertPartialData('_grist_Tables',
["id", "tableId", "primaryViewId", "summarySourceTable"],
sorted((tbl.id, tbl.tableId, tbl.primaryViewId, tbl.summarySourceTable)
for tbl in list_of_tables))
self.assertPartialData('_grist_Tables_column',
["id", "parentId", "colId", "type",
"isFormula", "formula", "summarySourceCol"],
sorted((col.id, tbl.id, col.colId, col.type,
col.isFormula, col.formula, col.summarySourceCol)
for tbl in list_of_tables
for col in tbl.columns))
def assertViews(self, list_of_views):
"""
Verifies that the given View test-records correspond to the metadata for views/sections/fields.
"""
self.assertPartialData('_grist_Views', ["id"],
[[view.id] for view in list_of_views])
self.assertPartialData('_grist_Views_section', ["id", "parentId", "parentKey", "tableRef"],
sorted((sec.id, view.id, sec.parentKey, sec.tableRef)
for view in list_of_views
for sec in view.sections))
self.assertTableData('_grist_Views_section_field', sort=(lambda r: r.parentPos),
cols="subset",
data=[["id", "parentId", "colRef"]] + sorted(
((field.id, sec.id, field.colRef)
for view in list_of_views
for sec in view.sections
for field in sec.fields), key=lambda t: t[1])
)
def load_sample(self, sample):
"""
Load the data engine with given sample data. The sample is a dict with keys "SCHEMA" and
"DATA", each a dictionary mapping table names to actions.TableData objects. "SCHEMA" contains
"_grist_Tables" and "_grist_Tables_column" tables.
"""
schema = sample["SCHEMA"]
self.engine.load_meta_tables(schema['_grist_Tables'], schema['_grist_Tables_column'])
for data in sample["DATA"].itervalues():
self.engine.load_table(data)
self.engine.load_done()
# The following are convenience methods for tests deriving from EngineTestCase.
def add_column(self, table_name, col_name, **kwargs):
return self.apply_user_action(['AddColumn', table_name, col_name, kwargs])
def modify_column(self, table_name, col_name, **kwargs):
return self.apply_user_action(['ModifyColumn', table_name, col_name, kwargs])
def remove_column(self, table_name, col_name):
return self.apply_user_action(['RemoveColumn', table_name, col_name])
def update_record(self, table_name, row_id, **kwargs):
return self.apply_user_action(['UpdateRecord', table_name, row_id, kwargs])
def add_record(self, table_name, row_id=None, **kwargs):
return self.apply_user_action(['AddRecord', table_name, row_id, kwargs])
def remove_record(self, table_name, row_id):
return self.apply_user_action(['RemoveRecord', table_name, row_id])
def update_records(self, table_name, col_names, row_data):
return self.apply_user_action(
('BulkUpdateRecord',) + testutil.table_data_from_rows(table_name, col_names, row_data))
@classmethod
def add_records_action(cls, table_name, data):
"""
Creates a BulkAddRecord action; data should be an array of rows, with first row containing
column names, with "id" column optional.
"""
col_names, row_data = data[0], data[1:]
if "id" not in col_names:
col_names = ["id"] + col_names
row_data = [[None] + r for r in row_data]
return ('BulkAddRecord',) + testutil.table_data_from_rows(table_name, col_names, row_data)
def add_records(self, table_name, col_names, row_data):
return self.apply_user_action(self.add_records_action(table_name, [col_names] + row_data))
def apply_user_action(self, user_action_repr, is_undo=False):
if not is_undo:
log.debug("Applying user action %r" % (user_action_repr,))
if self._undo_state_tracker is not None:
doc_state = self.getFullEngineData()
self.call_counts.clear()
out_actions = self.engine.apply_user_actions([useractions.from_repr(user_action_repr)])
out_actions.calls = self.call_counts.copy()
if not is_undo and self._undo_state_tracker is not None:
self._undo_state_tracker.append((doc_state, out_actions.undo[:]))
return out_actions
def apply_undo_actions(self, undo_actions):
"""
Applies all doc_actions together (as happens e.g. for undo).
"""
action = ["ApplyUndoActions", [actions.get_action_repr(a) for a in undo_actions]]
return self.apply_user_action(action, is_undo=True)
def test_undo(test_method):
"""
If a test method is decorated with `@test_engine.test_undo`, then we will store the state before
each apply_user_action() call, and at the end of the test, undo each user-action and compare the
state. This makes for a fairly comprehensive test of undo.
"""
@functools.wraps(test_method)
def wrapped(self):
self._undo_state_tracker = []
test_method(self)
for (expected_engine_data, undo_actions) in reversed(self._undo_state_tracker):
log.debug("Applying undo actions %r" % (undo_actions,))
self.apply_undo_actions(undo_actions)
self.assertEqualDocData(self.getFullEngineData(), expected_engine_data)
return wrapped
class TestEngine(EngineTestCase):
samples = {}
#----------------------------------------------------------------------
# Implementations of the actual script steps.
#----------------------------------------------------------------------
def process_apply_step(self, data):
"""
Processes the "APPLY" step of a test script, applying a user action, and checking the
resulting action group's return value (if present)
"""
if "USER_ACTION" in data:
user_actions = [useractions.from_repr(data.pop("USER_ACTION"))]
else:
user_actions = [useractions.from_repr(u) for u in data.pop("USER_ACTIONS")]
expected_call_counts = data.pop("CHECK_CALL_COUNTS", None)
expected_actions = data.pop("ACTIONS", {})
expected_actions.setdefault("stored", [])
expected_actions.setdefault("calc", [])
expected_actions.setdefault("undo", [])
if data:
raise ValueError("Unrecognized key %s in APPLY step" % data.popitem()[0])
self.call_counts.clear()
out_actions = self.engine.apply_user_actions(user_actions)
self.assertOutActions(out_actions, expected_actions)
if expected_call_counts:
self.assertEqual(self.call_counts, expected_call_counts)
return out_actions
#----------------------------------------------------------------------
# The runner for scripted test cases.
#----------------------------------------------------------------------
def _run_test_body(self, _name, body):
"""
Runs the actual script defined in the JSON test-script file.
"""
undo_actions = []
loaded_sample = None
for line, step, data in body:
try:
if step == "LOAD_SAMPLE":
if loaded_sample:
# pylint: disable=unsubscriptable-object
self._verify_undo_all(undo_actions, loaded_sample["DATA"])
loaded_sample = self.samples[data]
self.load_sample(loaded_sample)
elif step == "APPLY":
action_group = self.process_apply_step(data)
undo_actions.extend(action_group.undo)
elif step == "CHECK_OUTPUT":
expected_data = {}
if "USE_SAMPLE" in data:
sample = self.samples[data.pop("USE_SAMPLE")]
expected_data = sample["DATA"].copy()
expected_data.update({t: testutil.table_data_from_rows(t, tdata[0], tdata[1:])
for (t, tdata) in data.iteritems()})
self.assertCorrectEngineData(expected_data)
else:
raise ValueError("Unrecognized step %s in test script" % step)
except Exception, e:
prefix = "LINE %s: " % line
e.args = (prefix + e.args[0],) + e.args[1:] if e.args else (prefix,)
raise
self._verify_undo_all(undo_actions, loaded_sample["DATA"])
def _verify_undo_all(self, undo_actions, expected_data):
"""
At the end of each test, undo all and verify we get back to the originally loaded sample.
"""
self.apply_undo_actions(undo_actions)
del undo_actions[:]
self.assertCorrectEngineData(expected_data)
# TODO We need several more tests.
# 1. After a bunch of schema actions, create a new engine from the resulting schema, ensure that
# modified engine and new engine produce the same results AND the same dep_graph.
# 2. Build up a table by adding one column at a time, in "good" order and in "bad" order (with
# references to columns that will be added later)
# 3. Tear down a table in both of the orders above.
# 4. At each intermediate state of 2 and 3, new engine should produce same results as the
# modified engine (and have the same state such as dep_graph).
sample1 = {
"SCHEMA": [
[1, "Address", [
[11, "city", "Text", False, "", "", ""],
[12, "state", "Text", False, "", "", ""],
[13, "amount", "Numeric", False, "", "", ""],
]]
],
"DATA": {
"Address": [
["id", "city", "state", "amount" ],
[ 21, "New York", "NY" , 1 ],
[ 22, "Albany", "NY" , 2 ],
]
}
}
def test_no_private_fields(self):
self.load_sample(testutil.parse_test_sample(self.sample1))
data = self.engine.fetch_table("_grist_Tables", private=True)
self.assertIn('tableId', data.columns)
self.assertIn('columns', data.columns)
self.assertIn('viewSections', data.columns)
data = self.engine.fetch_table("_grist_Tables")
self.assertIn('tableId', data.columns)
self.assertNotIn('columns', data.columns)
self.assertNotIn('viewSections', data.columns)
def test_fetch_table_query(self):
self.load_sample(testutil.parse_test_sample(self.sample1))
col_names = ["id", "city", "state", "amount" ]
data = self.engine.fetch_table('Address', query={'state': ['NY']})
self.assertEqualDocData({'Address': data},
{'Address': testutil.table_data_from_rows('Address', col_names, [
[ 21, "New York", "NY" , 1 ],
[ 22, "Albany", "NY" , 2 ],
])})
data = self.engine.fetch_table('Address', query={'city': ['New York'], 'state': ['NY']})
self.assertEqualDocData({'Address': data},
{'Address': testutil.table_data_from_rows('Address', col_names, [
[ 21, "New York", "NY" , 1 ],
])})
data = self.engine.fetch_table('Address', query={'amount': [2.0]})
self.assertEqualDocData({'Address': data},
{'Address': testutil.table_data_from_rows('Address', col_names, [
[ 22, "Albany", "NY" , 2 ],
])})
data = self.engine.fetch_table('Address', query={'city': ['New York'], 'amount': [2.0]})
self.assertEqualDocData({'Address': data},
{'Address': testutil.table_data_from_rows('Address', col_names, [])})
data = self.engine.fetch_table('Address', query={'city': ['New York'], 'amount': [1.0, 2.0]})
self.assertEqualDocData({'Address': data},
{'Address': testutil.table_data_from_rows('Address', col_names, [
[ 21, "New York", "NY" , 1 ],
])})
# Ensure empty filter list works too.
data = self.engine.fetch_table('Address', query={'city': ['New York'], 'amount': []})
self.assertEqualDocData({'Address': data},
{'Address': testutil.table_data_from_rows('Address', col_names, [])})
def test_schema_restore_on_error(self):
# Simulate an error inside a DocAction, and make sure we restore the schema (don't leave it in
# inconsistent with metadata).
self.load_sample(testutil.parse_test_sample(self.sample1))
with self.assertRaisesRegexp(AttributeError, r"'BAD'"):
self.add_column('Address', 'bad', isFormula=False, type="BAD")
self.engine.assert_schema_consistent()
def create_tests_from_script(samples, test_cases):
"""
Dynamically create tests from a file containing a JSON spec for test cases. The reason for doing
it this way is because the same JSON spec is used to test Python and JS code.
Tests are created as methods to a TestCase. It's done on import, so that python unittest feature
to run only particular test cases can apply to these cases too.
"""
TestEngine.samples = samples
for case in test_cases:
create_test_case("test_" + case["TEST_CASE"], case["BODY"])
def create_test_case(name, body):
"""
Helper for create_tests_from_script, which creates a single test case.
"""
def run(self):
self._run_test_body(name, body)
setattr(TestEngine, name, run)
# Parse and create test cases on module load. This way the python unittest feature to run only
# particular test cases can apply to these cases too.
create_tests_from_script(*testutil.parse_testscript())
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,47 @@
import testsamples
import test_engine
class TestFindCol(test_engine.EngineTestCase):
def test_find_col_from_values(self):
# Test basic funtionality.
self.load_sample(testsamples.sample_students)
self.assertEqual(self.engine.find_col_from_values(("Columbia", "Yale", "Eureka"), 0),
[4, 10])
self.assertEqual(self.engine.find_col_from_values(("Columbia", "Yale", "Eureka"), 1),
[4])
self.assertEqual(self.engine.find_col_from_values(["Yale"], 2),
[10, 4])
self.assertEqual(self.engine.find_col_from_values(("Columbia", "Yale", "Eureka"), 0, "Schools"),
[10])
def test_find_col_with_nonhashable(self):
self.load_sample(testsamples.sample_students)
# Add a couple of columns returning list, which is not hashable. There used to be a bug where
# non-hashable values would cause an exception.
self.add_column("Students", "foo", formula="list(Schools.lookupRecords(name=$schoolName))")
# This column returns a non-hashable value, but is otherwise the best match.
self.add_column("Students", "bar", formula=
"[1,2,3] if $firstName == 'Bill' else $schoolName.lower()")
# Check the columns are added with expected colRefs
self.assertTableData('_grist_Tables_column', cols="subset", rows="subset", data=[
["id", "colId", "type", "isFormula" ],
[22, "foo", "Any", True ],
[23, "bar", "Any", True ],
])
self.assertTableData("Students", cols="subset", data=[
["id","firstName","lastName", "schoolName", "bar", ],
[1, "Barack", "Obama", "Columbia", "columbia" ],
[2, "George W", "Bush", "Yale", "yale" ],
[3, "Bill", "Clinton", "Columbia", [1,2,3] ],
[4, "George H", "Bush", "Yale", "yale" ],
[5, "Ronald", "Reagan", "Eureka", "eureka" ],
[6, "Gerald", "Ford", "Yale", "yale" ],
])
self.assertEqual(self.engine.find_col_from_values(("Columbia", "Yale", "Eureka"), 0), [4, 10])
self.assertEqual(self.engine.find_col_from_values(("columbia", "yale", "Eureka"), 0), [23, 4])
# Test that it's safe to include a non-hashable value in the request.
self.assertEqual(self.engine.find_col_from_values(("columbia", "yale", ["Eureka"]), 0), [23])

@ -0,0 +1,646 @@
"""
Tests that formula error messages (traceback) are correct
"""
import depend
import textwrap
import test_engine
import testutil
import objtypes
class TestErrorMessage(test_engine.EngineTestCase):
syntax_err = \
"""
if sum(3, 5) > 6:
return 6
else:
return: 0
"""
indent_err = \
"""
if sum(3, 5) > 6:
return 6
"""
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Math", [
[11, "excel_formula", "Text", True, "SQRT(16, 2)", "", ""],
[12, "built_in_formula", "Text", True, "max(5)", "", ""],
[13, "syntax_err", "Text", True, syntax_err, "", ""],
[14, "indent_err", "Text", True, indent_err, "", ""],
[15, "other_err", "Text", True, textwrap.dedent(indent_err), "", ""],
[15, "custom_err", "Text", True, "raise Exception('hello')", "", ""],
]]
],
"DATA": {
"Math": [
["id"],
[3],
]
}
})
def assertFormulaError(self, exc, type_, message, tracebackRegexp=None):
self.assertIsInstance(exc, objtypes.RaisedException)
self.assertIsInstance(exc.error, type_)
self.assertEqual(str(exc.error), message)
if tracebackRegexp:
self.assertRegexpMatches(exc.details, tracebackRegexp)
def test_formula_errors(self):
self.load_sample(self.sample)
self.assertFormulaError(self.engine.get_formula_error('Math', 'excel_formula', 3),
TypeError, 'SQRT() takes exactly 1 argument (2 given)',
r"TypeError: SQRT\(\) takes exactly 1 argument \(2 given\)")
self.assertFormulaError(self.engine.get_formula_error('Math', 'built_in_formula', 3),
TypeError, "'int' object is not iterable")
self.assertFormulaError(self.engine.get_formula_error('Math', 'syntax_err', 3),
SyntaxError, "invalid syntax on line 5 col 9")
self.assertFormulaError(self.engine.get_formula_error('Math', 'indent_err', 3),
IndentationError, "unexpected indent on line 2 col 2")
self.assertFormulaError(self.engine.get_formula_error('Math', 'other_err', 3),
TypeError, "'int' object is not iterable",
r"line \d+, in other_err")
self.assertFormulaError(self.engine.get_formula_error('Math', 'custom_err', 3),
Exception, "hello")
def test_lookup_state(self):
# Bug https://phab.getgrist.com/T297 was caused by lookup maps getting corrupted while
# re-evaluating a formula for the sake of getting error details. This test case reproduces the
# bug in the old code and verifies that it is fixed.
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "LookupTest", [
[11, "A", "Numeric", False, "", "", ""],
[12, "B", "Text", True, "LookupTest.lookupOne(A=2).x.upper()", "", ""],
]]
],
"DATA": {
"LookupTest": [
["id", "A"],
[7, 2],
]
}
})
self.load_sample(sample)
self.assertTableData('LookupTest', data=[
['id', 'A', 'B'],
[ 7, 2., objtypes.RaisedException(AttributeError())],
])
# Updating a dependency shouldn't cause problems.
self.update_record('LookupTest', 7, A=3)
self.assertTableData('LookupTest', data=[
['id', 'A', 'B'],
[ 7, 3., objtypes.RaisedException(AttributeError())],
])
# Fetch the error details.
self.assertFormulaError(self.engine.get_formula_error('LookupTest', 'B', 7),
AttributeError, "Table 'LookupTest' has no column 'x'")
# Updating a dependency after the fetch used to cause the error
# "AttributeError: 'Table' object has no attribute 'col_id'". Check that it's fixed.
self.update_record('LookupTest', 7, A=2) # Should NOT raise an exception.
self.assertTableData('LookupTest', data=[
['id', 'A', 'B'],
[ 7, 2., objtypes.RaisedException(AttributeError())],
])
# Add the column that will fix the attribute error.
self.add_column('LookupTest', 'x', type='Text')
self.assertTableData('LookupTest', data=[
['id', 'A', 'x', 'B'],
[ 7, 2., '', '' ],
])
# And check that the dependency still works and is recomputed.
self.update_record('LookupTest', 7, x='hello')
self.assertTableData('LookupTest', data=[
['id', 'A', 'x', 'B'],
[ 7, 2., 'hello', 'HELLO'],
])
self.update_record('LookupTest', 7, A=3)
self.assertTableData('LookupTest', data=[
['id', 'A', 'x', 'B'],
[ 7, 3., 'hello', ''],
])
def test_undo_side_effects(self):
# Ensures that side-effects (i.e. generated doc actions) produced while evaluating
# get_formula_errors() get reverted.
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Address", [
[11, "city", "Text", False, "", "", ""],
[12, "state", "Text", False, "", "", ""],
]],
[2, "Foo", [
# Note: the formula below is a terrible example of a formula, which intentionally
# creates a new record every time it evaluates.
[21, "B", "Any", True,
"Address.lookupOrAddDerived(city=str(len(Address.all)))", "", ""],
]]
],
"DATA": {
"Foo": [["id"], [1]]
}
})
self.load_sample(sample)
self.assertTableData('Address', data=[
['id', 'city', 'state'],
[1, '0', ''],
])
# Note that evaluating the formula again would add a new record (Address[2]), but when done as
# part of get_formula_error(), that action gets undone.
self.assertEqual(str(self.engine.get_formula_error('Foo', 'B', 1)), "Address[2]")
self.assertTableData('Address', data=[
['id', 'city', 'state'],
[1, '0', ''],
])
def test_formula_reading_from_an_errored_formula(self):
# There was a bug whereby if one formula (call it D) referred to
# another (call it T), and that other formula was in error, the
# error values of that second formula would not be passed on the
# client as a BulkUpdateRecord. The bug was dependent on order of
# evaluation of columns. D would be evaluated first, and evaluate
# T in a nested way. When evaluating T, a BulkUpdateRecord would
# be prepared correctly, and when popping back to evaluate D,
# the BulkUpdateRecord for D would be prepared correctly, but since
# D was an error, any nested actions would be reverted (this is
# logic related to undoing potential side-effects on failure).
# First, set up a table with a sequence in A, a formula to do cumulative sums in T,
# and a formula D to copy T.
formula = "recs = UpdateTest.lookupRecords()\nsum(r.A for r in recs if r.A <= $A)"
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "UpdateTest", [
[20, "A", "Numeric", False, "", "", ""],
[21, "T", "Numeric", True, formula, "", ""],
[22, "D", "Numeric", True, "$T", "", ""],
]]
],
"DATA": {
"UpdateTest": [
["id", "A"],
[1, 1],
[2, 2],
[3, 3],
]
}
})
# Check the setup is working correctly.
self.load_sample(sample)
self.assertTableData('UpdateTest', data=[
['id', 'A', 'T', 'D'],
[ 1, 1., 1., 1.],
[ 2, 2., 3., 3.],
[ 3, 3., 6., 6.],
])
# Now rename the data column. This rename results in a partial
# update to the T formula that leaves it broken (not all the As are caught).
out_actions = self.apply_user_action(["RenameColumn", "UpdateTest", "A", "AA"])
# Make sure the we have bulk updates for both T and D, and not just D.
err = ["E", "AttributeError"]
self.assertPartialOutActions(out_actions, { "calc": [
[
"BulkUpdateRecord", "UpdateTest", [1, 2, 3], {
"T": [err, err, err]
}
],
[
"BulkUpdateRecord", "UpdateTest", [1, 2, 3], {
"D": [err, err, err]
}
]
]})
# Make sure the table is in the correct state.
errVal = objtypes.RaisedException(AttributeError())
self.assertTableData('UpdateTest', data=[
['id', 'AA', 'T', 'D'],
[ 1, 1., errVal, errVal],
[ 2, 2., errVal, errVal],
[ 3, 3., errVal, errVal],
])
def test_undo_side_effects_with_reordering(self):
# As for test_undo_side_effects, but now after creating a row in a
# formula we try to access a cell that hasn't been recomputed yet.
# That will result in the formula evalution being abandoned, the
# desired cell being calculated, then the formula being retried.
# All going well, we should end up with one row, not two.
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Address", [
[11, "city", "Text", False, "", "", ""],
[12, "state", "Text", False, "", "", ""],
]],
[2, "Foo", [
# Note: the formula below is a terrible example of a formula, which intentionally
# creates a new record every time it evaluates.
[21, "B", "Any", True,
"Address.lookupOrAddDerived(city=str(len(Address.all)))\nreturn $C", "", ""],
[22, "C", "Numeric", True, "42", "", ""],
]]
],
"DATA": {
"Foo": [["id"], [1]]
}
})
self.load_sample(sample)
self.assertTableData('Address', data=[
['id', 'city', 'state'],
[1, '0', ''],
])
def test_attribute_error(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "AttrTest", [
[30, "A", "Numeric", False, "", "", ""],
[31, "B", "Numeric", True, "$AA", "", ""],
[32, "C", "Numeric", True, "$B", "", ""],
]]
],
"DATA": {
"AttrTest": [
["id", "A"],
[1, 1],
[2, 2],
]
}
})
self.load_sample(sample)
errVal = objtypes.RaisedException(AttributeError())
self.assertTableData('AttrTest', data=[
['id', 'A', 'B', 'C'],
[1, 1, errVal, errVal],
[2, 2, errVal, errVal],
])
self.assertFormulaError(self.engine.get_formula_error('AttrTest', 'B', 1),
AttributeError, "Table 'AttrTest' has no column 'AA'",
r"AttributeError: Table 'AttrTest' has no column 'AA'")
self.assertFormulaError(self.engine.get_formula_error('AttrTest', 'C', 1),
AttributeError, "Table 'AttrTest' has no column 'AA'",
r"AttributeError: Table 'AttrTest' has no column 'AA'")
def test_cumulative_formula(self):
formula = ("Table1.lookupOne(A=$A-1).Principal + Table1.lookupOne(A=$A-1).Interest " +
"if $A > 1 else 1000")
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[30, "A", "Numeric", False, "", "", ""],
[31, "Principal", "Numeric", True, formula, "", ""],
[32, "Interest", "Numeric", True, "int($Principal * 0.1)", "", ""],
]]
],
"DATA": {
"Table1": [
["id", "A"],
[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
]
}
})
self.load_sample(sample)
self.assertTableData('Table1', data=[
['id', 'A', 'Principal', 'Interest'],
[ 1, 1, 1000.0, 100.0],
[ 2, 2, 1100.0, 110.0],
[ 3, 3, 1210.0, 121.0],
[ 4, 4, 1331.0, 133.0],
[ 5, 5, 1464.0, 146.0],
])
self.update_records('Table1', ['id', 'A'], [
[1, 5], [2, 3], [3, 4], [4, 2], [5, 1]
])
self.assertTableData('Table1', data=[
['id', 'A', 'Principal', 'Interest'],
[ 1, 5, 1464.0, 146.0],
[ 2, 3, 1210.0, 121.0],
[ 3, 4, 1331.0, 133.0],
[ 4, 2, 1100.0, 110.0],
[ 5, 1, 1000.0, 100.0],
])
def test_trivial_cycle(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[31, "A", "Numeric", False, "", "", ""],
[31, "B", "Numeric", True, "$B", "", ""],
]]
],
"DATA": {
"Table1": [
["id", "A"],
[1, 1],
[2, 2],
[3, 3],
]
}
})
self.load_sample(sample)
circle = objtypes.RaisedException(depend.CircularRefError())
self.assertTableData('Table1', data=[
['id', 'A', 'B'],
[ 1, 1, circle],
[ 2, 2, circle],
[ 3, 3, circle],
])
def test_cycle(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[30, "A", "Numeric", False, "", "", ""],
[31, "Principal", "Numeric", True, "$Interest", "", ""],
[32, "Interest", "Numeric", True, "$Principal", "", ""],
[33, "A2", "Numeric", True, "$A", "", ""],
]]
],
"DATA": {
"Table1": [
["id", "A"],
[1, 1],
[2, 2],
[3, 3],
]
}
})
self.load_sample(sample)
circle = objtypes.RaisedException(depend.CircularRefError())
self.assertTableData('Table1', data=[
['id', 'A', 'Principal', 'Interest', 'A2'],
[ 1, 1, circle, circle, 1],
[ 2, 2, circle, circle, 2],
[ 3, 3, circle, circle, 3],
])
def test_cycle_and_copy(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[31, "A", "Numeric", False, "", "", ""],
[31, "B", "Numeric", True, "$C", "", ""],
[32, "C", "Numeric", True, "$C", "", ""],
]]
],
"DATA": {
"Table1": [
["id", "A"],
[1, 1],
[2, 2],
[3, 3],
]
}
})
self.load_sample(sample)
circle = objtypes.RaisedException(depend.CircularRefError())
self.assertTableData('Table1', data=[
['id', 'A', 'B', 'C'],
[ 1, 1, circle, circle],
[ 2, 2, circle, circle],
[ 3, 3, circle, circle],
])
def test_cycle_and_reference(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[2, "ATable", [
[32, "A", "Ref:ZTable", False, "", "", ""],
[33, "B", "Numeric", True, "$A.B", "", ""],
]],
[1, "ZTable", [
[31, "A", "Numeric", False, "", "", ""],
[31, "B", "Numeric", True, "$B", "", ""],
]],
],
"DATA": {
"ATable": [
["id", "A"],
[1, 1],
[2, 2],
[3, 3],
],
"ZTable": [
["id", "A"],
[1, 6],
[2, 7],
[3, 8],
]
}
})
self.load_sample(sample)
circle = objtypes.RaisedException(depend.CircularRefError())
self.assertTableData('ATable', data=[
['id', 'A', 'B'],
[ 1, 1, circle],
[ 2, 2, circle],
[ 3, 3, circle],
])
self.assertTableData('ZTable', data=[
['id', 'A', 'B'],
[ 1, 6, circle],
[ 2, 7, circle],
[ 3, 8, circle],
])
def test_cumulative_efficiency(self):
# Make sure cumulative formula evaluation doesn't fall over after more than a few rows.
top = 250
# Compute compound interest in ascending order of A
formula = ("Table1.lookupOne(A=$A-1).Principal + Table1.lookupOne(A=$A-1).Interest " +
"if $A > 1 else 1000")
# Compute compound interest in descending order of A
rformula = ("Table1.lookupOne(A=$A+1).RPrincipal + Table1.lookupOne(A=$A+1).RInterest " +
"if $A < %d else 1000" % top)
rows = [["id", "A"]]
for i in range(1, top + 1):
rows.append([i, i])
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[30, "A", "Numeric", False, "", "", ""],
[31, "Principal", "Numeric", True, formula, "", ""],
[32, "Interest", "Numeric", True, "int($Principal * 0.1)", "", ""],
[33, "RPrincipal", "Numeric", True, rformula, "", ""],
[34, "RInterest", "Numeric", True, "int($RPrincipal * 0.1)", "", ""],
[35, "Total", "Numeric", True, "$Principal + $RPrincipal", "", ""],
]],
[2, "Readout", [
[36, "LastPrincipal", "Numeric", True, "Table1.lookupOne(A=%d).Principal" % top, "", ""],
[37, "LastRPrincipal", "Numeric", True, "Table1.lookupOne(A=1).RPrincipal", "", ""],
[38, "FirstTotal", "Numeric", True, "Table1.lookupOne(A=1).Total", "", ""],
[39, "LastTotal", "Numeric", True, "Table1.lookupOne(A=%d).Total" % top, "", ""],
]]
],
"DATA": {
"Table1": rows,
"Readout": [["id"], [1]],
}
})
self.load_sample(sample)
principal = 20213227788876
self.assertTableData('Readout', data=[
['id', 'LastPrincipal', 'LastRPrincipal', 'FirstTotal', 'LastTotal'],
[1, principal, principal, principal + 1000, principal + 1000],
])
def test_cumulative_formula_with_references(self):
top = 100
formula = "max($Prev.Principal + $Prev.Interest, 1000)"
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[41, "Prev", "Ref:Table1", True, "$id - 1", "", ""],
[42, "Principal", "Numeric", True, formula, "", ""],
[43, "Interest", "Numeric", True, "int($Principal * 0.1)", "", ""],
]],
[2, "Readout", [
[46, "LastPrincipal", "Numeric", True, "Table1.lookupOne(id=%d).Principal" % top, "", ""],
]]
],
"DATA": {
"Table1": [["id"]] + [[r] for r in range(1, top + 1)],
"Readout": [["id"], [1]],
}
})
self.load_sample(sample)
self.assertTableData('Readout', data=[
['id', 'LastPrincipal'],
[1, 12494908.0],
])
self.modify_column("Table1", "Prev", formula="$id - 1 if $id > 1 else 100")
self.assertTableData('Readout', data=[
['id', 'LastPrincipal'],
[1, objtypes.RaisedException(depend.CircularRefError())],
])
def test_catch_all_in_formula(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[51, "A", "Numeric", False, "", "", ""],
[52, "B1", "Numeric", True, "try:\n return $A+$C\nexcept:\n return 42", "", ""],
[53, "B2", "Numeric", True, "try:\n return $D+None\nexcept:\n return 42", "", ""],
[54, "B3", "Numeric", True, "try:\n return $A+$B4+$D\nexcept:\n return 42", "", ""],
[55, "B4", "Numeric", True, "try:\n return $A+$B3+$D\nexcept:\n return 42", "", ""],
[56, "B5", "Numeric", True,
"try:\n return $E+1\nexcept:\n raise Exception('monkeys!')", "", ""],
[56, "B6", "Numeric", True,
"try:\n return $F+1\nexcept Exception as e:\n e.node = e.row_id = 'monkey'", "", ""],
[57, "C", "Numeric", False, "", "", ""],
[58, "D", "Numeric", True, "$A", "", ""],
[59, "E", "Numeric", True, "$A", "", ""],
[59, "F", "Numeric", True, "$A", "", ""],
]],
],
"DATA": {
"Table1": [["id", "A", "C"], [1, 1, 2], [2, 20, 10]],
}
})
self.load_sample(sample)
circle = objtypes.RaisedException(depend.CircularRefError())
# B4 is a subtle case. B3 and B4 refer to each other. B3 is recomputed first,
# and cells evaluate to a CircularRefError. Now B3 has a value, so B4 can be
# evaluated, and results in 42 when addition of an integer and an exception value
# fails.
self.assertTableData('Table1', data=[
['id', 'A', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'C', 'D', 'E', 'F'],
[1, 1, 3, 42, circle, 42, 2, 2, 2, 1, 1, 1],
[2, 20, 30, 42, circle, 42, 21, 21, 10, 20, 20, 20],
])
def test_reference_column(self):
# There was a bug where self-references could result in a column being prematurely
# considered complete.
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[40, "Ident", "Text", False, "", "", ""],
[41, "Prev", "Ref:Table1", False, "", "", ""],
[42, "Calc", "Numeric", True, "$Prev.Calc * 1.5 if $Prev else 1", "", ""]
]]],
"DATA": {
"Table1": [
['id', 'Ident', 'Prev'],
[1, 'a', 0],
[2, 'b', 1],
[3, 'c', 4],
[4, 'd', 0],
]
}
})
self.load_sample(sample)
self.assertTableData('Table1', data=[
['id', 'Ident', 'Prev', 'Calc'],
[1, 'a', 0, 1.0],
[2, 'b', 1, 1.5],
[3, 'c', 4, 1.5],
[4, 'd', 0, 1.0]
])
def test_loop(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[31, "A", "Numeric", False, "", "", ""],
[31, "B", "Numeric", True, "$C", "", ""],
[32, "C", "Numeric", True, "$B", "", ""],
]]
],
"DATA": {
"Table1": [
["id", "A"],
[1, 1],
[2, 2],
[3, 3],
]
}
})
self.load_sample(sample)
circle = objtypes.RaisedException(depend.CircularRefError())
self.assertTableData('Table1', data=[
['id', 'A', 'B', 'C'],
[ 1, 1, circle, circle],
[ 2, 2, circle, circle],
[ 3, 3, circle, circle],
])

@ -0,0 +1,29 @@
import doctest
import functions
import moment
_old_date_get_global_tz = None
def date_setUp(doc_test):
# pylint: disable=unused-argument
global _old_date_get_global_tz # pylint: disable=global-statement
_old_date_get_global_tz = functions.date._get_global_tz
functions.date._get_global_tz = lambda: moment.tzinfo('America/New_York')
def date_tearDown(doc_test):
# pylint: disable=unused-argument
functions.date._get_global_tz = _old_date_get_global_tz
# This works with the unittest module to turn all the doctests in the functions' doc-comments into
# unittest test cases.
def load_tests(loader, tests, ignore):
tests.addTests(doctest.DocTestSuite(functions.date, setUp = date_setUp, tearDown = date_tearDown))
tests.addTests(doctest.DocTestSuite(functions.info, setUp = date_setUp, tearDown = date_tearDown))
tests.addTests(doctest.DocTestSuite(functions.logical))
tests.addTests(doctest.DocTestSuite(functions.math))
tests.addTests(doctest.DocTestSuite(functions.stats))
tests.addTests(doctest.DocTestSuite(functions.text))
tests.addTests(doctest.DocTestSuite(functions.schedule,
setUp = date_setUp, tearDown = date_tearDown))
return tests

@ -0,0 +1,204 @@
import unittest
import difflib
import re
import gencode
import identifiers
import records
import schema
import table
import testutil
schema_data = [
[1, "Students", [
[1, "firstName", "Text", False, '', "firstName", ''],
[2, "lastName", "Text", False, '', "lastName", ''],
[3, "fullName", "Any", True,
"rec.firstName + ' ' + rec.lastName", "fullName", ''],
[4, "fullNameLen", "Any", True, "len(rec.fullName)", "fullNameLen", ''],
[5, "school", "Ref:Schools", False, '', "school", ''],
[6, "schoolShort", "Any", True, "rec.school.name.split(' ')[0]", "schoolShort", ''],
[9, "schoolRegion", "Any", True,
"addr = $school.address\naddr.state if addr.country == 'US' else addr.region",
"schoolRegion", ''],
[8, "school2", "Ref:Schools", True, "Schools.lookupFirst(name=rec.school.name)", "", ""]
]],
[2, "Schools", [
[10, "name", "Text", False, '', "name", ''],
[12, "address", "Ref:Address",False, '', "address", '']
]],
[3, "Address", [
[21, "city", "Text", False, '', "city", ''],
[27, "state", "Text", False, '', "state", ''],
[28, "country", "Text", False, "'US'", "country", ''],
[29, "region", "Any", True,
"{'US': 'North America', 'UK': 'Europe'}.get(rec.country, 'N/A')", "region", ''],
[30, "badSyntax", "Any", True, "for a in b\n10", "", ""],
]]
]
class TestGenCode(unittest.TestCase):
def setUp(self):
# Convert the meta tables to appropriate table representations for loading.
meta_tables = testutil.table_data_from_rows(
'_grist_Tables',
("id", "tableId"),
[(table_row_id, table_id) for (table_row_id, table_id, _) in schema_data])
meta_columns = testutil.table_data_from_rows(
'_grist_Tables_column',
("parentId", "parentPos", "id", "colId", "type",
"isFormula", "formula", "label", "widgetOptions"),
[[table_row_id, i] + e for (table_row_id, _, entries) in schema_data
for (i, e) in enumerate(entries)])
self.schema = schema.build_schema(meta_tables, meta_columns, include_builtin=False)
def test_make_module_text(self):
"""
Test that make_module_text produces the exact sample output that we have stored
in the docstring of usercode.py.
"""
import usercode
usercode_sample_re = re.compile(r'^==========*\n', re.M)
saved_sample = usercode_sample_re.split(usercode.__doc__)[1]
gcode = gencode.GenCode()
gcode.make_module(self.schema)
generated = gcode.get_user_text()
self.assertEqual(generated, saved_sample, "Generated code doesn't match sample:\n" +
"".join(difflib.unified_diff(generated.splitlines(True),
saved_sample.splitlines(True),
fromfile="generated",
tofile="usercode.py")))
def test_make_module(self):
"""
Test that the generated module has the classes and nested classes we expect.
"""
gcode = gencode.GenCode()
gcode.make_module(self.schema)
module = gcode.usercode
self.assertTrue(isinstance(module.Students, table.UserTable))
self.assertTrue(issubclass(module.Students.Record, records.Record))
self.assertTrue(issubclass(module.Students.RecordSet, records.RecordSet))
self.assertIs(module.Students.RecordSet.Record, module.Students.Record)
def test_pick_col_ident(self):
self.assertEqual(identifiers.pick_col_ident("asdf"), "asdf")
self.assertEqual(identifiers.pick_col_ident(" a s==d!~@#$%^f"), "a_s_d_f")
self.assertEqual(identifiers.pick_col_ident("123asdf"), "c123asdf")
self.assertEqual(identifiers.pick_col_ident("!@#"), "A")
self.assertEqual(identifiers.pick_col_ident("!@#1"), "c1")
self.assertEqual(identifiers.pick_col_ident("heLLO world"), "heLLO_world")
self.assertEqual(identifiers.pick_col_ident("!@#", avoid={"A"}), "B")
self.assertEqual(identifiers.pick_col_ident("foo", avoid={"bar"}), "foo")
self.assertEqual(identifiers.pick_col_ident("foo", avoid={"foo"}), "foo2")
self.assertEqual(identifiers.pick_col_ident("foo", avoid={"foo", "foo2", "foo3"}), "foo4")
self.assertEqual(identifiers.pick_col_ident("foo1", avoid={"foo1", "foo2", "foo1_2"}), "foo1_3")
self.assertEqual(identifiers.pick_col_ident(""), "A")
self.assertEqual(identifiers.pick_table_ident(""), "Table1")
self.assertEqual(identifiers.pick_col_ident("", avoid={"A"}), "B")
self.assertEqual(identifiers.pick_col_ident("", avoid={"A","B"}), "C")
self.assertEqual(identifiers.pick_col_ident(None, avoid={"A","B"}), "C")
self.assertEqual(identifiers.pick_col_ident("", avoid={'a','b','c','d','E'}), 'F')
self.assertEqual(identifiers.pick_col_ident(2, avoid={"c2"}), "c2_2")
large_set = set()
for i in xrange(730):
large_set.add(identifiers._gen_ident(large_set))
self.assertEqual(identifiers.pick_col_ident("", avoid=large_set), "ABC")
def test_pick_table_ident(self):
self.assertEqual(identifiers.pick_table_ident("123asdf"), "T123asdf")
self.assertEqual(identifiers.pick_table_ident("!@#"), "Table1")
self.assertEqual(identifiers.pick_table_ident("!@#1"), "T1")
self.assertEqual(identifiers.pick_table_ident("heLLO world"), "HeLLO_world")
self.assertEqual(identifiers.pick_table_ident("foo", avoid={"Foo"}), "Foo2")
self.assertEqual(identifiers.pick_table_ident("foo", avoid={"Foo", "Foo2"}), "Foo3")
self.assertEqual(identifiers.pick_table_ident("FOO", avoid={"foo", "foo2"}), "FOO3")
self.assertEqual(identifiers.pick_table_ident(None, avoid={"Table"}), "Table1")
self.assertEqual(identifiers.pick_table_ident(None, avoid={"Table1"}), "Table2")
self.assertEqual(identifiers.pick_table_ident("!@#", avoid={"Table1"}), "Table2")
self.assertEqual(identifiers.pick_table_ident(None, avoid={"Table1", "Table2"}), "Table3")
large_set = set()
for i in xrange(730):
large_set.add("Table%d" % i)
self.assertEqual(identifiers.pick_table_ident("", avoid=large_set), "Table730")
def test_pick_col_ident_list(self):
self.assertEqual(identifiers.pick_col_ident_list(["foo", "bar"], avoid={"bar"}),
["foo", "bar2"])
self.assertEqual(identifiers.pick_col_ident_list(["bar", "bar"], avoid={"foo"}),
["bar", "bar2"])
self.assertEqual(identifiers.pick_col_ident_list(["bar", "bar"], avoid={"bar"}),
["bar2", "bar3"])
self.assertEqual(identifiers.pick_col_ident_list(["bAr", "BAR"], avoid={"bar"}),
["bAr2", "BAR3"])
def test_gen_ident(self):
self.assertEqual(identifiers._gen_ident(set()), 'A')
self.assertEqual(identifiers._gen_ident({'A'}), 'B')
self.assertEqual(identifiers._gen_ident({'foo','E','F','H'}), 'A')
self.assertEqual(identifiers._gen_ident({'a','b','c','d','E'}), 'F')
def test_get_grist_type(self):
self.assertEqual(gencode.get_grist_type("Ref:Foo"), "grist.Reference('Foo')")
self.assertEqual(gencode.get_grist_type("RefList:Foo"), "grist.ReferenceList('Foo')")
self.assertEqual(gencode.get_grist_type("Int"), "grist.Int()")
self.assertEqual(gencode.get_grist_type("DateTime:America/NewYork"),
"grist.DateTime('America/NewYork')")
self.assertEqual(gencode.get_grist_type("DateTime:"), "grist.DateTime()")
self.assertEqual(gencode.get_grist_type("DateTime"), "grist.DateTime()")
self.assertEqual(gencode.get_grist_type("DateTime: foo bar "), "grist.DateTime('foo bar')")
self.assertEqual(gencode.get_grist_type("DateTime: "), "grist.DateTime()")
self.assertEqual(gencode.get_grist_type("RefList:\n ~!@#$%^&*'\":;,\t"),
"grist.ReferenceList('~!@#$%^&*\\'\":;,')")
def test_grist_names(self):
# Verifies that we can correctly extract the names of Grist objects that occur in formulas.
# This is used by automatic formula adjustments when columns or tables get renamed.
gcode = gencode.GenCode()
gcode.make_module(self.schema)
# The output of grist_names is described in codebuilder.py, and copied here:
# col_info: (table_id, col_id) for the formula the name is found in. It is the value passed
# in by gencode.py to codebuilder.make_formula_body().
# start_pos: Index of the start character of the name in the text of the formula.
# table_id: Parsed name when the tuple is for a table name; the name of the column's table
# when the tuple is for a column name.
# col_id: None when tuple is for a table name; col_id when the tuple is for a column name.
expected_names = [
(('Students', 'fullName'), 4, 'Students', 'firstName'),
(('Students', 'fullName'), 26, 'Students', 'lastName'),
(('Students', 'fullNameLen'), 8, 'Students', 'fullName'),
(('Students', 'schoolShort'), 11, 'Schools', 'name'),
(('Students', 'schoolShort'), 4, 'Students', 'school'),
(('Students', 'schoolRegion'), 15, 'Schools', 'address'),
(('Students', 'schoolRegion'), 8, 'Students', 'school'),
(('Students', 'schoolRegion'), 42, 'Address', 'country'),
(('Students', 'schoolRegion'), 28, 'Address', 'state'),
(('Students', 'schoolRegion'), 68, 'Address', 'region'),
(('Students', 'school2'), 0, 'Schools', None),
(('Students', 'school2'), 36, 'Schools', 'name'),
(('Students', 'school2'), 29, 'Students', 'school'),
(('Address', 'region'), 48, 'Address', 'country'),
]
self.assertEqual(gcode.grist_names(), expected_names)
# Test the case of a bare-word function with a keyword argument appearing in a formula. This
# case had a bug with code parsing.
self.schema['Address'].columns['testcol'] = schema.SchemaColumn(
'testcol', 'Any', True, 'foo(bar=$region) or max(Students.all, key=lambda n: -n)')
gcode.make_module(self.schema)
self.assertEqual(gcode.grist_names(), expected_names + [
(('Address', 'testcol'), 9, 'Address', 'region'),
(('Address', 'testcol'), 24, 'Students', None),
])
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,159 @@
import unittest
import gpath
class TestGpath(unittest.TestCase):
def setUp(self):
self.obj = {
"foo": [{"bar": 1}, {"bar": 2}, {"baz": 3}],
"hello": "world"
}
def test_get(self):
self.assertEqual(gpath.get(self.obj, ["foo", 0, "bar"]), 1)
self.assertEqual(gpath.get(self.obj, ["foo", 2]), {"baz": 3})
self.assertEqual(gpath.get(self.obj, ["hello"]), "world")
self.assertEqual(gpath.get(self.obj, []), self.obj)
self.assertEqual(gpath.get(self.obj, ["foo", 0, "baz"]), None)
self.assertEqual(gpath.get(self.obj, ["foo", 4]), None)
self.assertEqual(gpath.get(self.obj, ["foo", 4, "baz"]), None)
self.assertEqual(gpath.get(self.obj, [0]), None)
def test_set(self):
gpath.place(self.obj, ["foo"], {"bar": 1, "baz": 2})
self.assertEqual(self.obj["foo"], {"bar": 1, "baz": 2})
gpath.place(self.obj, ["foo", "bar"], 17)
self.assertEqual(self.obj["foo"], {"bar": 17, "baz": 2})
gpath.place(self.obj, ["foo", "baz"], None)
self.assertEqual(self.obj["foo"], {"bar": 17})
self.assertEqual(self.obj["hello"], "world")
gpath.place(self.obj, ["hello"], None)
self.assertFalse("hello" in self.obj)
gpath.place(self.obj, ["hello"], None) # OK to remove a non-existent property.
self.assertFalse("hello" in self.obj)
gpath.place(self.obj, ["hello"], "blah")
self.assertEqual(self.obj["hello"], "blah")
def test_set_strict(self):
with self.assertRaisesRegexp(Exception, r"non-existent"):
gpath.place(self.obj, ["bar", 4], 17)
with self.assertRaisesRegexp(Exception, r"not a plain object"):
gpath.place(self.obj, ["foo", 0], 17)
def test_insert(self):
self.assertEqual(self.obj["foo"], [{"bar": 1}, {"bar": 2}, {"baz": 3}])
gpath.insert(self.obj, ["foo", 0], "asdf")
self.assertEqual(self.obj["foo"], ["asdf", {"bar": 1}, {"bar": 2}, {"baz": 3}])
gpath.insert(self.obj, ["foo", 3], "hello")
self.assertEqual(self.obj["foo"], ["asdf", {"bar": 1}, {"bar": 2}, "hello", {"baz": 3}])
gpath.insert(self.obj, ["foo", None], "world")
self.assertEqual(self.obj["foo"],
["asdf", {"bar": 1}, {"bar": 2}, "hello", {"baz": 3}, "world"])
def test_insert_strict(self):
with self.assertRaisesRegexp(Exception, r'not an array'):
gpath.insert(self.obj, ["foo"], "asdf")
with self.assertRaisesRegexp(Exception, r'invalid.*index'):
gpath.insert(self.obj, ["foo", -1], 17)
with self.assertRaisesRegexp(Exception, r'invalid.*index'):
gpath.insert(self.obj, ["foo", "foo"], 17)
def test_update(self):
"""update should update array items"""
self.assertEqual(self.obj["foo"], [{"bar": 1}, {"bar": 2}, {"baz": 3}])
gpath.update(self.obj, ["foo", 0], "asdf")
self.assertEqual(self.obj["foo"], ["asdf", {"bar": 2}, {"baz": 3}])
gpath.update(self.obj, ["foo", 2], "hello")
self.assertEqual(self.obj["foo"], ["asdf", {"bar": 2}, "hello"])
gpath.update(self.obj, ["foo", 1], None)
self.assertEqual(self.obj["foo"], ["asdf", None, "hello"])
def test_update_strict(self):
"""update should be strict"""
with self.assertRaisesRegexp(Exception, r'non-existent'):
gpath.update(self.obj, ["bar", 4], 17)
with self.assertRaisesRegexp(Exception, r'not an array'):
gpath.update(self.obj, ["foo"], 17)
with self.assertRaisesRegexp(Exception, r'invalid.*index'):
gpath.update(self.obj, ["foo", -1], 17)
with self.assertRaisesRegexp(Exception, r'invalid.*index'):
gpath.update(self.obj, ["foo", None], 17)
def test_remove(self):
"""remove should remove indices"""
self.assertEqual(self.obj["foo"], [{"bar": 1}, {"bar": 2}, {"baz": 3}])
gpath.remove(self.obj, ["foo", 0])
self.assertEqual(self.obj["foo"], [{"bar": 2}, {"baz": 3}])
gpath.remove(self.obj, ["foo", 1])
self.assertEqual(self.obj["foo"], [{"bar": 2}])
gpath.remove(self.obj, ["foo", 0])
self.assertEqual(self.obj["foo"], [])
def test_remove_strict(self):
"""remove should be strict"""
with self.assertRaisesRegexp(Exception, r'non-existent'):
gpath.remove(self.obj, ["bar", 4])
with self.assertRaisesRegexp(Exception, r'not an array'):
gpath.remove(self.obj, ["foo"])
with self.assertRaisesRegexp(Exception, r'invalid.*index'):
gpath.remove(self.obj, ["foo", -1])
with self.assertRaisesRegexp(Exception, r'invalid.*index'):
gpath.remove(self.obj, ["foo", None])
def test_glob(self):
"""glob should scan arrays"""
self.assertEqual(self.obj["foo"], [{"bar": 1}, {"bar": 2}, {"baz": 3}])
self.assertEqual(gpath.place(self.obj, ["foo", "*", "bar"], 17), 3)
self.assertEqual(self.obj["foo"], [{"bar": 17}, {"bar": 17}, {"baz": 3, "bar": 17}])
with self.assertRaisesRegexp(Exception, r'non-existent object at \/foo\/\*\/bad'):
gpath.place(self.obj, ["foo", "*", "bad", "test"], 10)
self.assertEqual(gpath.update(self.obj, ["foo", "*"], "hello"), 3)
self.assertEqual(self.obj["foo"], ["hello", "hello", "hello"])
def test_glob_strict_wildcard(self):
"""should only support tail wildcard for updates"""
with self.assertRaisesRegexp(Exception, r'invalid array index'):
gpath.remove(self.obj, ["foo", "*"])
with self.assertRaisesRegexp(Exception, r'invalid array index'):
gpath.insert(self.obj, ["foo", "*"], 1)
def test_glob_wildcard_keys(self):
"""should not scan object keys"""
self.assertEqual(self.obj["foo"], [{"bar": 1}, {"bar": 2}, {"baz": 3}])
self.assertEqual(gpath.place(self.obj, ["foo", 0, "*"], 17), 1)
self.assertEqual(self.obj["foo"], [{"bar": 1, '*': 17}, {"bar": 2}, {"baz": 3}])
with self.assertRaisesRegexp(Exception, r'non-existent'):
gpath.place(self.obj, ["*", 0, "bar"], 17)
def test_glob_nested(self):
"""should scan nested arrays"""
self.obj = [{"a": [1,2,3]}, {"a": [4,5,6]}, {"a": [7,8,9]}]
self.assertEqual(gpath.update(self.obj, ["*", "a", "*"], 5), 9)
self.assertEqual(self.obj, [{"a": [5,5,5]}, {"a": [5,5,5]}, {"a": [5,5,5]}])
def test_dirname(self):
"""dirname should return path without last component"""
self.assertEqual(gpath.dirname(["foo", "bar", "baz"]), ["foo", "bar"])
self.assertEqual(gpath.dirname([1, 2]), [1])
self.assertEqual(gpath.dirname(["foo"]), [])
self.assertEqual(gpath.dirname([]), [])
def test_basename(self):
"""basename should return the last component of path"""
self.assertEqual(gpath.basename(["foo", "bar", "baz"]), "baz")
self.assertEqual(gpath.basename([1, 2]), 2)
self.assertEqual(gpath.basename(["foo"]), "foo")
self.assertEqual(gpath.basename([]), None)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,150 @@
# pylint: disable=line-too-long
import logger
import test_engine
log = logger.Logger(__name__, logger.INFO)
class TestImportActions(test_engine.EngineTestCase):
def init_state(self):
# Add source table
self.apply_user_action(['AddTable', 'Source', [{'id': 'Name', 'type': 'Text'},
{'id': 'City', 'type': 'Text'},
{'id': 'Zip', 'type': 'Int'}]])
self.apply_user_action(['BulkAddRecord', 'Source', [1, 2], {'Name': ['John', 'Alison'],
'City': ['New York', 'Boston'],
'Zip': [03011, 07003]}])
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[1, "manualSort", "ManualSortPos", False, ""],
[2, "Name", "Text", False, ""],
[3, "City", "Text", False, ""],
[4, "Zip", "Int", False, ""],
], rows=lambda r: r.parentId.id == 1)
# Add destination table which contains columns corresponding to source table
self.apply_user_action(['AddTable', 'Destination1', [{'id': 'Name', 'type': 'Text'},
{'id': 'City', 'type': 'Text'}]])
self.apply_user_action(['BulkAddRecord', 'Destination1', [1, 2], {'Name': ['Bob'],
'City': ['New York']}])
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[5, "manualSort", "ManualSortPos", False, ""],
[6, "Name", "Text", False, ""],
[7, "City", "Text", False, ""],
], rows=lambda r: r.parentId.id == 2)
# Add destination table which has no columns corresponding to source table
self.apply_user_action(['AddTable', 'Destination2', [{'id': 'State', 'type': 'Text'}]])
self.apply_user_action(['BulkAddRecord', 'Destination2', [1, 2], {'State': ['NY']}])
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[8, "manualSort", "ManualSortPos", False, ""],
[9, "State", "Text", False, ""]
], rows=lambda r: r.parentId.id == 3)
# Verify created tables
self.assertPartialData("_grist_Tables", ["id", "tableId"], [
[1, "Source"],
[2, "Destination1"],
[3, "Destination2"],
])
# Verify created sections
self.assertPartialData("_grist_Views_section", ["id", "tableRef", 'fields'], [
[1, 1, [1, 2, 3]], # section for "Source" table
[2, 2, [4, 5]], # section for "Destination1" table
[3, 3, [6]] # section for "Destination2" table
])
def test_transform(self):
# Add source and destination tables
self.init_state()
# Update transform while importing to destination table which have
# columns with the same names as source
self.apply_user_action(['GenImporterView', 'Source', 'Destination1', None])
# Verify the new structure of source table and sections
# (two columns with special names were added)
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[1, "manualSort", "ManualSortPos", False, ""],
[2, "Name", "Text", False, ""],
[3, "City", "Text", False, ""],
[4, "Zip", "Int", False, ""],
[10, "gristHelper_Import_Name", "Text", True, "$Name"],
[11, "gristHelper_Import_City", "Text", True, "$City"],
], rows=lambda r: r.parentId.id == 1)
self.assertTableData('Source', cols="all", data=[
["id", "Name", "City", "Zip", "gristHelper_Import_Name", "gristHelper_Import_City", "manualSort"],
[1, "John", "New York", 03011, "John", "New York", 1.0],
[2, "Alison", "Boston", 07003, "Alison", "Boston", 2.0],
])
self.assertPartialData("_grist_Views_section", ["id", "tableRef", 'fields'], [
[1, 1, [1, 2, 3]],
[2, 2, [4, 5]],
[3, 3, [6]],
[4, 1, [7, 8]] # new section for transform preview
])
# Apply useraction again to verify that old columns and sections are removing
# Update transform while importing to destination table which has no common columns with source
self.apply_user_action(['GenImporterView', 'Source', 'Destination2', None])
# Verify the new structure of source table and sections (old special columns were removed
# and one new columns with empty formula were added)
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[1, "manualSort", "ManualSortPos", False, ""],
[2, "Name", "Text", False, ""],
[3, "City", "Text", False, ""],
[4, "Zip", "Int", False, ""],
[10, "gristHelper_Import_State", "Text", True, ""]
], rows=lambda r: r.parentId.id == 1)
self.assertTableData('Source', cols="all", data=[
["id", "Name", "City", "Zip", "gristHelper_Import_State", "manualSort"],
[1, "John", "New York", 03011, "", 1.0],
[2, "Alison", "Boston", 07003, "", 2.0],
])
self.assertPartialData("_grist_Views_section", ["id", "tableRef", 'fields'], [
[1, 1, [1, 2, 3]],
[2, 2, [4, 5]],
[3, 3, [6]],
[4, 1, [7]] # new section for transform preview
])
def test_transform_destination_new_table(self):
# Add source and destination tables
self.init_state()
# Update transform while importing to destination table which is "New Table"
self.apply_user_action(['GenImporterView', 'Source', None, None])
# Verify the new structure of source table and sections (old special columns were removed
# and three new columns, which are the same as in source table were added)
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[1, "manualSort", "ManualSortPos", False, ""],
[2, "Name", "Text", False, ""],
[3, "City", "Text", False, ""],
[4, "Zip", "Int", False, ""],
[10, "gristHelper_Import_Name", "Text", True, "$Name"],
[11, "gristHelper_Import_City", "Text", True, "$City"],
[12, "gristHelper_Import_Zip", "Int", True, "$Zip"],
], rows=lambda r: r.parentId.id == 1)
self.assertTableData('Source', cols="all", data=[
["id", "Name", "City", "Zip", "gristHelper_Import_Name", "gristHelper_Import_City", "gristHelper_Import_Zip", "manualSort"],
[1, "John", "New York", 03011, "John", "New York", 03011, 1.0],
[2, "Alison", "Boston", 07003, "Alison", "Boston", 07003, 2.0],
])
self.assertPartialData("_grist_Views_section", ["id", "tableRef", 'fields'], [
[1, 1, [1, 2, 3]],
[2, 2, [4, 5]],
[3, 3, [6]],
[4, 1, [7, 8, 9]], # new section for transform preview
])

@ -0,0 +1,174 @@
# pylint: disable=line-too-long
import logger
import test_engine
log = logger.Logger(__name__, logger.INFO)
#TODO: test naming (basics done, maybe check numbered column renaming)
#TODO: check autoimport into existing table (match up column names)
class TestImportTransform(test_engine.EngineTestCase):
def init_state(self):
# Add source table
self.apply_user_action(['AddTable', 'Hidden_table', [
{'id': 'fname', 'type': 'Text'},
{'id': 'mname', 'type': 'Text'},
{'id': 'lname', 'type': 'Text'},
]])
self.apply_user_action(['BulkAddRecord', 'Hidden_table', [1, 2], {'fname': ['Carry', 'Don'],
'mname': ['M.', 'B.'],
'lname': ['Jonson', "Yoon"]
}])
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[1, "manualSort", "ManualSortPos", False, ""],
[2, "fname", "Text", False, ""],
[3, "mname", "Text", False, ""],
[4, "lname", "Text", False, ""],
], rows=lambda r: r.parentId.id == 1)
#Filled in colids for existing table
self.TEMP_transform_rule_colids = {
"destCols": [
{ "colId": "First_Name", "label": "First Name",
"type": "Text", "formula": "$fname" },
{ "colId": "Last_Name", "label": "Last Name",
"type": "Text", "formula": "$lname" },
{ "colId": "Middle_Initial", "label": "Middle Initial",
"type": "Text", "formula": "$mname[0]" },
#{ "colId": "Blank", "label": "Blank", //destination1 has no blank column
# "type": "Text", "formula": "" },
]
}
#Then try it with blank in colIds (for new tables)
self.TEMP_transform_rule_no_colids = {
"destCols": [
{ "colId": None, "label": "First Name",
"type": "Text", "formula": "$fname" },
{ "colId": None, "label": "Last Name",
"type": "Text", "formula": "$lname" },
{ "colId": None, "label": "Middle Initial",
"type": "Text", "formula": "$mname[0]" },
{ "colId": None, "label": "Blank",
"type": "Text", "formula": "" },
]
}
# Add destination table which contains columns corresponding to source table with different names
self.apply_user_action(['AddTable', 'Destination1', [
{'label': 'First Name', 'id': 'First_Name', 'type': 'Text'},
{'label': 'Last Name', 'id': 'Last_Name', 'type': 'Text'},
{'label': 'Middle Initial', 'id': 'Middle_Initial', 'type': 'Text'}]])
self.apply_user_action(['BulkAddRecord', 'Destination1', [1], {'First_Name': ['Bob'],
'Last_Name': ['Nike'],
'Middle_Initial': ['F.']}])
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[5, "manualSort", "ManualSortPos", False, ""],
[6, "First_Name", "Text", False, ""],
[7, "Last_Name", "Text", False, ""],
[8, "Middle_Initial","Text", False, ""],
], rows=lambda r: r.parentId.id == 2)
# Verify created tables
self.assertPartialData("_grist_Tables", ["id", "tableId"], [
[1, "Hidden_table"],
[2, "Destination1"]
])
def test_finish_import_into_new_table(self):
# Add source and destination tables
self.init_state()
#into_new_table = True, transform_rule : no colids (will be generated for new table)
self.apply_user_action(['TransformAndFinishImport', 'Hidden_table', 'NewTable', True, self.TEMP_transform_rule_no_colids])
#1-4 in hidden table, 5-8 in destTable, 9-13 for new table
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[ 9, "manualSort", "ManualSortPos", False, ""],
[10, "First_Name", "Text", False, ""],
[11, "Last_Name", "Text", False, ""],
[12, "Middle_Initial", "Text", False, ""],
[13, "Blank", "Text", False, ""],
], rows=lambda r: r.parentId.id == 3)
self.assertTableData('NewTable', cols="all", data=[
["id", "First_Name", "Last_Name", "Middle_Initial", "Blank", "manualSort"],
[1, "Carry", "Jonson", "M", "", 1.0],
[2, "Don", "Yoon", "B", "", 2.0]
])
# Verify removed hidden table and add the new one
self.assertPartialData("_grist_Tables", ["id", "tableId"], [
[2, "Destination1"],
[3, "NewTable"]
])
def test_finish_import_into_existing_table(self):
# Add source and destination tables
self.init_state()
#into_new_table false, transform_rule=null
self.apply_user_action(['TransformAndFinishImport', 'Hidden_table', 'Destination1', False, self.TEMP_transform_rule_colids])
#1-4 in hidden table, 5-8 in destTable
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[5, "manualSort", "ManualSortPos", False, ""],
[6, "First_Name", "Text", False, ""],
[7, "Last_Name", "Text", False, ""],
[8, "Middle_Initial", "Text", False, ""],
], rows=lambda r: r.parentId.id == 2)
self.assertTableData('Destination1', cols="all", data=[
["id", "First_Name", "Last_Name", "Middle_Initial", "manualSort"],
[1, "Bob", "Nike", "F.", 1.0], #F. was there to begin with
[2, "Carry", "Jonson", "M", 2.0], #others imported with $mname[0]
[3, "Don", "Yoon", "B", 3.0],
])
# Verify removed hidden table
self.assertPartialData("_grist_Tables", ["id", "tableId"], [[2, "Destination1"]])
#does the same thing using a blank transform rule
def test_finish_import_into_new_table_blank(self):
# Add source and destination tables
self.init_state()
#into_new_table = True, transform_rule : no colids (will be generated for new table)
self.apply_user_action(['TransformAndFinishImport', 'Hidden_table', 'NewTable', True, None])
#1-4 in src table, 5-8 in hiddentable
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[9, "manualSort", "ManualSortPos", False, ""],
[10, "fname", "Text", False, ""],
[11, "mname", "Text", False, ""],
[12, "lname", "Text", False, ""],
], rows=lambda r: r.parentId.id == 3)
self.assertTableData('NewTable', cols="all", data=[
["id", "fname", "lname", "mname", "manualSort"],
[1, "Carry", "Jonson", "M.", 1.0],
[2, "Don", "Yoon", "B.", 2.0]
])
# Verify removed hidden table and add the new one
self.assertPartialData("_grist_Tables", ["id", "tableId"], [
[2, "Destination1"],
[3, "NewTable"]
])

@ -0,0 +1,38 @@
import unittest
import logger
class TestLogger(unittest.TestCase):
def _log_handler(self, level, name, msg):
self.messages.append((level, name, msg))
def setUp(self):
self.messages = []
self.orig_handler = logger.set_handler(self._log_handler)
def tearDown(self):
logger.set_handler(self.orig_handler)
def test_logger(self):
log = logger.Logger("foo", logger.INFO)
log.info("Hello Info")
log.debug("Hello Debug")
log.warn("Hello Warn")
self.assertEqual(self.messages, [
(logger.INFO, 'foo', 'Hello Info'),
(logger.WARN, 'foo', 'Hello Warn'),
])
del self.messages[:]
log = logger.Logger("baz", logger.DEBUG)
log.debug("Hello Debug")
log.info("Hello Info")
self.assertEqual(self.messages, [
(logger.DEBUG, 'baz', 'Hello Debug'),
(logger.INFO, 'baz', 'Hello Info'),
])
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,646 @@
import actions
import logger
import testsamples
import testutil
import test_engine
log = logger.Logger(__name__, logger.INFO)
def _bulk_update(table_name, col_names, row_data):
return actions.BulkUpdateRecord(
*testutil.table_data_from_rows(table_name, col_names, row_data))
class TestLookups(test_engine.EngineTestCase):
def test_verify_sample(self):
self.load_sample(testsamples.sample_students)
self.assertPartialData("Students", ["id", "schoolIds", "schoolCities" ], [
[1, "1:2", "New York:Colombia" ],
[2, "3:4", "New Haven:West Haven" ],
[3, "1:2", "New York:Colombia" ],
[4, "3:4", "New Haven:West Haven" ],
[5, "", ""],
[6, "3:4", "New Haven:West Haven" ]
])
#----------------------------------------
def test_lookup_dependencies(self, pre_loaded=False):
"""
Test changes to records accessed via lookup.
"""
if not pre_loaded:
self.load_sample(testsamples.sample_students)
out_actions = self.update_record("Address", 14, city="Bedford")
self.assertPartialOutActions(out_actions, {
"calc": [_bulk_update("Students", ["id", "schoolCities" ], [
[2, "New Haven:Bedford" ],
[4, "New Haven:Bedford" ],
[6, "New Haven:Bedford" ]]
)],
"calls": {"Students": {"schoolCities": 3}}
})
out_actions = self.update_record("Schools", 4, address=13)
self.assertPartialOutActions(out_actions, {
"calc": [_bulk_update("Students", ["id", "schoolCities" ], [
[2, "New Haven:New Haven" ],
[4, "New Haven:New Haven" ],
[6, "New Haven:New Haven" ]]
)],
"calls": {"Students": {"schoolCities": 3}}
})
out_actions = self.update_record("Address", 14, city="Hartford")
# No schoolCities need to be recalculatd here, since nothing depends on Address 14 any more.
self.assertPartialOutActions(out_actions, {
"calls": {}
})
# Confirm the final result.
self.assertPartialData("Students", ["id", "schoolIds", "schoolCities" ], [
[1, "1:2", "New York:Colombia" ],
[2, "3:4", "New Haven:New Haven" ],
[3, "1:2", "New York:Colombia" ],
[4, "3:4", "New Haven:New Haven" ],
[5, "", ""],
[6, "3:4", "New Haven:New Haven" ]
])
#----------------------------------------
def test_dependency_reset(self, pre_loaded=False):
"""
A somewhat tricky case. We know that Student 2 depends on Schools 3,4 and on Address 13,14.
If we change Student 2 to depend on nothing, then changing Address 13 should not cause it to
recompute.
"""
if not pre_loaded:
self.load_sample(testsamples.sample_students)
out_actions = self.update_record("Address", 13, city="AAA")
self.assertPartialOutActions(out_actions, {
"calls": {"Students": {"schoolCities": 3}} # Initially 3 students depend on Address 13.
})
out_actions = self.update_record("Students", 2, schoolName="Invalid")
out_actions = self.update_record("Address", 13, city="BBB")
# If the count below is 3, then the engine forgot to reset the dependencies of Students 2.
self.assertPartialOutActions(out_actions, {
"calls": {"Students": {"schoolCities": 2}} # Now only 2 Students depend on Address 13.
})
#----------------------------------------
def test_lookup_key_changes(self, pre_loaded=False):
"""
Test changes to lookup values in the target table. Note that student #3 does not depend on
any records, but depends on the value "Eureka", so gets updated when this value appears.
"""
if not pre_loaded:
self.load_sample(testsamples.sample_students)
out_actions = self.update_record("Schools", 2, name="Eureka")
self.assertPartialOutActions(out_actions, {
"calc": [
actions.BulkUpdateRecord("Students", [1,3,5], {
'schoolCities': ["New York", "New York", "Colombia"]
}),
actions.BulkUpdateRecord("Students", [1,3,5], {
'schoolIds': ["1", "1","2"]
}),
],
"calls": {"Students": { 'schoolCities': 3, 'schoolIds': 3 },
"Schools": {'#lookup#name': 1} },
})
# Test changes to lookup values in the table doing the lookup.
out_actions = self.update_records("Students", ["id", "schoolName"], [
[3, ""],
[5, "Yale"]
])
self.assertPartialOutActions(out_actions, {
"calc": [
actions.BulkUpdateRecord("Students", [3,5], {'schoolCities': ["", "New Haven:West Haven"]}),
actions.BulkUpdateRecord("Students", [3,5], {'schoolIds': ["", "3:4"]}),
],
"calls": { "Students": { 'schoolCities': 2, 'schoolIds': 2 } },
})
# Confirm the final result.
self.assertPartialData("Students", ["id", "schoolIds", "schoolCities" ], [
[1, "1", "New York" ],
[2, "3:4", "New Haven:West Haven" ],
[3, "", "" ],
[4, "3:4", "New Haven:West Haven" ],
[5, "3:4", "New Haven:West Haven" ],
[6, "3:4", "New Haven:West Haven" ]
])
#----------------------------------------
def test_lookup_formula_after_schema_change(self):
self.load_sample(testsamples.sample_students)
self.add_column("Schools", "state", type="Text")
# Make a change that causes recomputation of a lookup formula after a schema change.
# We should NOT get attribute errors in the values.
out_actions = self.update_record("Schools", 4, address=13)
self.assertPartialOutActions(out_actions, {
"calc": [_bulk_update("Students", ["id", "schoolCities" ], [
[2, "New Haven:New Haven" ],
[4, "New Haven:New Haven" ],
[6, "New Haven:New Haven" ]]
)],
"calls": { "Students": { 'schoolCities': 3 } }
})
#----------------------------------------
def test_lookup_formula_changes(self):
self.load_sample(testsamples.sample_students)
self.add_column("Schools", "state", type="Text")
self.update_records("Schools", ["id", "state"], [
[1, "NY"],
[2, "MO"],
[3, "CT"],
[4, "CT"]
])
# Verify that when we change a formula, we get appropriate changes.
out_actions = self.modify_column("Students", "schoolCities", formula=(
"','.join(Schools.lookupRecords(name=$schoolName).state)"))
self.assertPartialOutActions(out_actions, {
"calc": [_bulk_update("Students", ["id", "schoolCities" ], [
[1, "NY,MO" ],
[2, "CT,CT" ],
[3, "NY,MO" ],
[4, "CT,CT" ],
[6, "CT,CT" ]]
)],
# Note that it got computed 6 times (once for each record), but one value remained unchanged
# (because no schools matched).
"calls": { "Students": { 'schoolCities': 6 } }
})
# Check that we've created new dependencies, and removed old ones.
out_actions = self.update_record("Schools", 4, address=13)
self.assertPartialOutActions(out_actions, {
"calls": {}
})
out_actions = self.update_record("Schools", 4, state="MA")
self.assertPartialOutActions(out_actions, {
"calc": [_bulk_update("Students", ["id", "schoolCities" ], [
[2, "CT,MA" ],
[4, "CT,MA" ],
[6, "CT,MA" ]]
)],
"calls": { "Students": { 'schoolCities': 3 } }
})
# If we change to look up uppercase values, we shouldn't find anything.
out_actions = self.modify_column("Students", "schoolCities", formula=(
"','.join(Schools.lookupRecords(name=$schoolName.upper()).state)"))
self.assertPartialOutActions(out_actions, {
"calc": [actions.BulkUpdateRecord("Students", [1,2,3,4,6],
{'schoolCities': ["","","","",""]})],
"calls": { "Students": { 'schoolCities': 6 } }
})
# Changes to dependencies should cause appropriate recalculations.
out_actions = self.update_record("Schools", 4, state="KY", name="EUREKA")
self.assertPartialOutActions(out_actions, {
"calc": [
actions.UpdateRecord("Students", 5, {'schoolCities': "KY"}),
actions.BulkUpdateRecord("Students", [2,4,6], {'schoolIds': ["3","3","3"]}),
],
"calls": {"Students": { 'schoolCities': 1, 'schoolIds': 3 },
'Schools': {'#lookup#name': 1 } }
})
self.assertPartialData("Students", ["id", "schoolIds", "schoolCities" ], [
# schoolCities aren't found here because we changed formula to lookup uppercase names.
[1, "1:2", "" ],
[2, "3", "" ],
[3, "1:2", "" ],
[4, "3", "" ],
[5, "", "KY" ],
[6, "3", "" ]
])
def test_add_remove_lookup(self):
# Verify that when we add or remove a lookup formula, we get appropriate changes.
self.load_sample(testsamples.sample_students)
# Add another lookup formula.
out_actions = self.add_column("Schools", "lastNames", formula=(
"','.join(Students.lookupRecords(schoolName=$name).lastName)"))
self.assertPartialOutActions(out_actions, {
"calc": [_bulk_update("Schools", ["id", "lastNames"], [
[1, "Obama,Clinton"],
[2, "Obama,Clinton"],
[3, "Bush,Bush,Ford"],
[4, "Bush,Bush,Ford"]]
)],
"calls": {"Schools": {"lastNames": 4}, "Students": {"#lookup#schoolName": 6}},
})
# Make sure it responds to changes.
out_actions = self.update_record("Students", 5, schoolName="Columbia")
self.assertPartialOutActions(out_actions, {
"calc": [
_bulk_update("Schools", ["id", "lastNames"], [
[1, "Obama,Clinton,Reagan"],
[2, "Obama,Clinton,Reagan"]]
),
actions.UpdateRecord("Students", 5, {"schoolCities": "New York:Colombia"}),
actions.UpdateRecord("Students", 5, {"schoolIds": "1:2"}),
],
"calls": {"Students": {'schoolCities': 1, 'schoolIds': 1, '#lookup#schoolName': 1},
"Schools": { 'lastNames': 2 }},
})
# Modify the column: in the process, the LookupMapColumn on Students.schoolName becomes unused
# while the old formula column is removed, but used again when it's added. It should not have
# to be rebuilt (so there should be no calls to recalculate the LookupMapColumn.
out_actions = self.modify_column("Schools", "lastNames", formula=(
"','.join(Students.lookupRecords(schoolName=$name).firstName)"))
self.assertPartialOutActions(out_actions, {
"calc": [_bulk_update("Schools", ["id", "lastNames"], [
[1, "Barack,Bill,Ronald"],
[2, "Barack,Bill,Ronald"],
[3, "George W,George H,Gerald"],
[4, "George W,George H,Gerald"]]
)],
"calls": {"Schools": {"lastNames": 4}}
})
# Remove the new lookup formula.
out_actions = self.remove_column("Schools", "lastNames")
self.assertPartialOutActions(out_actions, {}) # No calc actions
# Make sure that changes still work without errors.
out_actions = self.update_record("Students", 5, schoolName="Eureka")
self.assertPartialOutActions(out_actions, {
"calc": [
actions.UpdateRecord("Students", 5, {"schoolCities": ""}),
actions.UpdateRecord("Students", 5, {"schoolIds": ""}),
],
# This should NOT have '#lookup#schoolName' recalculation because there are no longer any
# formulas which do such a lookup.
"calls": { "Students": {'schoolCities': 1, 'schoolIds': 1}}
})
def test_multi_column_lookups(self):
"""
Check that we can do lookups by multiple columns.
"""
self.load_sample(testsamples.sample_students)
# Add a lookup formula which looks up a student matching on both first and last names.
self.add_column("Schools", "bestStudent", type="Text")
self.update_record("Schools", 1, bestStudent="Bush,George W")
self.add_column("Schools", "bestStudentId", formula=("""
if not $bestStudent: return ""
ln, fn = $bestStudent.split(",")
return ",".join(str(r.id) for r in Students.lookupRecords(firstName=fn, lastName=ln))
"""))
# Check data so far: only one record is filled.
self.assertPartialData("Schools", ["id", "bestStudent", "bestStudentId" ], [
[1, "Bush,George W", "2" ],
[2, "", "" ],
[3, "", "" ],
[4, "", "" ],
])
# Fill a few more records and check that we find records we should, and don't find those we
# shouldn't.
out_actions = self.update_records("Schools", ["id", "bestStudent"], [
[2, "Clinton,Bill"],
[3, "Norris,Chuck"],
[4, "Bush,George H"],
])
self.assertPartialOutActions(out_actions, {
"calc": [actions.BulkUpdateRecord("Schools", [2, 4], {"bestStudentId": ["3", "4"]})],
"calls": {"Schools": {"bestStudentId": 3}}
})
self.assertPartialData("Schools", ["id", "bestStudent", "bestStudentId" ], [
[1, "Bush,George W", "2" ],
[2, "Clinton,Bill", "3" ],
[3, "Norris,Chuck", "" ],
[4, "Bush,George H", "4" ],
])
# Now add more records, first matching only some of the lookup fields.
out_actions = self.add_record("Students", firstName="Chuck", lastName="Morris")
self.assertPartialOutActions(out_actions, {
"calls": {
# No calculations of anything Schools because nothing depends on the incomplete value.
"Students": {"#lookup#firstName:lastName": 2, "schoolIds": 1, "schoolCities": 1}
},
"retValues": [7],
})
# If we add a matching record, then we get a calculation of a record in Schools
out_actions = self.add_record("Students", firstName="Chuck", lastName="Norris")
self.assertPartialOutActions(out_actions, {
"calls": {
"Students": {"#lookup#firstName:lastName": 2, "schoolIds": 1, "schoolCities": 1},
"Schools": {"bestStudentId": 1}
},
"retValues": [8],
})
# And the data should be correct.
self.assertPartialData("Schools", ["id", "bestStudent", "bestStudentId" ], [
[1, "Bush,George W", "2" ],
[2, "Clinton,Bill", "3" ],
[3, "Norris,Chuck", "8" ],
[4, "Bush,George H", "4" ],
])
def test_record_removal(self):
# Remove a record, make sure that lookup maps get updated.
self.load_sample(testsamples.sample_students)
out_actions = self.remove_record("Schools", 3)
self.assertPartialOutActions(out_actions, {
"calc": [
actions.BulkUpdateRecord("Students", [2,4,6], {
"schoolCities": ["West Haven","West Haven","West Haven"]}),
actions.BulkUpdateRecord("Students", [2,4,6], {
"schoolIds": ["4","4","4"]}),
],
"calls": {
"Students": {"schoolIds": 3, "schoolCities": 3},
# LookupMapColumn is also updated but via a different path (unset() vs method() call), so
# it's not included in the count of formula calls.
}
})
self.assertPartialData("Students", ["id", "schoolIds", "schoolCities" ], [
[1, "1:2", "New York:Colombia" ],
[2, "4", "West Haven" ],
[3, "1:2", "New York:Colombia" ],
[4, "4", "West Haven" ],
[5, "", ""],
[6, "4", "West Haven" ]
])
def test_empty_relation(self):
# Make sure that when a relation becomes empty, it doesn't get messed up.
self.load_sample(testsamples.sample_students)
# Clear out dependencies.
self.update_records("Students", ["id", "schoolName"],
[ [i, ""] for i in [1,2,3,4,5,6] ])
self.assertPartialData("Students", ["id", "schoolIds", "schoolCities" ],
[ [i, "", ""] for i in [1,2,3,4,5,6] ])
# Make a number of changeas, to ensure they reuse rather than re-create _LookupRelations.
self.update_record("Students", 2, schoolName="Yale")
self.update_record("Students", 2, schoolName="Columbia")
self.update_record("Students", 3, schoolName="Columbia")
self.assertPartialData("Students", ["id", "schoolIds", "schoolCities" ], [
[1, "", ""],
[2, "1:2", "New York:Colombia" ],
[3, "1:2", "New York:Colombia" ],
[4, "", ""],
[5, "", ""],
[6, "", ""],
])
# When we messed up the dependencies, this change didn't cause a corresponding update. Check
# that it now does.
self.remove_record("Schools", 2)
self.assertPartialData("Students", ["id", "schoolIds", "schoolCities" ], [
[1, "", ""],
[2, "1", "New York" ],
[3, "1", "New York" ],
[4, "", ""],
[5, "", ""],
[6, "", ""],
])
def test_lookups_of_computed_values(self):
"""
Make sure that lookups get updated when the value getting looked up is a formula result.
"""
self.load_sample(testsamples.sample_students)
# Add a column like Schools.name, but computed, and change schoolIds to use that one instead.
self.add_column("Schools", "cname", formula="$name")
self.modify_column("Students", "schoolIds", formula=
"':'.join(str(id) for id in Schools.lookupRecords(cname=$schoolName).id)")
self.assertPartialData("Students", ["id", "schoolIds" ], [
[1, "1:2" ],
[2, "3:4" ],
[3, "1:2" ],
[4, "3:4" ],
[5, "" ],
[6, "3:4" ],
])
# Check that a change to School.name, which triggers a change to School.cname, causes a change
# to the looked-up ids. The changes here should be the same as in test_lookup_key_changes
# test, even though schoolIds depends on name indirectly.
out_actions = self.update_record("Schools", 2, name="Eureka")
self.assertPartialOutActions(out_actions, {
"calc": [
actions.UpdateRecord("Schools", 2, {"cname": "Eureka"}),
actions.BulkUpdateRecord("Students", [1,3,5], {
'schoolCities': ["New York", "New York", "Colombia"]
}),
actions.BulkUpdateRecord("Students", [1,3,5], {
'schoolIds': ["1", "1","2"]
}),
],
"calls": {"Students": { 'schoolCities': 3, 'schoolIds': 3 },
"Schools": {'#lookup#name': 1, '#lookup#cname': 1, "cname": 1} },
})
def use_saved_lookup_results(self):
"""
This sets up data so that lookupRecord results are stored in a column and used in another. Key
tests that check lookup dependencies should work unchanged with this setup.
"""
self.load_sample(testsamples.sample_students)
# Split up Students.schoolCities into Students.schools and Students.schoolCities.
self.add_column("Students", "schools", formula="Schools.lookupRecords(name=$schoolName)",
type="RefList:Schools")
self.modify_column("Students", "schoolCities",
formula="':'.join(r.address.city for r in $schools)")
# The following tests check correctness of dependencies when lookupResults are stored in one
# column and used in another. They reuse existing test cases with modified data.
def test_lookup_dependencies_reflist(self):
self.use_saved_lookup_results()
self.test_lookup_dependencies(pre_loaded=True)
# Confirm the final result including the additional 'schools' column.
self.assertPartialData("Students", ["id", "schools", "schoolIds", "schoolCities" ], [
[1, [1,2], "1:2", "New York:Colombia" ],
[2, [3,4], "3:4", "New Haven:New Haven" ],
[3, [1,2], "1:2", "New York:Colombia" ],
[4, [3,4], "3:4", "New Haven:New Haven" ],
[5, [], "", ""],
[6, [3,4], "3:4", "New Haven:New Haven" ]
])
def test_dependency_reset_reflist(self):
self.use_saved_lookup_results()
self.test_dependency_reset(pre_loaded=True)
def test_lookup_key_changes_reflist(self):
# We can't run this test case unchanged since our new column changes too in this test.
self.use_saved_lookup_results()
out_actions = self.update_record("Schools", 2, name="Eureka")
self.assertPartialOutActions(out_actions, {
"calc": [
actions.BulkUpdateRecord('Students', [1,3,5], {'schools': [[1],[1],[2]]}),
actions.BulkUpdateRecord("Students", [1,3,5], {
'schoolCities': ["New York", "New York", "Colombia"]
}),
actions.BulkUpdateRecord("Students", [1,3,5], {
'schoolIds': ["1", "1","2"]
}),
],
"calls": {"Students": { 'schools': 3, 'schoolCities': 3, 'schoolIds': 3 },
"Schools": {'#lookup#name': 1} },
})
# Test changes to lookup values in the table doing the lookup.
out_actions = self.update_records("Students", ["id", "schoolName"], [
[3, ""],
[5, "Yale"]
])
self.assertPartialOutActions(out_actions, {
"calc": [
actions.BulkUpdateRecord("Students", [3,5], {'schools': [[], [3,4]]}),
actions.BulkUpdateRecord("Students", [3,5], {'schoolCities': ["", "New Haven:West Haven"]}),
actions.BulkUpdateRecord("Students", [3,5], {'schoolIds': ["", "3:4"]}),
],
"calls": { "Students": { 'schools': 2, 'schoolCities': 2, 'schoolIds': 2 } },
})
# Confirm the final result.
self.assertPartialData("Students", ["id", "schools", "schoolIds", "schoolCities" ], [
[1, [1], "1", "New York" ],
[2, [3,4], "3:4", "New Haven:West Haven" ],
[3, [], "", "" ],
[4, [3,4], "3:4", "New Haven:West Haven" ],
[5, [3,4], "3:4", "New Haven:West Haven" ],
[6, [3,4], "3:4", "New Haven:West Haven" ]
])
def test_dependencies_relations_bug(self):
# We had a serious bug with dependencies, for which this test verifies a fix. Imagine Table2
# has a formula a=Table1.lookupOne(A=$A), and b=$a.foo. When col A changes in Table1, columns
# a and b in Table2 get recomputed. Each recompute triggers reset_rows() which is there to
# clear lookup relations (it actually triggers reset_dependencies() which resets rows for the
# relation on each dependency edge).
#
# The first recompute (of a) triggers reset_rows() on the LookupRelation, then recomputes the
# lookup formula which re-populates the relation correctly. The second recompute (of b) also
# triggers reset_rows(). The bug was that it was triggering it in the same LookupRelation, but
# since it doesn't get followed with recomputing the lookup formula, the relation remains
# incomplete.
#
# It's important that a formula like "b=$a.foo" doesn't reuse the LookupRelation by itself on
# the edge between b and $a, but a composition of IdentityRelation and LookupRelation. The
# composition will correctly forward reset_rows() to only the first half of the relation.
# Set up two tables with a situation as described above. Here, the role of column Table2.a
# above is taken by "Students.schools=Schools.lookupRecords(name=$schoolName)".
self.use_saved_lookup_results()
# We intentionally try behavior with type Any formulas too, without converting to a reference
# type, in case that affects relations.
self.modify_column("Students", "schools", type="Any")
self.add_column("Students", "schoolsCount", formula="len($schools.name)")
self.add_column("Students", "oneSchool", formula="Schools.lookupOne(name=$schoolName)")
self.add_column("Students", "oneSchoolName", formula="$oneSchool.name")
# A helper for comparing Record objects below.
schools_table = self.engine.tables['Schools']
def SchoolsRec(row_id):
return schools_table.Record(schools_table, row_id, None)
# We'll play with schools "Columbia" and "Eureka", which are rows 1,3,5 in the Students table.
self.assertTableData("Students", cols="subset", rows="subset", data=[
["id", "schoolName", "schoolsCount", "oneSchool", "oneSchoolName"],
[1, "Columbia", 2, SchoolsRec(1), "Columbia"],
[3, "Columbia", 2, SchoolsRec(1), "Columbia"],
[5, "Eureka", 0, SchoolsRec(0), ""],
])
# Now change Schools.schoolName which should trigger recomputations.
self.update_record("Schools", 1, name="Eureka")
self.assertTableData("Students", cols="subset", rows="subset", data=[
["id", "schoolName", "schoolsCount", "oneSchool", "oneSchoolName"],
[1, "Columbia", 1, SchoolsRec(2), "Columbia"],
[3, "Columbia", 1, SchoolsRec(2), "Columbia"],
[5, "Eureka", 1, SchoolsRec(1), "Eureka"],
])
# The first change is expected to work. The important check is that the relations don't get
# corrupted afterwards. So we do a second change to see if that still updates.
self.update_record("Schools", 1, name="Columbia")
self.assertTableData("Students", cols="subset", rows="subset", data=[
["id", "schoolName", "schoolsCount", "oneSchool", "oneSchoolName"],
[1, "Columbia", 2, SchoolsRec(1), "Columbia"],
[3, "Columbia", 2, SchoolsRec(1), "Columbia"],
[5, "Eureka", 0, SchoolsRec(0), ""],
])
# One more time, for good measure.
self.update_record("Schools", 1, name="Eureka")
self.assertTableData("Students", cols="subset", rows="subset", data=[
["id", "schoolName", "schoolsCount", "oneSchool", "oneSchoolName"],
[1, "Columbia", 1, SchoolsRec(2), "Columbia"],
[3, "Columbia", 1, SchoolsRec(2), "Columbia"],
[5, "Eureka", 1, SchoolsRec(1), "Eureka"],
])
def test_vlookup(self):
self.load_sample(testsamples.sample_students)
self.add_column("Students", "school", formula="VLOOKUP(Schools, name=$schoolName)")
self.add_column("Students", "schoolCity",
formula="VLOOKUP(Schools, name=$schoolName).address.city")
# A helper for comparing Record objects below.
schools_table = self.engine.tables['Schools']
def SchoolsRec(row_id):
return schools_table.Record(schools_table, row_id, None)
# We'll play with schools "Columbia" and "Eureka", which are rows 1,3,5 in the Students table.
self.assertTableData("Students", cols="subset", rows="all", data=[
["id", "schoolName", "school", "schoolCity"],
[1, "Columbia", SchoolsRec(1), "New York" ],
[2, "Yale", SchoolsRec(3), "New Haven" ],
[3, "Columbia", SchoolsRec(1), "New York" ],
[4, "Yale", SchoolsRec(3), "New Haven" ],
[5, "Eureka", SchoolsRec(0), "" ],
[6, "Yale", SchoolsRec(3), "New Haven" ],
])
# Now change some values which should trigger recomputations.
self.update_record("Schools", 1, name="Eureka")
self.update_record("Students", 2, schoolName="Unknown")
self.assertTableData("Students", cols="subset", rows="all", data=[
["id", "schoolName", "school", "schoolCity"],
[1, "Columbia", SchoolsRec(2), "Colombia" ],
[2, "Unknown", SchoolsRec(0), "" ],
[3, "Columbia", SchoolsRec(2), "Colombia" ],
[4, "Yale", SchoolsRec(3), "New Haven" ],
[5, "Eureka", SchoolsRec(1), "New York" ],
[6, "Yale", SchoolsRec(3), "New Haven" ],
])

@ -0,0 +1,147 @@
import random
import string
import timeit
import unittest
from collections import Hashable
import match_counter
from testutil import repeat_until_passes
# Here's an alternative implementation. Unlike the simple one, it never constructs a new data
# structure, or modifies dictionary keys while iterating, but it is still slower.
class MatchCounterOther(object):
def __init__(self, _sample):
self.sample_counts = {v: 0 for v in _sample}
def count_unique(self, iterable):
for v in iterable:
try:
n = self.sample_counts.get(v)
if n is not None:
self.sample_counts[v] = n + 1
except TypeError:
pass
matches = 0
for v, n in self.sample_counts.iteritems():
if n > 0:
matches += 1
self.sample_counts[v] = 0
return matches
# If not for dealing with unhashable errors, `.intersection(iterable)` would be by far the
# fastest. But with the extra iteration and especially checking for Hashable, it's super slow.
class MatchCounterIntersection(object):
def __init__(self, _sample):
self.sample = set(_sample)
def count_unique(self, iterable):
return len(self.sample.intersection(v for v in iterable if isinstance(v, Hashable)))
# This implementation doesn't measure the intersection, but it's interesting to compare its
# timings: this is still slower! Presumably because set intersection is native code that's more
# optimized than checking membership many times from Python.
class MatchCounterSimple(object):
def __init__(self, _sample):
self.sample = set(_sample)
def count_all(self, iterable):
return sum(1 for r in iterable if present(r, self.sample))
# This is much faster than using `isinstance(v, Hashable) and v in value_set`
def present(v, value_set):
try:
return v in value_set
except TypeError:
return False
# Set up a predictable random number generator.
r = random.Random(17)
def random_string():
length = r.randint(10,20)
return ''.join(r.choice(string.ascii_letters) for x in xrange(length))
def sample_with_repl(population, n):
return [r.choice(population) for x in xrange(n)]
# Here's some sample generated data.
sample = [random_string() for x in xrange(200)]
data1 = sample_with_repl([random_string() for x in xrange(20)] + r.sample(sample, 5), 1000)
data2 = sample_with_repl([random_string() for x in xrange(100)] + r.sample(sample, 15), 500)
# Include an example with an unhashable value, to ensure all implementation can handle it.
data3 = sample_with_repl([random_string() for x in xrange(10)] + sample, 2000) + [[1,2,3]]
class TestMatchCounter(unittest.TestCase):
def test_match_counter(self):
m = match_counter.MatchCounter(sample)
self.assertEqual(m.count_unique(data1), 5)
self.assertEqual(m.count_unique(data2), 15)
self.assertEqual(m.count_unique(data3), 200)
m = MatchCounterOther(sample)
self.assertEqual(m.count_unique(data1), 5)
self.assertEqual(m.count_unique(data2), 15)
self.assertEqual(m.count_unique(data3), 200)
# Do it again to ensure that we clear out state between counting.
self.assertEqual(m.count_unique(data1), 5)
self.assertEqual(m.count_unique(data2), 15)
self.assertEqual(m.count_unique(data3), 200)
m = MatchCounterIntersection(sample)
self.assertEqual(m.count_unique(data1), 5)
self.assertEqual(m.count_unique(data2), 15)
self.assertEqual(m.count_unique(data3), 200)
m = MatchCounterSimple(sample)
self.assertGreaterEqual(m.count_all(data1), 5)
self.assertGreaterEqual(m.count_all(data2), 15)
self.assertGreaterEqual(m.count_all(data3), 200)
@repeat_until_passes(3)
def test_timing(self):
setup='''
import match_counter
import test_match_counter as t
m1 = match_counter.MatchCounter(t.sample)
m2 = t.MatchCounterOther(t.sample)
m3 = t.MatchCounterSimple(t.sample)
m4 = t.MatchCounterIntersection(t.sample)
'''
N = 100
t1 = min(timeit.repeat(stmt='m1.count_unique(t.data1)', setup=setup, number=N, repeat=3)) / N
t2 = min(timeit.repeat(stmt='m2.count_unique(t.data1)', setup=setup, number=N, repeat=3)) / N
t3 = min(timeit.repeat(stmt='m3.count_all(t.data1)', setup=setup, number=N, repeat=3)) / N
t4 = min(timeit.repeat(stmt='m4.count_unique(t.data1)', setup=setup, number=N, repeat=3)) / N
#print "Timings/iter data1: %.3fus %.3fus %.3fus %.3fus" % (t1 * 1e6, t2 * 1e6, t3*1e6, t4*1e6)
self.assertLess(t1, t2)
self.assertLess(t1, t3)
self.assertLess(t1, t4)
t1 = min(timeit.repeat(stmt='m1.count_unique(t.data2)', setup=setup, number=N, repeat=3)) / N
t2 = min(timeit.repeat(stmt='m2.count_unique(t.data2)', setup=setup, number=N, repeat=3)) / N
t3 = min(timeit.repeat(stmt='m3.count_all(t.data2)', setup=setup, number=N, repeat=3)) / N
t4 = min(timeit.repeat(stmt='m4.count_unique(t.data2)', setup=setup, number=N, repeat=3)) / N
#print "Timings/iter data2: %.3fus %.3fus %.3fus %.3fus" % (t1 * 1e6, t2 * 1e6, t3*1e6, t4*1e6)
self.assertLess(t1, t2)
self.assertLess(t1, t3)
self.assertLess(t1, t4)
t1 = min(timeit.repeat(stmt='m1.count_unique(t.data3)', setup=setup, number=N, repeat=3)) / N
t2 = min(timeit.repeat(stmt='m2.count_unique(t.data3)', setup=setup, number=N, repeat=3)) / N
t3 = min(timeit.repeat(stmt='m3.count_all(t.data3)', setup=setup, number=N, repeat=3)) / N
t4 = min(timeit.repeat(stmt='m4.count_unique(t.data3)', setup=setup, number=N, repeat=3)) / N
#print "Timings/iter data3: %.3fus %.3fus %.3fus %.3fus" % (t1 * 1e6, t2 * 1e6, t3*1e6, t4*1e6)
self.assertLess(t1, t2)
self.assertLess(t1, t3)
self.assertLess(t1, t4)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,195 @@
import unittest
import actions
import schema
import table_data_set
import migrations
class TestMigrations(unittest.TestCase):
def test_migrations(self):
tdset = table_data_set.TableDataSet()
tdset.apply_doc_actions(schema_version0())
migration_actions = migrations.create_migrations(tdset.all_tables)
tdset.apply_doc_actions(migration_actions)
# Compare schema derived from migrations to the current schema.
migrated_schema = tdset.get_schema()
current_schema = {a.table_id: {c['id']: c for c in a.columns}
for a in schema.schema_create_actions()}
# pylint: disable=too-many-nested-blocks
if migrated_schema != current_schema:
# Figure out the version of new migration to suggest, and whether to update SCHEMA_VERSION.
new_version = max(schema.SCHEMA_VERSION, migrations.get_last_migration_version() + 1)
# Figure out the missing actions.
doc_actions = []
for table_id in sorted(current_schema.viewkeys() | migrated_schema.viewkeys()):
if table_id not in migrated_schema:
doc_actions.append(actions.AddTable(table_id, current_schema[table_id].values()))
elif table_id not in current_schema:
doc_actions.append(actions.RemoveTable(table_id))
else:
current_cols = current_schema[table_id]
migrated_cols = migrated_schema[table_id]
for col_id in sorted(current_cols.viewkeys() | migrated_cols.viewkeys()):
if col_id not in migrated_cols:
doc_actions.append(actions.AddColumn(table_id, col_id, current_cols[col_id]))
elif col_id not in current_cols:
doc_actions.append(actions.RemoveColumn(table_id, col_id))
else:
current_info = current_cols[col_id]
migrated_info = migrated_cols[col_id]
delta = {k: v for k, v in current_info.iteritems() if v != migrated_info.get(k)}
if delta:
doc_actions.append(actions.ModifyColumn(table_id, col_id, delta))
suggested_migration = (
"----------------------------------------------------------------------\n" +
"*** migrations.py ***\n" +
"----------------------------------------------------------------------\n" +
"@migration(schema_version=%s)\n" % new_version +
"def migration%s(tdset):\n" % new_version +
" return tdset.apply_doc_actions([\n" +
"".join(stringify(a) + ",\n" for a in doc_actions) +
" ])\n"
)
if new_version != schema.SCHEMA_VERSION:
suggested_schema_update = (
"----------------------------------------------------------------------\n" +
"*** schema.py ***\n" +
"----------------------------------------------------------------------\n" +
"SCHEMA_VERSION = %s\n" % new_version
)
else:
suggested_schema_update = ""
self.fail("Migrations are incomplete. Suggested migration to add:\n" +
suggested_schema_update + suggested_migration)
def stringify(doc_action):
if isinstance(doc_action, actions.AddColumn):
return ' add_column(%r, %s)' % (doc_action.table_id, col_info_args(doc_action.col_info))
elif isinstance(doc_action, actions.AddTable):
return (' actions.AddTable(%r, [\n' % doc_action.table_id +
''.join(' schema.make_column(%s),\n' % col_info_args(c)
for c in doc_action.columns) +
' ])')
else:
return " actions.%s(%s)" % (doc_action.__class__.__name__, ", ".join(map(repr, doc_action)))
def col_info_args(col_info):
extra = ""
for k in ("formula", "isFormula"):
v = col_info.get(k)
if v:
extra += ", %s=%r" % (k, v)
return "%r, %r%s" % (col_info['id'], col_info['type'], extra)
def schema_version0():
# This is the initial version of the schema before the very first migration. It's a historical
# snapshot, and thus should not be edited. The test verifies that starting with this v0,
# migrations bring the schema to the current version.
def make_column(col_id, col_type, formula='', isFormula=False):
return { "id": col_id, "type": col_type, "isFormula": isFormula, "formula": formula }
return [
actions.AddTable("_grist_DocInfo", [
make_column("docId", "Text"),
make_column("peers", "Text"),
make_column("schemaVersion", "Int"),
]),
actions.AddTable("_grist_Tables", [
make_column("tableId", "Text"),
]),
actions.AddTable("_grist_Tables_column", [
make_column("parentId", "Ref:_grist_Tables"),
make_column("parentPos", "PositionNumber"),
make_column("colId", "Text"),
make_column("type", "Text"),
make_column("widgetOptions","Text"),
make_column("isFormula", "Bool"),
make_column("formula", "Text"),
make_column("label", "Text")
]),
actions.AddTable("_grist_Imports", [
make_column("tableRef", "Ref:_grist_Tables"),
make_column("origFileName", "Text"),
make_column("parseFormula", "Text", isFormula=True,
formula="grist.parseImport(rec, table._engine)"),
make_column("delimiter", "Text", formula="','"),
make_column("doublequote", "Bool", formula="True"),
make_column("escapechar", "Text"),
make_column("quotechar", "Text", formula="'\"'"),
make_column("skipinitialspace", "Bool"),
make_column("encoding", "Text", formula="'utf8'"),
make_column("hasHeaders", "Bool"),
]),
actions.AddTable("_grist_External_database", [
make_column("host", "Text"),
make_column("port", "Int"),
make_column("username", "Text"),
make_column("dialect", "Text"),
make_column("database", "Text"),
make_column("storage", "Text"),
]),
actions.AddTable("_grist_External_table", [
make_column("tableRef", "Ref:_grist_Tables"),
make_column("databaseRef", "Ref:_grist_External_database"),
make_column("tableName", "Text"),
]),
actions.AddTable("_grist_TabItems", [
make_column("tableRef", "Ref:_grist_Tables"),
make_column("viewRef", "Ref:_grist_Views"),
]),
actions.AddTable("_grist_Views", [
make_column("name", "Text"),
make_column("type", "Text"),
make_column("layoutSpec", "Text"),
]),
actions.AddTable("_grist_Views_section", [
make_column("tableRef", "Ref:_grist_Tables"),
make_column("parentId", "Ref:_grist_Views"),
make_column("parentKey", "Text"),
make_column("title", "Text"),
make_column("defaultWidth", "Int", formula="100"),
make_column("borderWidth", "Int", formula="1"),
make_column("theme", "Text"),
make_column("chartType", "Text"),
make_column("layoutSpec", "Text"),
make_column("filterSpec", "Text"),
make_column("sortColRefs", "Text"),
make_column("linkSrcSectionRef", "Ref:_grist_Views_section"),
make_column("linkSrcColRef", "Ref:_grist_Tables_column"),
make_column("linkTargetColRef", "Ref:_grist_Tables_column"),
]),
actions.AddTable("_grist_Views_section_field", [
make_column("parentId", "Ref:_grist_Views_section"),
make_column("parentPos", "PositionNumber"),
make_column("colRef", "Ref:_grist_Tables_column"),
make_column("width", "Int"),
make_column("widgetOptions","Text"),
]),
actions.AddTable("_grist_Validations", [
make_column("formula", "Text"),
make_column("name", "Text"),
make_column("tableRef", "Int")
]),
actions.AddTable("_grist_REPL_Hist", [
make_column("code", "Text"),
make_column("outputText", "Text"),
make_column("errorText", "Text")
]),
actions.AddTable("_grist_Attachments", [
make_column("fileIdent", "Text"),
make_column("fileName", "Text"),
make_column("fileType", "Text"),
make_column("fileSize", "Int"),
make_column("timeUploaded", "DateTime")
]),
actions.AddRecord("_grist_DocInfo", 1, {})
]
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,328 @@
from datetime import datetime, date, timedelta
import unittest
import moment
import moment_parse
# Helpful strftime() format that imcludes all parts of the date including the time zone.
fmt = "%Y-%m-%d %H:%M:%S %Z"
class TestMoment(unittest.TestCase):
new_york = [
# - 1918 -
[datetime(1918, 3, 31, 6, 59, 59), -1633280401000, "EST", 300, 1, 59],
[datetime(1918, 3, 31, 7, 0, 0), -1633280400000, "EDT", 240, 3, 0],
[datetime(1918, 10, 27, 5, 59, 59), -1615140001000, "EDT", 240, 1, 59],
[datetime(1918, 10, 27, 6, 0, 0), -1615140000000, "EST", 300, 1, 0],
# - 1979 -
[datetime(1979, 4, 29, 6, 59, 59), 294217199000, "EST", 300, 1, 59],
[datetime(1979, 4, 29, 7, 0, 0), 294217200000, "EDT", 240, 3, 0],
[datetime(1979, 10, 28, 5, 59, 59), 309938399000, "EDT", 240, 1, 59],
[datetime(1979, 10, 28, 6, 0, 0), 309938400000, "EST", 300, 1, 0],
# - 2037 -
[datetime(2037, 3, 8, 6, 59, 59), 2120108399000, "EST", 300, 1, 59],
[datetime(2037, 03, 8, 7, 0, 0), 2120108400000, "EDT", 240, 3, 0],
[datetime(2037, 11, 1, 5, 59, 59), 2140667999000, "EDT", 240, 1, 59]
]
new_york_errors = [
["America/New_York", "2037-3-8 6:59:59", TypeError],
["America/New_York", [2037, 3, 8, 6, 59, 59], TypeError],
["America/new_york", datetime(1979, 4, 29, 6, 59, 59), KeyError]
]
los_angeles = [
# - 1918 -
# Spanning non-existent hour
[datetime(1918, 3, 31, 1, 59, 59, 0), -1633269601000, "PST", 480, 1, 59],
[datetime(1918, 3, 31, 2, 0, 0, 0), -1633273200000, "PST", 480, 1, 0],
[datetime(1918, 3, 31, 2, 59, 59, 0), -1633269601000, "PST", 480, 1, 59],
[datetime(1918, 3, 31, 3, 0, 0, 0), -1633269600000, "PDT", 420, 3, 0],
# Spanning doubly-existent hour
[datetime(1918, 10, 27, 0, 59, 59, 0), -1615132801000, "PDT", 420, 0, 59],
[datetime(1918, 10, 27, 1, 0, 0, 0), -1615132800000, "PDT", 420, 1, 0],
[datetime(1918, 10, 27, 1, 59, 59, 0), -1615129201000, "PDT", 420, 1, 59],
[datetime(1918, 10, 27, 2, 0, 0, 0), -1615125600000, "PST", 480, 2, 0],
# - 2008 -
# Spanning non-existent hour
[datetime(2008, 3, 9, 1, 59, 59, 0), 1205056799000, "PST", 480, 1, 59],
[datetime(2008, 3, 9, 2, 0, 0, 0), 1205053200000, "PST", 480, 1, 0],
[datetime(2008, 3, 9, 2, 59, 59, 0), 1205056799000, "PST", 480, 1, 59],
[datetime(2008, 3, 9, 3, 0, 0, 0), 1205056800000, "PDT", 420, 3, 0],
# Spanning doubly-existent hour
[datetime(2008, 11, 2, 0, 59, 59, 0), 1225612799000, "PDT", 420, 0, 59],
[datetime(2008, 11, 2, 1, 0, 0, 0), 1225612800000, "PDT", 420, 1, 0],
[datetime(2008, 11, 2, 1, 59, 59, 0), 1225616399000, "PDT", 420, 1, 59],
[datetime(2008, 11, 2, 2, 0, 0, 0), 1225620000000, "PST", 480, 2, 0],
# - 2037 -
[datetime(2037, 3, 8, 1, 59, 59, 0), 2120119199000, "PST", 480, 1, 59],
[datetime(2037, 3, 8, 2, 0, 0, 0), 2120115600000, "PST", 480, 1, 0],
[datetime(2037, 11, 1, 0, 59, 59, 0), 2140675199000, "PDT", 420, 0, 59],
[datetime(2037, 11, 1, 1, 0, 0, 0), 2140675200000, "PDT", 420, 1, 0],
]
parse_samples = [
# Basic set
['MM-DD-YYYY', '12-02-1999', 944092800.000000],
['DD-MM-YYYY', '12-02-1999', 918777600.000000],
['DD/MM/YYYY', '12/02/1999', 918777600.000000],
['DD_MM_YYYY', '12_02_1999', 918777600.000000],
['DD:MM:YYYY', '12:02:1999', 918777600.000000],
['D-M-YY', '2-2-99', 917913600.000000],
['YY', '99', 922060800.000000],
['DD-MM-YYYY h:m:s', '12-02-1999 2:45:10', 918787510.000000],
['DD-MM-YYYY h:m:s a', '12-02-1999 2:45:10 am', 918787510.000000],
['DD-MM-YYYY h:m:s a', '12-02-1999 2:45:10 pm', 918830710.000000],
['h:mm a', '12:00 pm', 1458648000.000000],
['h:mm a', '12:30 pm', 1458649800.000000],
['h:mm a', '12:00 am', 1458604800.000000],
['h:mm a', '12:30 am', 1458606600.000000],
['HH:mm', '12:00', 1458648000.000000],
['YYYY-MM-DDTHH:mm:ss', '2011-11-11T11:11:11', 1321009871.000000],
['ddd MMM DD HH:mm:ss YYYY', 'Tue Apr 07 22:52:51 2009', 1239144771.000000],
['ddd MMMM DD HH:mm:ss YYYY', 'Tue April 07 22:52:51 2009', 1239144771.000000],
['HH:mm:ss', '12:00:00', 1458648000.000000],
['HH:mm:ss', '12:30:00', 1458649800.000000],
['HH:mm:ss', '00:00:00', 1458604800.000000],
['HH:mm:ss S', '00:30:00 1', 1458606600.100000],
['HH:mm:ss SS', '00:30:00 12', 1458606600.120000],
['HH:mm:ss SSS', '00:30:00 123', 1458606600.123000],
['HH:mm:ss S', '00:30:00 7', 1458606600.700000],
['HH:mm:ss SS', '00:30:00 78', 1458606600.780000],
['HH:mm:ss SSS', '00:30:00 789', 1458606600.789000],
# Dropped m
['MM/DD/YYYY h:m:s a', '05/1/2012 12:25:00 p', 1335875100.000000],
['MM/DD/YYYY h:m:s a', '05/1/2012 12:25:00 a', 1335831900.000000],
# 2 digit year with YYYY
['D/M/YYYY', '9/2/99', 918518400.000000],
['D/M/YYYY', '9/2/1999', 918518400.000000],
['D/M/YYYY', '9/2/66', -122860800.000000],
['D/M/YYYY', '9/2/65', 3001363200.000000],
# No separators
['MMDDYYYY', '12021999', 944092800.000000],
['DDMMYYYY', '12021999', 918777600.000000],
['YYYYMMDD', '19991202', 944092800.000000],
['DDMMMYYYY', '10Sep2001', 1000080000.000000],
# Error forgiveness
['MM/DD/YYYY', '12-02-1999', 944092800.000000],
['DD/MM/YYYY', '12/02 /1999', 918777600.000000],
['DD:MM:YYYY', '12:02 :1999', 918777600.000000],
['D-M-YY', '2 2 99', 917913600.000000],
['DD-MM-YYYY h:m:s', '12-02-1999 2:45:10.00', 918787510.000000],
['h:mm a', '12:00pm', 1458648000.000000],
['HH:mm', '1200', 1458648000.000000],
['dddd MMMM DD HH:mm:ss YYYY', 'Tue Apr 7 22:52:51 2009', 1239144771.000000],
['ddd MMM DD HH:mm:ss YYYY', 'Tuesday April 7 22:52:51 2009', 1239144771.000000],
['ddd MMM Do HH:mm:ss YYYY', 'Tuesday April 7th 22:52:51 2009', 1239144771.000000]
]
parse_timezone_samples = [
# Timezone corner cases
['MM-DD-YYYY h:ma', '3-13-2016 1:59am', 'America/New_York', 1457852340], # EST
['MM-DD-YYYY h:ma', '3-13-2016 2:00am', 'America/New_York', 1457848800], # Invalid, -1hr
['MM-DD-YYYY h:ma', '3-13-2016 2:59am', 'America/New_York', 1457852340], # Invalid, -1hr
['MM-DD-YYYY h:ma', '3-13-2016 3:00am', 'America/New_York', 1457852400], # EDT
['MM-DD-YYYY h:ma', '3-13-2016 1:59am', 'America/Los_Angeles', 1457863140], # PST
['MM-DD-YYYY h:ma', '3-13-2016 2:00am', 'America/Los_Angeles', 1457859600], # Invalid, -1hr
['MM-DD-YYYY h:ma', '3-13-2016 2:59am', 'America/Los_Angeles', 1457863140], # Invalid, -1hr
['MM-DD-YYYY h:ma', '3-13-2016 3:00am', 'America/Los_Angeles', 1457863200] # PDT
]
def assertMatches(self, data_entry, moment_obj):
date, timestamp, abbr, offset, hour, minute = data_entry
dt = moment_obj.datetime()
self.assertEqual(moment_obj.timestamp, timestamp)
self.assertEqual(moment_obj.zoneAbbr(), abbr)
self.assertEqual(moment_obj.zoneOffset(), timedelta(minutes=-offset))
self.assertEqual(dt.hour, hour)
self.assertEqual(dt.minute, minute)
# For each UTC date, convert to New York time and compare with expected values
def test_standard_entry(self):
name = "America/New_York"
data = self.new_york
for entry in data:
date = entry[0]
timestamp = entry[1]
m = moment.tz(date).tz(name)
mts = moment.tz(timestamp, name)
self.assertMatches(entry, m)
self.assertMatches(entry, mts)
error_data = self.new_york_errors
for entry in error_data:
name = entry[0]
date = entry[1]
error = entry[2]
self.assertRaises(error, moment.tz, date, name)
# For each Los Angeles date, check that the returned date matches expected values
def test_zone_entry(self):
name = "America/Los_Angeles"
data = self.los_angeles
for entry in data:
date = entry[0]
timestamp = entry[1]
m = moment.tz(date, name)
self.assertMatches(entry, m)
def test_zone(self):
name = "America/New_York"
tzinfo = moment.tzinfo(name)
data = self.new_york
for entry in data:
date = entry[0]
ts = entry[1]
abbr = entry[2]
offset = entry[3]
dt = moment.tz(ts, name).datetime()
self.assertEqual(dt.tzname(), abbr)
self.assertEqual(dt.utcoffset(), timedelta(minutes=-offset))
def test_parse(self):
for s in self.parse_samples:
self.assertEqual(moment_parse.parse(s[1], s[0], 'UTC', date(2016, 3, 22)), s[2])
for s in self.parse_timezone_samples:
self.assertEqual(moment_parse.parse(s[1], s[0], s[2], date(2016, 3, 22)), s[3])
def test_ts_to_dt(self):
# Verify that ts_to_dt works as expected.
value_sec = 1426291200 # 2015-03-14 00:00:00 in UTC
value_dt_utc = moment.ts_to_dt(value_sec, moment.get_zone('UTC'))
value_dt_aware = moment.ts_to_dt(value_sec, moment.get_zone('America/New_York'))
self.assertEqual(value_dt_utc.strftime("%Y-%m-%d %H:%M:%S %Z"), '2015-03-14 00:00:00 UTC')
self.assertEqual(value_dt_aware.strftime("%Y-%m-%d %H:%M:%S %Z"), '2015-03-13 20:00:00 EDT')
def test_dst_switches(self):
# Verify that conversions around DST switches happen correctly. (This is tested in other tests
# as well, but this test case is more focused and easier to debug.)
dst_before = -1633280401
dst_begin = -1633280400
dst_end = -1615140001
dst_after = -1615140000
# Should have no surprises in converting to UTC, since there are not DST dfferences.
def ts_to_dt_utc(dt):
return moment.ts_to_dt(dt, moment.get_zone('UTC'))
self.assertEqual(ts_to_dt_utc(dst_before).strftime(fmt), "1918-03-31 06:59:59 UTC")
self.assertEqual(ts_to_dt_utc(dst_begin ).strftime(fmt), "1918-03-31 07:00:00 UTC")
self.assertEqual(ts_to_dt_utc(dst_end ).strftime(fmt), "1918-10-27 05:59:59 UTC")
self.assertEqual(ts_to_dt_utc(dst_after ).strftime(fmt), "1918-10-27 06:00:00 UTC")
# Converting to America/New_York should produce correct jumps.
def ts_to_dt_nyc(dt):
return moment.ts_to_dt(dt, moment.get_zone('America/New_York'))
self.assertEqual(ts_to_dt_nyc(dst_before).strftime(fmt), "1918-03-31 01:59:59 EST")
self.assertEqual(ts_to_dt_nyc(dst_begin ).strftime(fmt), "1918-03-31 03:00:00 EDT")
self.assertEqual(ts_to_dt_nyc(dst_end ).strftime(fmt), "1918-10-27 01:59:59 EDT")
self.assertEqual(ts_to_dt_nyc(dst_after ).strftime(fmt), "1918-10-27 01:00:00 EST")
self.assertEqual(ts_to_dt_nyc(dst_after + 3599).strftime(fmt), "1918-10-27 01:59:59 EST")
def test_tzinfo(self):
# Verify that tzinfo works correctly.
ts1 = 294217199000 # In EST
ts2 = 294217200000 # In EDT (spring forward, we skip ahead by 1 hour)
utc_dt1 = datetime(1979, 4, 29, 6, 59, 59)
utc_dt2 = datetime(1979, 4, 29, 7, 0, 0)
self.assertEqual(moment.tz(ts1).datetime().strftime(fmt), '1979-04-29 06:59:59 UTC')
self.assertEqual(moment.tz(ts2).datetime().strftime(fmt), '1979-04-29 07:00:00 UTC')
# Verify that we get correct time zone variation depending on DST status.
nyc_dt1 = moment.tz(ts1, 'America/New_York').datetime()
nyc_dt2 = moment.tz(ts2, 'America/New_York').datetime()
self.assertEqual(nyc_dt1.strftime(fmt), '1979-04-29 01:59:59 EST')
self.assertEqual(nyc_dt2.strftime(fmt), '1979-04-29 03:00:00 EDT')
# Make sure we can get timestamps back from these datatimes.
self.assertEqual(moment.dt_to_ts(nyc_dt1)*1000, ts1)
self.assertEqual(moment.dt_to_ts(nyc_dt2)*1000, ts2)
# Verify that the datetime objects we get produce correct time zones in terms of DST when we
# manipulate them. NOTE: it is a bit unexpected that we add 1hr + 1sec rather than just 1sec,
# but it seems like that is how Python datetime works. Note that timezone does get switched
# correctly between EDT and EST.
self.assertEqual(nyc_dt1 + timedelta(seconds=3601), nyc_dt2)
self.assertEqual(nyc_dt2 - timedelta(seconds=3601), nyc_dt1)
self.assertEqual((nyc_dt1 + timedelta(seconds=3601)).strftime(fmt), '1979-04-29 03:00:00 EDT')
self.assertEqual((nyc_dt2 - timedelta(seconds=3601)).strftime(fmt), '1979-04-29 01:59:59 EST')
def test_dt_to_ds(self):
# Verify that dt_to_ts works for both naive and aware datetime objects.
value_dt = datetime(2015, 03, 14, 0, 0) # In UTC
value_sec = 1426291200
tzla = moment.get_zone('America/Los_Angeles')
def format_utc(ts):
return moment.ts_to_dt(ts, moment.get_zone('UTC')).strftime(fmt)
# Check that a naive datetime is interpreted in UTC.
self.assertEqual(value_dt.strftime("%Y-%m-%d %H:%M:%S %Z"), '2015-03-14 00:00:00 ')
self.assertEqual(moment.dt_to_ts(value_dt), value_sec) # Interpreted in UTC
# Get an explicit UTC version and make sure that also works.
value_dt_utc = value_dt.replace(tzinfo=moment.TZ_UTC)
self.assertEqual(value_dt_utc.strftime(fmt), '2015-03-14 00:00:00 UTC')
self.assertEqual(moment.dt_to_ts(value_dt_utc), value_sec)
# Get an aware datetime, and make sure that works too.
value_dt_aware = moment.ts_to_dt(value_sec, moment.get_zone('America/New_York'))
self.assertEqual(value_dt_aware.strftime(fmt), '2015-03-13 20:00:00 EDT')
self.assertEqual(moment.dt_to_ts(value_dt_aware), value_sec)
# Check that dt_to_ts pays attention to the timezone.
# If we interpret midnight in LA time, it's a later timestamp.
self.assertEqual(format_utc(moment.dt_to_ts(value_dt, tzla)), '2015-03-14 07:00:00 UTC')
# The second argument is ignored if the datetime is aware.
self.assertEqual(format_utc(moment.dt_to_ts(value_dt_utc, tzla)), '2015-03-14 00:00:00 UTC')
self.assertEqual(format_utc(moment.dt_to_ts(value_dt_aware, tzla)), '2015-03-14 00:00:00 UTC')
# If we modify an aware datetime, we may get a new timezone abbreviation.
value_dt_aware -= timedelta(days=28)
self.assertEqual(value_dt_aware.strftime(fmt), '2015-02-13 20:00:00 EST')
def test_date_to_ts(self):
d = date(2015, 03, 14)
tzla = moment.get_zone('America/Los_Angeles')
def format_utc(ts):
return moment.ts_to_dt(ts, moment.get_zone('UTC')).strftime(fmt)
self.assertEqual(format_utc(moment.date_to_ts(d)), '2015-03-14 00:00:00 UTC')
self.assertEqual(format_utc(moment.date_to_ts(d, tzla)), '2015-03-14 07:00:00 UTC')
self.assertEqual(moment.ts_to_dt(moment.date_to_ts(d, tzla), tzla).strftime(fmt),
'2015-03-14 00:00:00 PDT')
def test_parse_iso(self):
tzny = moment.get_zone('America/New_York')
iso = moment.parse_iso
self.assertEqual(iso('2011-11-11T11:11:11'), 1321009871.000000)
self.assertEqual(iso('2019-01-22T00:47:39.219071-05:00'), 1548136059.219071)
self.assertEqual(iso('2019-01-22T00:47:39.219071-0500'), 1548136059.219071)
self.assertEqual(iso('2019-01-22T00:47:39.219071', timezone=tzny), 1548136059.219071)
self.assertEqual(iso('2019-01-22T00:47:39.219071'), 1548118059.219071)
self.assertEqual(iso('2019-01-22T00:47:39.219071Z'), 1548118059.219071)
self.assertEqual(iso('2019-01-22T00:47:39.219071Z', timezone=tzny), 1548118059.219071)
self.assertEqual(iso('2019-01-22T00:47:39.219'), 1548118059.219)
self.assertEqual(iso('2019-01-22T00:47:39'), 1548118059)
self.assertEqual(iso('2019-01-22 00:47:39.219071'), 1548118059.219071)
self.assertEqual(iso('2019-01-22 00:47:39'), 1548118059)
self.assertEqual(iso('2019-01-22'), 1548115200)
def test_parse_iso_date(self):
tzny = moment.get_zone('America/New_York')
iso = moment.parse_iso_date
# Note that time components and time zone do NOT affect the returned timestamp.
self.assertEqual(iso('2019-01-22'), 1548115200)
self.assertEqual(iso('2019-01-22T00:47:39.219071'), 1548115200)
self.assertEqual(iso('2019-01-22 00:47:39Z'), 1548115200)
self.assertEqual(iso('2019-01-22T00:47:39.219071-05:00'), 1548115200)
self.assertEqual(iso('2019-01-22T00:47:39.219071+05:00'), 1548115200)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,361 @@
import relabeling
from sortedcontainers import SortedListWithKey
from itertools import izip
import unittest
import sys
# Shortcut to keep code more concise.
r = relabeling
def skipfloats(x, n):
for i in xrange(n):
x = relabeling.nextfloat(x)
return x
class Item(object):
"""
Tests use Item for items of the sorted lists we maintain.
"""
def __init__(self, value, key):
self.value = value
self.key = key
def __repr__(self):
return "Item(v=%s,k=%s)" % (self.value, self.key)
class ItemList(object):
def __init__(self, val_key_pairs):
self._slist = SortedListWithKey(key=lambda item: item.key)
self._slist.update(Item(v, k) for (v, k) in val_key_pairs)
self.num_update_events = 0
self.num_updated_keys = 0
def get_values(self):
return [item.value for item in self._slist]
def get_list(self):
return self._slist
def find_value(self, value):
return next((item for item in self._slist if item.value == value), None)
def avg_updated_keys(self):
return float(self.num_updated_keys) / len(self._slist)
def next(self, item):
return self._slist[self._slist.index(item) + 1]
def prev(self, item):
return self._slist[self._slist.index(item) - 1]
def insert_items(self, val_key_pairs, prepare_inserts=r.prepare_inserts):
keys = [k for (v, k) in val_key_pairs]
adjustments, new_keys = prepare_inserts(self._slist, keys)
if adjustments:
self.num_update_events += 1
self.num_updated_keys += len(adjustments)
# Updating items is a bit tricky: we have to do it without violating order (just changing
# key of an existing item easily might), so we remove items first. And we can only rely on
# indices if we scan items in a backwards order.
items = [self._slist.pop(index) for (index, key) in reversed(adjustments)]
items.reverse()
for (index, key), item in izip(adjustments, items):
item.key = key
self._slist.update(items)
# Now add the new items.
self._slist.update(Item(val, new_key) for (val, _), new_key in izip(val_key_pairs, new_keys))
# For testing, pass along the return value from prepare_inserts.
return adjustments, new_keys
class TestRelabeling(unittest.TestCase):
def test_nextfloat(self):
def verify_nextfloat(x):
nx = r.nextfloat(x)
self.assertNotEqual(nx, x)
self.assertGreater(nx, x)
self.assertEqual(r.prevfloat(nx), x)
average = (nx + x) / 2
self.assertTrue(average == nx or average == x)
verify_nextfloat(1)
verify_nextfloat(-1)
verify_nextfloat(417)
verify_nextfloat(-417)
verify_nextfloat(12312422)
verify_nextfloat(-12312422)
verify_nextfloat(0.1234)
verify_nextfloat(-0.1234)
verify_nextfloat(0.00005)
verify_nextfloat(-0.00005)
verify_nextfloat(0.0)
verify_nextfloat(r.nextfloat(0.0))
verify_nextfloat(sys.float_info.min)
verify_nextfloat(-sys.float_info.min)
def test_prevfloat(self):
def verify_prevfloat(x):
nx = r.prevfloat(x)
self.assertNotEqual(nx, x)
self.assertLess(nx, x)
self.assertEqual(r.nextfloat(nx), x)
average = (nx + x) / 2
self.assertTrue(average == nx or average == x)
verify_prevfloat(1)
verify_prevfloat(-1)
verify_prevfloat(417)
verify_prevfloat(-417)
verify_prevfloat(12312422)
verify_prevfloat(-12312422)
verify_prevfloat(0.1234)
verify_prevfloat(-0.1234)
verify_prevfloat(0.00005)
verify_prevfloat(-0.00005)
verify_prevfloat(r.nextfloat(0.0))
verify_prevfloat(sys.float_info.min)
verify_prevfloat(-sys.float_info.min)
def test_range_around_float(self):
def verify_range(bits, begin, end):
self.assertEqual(r.range_around_float(begin, bits), (begin, end))
self.assertEqual(r.range_around_float((end + begin) / 2, bits), (begin, end))
delta = r.nextfloat(begin) - begin
if begin + delta < end:
self.assertEqual(r.range_around_float(begin + delta, bits), (begin, end))
if end - delta >= begin:
self.assertEqual(r.range_around_float(end - delta, bits), (begin, end))
def verify_small_range_at(begin):
verify_range(0, begin, skipfloats(begin, 1))
verify_range(1, begin, skipfloats(begin, 2))
verify_range(4, begin, skipfloats(begin, 16))
verify_range(10, begin, skipfloats(begin, 1024))
verify_small_range_at(1.0)
verify_small_range_at(0.5)
verify_small_range_at(0.25)
verify_small_range_at(0.75)
verify_small_range_at(17.0)
verify_range(52, 1.0, 2.0)
self.assertEqual(r.range_around_float(1.4, 52), (1.0, 2.0))
verify_range(52, 0.5, 1.0)
self.assertEqual(r.range_around_float(0.75, 52), (0.5, 1.0))
self.assertEqual(r.range_around_float(17, 48), (17.0, 18.0))
self.assertEqual(r.range_around_float(17, 49), (16.0, 18.0))
self.assertEqual(r.range_around_float(17, 50), (16.0, 20.0))
self.assertEqual(r.range_around_float(17, 51), (16.0, 24.0))
self.assertEqual(r.range_around_float(17, 52), (16.0, 32.0))
verify_range(51, 0.25, 0.375)
self.assertEqual(r.range_around_float(0.27, 51), (0.25, 0.375))
self.assertEqual(r.range_around_float(0.30, 51), (0.25, 0.375))
self.assertEqual(r.range_around_float(0.37, 51), (0.25, 0.375))
verify_range(51, 0.50, 0.75)
verify_range(51, 0.75, 1.0)
verify_range(52, 0.25, 0.5)
# Range around 0 isn't quite right, and possibly can't be. But we test that it's at least
# something meaningful.
self.assertEqual(r.range_around_float(0.00, 52), (0.00, 0.5))
self.assertEqual(r.range_around_float(0.25, 52), (0.25, 0.5))
self.assertEqual(r.range_around_float(0.00, 50), (0.00, 0.125))
self.assertEqual(r.range_around_float(0.10, 50), (0.09375, 0.109375))
self.assertEqual(r.range_around_float(0.0, 53), (0.00, 1))
self.assertEqual(r.range_around_float(0.5, 53), (0.00, 1))
self.assertEqual(r.range_around_float(0, 0), (0.0, skipfloats(0.5, 1) - 0.5))
self.assertEqual(r.range_around_float(0, 1), (0.0, skipfloats(0.5, 2) - 0.5))
self.assertEqual(r.range_around_float(0, 4), (0.0, skipfloats(0.5, 16) - 0.5))
self.assertEqual(r.range_around_float(0, 10), (0.0, skipfloats(0.5, 1024) - 0.5))
def test_all_distinct(self):
# Just like r.get_range, but includes endpoints.
def full_range(start, end, count):
return [start] + r.get_range(start, end, count) + [end]
self.assertTrue(r.all_distinct(range(1000)))
self.assertTrue(r.all_distinct([]))
self.assertTrue(r.all_distinct([1.0]))
self.assertFalse(r.all_distinct([1.0, 1.0]))
self.assertTrue(r.all_distinct(full_range(0, 1, 1000)))
self.assertFalse(r.all_distinct(full_range(1.0, r.nextfloat(1.0), 1)))
self.assertFalse(r.all_distinct(full_range(1.0, skipfloats(1.0, 10), 10)))
self.assertTrue(r.all_distinct(full_range(1.0, skipfloats(1.0, 11), 10)))
self.assertTrue(r.all_distinct(full_range(0.1, skipfloats(0.1, 100), 99)))
self.assertFalse(r.all_distinct(full_range(0.1, skipfloats(0.1, 100), 100)))
def test_get_range(self):
self.assertEqual(r.get_range(0.0, 2.0, 3), [0.5, 1, 1.5])
self.assertEqual(r.get_range(1, 17, 7), [3,5,7,9,11,13,15])
self.assertEqual(r.get_range(-1, 1.5, 4), [-0.5, 0, 0.5, 1])
def test_prepare_inserts_simple(self):
slist = SortedListWithKey(key=lambda i: i.key)
self.assertEqual(r.prepare_inserts(slist, [4.0]), ([], [1.0]))
self.assertEqual(r.prepare_inserts(slist, [0.0]), ([], [1.0]))
self.assertEqual(r.prepare_inserts(slist, [4.0, 4.0, 5, 6]), ([], [1.0, 2.0, 3.0, 4.0]))
self.assertEqual(r.prepare_inserts(slist, [4, 5, 6, 5, 4]), ([], [1,3,5,4,2]))
slist.update(Item(v, k) for (v, k) in zip(['a','b','c'], [3.0, 4.0, 5.0]))
self.assertEqual(r.prepare_inserts(slist, [0.0]), ([], [1.5]))
values = 'defgijkl'
to_update, to_add = r.prepare_inserts(slist, [3,3,4,5,6,4,6,4])
self.assertEqual(to_add, [1., 2., 3.25, 4.5, 6., 3.5, 7., 3.75])
self.assertEqual(to_update, [])
slist.update(Item(v, k) for (v, k) in zip(values, to_add))
self.assertEqual([i.value for i in slist], list('deafjlbgcik'))
def test_with_invalid(self):
slist = SortedListWithKey(key=lambda i: i.key)
slist.add(Item('a', 0))
self.assertEqual(r.prepare_inserts(slist, [0.0]), ([(0, 2.0)], [1.0]))
self.assertEqual(r.prepare_inserts(slist, [1.0]), ([], [1.0]))
slist = SortedListWithKey(key=lambda i: i.key)
slist.update(Item(v, k) for (v, k) in zip('abcdef', [0, 0, 0, 1, 1, 1]))
# We expect the whole range to be renumbered.
self.assertEqual(r.prepare_inserts(slist, [0.0, 0.0]),
([(0, 3.0), (1, 4.0), (2, 5.0), (3, 6.0), (4, 7.0), (5, 8.0)],
[1.0, 2.0]))
# We also expect a renumbering if there are negative or infinite values.
slist = SortedListWithKey(key=lambda i: i.key)
slist.add(Item('a', float('inf')))
self.assertEqual(r.prepare_inserts(slist, [0.0]), ([(0, 2.0)], [1.0]))
self.assertEqual(r.prepare_inserts(slist, [float('inf')]), ([(0, 2.0)], [1.0]))
slist = SortedListWithKey(key=lambda i: i.key)
slist.add(Item('a', -17.0))
self.assertEqual(r.prepare_inserts(slist, [0.0]), ([(0, 1.0)], [2.0]))
self.assertEqual(r.prepare_inserts(slist, [float('-inf')]), ([(0, 2.0)], [1.0]))
def test_with_dups(self):
slist = SortedListWithKey(key=lambda i: i.key)
slist.update(Item(v, k) for (v, k) in zip('abcdef', [1, 1, 1, 2, 2, 2]))
self.assertEqual(r.prepare_inserts(slist, [0.0]), ([], [0.5]))
def test_renumber_endpoints1(self):
self._do_test_renumber_ends([])
def test_renumber_endpoints2(self):
self._do_test_renumber_ends(zip("abcd", [40,50,60,70]))
def _do_test_renumber_ends(self, initial):
# Test insertions that happen together on the left and on the right.
slist = ItemList(initial)
for i in xrange(2000):
slist.insert_items([(i, float('-inf')), (-i, float('inf'))])
self.assertEqual(slist.get_values(),
rev_range(2000) + [v for v,k in initial] + range(0, -2000, -1))
#print slist.num_update_events, slist.num_updated_keys
self.assertLess(slist.avg_updated_keys(), 3)
self.assertLess(slist.num_update_events, 80)
def test_renumber_left(self):
slist = ItemList(zip("abcd", [4,5,6,7]))
ins_item = slist.find_value('c')
for i in xrange(1000):
slist.insert_items([(i, ins_item.key)])
# Check the end result
self.assertEqual(slist.get_values(), ['a', 'b'] + range(1000) + ['c', 'd'])
self.assertAlmostEqual(slist.avg_updated_keys(), 3.5, delta=1)
self.assertLess(slist.num_update_events, 40)
def test_renumber_right(self):
slist = ItemList(zip("abcd", [4,5,6,7]))
ins_item = slist.find_value('b')
for i in xrange(1000):
slist.insert_items([(i, r.nextfloat(ins_item.key))])
# Check the end result
self.assertEqual(slist.get_values(), ['a', 'b'] + rev_range(1000) + ['c', 'd'])
self.assertAlmostEqual(slist.avg_updated_keys(), 3.5, delta=1)
self.assertLess(slist.num_update_events, 40)
def test_renumber_left_dumb(self):
# Here we use the "dumb" approach, and see that in our test case it's significantly worse.
# (The badness increases with the number of insertions, but we'll keep numbers small to keep
# the test fast.)
slist = ItemList(zip("abcd", [4,5,6,7]))
ins_item = slist.find_value('c')
for i in xrange(1000):
slist.insert_items([(i, ins_item.key)], prepare_inserts=r.prepare_inserts_dumb)
self.assertEqual(slist.get_values(), ['a', 'b'] + range(1000) + ['c', 'd'])
self.assertGreater(slist.avg_updated_keys(), 8)
def test_renumber_right_dumb(self):
slist = ItemList(zip("abcd", [4,5,6,7]))
ins_item = slist.find_value('b')
for i in xrange(1000):
slist.insert_items([(i, r.nextfloat(ins_item.key))], prepare_inserts=r.prepare_inserts_dumb)
self.assertEqual(slist.get_values(), ['a', 'b'] + rev_range(1000) + ['c', 'd'])
self.assertGreater(slist.avg_updated_keys(), 8)
def test_renumber_multiple(self):
# In this test, we make multiple difficult insertions at each step: to the left and to the
# right of each value. This should involve some adjustments that get affected by subsequent
# adjustments during the same prepare_inserts() call.
slist = ItemList(zip("abcd", [4,5,6,7]))
# We insert items on either side of each of the original items (a, b, c, d).
ins_items = list(slist.get_list())
N = 250
for i in xrange(N):
slist.insert_items([("%sr%s" % (x.value, i), r.nextfloat(x.key)) for x in ins_items] +
[("%sl%s" % (x.value, i), x.key) for x in ins_items] +
# After the first insertion, also insert items next on either side of the
# neighbors of the original a, b, c, d items.
([("%sR%s" % (x.value, i), r.nextfloat(slist.next(x).key))
for x in ins_items] +
[("%sL%s" % (x.value, i), slist.prev(x).key) for x in ins_items]
if i > 0 else []))
# The list should grow like this:
# a, b, c, d
# al0, a, ar0, ... (same for b, c, d)
# aL1, al0, al1, a, ar1, ar0, aR1, ...
# aL1, al0, aL2, al1, al2, a, ar2, ar1, aR2, ar0, aR1, ...
def left_half(val):
half = range(2*N - 1)
half[0::2] = ['%sL%d' % (val, i) for i in xrange(1, N + 1)]
half[1::2] = ['%sl%d' % (val, i) for i in xrange(0, N - 1)]
half[-1] = '%sl%d' % (val, N - 1)
return half
def right_half(val):
# Best described as the reverse of left_half
return [v.replace('l', 'r').replace('L', 'R') for v in reversed(left_half(val))]
# The list we expect to see is of the form [aL1, al1, aL2, al2, ... aL1000, al1000, a,
# ar1000, aR1000, ..., aR1],
# followed by the same sequence for b, c, and d.
self.assertEqual(slist.get_values(), sum([left_half(v) + [v] + right_half(v)
for v in ('a', 'b', 'c', 'd')], []))
self.assertAlmostEqual(slist.avg_updated_keys(), 2.5, delta=1)
self.assertLess(slist.num_update_events, 40)
def rev_range(n):
return list(reversed(range(n)))
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,390 @@
# -*- coding: utf-8 -*-
import logger
import testutil
import test_engine
log = logger.Logger(__name__, logger.INFO)
class TestRenames(test_engine.EngineTestCase):
# Simpler cases of column renames in formulas. Here's the list of cases we support and test.
# $COLUMN where NAME is a column (formula or non-formula)
# $ref.COLUMN when $ref is a non-formula Reference column
# $ref.column.COLUMN
# $ref.COLUMN when $ref is a function with a Ref type.
# $ref.COLUMN when $ref is a function with Any type but clearly returning a Ref.
# Table.lookupFunc(COLUMN1=value, COLUMN2=value) and for .lookupRecords
# Table.lookupFunc(...).COLUMN and for .lookupRecords
# Table.lookupFunc(...).foo.COLUMN and for .lookupRecords
# [x.COLUMN for x in Table.lookupRecords(...)] for different kinds of comprehensions
# TABLE.lookupFunc(...) where TABLE is a user-defined table.
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Address", [
[21, "city", "Text", False, "", "", ""],
]],
[2, "People", [
[22, "name", "Text", False, "", "", ""],
[23, "addr", "Ref:Address", False, "", "", ""],
[24, "city", "Any", True, "$addr.city", "", ""],
]]
],
"DATA": {
"Address": [
["id", "city" ],
[11, "New York" ],
[12, "Colombia" ],
[13, "New Haven" ],
[14, "West Haven" ],
],
"People": [
["id", "name" , "addr" ],
[1, "Bob" , 12 ],
[2, "Alice" , 13 ],
[3, "Doug" , 12 ],
[4, "Sam" , 11 ],
],
}
})
def test_rename_rec_attribute(self):
# Simple case: we are renaming `$COLUMN`.
self.load_sample(self.sample)
out_actions = self.apply_user_action(["RenameColumn", "People", "addr", "address"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "addr", "address"],
["ModifyColumn", "People", "city", {"formula": "$address.city"}],
["BulkUpdateRecord", "_grist_Tables_column", [23, 24], {
"colId": ["address", "city"],
"formula": ["", "$address.city"]
}],
],
# Things should get recomputed, but produce same results, hence no calc actions.
"calc": []
})
# Make sure renames of formula columns are also recognized.
self.add_column("People", "CityUpper", formula="$city.upper()")
out_actions = self.apply_user_action(["RenameColumn", "People", "city", "ciudad"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "city", "ciudad"],
["ModifyColumn", "People", "CityUpper", {"formula": "$ciudad.upper()"}],
["BulkUpdateRecord", "_grist_Tables_column", [24, 25], {
"colId": ["ciudad", "CityUpper"],
"formula": ["$address.city", "$ciudad.upper()"]
}]
]})
def test_rename_reference_attribute(self):
# Slightly harder: renaming `$ref.COLUMN`
self.load_sample(self.sample)
out_actions = self.apply_user_action(["RenameColumn", "Address", "city", "ciudad"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Address", "city", "ciudad"],
["ModifyColumn", "People", "city", {"formula": "$addr.ciudad"}],
["BulkUpdateRecord", "_grist_Tables_column", [21, 24], {
"colId": ["ciudad", "city"],
"formula": ["", "$addr.ciudad"]
}],
]})
def test_rename_ref_ref_attr(self):
# Slightly harder still: renaming $ref.column.COLUMN.
self.load_sample(self.sample)
self.add_column("Address", "person", type="Ref:People")
self.add_column("Address", "person_city", formula="$person.addr.city")
self.add_column("Address", "person_city2", formula="a = $person.addr\nreturn a.city")
out_actions = self.apply_user_action(["RenameColumn", "Address", "city", "ciudad"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Address", "city", "ciudad"],
["ModifyColumn", "People", "city", {"formula": "$addr.ciudad"}],
["ModifyColumn", "Address", "person_city", {"formula": "$person.addr.ciudad"}],
["ModifyColumn", "Address", "person_city2", {"formula":
"a = $person.addr\nreturn a.ciudad"}],
["BulkUpdateRecord", "_grist_Tables_column", [21, 24, 26, 27], {
"colId": ["ciudad", "city", "person_city", "person_city2"],
"formula": ["", "$addr.ciudad", "$person.addr.ciudad", "a = $person.addr\nreturn a.ciudad"]
}],
]})
def test_rename_typed_ref_func_attr(self):
# Renaming `$ref.COLUMN` when $ref is a function with a Ref type.
self.load_sample(self.sample)
self.add_column("People", "addr_func", type="Ref:Address", isFormula=True, formula="$addr")
self.add_column("People", "city2", formula="$addr_func.city")
out_actions = self.apply_user_action(["RenameColumn", "Address", "city", "ciudad"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Address", "city", "ciudad"],
["ModifyColumn", "People", "city", {"formula": "$addr.ciudad"}],
["ModifyColumn", "People", "city2", {"formula": "$addr_func.ciudad"}],
["BulkUpdateRecord", "_grist_Tables_column", [21, 24, 26], {
"colId": ["ciudad", "city", "city2"],
"formula": ["", "$addr.ciudad", "$addr_func.ciudad"]
}],
]})
def test_rename_any_ref_func_attr(self):
# Renaming `$ref.COLUMN` when $ref is a function with Any type but clearly returning a Ref.
self.load_sample(self.sample)
self.add_column("People", "addr_func", isFormula=True, formula="$addr")
self.add_column("People", "city3", formula="$addr_func.city")
out_actions = self.apply_user_action(["RenameColumn", "Address", "city", "ciudad"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Address", "city", "ciudad"],
["ModifyColumn", "People", "city", {"formula": "$addr.ciudad"}],
["ModifyColumn", "People", "city3", {"formula": "$addr_func.ciudad"}],
["BulkUpdateRecord", "_grist_Tables_column", [21, 24, 26], {
"colId": ["ciudad", "city", "city3"],
"formula": ["", "$addr.ciudad", "$addr_func.ciudad"]
}],
]})
def test_rename_reflist_attr(self):
# Renaming `$ref.COLUMN` where $ref is a data or function with RefList type (most importantly
# applies to the $group column of summary tables).
self.load_sample(self.sample)
self.add_column("People", "addr_list", type="RefList:Address", isFormula=False)
self.add_column("People", "addr_func", type="RefList:Address", isFormula=True, formula="[1,2]")
self.add_column("People", "citysum", formula="sum($addr_func.city) + sum($addr_list.city)")
out_actions = self.apply_user_action(["RenameColumn", "Address", "city", "ciudad"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Address", "city", "ciudad"],
["ModifyColumn", "People", "city", {"formula": "$addr.ciudad"}],
["ModifyColumn", "People", "citysum", {"formula":
"sum($addr_func.ciudad) + sum($addr_list.ciudad)"}],
["BulkUpdateRecord", "_grist_Tables_column", [21, 24, 27], {
"colId": ["ciudad", "city", "citysum"],
"formula": ["", "$addr.ciudad", "sum($addr_func.ciudad) + sum($addr_list.ciudad)"]
}],
]})
def test_rename_lookup_param(self):
# Renaming `Table.lookupOne(COLUMN1=value, COLUMN2=value)` and for `.lookupRecords`
self.load_sample(self.sample)
self.add_column("Address", "people", formula="People.lookupOne(addr=$id, city=$city)")
self.add_column("Address", "people2", formula="People.lookupRecords(addr=$id)")
out_actions = self.apply_user_action(["RenameColumn", "People", "addr", "ADDRESS"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "addr", "ADDRESS"],
["ModifyColumn", "People", "city", {"formula": "$ADDRESS.city"}],
["ModifyColumn", "Address", "people",
{"formula": "People.lookupOne(ADDRESS=$id, city=$city)"}],
["ModifyColumn", "Address", "people2",
{"formula": "People.lookupRecords(ADDRESS=$id)"}],
["BulkUpdateRecord", "_grist_Tables_column", [23, 24, 25, 26], {
"colId": ["ADDRESS", "city", "people", "people2"],
"formula": ["", "$ADDRESS.city",
"People.lookupOne(ADDRESS=$id, city=$city)",
"People.lookupRecords(ADDRESS=$id)"]
}],
]})
# Another rename that should affect the second parameter.
out_actions = self.apply_user_action(["RenameColumn", "People", "city", "ciudad"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "city", "ciudad"],
["ModifyColumn", "Address", "people",
{"formula": "People.lookupOne(ADDRESS=$id, ciudad=$city)"}],
["BulkUpdateRecord", "_grist_Tables_column", [24, 25], {
"colId": ["ciudad", "people"],
"formula": ["$ADDRESS.city", "People.lookupOne(ADDRESS=$id, ciudad=$city)"]
}],
]})
# This is kind of unnecessary, but checks how the values of params are affected separately.
out_actions = self.apply_user_action(["RenameColumn", "Address", "city", "city2"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Address", "city", "city2"],
["ModifyColumn", "People", "ciudad", {"formula": "$ADDRESS.city2"}],
["ModifyColumn", "Address", "people",
{"formula": "People.lookupOne(ADDRESS=$id, ciudad=$city2)"}],
["BulkUpdateRecord", "_grist_Tables_column", [21, 24, 25], {
"colId": ["city2", "ciudad", "people"],
"formula": ["", "$ADDRESS.city2", "People.lookupOne(ADDRESS=$id, ciudad=$city2)"]
}],
]})
def test_rename_lookup_result_attr(self):
# Renaming `Table.lookupOne(...).COLUMN` and for `.lookupRecords`
self.load_sample(self.sample)
self.add_column("Address", "people", formula="People.lookupOne(addr=$id, city=$city).name")
self.add_column("Address", "people2", formula="People.lookupRecords(addr=$id).name")
out_actions = self.apply_user_action(["RenameColumn", "People", "name", "nombre"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "name", "nombre"],
["ModifyColumn", "Address", "people", {"formula":
"People.lookupOne(addr=$id, city=$city).nombre"}],
["ModifyColumn", "Address", "people2", {"formula":
"People.lookupRecords(addr=$id).nombre"}],
["BulkUpdateRecord", "_grist_Tables_column", [22, 25, 26], {
"colId": ["nombre", "people", "people2"],
"formula": ["",
"People.lookupOne(addr=$id, city=$city).nombre",
"People.lookupRecords(addr=$id).nombre"]
}]
]})
def test_rename_lookup_ref_attr(self):
# Renaming `Table.lookupOne(...).foo.COLUMN` and for `.lookupRecords`
self.load_sample(self.sample)
self.add_column("Address", "people", formula="People.lookupOne(addr=$id, city=$city).addr.city")
self.add_column("Address", "people2", formula="People.lookupRecords(addr=$id).addr.city")
out_actions = self.apply_user_action(["RenameColumn", "Address", "city", "ciudad"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Address", "city", "ciudad"],
["ModifyColumn", "People", "city", {"formula": "$addr.ciudad"}],
["ModifyColumn", "Address", "people", {"formula":
"People.lookupOne(addr=$id, city=$ciudad).addr.ciudad"}],
["ModifyColumn", "Address", "people2", {"formula":
"People.lookupRecords(addr=$id).addr.ciudad"}],
["BulkUpdateRecord", "_grist_Tables_column", [21, 24, 25, 26], {
"colId": ["ciudad", "city", "people", "people2"],
"formula": ["", "$addr.ciudad",
"People.lookupOne(addr=$id, city=$ciudad).addr.ciudad",
"People.lookupRecords(addr=$id).addr.ciudad"]
}]
]})
def test_rename_lookup_iter_attr(self):
# Renaming `[x.COLUMN for x in Table.lookupRecords(...)]`.
self.load_sample(self.sample)
self.add_column("Address", "people",
formula="','.join(x.addr.city for x in People.lookupRecords(addr=$id))")
self.add_column("Address", "people2",
formula="','.join([x.addr.city for x in People.lookupRecords(addr=$id)])")
self.add_column("Address", "people3",
formula="','.join({x.addr.city for x in People.lookupRecords(addr=$id)})")
self.add_column("Address", "people4",
formula="{x.addr.city:x.addr for x in People.lookupRecords(addr=$id)}")
out_actions = self.apply_user_action(["RenameColumn", "People", "addr", "ADDRESS"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "addr", "ADDRESS"],
["ModifyColumn", "People", "city", {"formula": "$ADDRESS.city"}],
["ModifyColumn", "Address", "people",
{"formula": "','.join(x.ADDRESS.city for x in People.lookupRecords(ADDRESS=$id))"}],
["ModifyColumn", "Address", "people2",
{"formula": "','.join([x.ADDRESS.city for x in People.lookupRecords(ADDRESS=$id)])"}],
["ModifyColumn", "Address", "people3",
{"formula": "','.join({x.ADDRESS.city for x in People.lookupRecords(ADDRESS=$id)})"}],
["ModifyColumn", "Address", "people4",
{"formula": "{x.ADDRESS.city:x.ADDRESS for x in People.lookupRecords(ADDRESS=$id)}"}],
["BulkUpdateRecord", "_grist_Tables_column", [23, 24, 25, 26, 27, 28], {
"colId": ["ADDRESS", "city", "people", "people2", "people3", "people4"],
"formula": ["", "$ADDRESS.city",
"','.join(x.ADDRESS.city for x in People.lookupRecords(ADDRESS=$id))",
"','.join([x.ADDRESS.city for x in People.lookupRecords(ADDRESS=$id)])",
"','.join({x.ADDRESS.city for x in People.lookupRecords(ADDRESS=$id)})",
"{x.ADDRESS.city:x.ADDRESS for x in People.lookupRecords(ADDRESS=$id)}"],
}],
]})
def test_rename_table(self):
# Renaming TABLE.lookupFunc(...) where TABLE is a user-defined table.
self.load_sample(self.sample)
self.add_column("Address", "people", formula="People.lookupRecords(addr=$id)")
self.add_column("Address", "people2", type="Ref:People", formula="People.lookupOne(addr=$id)")
out_actions = self.apply_user_action(["RenameTable", "People", "Persons"])
self.assertPartialOutActions(out_actions, { "stored": [
["ModifyColumn", "Address", "people2", {"type": "Int"}],
["RenameTable", "People", "Persons"],
["UpdateRecord", "_grist_Tables", 2, {"tableId": "Persons"}],
["ModifyColumn", "Address", "people2", {
"type": "Ref:Persons", "formula": "Persons.lookupOne(addr=$id)" }],
["ModifyColumn", "Address", "people", {"formula": "Persons.lookupRecords(addr=$id)"}],
["BulkUpdateRecord", "_grist_Tables_column", [26, 25], {
"type": ["Ref:Persons", "Any"],
"formula": ["Persons.lookupOne(addr=$id)", "Persons.lookupRecords(addr=$id)"]
}],
]})
def test_rename_table_autocomplete(self):
# Renaming a table should not leave the old name available for auto-complete.
self.load_sample(self.sample)
names = {"People", "Persons"}
self.assertEqual(names.intersection(self.engine.autocomplete("Pe", "Address")), {"People"})
# Rename the table and ensure that "People" is no longer present among top-level names.
out_actions = self.apply_user_action(["RenameTable", "People", "Persons"])
self.assertEqual(names.intersection(self.engine.autocomplete("Pe", "Address")), {"Persons"})
def test_rename_to_id(self):
# Check that we renaming a column to "Id" disambiguates it with a suffix.
self.load_sample(self.sample)
out_actions = self.apply_user_action(["RenameColumn", "People", "name", "Id"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "name", "Id2"],
["UpdateRecord", "_grist_Tables_column", 22, {"colId": "Id2"}],
]})
def test_renames_with_non_ascii(self):
# Test that presence of unicode does not interfere with formula adjustments for renaming.
self.load_sample(self.sample)
self.add_column("Address", "CityUpper", formula="'Øî'+$city.upper()+'áü'")
out_actions = self.apply_user_action(["RenameColumn", "Address", "city", "ciudad"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Address", "city", "ciudad"],
["ModifyColumn", "People", "city", {"formula": "$addr.ciudad"}],
["ModifyColumn", "Address", "CityUpper", {"formula": "'Øî'+$ciudad.upper()+'áü'"}],
["BulkUpdateRecord", "_grist_Tables_column", [21, 24, 25], {
"colId": ["ciudad", "city", "CityUpper"],
"formula": ["", "$addr.ciudad", "'Øî'+$ciudad.upper()+'áü'"],
}]
]})
self.assertTableData("Address", cols="all", data=[
["id", "ciudad", "CityUpper"],
[11, "New York", "ØîNEW YORKáü"],
[12, "Colombia", "ØîCOLOMBIAáü"],
[13, "New Haven", "ØîNEW HAVENáü"],
[14, "West Haven", "ØîWEST HAVENáü"],
])
def test_rename_updates_properties(self):
# This tests for the following bug: a column A of type Any with formula Table1.lookupOne(B=$B)
# will return a correct reference; when column Table1.X is renamed to Y, $A.X will be changed
# to $A.Y correctly. The bug was that the fixed $A.Y formula would fail incorrectly with
# "Table1 has no column 'Y'".
#
# The cause was that Record objects created by $A were not affected by the
# rename, or recomputed after it, and contained a stale list of allowed column names (the fix
# removes reliance on storing column names in the Record class).
self.load_sample(self.sample)
self.add_column("Address", "person", formula="People.lookupOne(addr=$id)")
self.add_column("Address", "name", formula="$person.name")
from datetime import date
# A helper for comparing Record objects below.
people_table = self.engine.tables['People']
people_rec = lambda row_id: people_table.Record(people_table, row_id, None)
# Verify the data and calculations are correct.
self.assertTableData("Address", cols="all", data=[
["id", "city", "person", "name"],
[11, "New York", people_rec(4), "Sam"],
[12, "Colombia", people_rec(1), "Bob"],
[13, "New Haven", people_rec(2), "Alice"],
[14, "West Haven", people_rec(0), ""],
])
# Do the rename.
out_actions = self.apply_user_action(["RenameColumn", "People", "name", "name2"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "name", "name2"],
["ModifyColumn", "Address", "name", {"formula": "$person.name2"}],
["BulkUpdateRecord", "_grist_Tables_column", [22, 26], {
"colId": ["name2", "name"],
"formula": ["", "$person.name2"],
}]
]})
# Verify the data and calculations are correct after the rename.
self.assertTableData("Address", cols="all", data=[
["id", "city", "person", "name"],
[11, "New York", people_rec(4), "Sam"],
[12, "Colombia", people_rec(1), "Bob"],
[13, "New Haven", people_rec(2), "Alice"],
[14, "West Haven", people_rec(0), ""],
])

@ -0,0 +1,396 @@
import logger
import textwrap
import test_engine
log = logger.Logger(__name__, logger.INFO)
def _replace_col_name(data, old_name, new_name):
"""For verifying data, renames a column in the header in-place."""
data[0] = [(new_name if c == old_name else c) for c in data[0]]
class TestRenames2(test_engine.EngineTestCase):
# Another test for column renames, which tests crazier interconnected formulas.
# This one includes a bunch of cases where renames fail, marked as TODOs.
def setUp(self):
super(TestRenames2, self).setUp()
# Create a schema with several tables including some references and lookups.
self.apply_user_action(["AddTable", "People", [
{"id": "name", "type": "Text"}
]])
self.apply_user_action(["AddTable", "Games", [
{"id": "name", "type": "Text"},
{"id": "winner", "type": "Ref:People", "isFormula": True,
"formula": "Entries.lookupOne(game=$id, rank=1).person"},
{"id": "second", "type": "Ref:People", "isFormula": True,
"formula": "Entries.lookupOne(game=$id, rank=2).person"},
]])
self.apply_user_action(["AddTable", "Entries", [
{"id": "game", "type": "Ref:Games"},
{"id": "person", "type": "Ref:People"},
{"id": "rank", "type": "Int"},
]])
# Fill it with some sample data.
self.add_records("People", ["name"], [
["Bob"], ["Alice"], ["Carol"], ["Doug"], ["Eve"]])
self.add_records("Games", ["name"], [
["ChessA"], ["GoA"], ["ChessB"], ["CheckersA"]])
self.add_records("Entries", ["game", "person", "rank"], [
[ 1, 2, 1],
[ 1, 4, 2],
[ 2, 1, 2],
[ 2, 2, 1],
[ 3, 4, 1],
[ 3, 3, 2],
[ 4, 5, 1],
[ 4, 1, 2],
])
# Check the data, to see it, and confirm that lookups work.
self.assertTableData("People", cols="subset", data=[
[ "id", "name" ],
[ 1, "Bob" ],
[ 2, "Alice" ],
[ 3, "Carol" ],
[ 4, "Doug" ],
[ 5, "Eve" ],
])
self.assertTableData("Games", cols="subset", data=[
[ "id", "name" , "winner", "second" ],
[ 1, "ChessA" , 2, 4, ],
[ 2, "GoA" , 2, 1, ],
[ 3, "ChessB" , 4, 3, ],
[ 4, "CheckersA" , 5, 1 ],
])
# This was just setpu. Now create some crazy formulas that overuse referenes in crazy ways.
self.partner_names = textwrap.dedent(
"""
games = Entries.lookupRecords(person=$id).game
partners = [e.person for g in games for e in Entries.lookupRecords(game=g)]
return ' '.join(p.name for p in partners if p.id != $id)
""")
self.partner = textwrap.dedent(
"""
game = Entries.lookupOne(person=$id).game
next(e.person for e in Entries.lookupRecords(game=game) if e.person != rec)
""").strip()
self.add_column("People", "N", formula="$name.upper()")
self.add_column("People", "Games_Won", formula=(
"' '.join(e.game.name for e in Entries.lookupRecords(person=$id, rank=1))"))
self.add_column("People", "PartnerNames", formula=self.partner_names)
self.add_column("People", "partner", type="Ref:People", formula=self.partner)
self.add_column("People", "partner4", type="Ref:People", formula=(
"$partner.partner.partner.partner"))
# Make it hard to follow references by using the same names in different tables.
self.add_column("People", "win", type="Ref:Games",
formula="Entries.lookupOne(person=$id, rank=1).game")
self.add_column("Games", "win", type="Ref:People", formula="$winner")
self.add_column("Games", "win3_person_name", formula="$win.win.win.name")
self.add_column("Games", "win4_game_name", formula="$win.win.win.win.name")
# This is just for help us know which columns have which rowIds.
self.assertTableData("_grist_Tables_column", cols="subset", data=[
[ "id", "parentId", "colId" ],
[ 1, 1, "manualSort" ],
[ 2, 1, "name" ],
[ 3, 2, "manualSort" ],
[ 4, 2, "name" ],
[ 5, 2, "winner" ],
[ 6, 2, "second" ],
[ 7, 3, "manualSort" ],
[ 8, 3, "game" ],
[ 9, 3, "person" ],
[ 10, 3, "rank" ],
[ 11, 1, "N" ],
[ 12, 1, "Games_Won" ],
[ 13, 1, "PartnerNames" ],
[ 14, 1, "partner" ],
[ 15, 1, "partner4" ],
[ 16, 1, "win" ],
[ 17, 2, "win" ],
[ 18, 2, "win3_person_name" ],
[ 19, 2, "win4_game_name" ],
])
# Check the data before we start on the renaming.
self.people_data = [
[ "id", "name" , "N", "Games_Won", "PartnerNames", "partner", "partner4", "win" ],
[ 1, "Bob" , "BOB", "", "Alice Eve" , 2, 4 , 0 ],
[ 2, "Alice", "ALICE", "ChessA GoA", "Doug Bob" , 4, 2 , 1 ],
[ 3, "Carol", "CAROL", "", "Doug" , 4, 2 , 0 ],
[ 4, "Doug" , "DOUG", "ChessB", "Alice Carol" , 2, 4 , 3 ],
[ 5, "Eve" , "EVE", "CheckersA", "Bob" , 1, 2 , 4 ],
]
self.games_data = [
[ "id", "name" , "winner", "second", "win", "win3_person_name", "win4_game_name" ],
[ 1, "ChessA" , 2, 4 , 2 , "Alice" , "ChessA" ],
[ 2, "GoA" , 2, 1 , 2 , "Alice" , "ChessA" ],
[ 3, "ChessB" , 4, 3 , 4 , "Doug" , "ChessB" ],
[ 4, "CheckersA" , 5, 1 , 5 , "Eve" , "CheckersA" ],
]
self.assertTableData("People", cols="subset", data=self.people_data)
self.assertTableData("Games", cols="subset", data=self.games_data)
def test_renames_a(self):
# Rename Entries.game: affects Games.winner, Games.second, People.Games_Won,
# People.PartnerNames, People.partner.
out_actions = self.apply_user_action(["RenameColumn", "Entries", "game", "juego"])
self.partner_names = textwrap.dedent(
"""
games = Entries.lookupRecords(person=$id).juego
partners = [e.person for g in games for e in Entries.lookupRecords(juego=g)]
return ' '.join(p.name for p in partners if p.id != $id)
""")
self.partner = textwrap.dedent(
"""
game = Entries.lookupOne(person=$id).juego
next(e.person for e in Entries.lookupRecords(juego=game) if e.person != rec)
""").strip()
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Entries", "game", "juego"],
["ModifyColumn", "Games", "winner",
{"formula": "Entries.lookupOne(juego=$id, rank=1).person"}],
["ModifyColumn", "Games", "second",
{"formula": "Entries.lookupOne(juego=$id, rank=2).person"}],
["ModifyColumn", "People", "Games_Won", {
"formula": "' '.join(e.juego.name for e in Entries.lookupRecords(person=$id, rank=1))"
}],
["ModifyColumn", "People", "PartnerNames", { "formula": self.partner_names }],
["ModifyColumn", "People", "partner", {"formula": self.partner}],
["ModifyColumn", "People", "win",
{"formula": "Entries.lookupOne(person=$id, rank=1).juego"}],
["BulkUpdateRecord", "_grist_Tables_column", [8, 5, 6, 12, 13, 14, 16], {
"colId": ["juego", "winner", "second", "Games_Won", "PartnerNames", "partner", "win"],
"formula": ["",
"Entries.lookupOne(juego=$id, rank=1).person",
"Entries.lookupOne(juego=$id, rank=2).person",
"' '.join(e.juego.name for e in Entries.lookupRecords(person=$id, rank=1))",
self.partner_names,
self.partner,
"Entries.lookupOne(person=$id, rank=1).juego"
]
}],
]})
# Verify data to ensure there are no AttributeErrors.
self.assertTableData("People", cols="subset", data=self.people_data)
self.assertTableData("Games", cols="subset", data=self.games_data)
def test_renames_b(self):
# Rename Games.name: affects People.Games_Won, Games.win4_game_name
# TODO: win4_game_name isn't updated due to astroid avoidance of looking up the same attr on
# the same class during inference.
out_actions = self.apply_user_action(["RenameColumn", "Games", "name", "nombre"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Games", "name", "nombre"],
["ModifyColumn", "People", "Games_Won", {
"formula": "' '.join(e.game.nombre for e in Entries.lookupRecords(person=$id, rank=1))"
}],
["BulkUpdateRecord", "_grist_Tables_column", [4, 12], {
"colId": ["nombre", "Games_Won"],
"formula": [
"", "' '.join(e.game.nombre for e in Entries.lookupRecords(person=$id, rank=1))"]
}],
]})
# Fix up things missed due to the TODOs above.
self.modify_column("Games", "win4_game_name", formula="$win.win.win.win.nombre")
# Verify data to ensure there are no AttributeErrors.
_replace_col_name(self.games_data, "name", "nombre")
self.assertTableData("People", cols="subset", data=self.people_data)
self.assertTableData("Games", cols="subset", data=self.games_data)
def test_renames_c(self):
# Rename Entries.person: affects People.ParnerNames
out_actions = self.apply_user_action(["RenameColumn", "Entries", "person", "persona"])
self.partner_names = textwrap.dedent(
"""
games = Entries.lookupRecords(persona=$id).game
partners = [e.persona for g in games for e in Entries.lookupRecords(game=g)]
return ' '.join(p.name for p in partners if p.id != $id)
""")
self.partner = textwrap.dedent(
"""
game = Entries.lookupOne(persona=$id).game
next(e.persona for e in Entries.lookupRecords(game=game) if e.persona != rec)
""").strip()
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Entries", "person", "persona"],
["ModifyColumn", "Games", "winner",
{"formula": "Entries.lookupOne(game=$id, rank=1).persona"}],
["ModifyColumn", "Games", "second",
{"formula": "Entries.lookupOne(game=$id, rank=2).persona"}],
["ModifyColumn", "People", "Games_Won", {
"formula": "' '.join(e.game.name for e in Entries.lookupRecords(persona=$id, rank=1))"
}],
["ModifyColumn", "People", "PartnerNames", { "formula": self.partner_names }],
["ModifyColumn", "People", "partner", {"formula": self.partner}],
["ModifyColumn", "People", "win",
{"formula": "Entries.lookupOne(persona=$id, rank=1).game"}],
["BulkUpdateRecord", "_grist_Tables_column", [9, 5, 6, 12, 13, 14, 16], {
"colId": ["persona", "winner", "second", "Games_Won", "PartnerNames", "partner", "win"],
"formula": ["",
"Entries.lookupOne(game=$id, rank=1).persona",
"Entries.lookupOne(game=$id, rank=2).persona",
"' '.join(e.game.name for e in Entries.lookupRecords(persona=$id, rank=1))",
self.partner_names,
self.partner,
"Entries.lookupOne(persona=$id, rank=1).game"
]
}],
]})
self.assertTableData("People", cols="subset", data=self.people_data)
self.assertTableData("Games", cols="subset", data=self.games_data)
def test_renames_d(self):
# Rename People.name: affects People.N, People.ParnerNames
# TODO: win3_person_name ($win.win.win.name) does NOT get updated correctly with astroid
# because of a limitation in astroid inference: it refuses to look up the same attr on the
# same class during inference (in order to protect against too much recursion).
# TODO: PartnerNames does NOT get updated correctly because astroid doesn't infer meanings of
# lists very well.
out_actions = self.apply_user_action(["RenameColumn", "People", "name", "nombre"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "name", "nombre"],
["ModifyColumn", "People", "N", {"formula": "$nombre.upper()"}],
["BulkUpdateRecord", "_grist_Tables_column", [2, 11], {
"colId": ["nombre", "N"],
"formula": ["", "$nombre.upper()"]
}]
]})
# Fix up things missed due to the TODOs above.
self.modify_column("Games", "win3_person_name", formula="$win.win.win.nombre")
self.modify_column("People", "PartnerNames",
formula=self.partner_names.replace("name", "nombre"))
_replace_col_name(self.people_data, "name", "nombre")
self.assertTableData("People", cols="subset", data=self.people_data)
self.assertTableData("Games", cols="subset", data=self.games_data)
def test_renames_e(self):
# Rename People.partner: affects People.partner4
# TODO: partner4 ($partner.partner.partner.partner) only gets updated partly because of
# astroid's avoidance of looking up the same attr on the same class during inference.
out_actions = self.apply_user_action(["RenameColumn", "People", "partner", "companero"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "partner", "companero"],
["ModifyColumn", "People", "partner4", {
"formula": "$companero.companero.partner.partner"
}],
["BulkUpdateRecord", "_grist_Tables_column", [14, 15], {
"colId": ["companero", "partner4"],
"formula": [self.partner, "$companero.companero.partner.partner"]
}]
]})
# Fix up things missed due to the TODOs above.
self.modify_column("People", "partner4", formula="$companero.companero.companero.companero")
_replace_col_name(self.people_data, "partner", "companero")
self.assertTableData("People", cols="subset", data=self.people_data)
self.assertTableData("Games", cols="subset", data=self.games_data)
def test_renames_f(self):
# Rename People.win -> People.pwin. Make sure only Game.win is not affected.
out_actions = self.apply_user_action(["RenameColumn", "People", "win", "pwin"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "People", "win", "pwin"],
["ModifyColumn", "Games", "win3_person_name", {"formula": "$win.pwin.win.name"}],
# TODO: the omission of the 4th win's update is due to the same astroid bug mentioned above.
["ModifyColumn", "Games", "win4_game_name", {"formula": "$win.pwin.win.win.name"}],
["BulkUpdateRecord", "_grist_Tables_column", [16, 18, 19], {
"colId": ["pwin", "win3_person_name", "win4_game_name"],
"formula": ["Entries.lookupOne(person=$id, rank=1).game",
"$win.pwin.win.name", "$win.pwin.win.win.name"]}],
]})
# Fix up things missed due to the TODOs above.
self.modify_column("Games", "win4_game_name", formula="$win.pwin.win.pwin.name")
_replace_col_name(self.people_data, "win", "pwin")
self.assertTableData("People", cols="subset", data=self.people_data)
self.assertTableData("Games", cols="subset", data=self.games_data)
def test_renames_g(self):
# Rename Games.win -> Games.gwin.
out_actions = self.apply_user_action(["RenameColumn", "Games", "win", "gwin"])
self.assertPartialOutActions(out_actions, { "stored": [
["RenameColumn", "Games", "win", "gwin"],
["ModifyColumn", "Games", "win3_person_name", {"formula": "$gwin.win.gwin.name"}],
["ModifyColumn", "Games", "win4_game_name", {"formula": "$gwin.win.gwin.win.name"}],
["BulkUpdateRecord", "_grist_Tables_column", [17, 18, 19], {
"colId": ["gwin", "win3_person_name", "win4_game_name"],
"formula": ["$winner", "$gwin.win.gwin.name", "$gwin.win.gwin.win.name"]}],
]})
_replace_col_name(self.games_data, "win", "gwin")
self.assertTableData("People", cols="subset", data=self.people_data)
self.assertTableData("Games", cols="subset", data=self.games_data)
def test_renames_h(self):
# Rename Entries -> Entradas. Affects Games.winner, Games.second, People.Games_Won,
# People.PartnerNames, People.partner, People.win.
out_actions = self.apply_user_action(["RenameTable", "Entries", "Entradas"])
self.partner_names = textwrap.dedent(
"""
games = Entradas.lookupRecords(person=$id).game
partners = [e.person for g in games for e in Entradas.lookupRecords(game=g)]
return ' '.join(p.name for p in partners if p.id != $id)
""")
self.partner = textwrap.dedent(
"""
game = Entradas.lookupOne(person=$id).game
next(e.person for e in Entradas.lookupRecords(game=game) if e.person != rec)
""").strip()
self.assertPartialOutActions(out_actions, { "stored": [
["RenameTable", "Entries", "Entradas"],
["UpdateRecord", "_grist_Tables", 3, {"tableId": "Entradas"}],
["ModifyColumn", "Games", "winner",
{"formula": "Entradas.lookupOne(game=$id, rank=1).person"}],
["ModifyColumn", "Games", "second",
{"formula": "Entradas.lookupOne(game=$id, rank=2).person"}],
["ModifyColumn", "People", "Games_Won", {
"formula": "' '.join(e.game.name for e in Entradas.lookupRecords(person=$id, rank=1))"
}],
["ModifyColumn", "People", "PartnerNames", { "formula": self.partner_names }],
["ModifyColumn", "People", "partner", {"formula": self.partner}],
["ModifyColumn", "People", "win",
{"formula": "Entradas.lookupOne(person=$id, rank=1).game"}],
["BulkUpdateRecord", "_grist_Tables_column", [5, 6, 12, 13, 14, 16], {
"formula": [
"Entradas.lookupOne(game=$id, rank=1).person",
"Entradas.lookupOne(game=$id, rank=2).person",
"' '.join(e.game.name for e in Entradas.lookupRecords(person=$id, rank=1))",
self.partner_names,
self.partner,
"Entradas.lookupOne(person=$id, rank=1).game"
]}],
]})
self.assertTableData("People", cols="subset", data=self.people_data)
self.assertTableData("Games", cols="subset", data=self.games_data)

@ -0,0 +1,119 @@
# This test verifies behavior when a formula produces side effects. The prime example is
# lookupOrAddDerived() function, which adds new records (and is the basis for summary tables).
import objtypes
import test_engine
import testutil
class TestSideEffects(test_engine.EngineTestCase):
address_table_data = [
["id", "city", "state", "amount" ],
[ 21, "New York", "NY" , 1 ],
[ 22, "Albany", "NY" , 2 ],
]
schools_table_data = [
["id", "city" , "name" ],
[1, "Boston" , "MIT" ],
[2, "New York" , "NYU" ],
]
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Address", [
[1, "city", "Text", False, "", "", ""],
[2, "state", "Text", False, "", "", ""],
[3, "amount", "Numeric", False, "", "", ""],
]],
[2, "Schools", [
[11, "name", "Text", False, "", "", ""],
[12, "city", "Text", False, "", "", ""],
]],
],
"DATA": {
"Address": address_table_data,
"Schools": schools_table_data,
}
})
def test_failure_after_side_effect(self):
# Verify that when a formula fails after a side-effect, the effect is reverted.
self.load_sample(self.sample)
formula = 'Schools.lookupOrAddDerived(city="TESTCITY")\nraise Exception("test-error")'
out_actions = self.apply_user_action(['AddColumn', 'Address', "A", { 'formula': formula }])
self.assertPartialOutActions(out_actions, { "stored": [
["AddColumn", "Address", "A", {"formula": formula, "isFormula": True, "type": "Any"}],
["AddRecord", "_grist_Tables_column", 13, {
"colId": "A", "formula": formula, "isFormula": True, "label": "A",
"parentId": 1, "parentPos": 4.0, "type": "Any", "widgetOptions": ""
}],
# The thing to note here is that while lookupOrAddDerived() should have added a row to
# Schools, the Exception negated it, and there is no action to add that row.
]})
# Check that data is as expected: no new records in Schools, one new column in Address.
self.assertTableData('Schools', cols="all", data=self.schools_table_data)
self.assertTableData('Address', cols="all", data=[
["id", "city", "state", "amount", "A" ],
[ 21, "New York", "NY" , 1, objtypes.RaisedException(Exception()) ],
[ 22, "Albany", "NY" , 2, objtypes.RaisedException(Exception()) ],
])
def test_calc_actions_in_side_effect_rollback(self):
self.load_sample(self.sample)
# Formula which allows a side effect to be conditionally rolled back.
formula = '''
Schools.lookupOrAddDerived(city=$city)
if $amount < 0:
raise Exception("test-error")
'''
self.add_column('Schools', 'ucity', formula='$city.upper()')
self.add_column('Address', 'A', formula=formula)
self.assertTableData('Schools', cols="all", data=[
["id", "city", "name", "ucity"],
[1, "Boston", "MIT", "BOSTON"],
[2, "New York", "NYU", "NEW YORK"],
[3, "Albany", "", "ALBANY"],
])
# Check that a successful side-effect which adds a row triggers calc actions for that row.
out_actions = self.update_record('Address', 22, city="aaa", amount=1000)
self.assertPartialOutActions(out_actions, {
"stored": [
["UpdateRecord", "Address", 22, {"amount": 1000.0, "city": "aaa"}],
["AddRecord", "Schools", 4, {"city": "aaa"}],
],
"calc": [
["UpdateRecord", "Schools", 4, {"ucity": "AAA"}],
],
})
self.assertTableData('Schools', cols="all", data=[
["id", "city", "name", "ucity"],
[1, "Boston", "MIT", "BOSTON"],
[2, "New York", "NYU", "NEW YORK"],
[3, "Albany", "", "ALBANY"],
[4, "aaa", "", "AAA"],
])
# Check that a side effect that failed and got rolled back does not include calc actions for
# the rows that didn't stay.
out_actions = self.update_record('Address', 22, city="bbb", amount=-3)
self.assertPartialOutActions(out_actions, {
"stored": [
["UpdateRecord", "Address", 22, {"amount": -3.0, "city": "bbb"}],
],
"calc": [
["UpdateRecord", "Address", 22, {"A": ["E", "Exception"]}],
],
})
self.assertTableData('Schools', cols="all", data=[
["id", "city", "name", "ucity"],
[1, "Boston", "MIT", "BOSTON"],
[2, "New York", "NYU", "NEW YORK"],
[3, "Albany", "", "ALBANY"],
[4, "aaa", "", "AAA"],
])

@ -0,0 +1,843 @@
"""
Test of Summary tables. This has many test cases, so to keep files smaller, it's split into two
files: test_summary.py and test_summary2.py.
"""
import actions
import logger
import summary
import testutil
import test_engine
from test_engine import Table, Column, View, Section, Field
log = logger.Logger(__name__, logger.INFO)
class TestSummary(test_engine.EngineTestCase):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Address", [
[11, "city", "Text", False, "", "City", ""],
[12, "state", "Text", False, "", "State", "WidgetOptions1"],
[13, "amount", "Numeric", False, "", "Amount", "WidgetOptions2"],
]]
],
"DATA": {
"Address": [
["id", "city", "state", "amount" ],
[ 21, "New York", "NY" , 1. ],
[ 22, "Albany", "NY" , 2. ],
[ 23, "Seattle", "WA" , 3. ],
[ 24, "Chicago", "IL" , 4. ],
[ 25, "Bedford", "MA" , 5. ],
[ 26, "New York", "NY" , 6. ],
[ 27, "Buffalo", "NY" , 7. ],
[ 28, "Bedford", "NY" , 8. ],
[ 29, "Boston", "MA" , 9. ],
[ 30, "Yonkers", "NY" , 10. ],
[ 31, "New York", "NY" , 11. ],
]
}
})
starting_table = Table(1, "Address", primaryViewId=0, summarySourceTable=0, columns=[
Column(11, "city", "Text", isFormula=False, formula="", summarySourceCol=0),
Column(12, "state", "Text", isFormula=False, formula="", summarySourceCol=0),
Column(13, "amount", "Numeric", isFormula=False, formula="", summarySourceCol=0),
])
starting_table_data = [
["id", "city", "state", "amount" ],
[ 21, "New York", "NY" , 1 ],
[ 22, "Albany", "NY" , 2 ],
[ 23, "Seattle", "WA" , 3 ],
[ 24, "Chicago", "IL" , 4 ],
[ 25, "Bedford", "MA" , 5 ],
[ 26, "New York", "NY" , 6 ],
[ 27, "Buffalo", "NY" , 7 ],
[ 28, "Bedford", "NY" , 8 ],
[ 29, "Boston", "MA" , 9 ],
[ 30, "Yonkers", "NY" , 10 ],
[ 31, "New York", "NY" , 11 ],
]
#----------------------------------------------------------------------
def test_encode_summary_table_name(self):
self.assertEqual(summary.encode_summary_table_name("Foo"), "GristSummary_3_Foo")
self.assertEqual(summary.encode_summary_table_name("Foo2"), "GristSummary_4_Foo2")
self.assertEqual(summary.decode_summary_table_name("GristSummary_3_Foo"), "Foo")
self.assertEqual(summary.decode_summary_table_name("GristSummary_4_Foo2"), "Foo2")
self.assertEqual(summary.decode_summary_table_name("GristSummary_3_Foo2"), "Foo")
self.assertEqual(summary.decode_summary_table_name("GristSummary_4_Foo2_2"), "Foo2")
# Test that underscore in the name is OK.
self.assertEqual(summary.decode_summary_table_name("GristSummary_5_Foo_234"), "Foo_2")
self.assertEqual(summary.decode_summary_table_name("GristSummary_4_Foo_234"), "Foo_")
self.assertEqual(summary.decode_summary_table_name("GristSummary_6__Foo_234"), "_Foo_2")
# Test that we return None for invalid values.
self.assertEqual(summary.decode_summary_table_name("Foo2"), None)
self.assertEqual(summary.decode_summary_table_name("GristSummary_3Foo"), None)
self.assertEqual(summary.decode_summary_table_name("GristSummary_4_Foo"), None)
self.assertEqual(summary.decode_summary_table_name("GristSummary_3X_Foo"), None)
self.assertEqual(summary.decode_summary_table_name("_5_Foo_234"), None)
self.assertEqual(summary.decode_summary_table_name("_GristSummary_3_Foo"), None)
self.assertEqual(summary.decode_summary_table_name("gristsummary_3_Foo"), None)
self.assertEqual(summary.decode_summary_table_name("GristSummary3_Foo"), None)
#----------------------------------------------------------------------
def test_create_view_section(self):
self.load_sample(self.sample)
# Verify the starting table; there should be no views yet.
self.assertTables([self.starting_table])
self.assertViews([])
# Create a view + section for the initial table.
self.apply_user_action(["CreateViewSection", 1, 0, "record", None])
# Verify that we got a new view, with one section, and three fields.
self.assertTables([self.starting_table])
basic_view = View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(1, colRef=11),
Field(2, colRef=12),
Field(3, colRef=13),
])
])
self.assertViews([basic_view])
self.assertTableData("Address", self.starting_table_data)
# Create a "Totals" section, i.e. a summary with no group-by columns.
self.apply_user_action(["CreateViewSection", 1, 0, "record", []])
# Verify that a new table gets created, and a new view, with a section for that table,
# and some auto-generated summary fields.
summary_table1 = Table(2, "GristSummary_7_Address", primaryViewId=0, summarySourceTable=1,
columns=[
Column(14, "group", "RefList:Address", isFormula=True, summarySourceCol=0,
formula="table.getSummarySourceGroup(rec)"),
Column(15, "count", "Int", isFormula=True, summarySourceCol=0,
formula="len($group)"),
Column(16, "amount", "Numeric", isFormula=True, summarySourceCol=0,
formula="SUM($group.amount)"),
])
summary_view1 = View(2, sections=[
Section(2, parentKey="record", tableRef=2, fields=[
Field(4, colRef=15),
Field(5, colRef=16),
])
])
self.assertTables([self.starting_table, summary_table1])
self.assertViews([basic_view, summary_view1])
# Verify the summarized data.
self.assertTableData('GristSummary_7_Address', cols="subset", data=[
[ "id", "count", "amount"],
[ 1, 11, 66.0 ],
])
# Create a summary section, grouped by the "State" column.
self.apply_user_action(["CreateViewSection", 1, 0, "record", [12]])
# Verify that a new table gets created again, a new view, and a section for that table.
# Note that we also check that summarySourceTable and summarySourceCol fields are correct.
summary_table2 = Table(3, "GristSummary_7_Address2", primaryViewId=0, summarySourceTable=1,
columns=[
Column(17, "state", "Text", isFormula=False, formula="", summarySourceCol=12),
Column(18, "group", "RefList:Address", isFormula=True, summarySourceCol=0,
formula="table.getSummarySourceGroup(rec)"),
Column(19, "count", "Int", isFormula=True, summarySourceCol=0,
formula="len($group)"),
Column(20, "amount", "Numeric", isFormula=True, summarySourceCol=0,
formula="SUM($group.amount)"),
])
summary_view2 = View(3, sections=[
Section(3, parentKey="record", tableRef=3, fields=[
Field(6, colRef=17),
Field(7, colRef=19),
Field(8, colRef=20),
])
])
self.assertTables([self.starting_table, summary_table1, summary_table2])
self.assertViews([basic_view, summary_view1, summary_view2])
# Verify more fields of the new column objects.
self.assertTableData('_grist_Tables_column', rows="subset", cols="subset", data=[
['id', 'colId', 'type', 'formula', 'widgetOptions', 'label'],
[17, 'state', 'Text', '', 'WidgetOptions1', 'State'],
[20, 'amount', 'Numeric', 'SUM($group.amount)', 'WidgetOptions2', 'Amount'],
])
# Verify the summarized data.
self.assertTableData('GristSummary_7_Address2', cols="subset", data=[
[ "id", "state", "count", "amount" ],
[ 1, "NY", 7, 1.+2+6+7+8+10+11 ],
[ 2, "WA", 1, 3. ],
[ 3, "IL", 1, 4. ],
[ 4, "MA", 2, 5.+9 ],
])
# Create a summary section grouped by two columns ("city" and "state").
self.apply_user_action(["CreateViewSection", 1, 0, "record", [11,12]])
# Verify the new table and views.
summary_table3 = Table(4, "GristSummary_7_Address3", primaryViewId=0, summarySourceTable=1,
columns=[
Column(21, "city", "Text", isFormula=False, formula="", summarySourceCol=11),
Column(22, "state", "Text", isFormula=False, formula="", summarySourceCol=12),
Column(23, "group", "RefList:Address", isFormula=True, summarySourceCol=0,
formula="table.getSummarySourceGroup(rec)"),
Column(24, "count", "Int", isFormula=True, summarySourceCol=0,
formula="len($group)"),
Column(25, "amount", "Numeric", isFormula=True, summarySourceCol=0,
formula="SUM($group.amount)"),
])
summary_view3 = View(4, sections=[
Section(4, parentKey="record", tableRef=4, fields=[
Field(9, colRef=21),
Field(10, colRef=22),
Field(11, colRef=24),
Field(12, colRef=25),
])
])
self.assertTables([self.starting_table, summary_table1, summary_table2, summary_table3])
self.assertViews([basic_view, summary_view1, summary_view2, summary_view3])
# Verify the summarized data.
self.assertTableData('GristSummary_7_Address3', cols="subset", data=[
[ "id", "city", "state", "count", "amount" ],
[ 1, "New York", "NY" , 3, 1.+6+11 ],
[ 2, "Albany", "NY" , 1, 2. ],
[ 3, "Seattle", "WA" , 1, 3. ],
[ 4, "Chicago", "IL" , 1, 4. ],
[ 5, "Bedford", "MA" , 1, 5. ],
[ 6, "Buffalo", "NY" , 1, 7. ],
[ 7, "Bedford", "NY" , 1, 8. ],
[ 8, "Boston", "MA" , 1, 9. ],
[ 9, "Yonkers", "NY" , 1, 10. ],
])
# The original table's data should not have changed.
self.assertTableData("Address", self.starting_table_data)
#----------------------------------------------------------------------
def test_summary_gencode(self):
self.maxDiff = 1000 # If there is a discrepancy, allow the bigger diff.
self.load_sample(self.sample)
self.apply_user_action(["CreateViewSection", 1, 0, "record", []])
self.apply_user_action(["CreateViewSection", 1, 0, "record", [11,12]])
self.assertMultiLineEqual(self.engine.fetch_table_schema(),
"""import grist
from functions import * # global uppercase functions
import datetime, math, re # modules commonly needed in formulas
@grist.UserTable
class Address:
city = grist.Text()
state = grist.Text()
amount = grist.Numeric()
class _Summary:
@grist.formulaType(grist.ReferenceList('Address'))
def group(rec, table):
return table.getSummarySourceGroup(rec)
@grist.formulaType(grist.Int())
def count(rec, table):
return len(rec.group)
@grist.formulaType(grist.Numeric())
def amount(rec, table):
return SUM(rec.group.amount)
""")
#----------------------------------------------------------------------
def test_summary_table_reuse(self):
# Test that we'll reuse a suitable summary table when already available.
self.load_sample(self.sample)
# Create a summary section grouped by two columns ("city" and "state").
self.apply_user_action(["CreateViewSection", 1, 0, "record", [11,12]])
# Verify the new table and views.
summary_table = Table(2, "GristSummary_7_Address", primaryViewId=0, summarySourceTable=1,
columns=[
Column(14, "city", "Text", isFormula=False, formula="", summarySourceCol=11),
Column(15, "state", "Text", isFormula=False, formula="", summarySourceCol=12),
Column(16, "group", "RefList:Address", isFormula=True, summarySourceCol=0,
formula="table.getSummarySourceGroup(rec)"),
Column(17, "count", "Int", isFormula=True, summarySourceCol=0,
formula="len($group)"),
Column(18, "amount", "Numeric", isFormula=True, summarySourceCol=0,
formula="SUM($group.amount)"),
])
summary_view = View(1, sections=[
Section(1, parentKey="record", tableRef=2, fields=[
Field(1, colRef=14),
Field(2, colRef=15),
Field(3, colRef=17),
Field(4, colRef=18),
])
])
self.assertTables([self.starting_table, summary_table])
self.assertViews([summary_view])
# Create twoo other views + view sections with the same breakdown (in different order
# of group-by fields, which should still reuse the same table).
self.apply_user_action(["CreateViewSection", 1, 0, "record", [12,11]])
self.apply_user_action(["CreateViewSection", 1, 0, "record", [11,12]])
summary_view2 = View(2, sections=[
Section(2, parentKey="record", tableRef=2, fields=[
Field(5, colRef=15),
Field(6, colRef=14),
Field(7, colRef=17),
Field(8, colRef=18),
])
])
summary_view3 = View(3, sections=[
Section(3, parentKey="record", tableRef=2, fields=[
Field(9, colRef=14),
Field(10, colRef=15),
Field(11, colRef=17),
Field(12, colRef=18),
])
])
# Verify that we have a new view, but are reusing the table.
self.assertTables([self.starting_table, summary_table])
self.assertViews([summary_view, summary_view2, summary_view3])
# Verify the summarized data.
self.assertTableData('GristSummary_7_Address', cols="subset", data=[
[ "id", "city", "state", "count", "amount" ],
[ 1, "New York", "NY" , 3, 1.+6+11 ],
[ 2, "Albany", "NY" , 1, 2. ],
[ 3, "Seattle", "WA" , 1, 3. ],
[ 4, "Chicago", "IL" , 1, 4. ],
[ 5, "Bedford", "MA" , 1, 5. ],
[ 6, "Buffalo", "NY" , 1, 7. ],
[ 7, "Bedford", "NY" , 1, 8. ],
[ 8, "Boston", "MA" , 1, 9. ],
[ 9, "Yonkers", "NY" , 1, 10. ],
])
#----------------------------------------------------------------------
def test_summary_no_invalid_reuse(self):
# Verify that if we have some summary tables for one table, they don't mistakenly get used
# when we need a summary for another table.
# Load table and create a couple summary sections, for totals, and grouped by "state".
self.load_sample(self.sample)
self.apply_user_action(["CreateViewSection", 1, 0, "record", []])
self.apply_user_action(["CreateViewSection", 1, 0, "record", [12]])
self.assertTables([
self.starting_table,
Table(2, "GristSummary_7_Address", 0, 1, columns=[
Column(14, "group", "RefList:Address", True, "table.getSummarySourceGroup(rec)", 0),
Column(15, "count", "Int", True, "len($group)", 0),
Column(16, "amount", "Numeric", True, "SUM($group.amount)", 0),
]),
Table(3, "GristSummary_7_Address2", 0, 1, columns=[
Column(17, "state", "Text", False, "", 12),
Column(18, "group", "RefList:Address", True, "table.getSummarySourceGroup(rec)", 0),
Column(19, "count", "Int", True, "len($group)", 0),
Column(20, "amount", "Numeric", True, "SUM($group.amount)", 0),
]),
])
# Create another table similar to the first one.
self.apply_user_action(["AddTable", "Address2", [
{ "id": "city", "type": "Text" },
{ "id": "state", "type": "Text" },
{ "id": "amount", "type": "Numeric" },
]])
data = self.sample["DATA"]["Address"]
self.apply_user_action(["BulkAddRecord", "Address2", data.row_ids, data.columns])
# Check that we've loaded the right data, and have the new table.
self.assertTableData("Address", cols="subset", data=self.starting_table_data)
self.assertTableData("Address2", cols="subset", data=self.starting_table_data)
self.assertTableData("_grist_Tables", cols="subset", data=[
['id', 'tableId', 'summarySourceTable'],
[ 1, 'Address', 0],
[ 2, 'GristSummary_7_Address', 1],
[ 3, 'GristSummary_7_Address2', 1],
[ 4, 'Address2', 0],
])
# Now create similar summary sections for the new table.
self.apply_user_action(["CreateViewSection", 4, 0, "record", []])
self.apply_user_action(["CreateViewSection", 4, 0, "record", [23]])
# Make sure this creates new section rather than reuses similar ones for the wrong table.
self.assertTables([
self.starting_table,
Table(2, "GristSummary_7_Address", 0, 1, columns=[
Column(14, "group", "RefList:Address", True, "table.getSummarySourceGroup(rec)", 0),
Column(15, "count", "Int", True, "len($group)", 0),
Column(16, "amount", "Numeric", True, "SUM($group.amount)", 0),
]),
Table(3, "GristSummary_7_Address2", 0, 1, columns=[
Column(17, "state", "Text", False, "", 12),
Column(18, "group", "RefList:Address", True, "table.getSummarySourceGroup(rec)", 0),
Column(19, "count", "Int", True, "len($group)", 0),
Column(20, "amount", "Numeric", True, "SUM($group.amount)", 0),
]),
Table(4, "Address2", primaryViewId=3, summarySourceTable=0, columns=[
Column(21, "manualSort", "ManualSortPos",False, "", 0),
Column(22, "city", "Text", False, "", 0),
Column(23, "state", "Text", False, "", 0),
Column(24, "amount", "Numeric", False, "", 0),
]),
Table(5, "GristSummary_8_Address2", 0, 4, columns=[
Column(25, "group", "RefList:Address2", True, "table.getSummarySourceGroup(rec)", 0),
Column(26, "count", "Int", True, "len($group)", 0),
Column(27, "amount", "Numeric", True, "SUM($group.amount)", 0),
]),
Table(6, "GristSummary_8_Address2_2", 0, 4, columns=[
Column(28, "state", "Text", False, "", 23),
Column(29, "group", "RefList:Address2", True, "table.getSummarySourceGroup(rec)", 0),
Column(30, "count", "Int", True, "len($group)", 0),
Column(31, "amount", "Numeric", True, "SUM($group.amount)", 0),
]),
])
#----------------------------------------------------------------------
def test_summary_updates(self):
# Verify that summary tables update automatically when we change a value used in a summary
# formula; or a value in a group-by column; or add/remove a record; that records get
# auto-added when new group-by combinations appear.
# Load sample and create a summary section grouped by two columns ("city" and "state").
self.load_sample(self.sample)
self.apply_user_action(["CreateViewSection", 1, 0, "record", [11,12]])
# Verify that the summary table respects all updates to the source table.
self._do_test_updates("Address", "GristSummary_7_Address")
def _do_test_updates(self, source_tbl_name, summary_tbl_name):
# This is the main part of test_summary_updates(). It's moved to its own method so that
# updates can be verified the same way after a table rename.
# Verify the summarized data.
self.assertTableData(summary_tbl_name, cols="subset", data=[
[ "id", "city", "state", "count", "amount" ],
[ 1, "New York", "NY" , 3, 1.+6+11 ],
[ 2, "Albany", "NY" , 1, 2. ],
[ 3, "Seattle", "WA" , 1, 3. ],
[ 4, "Chicago", "IL" , 1, 4. ],
[ 5, "Bedford", "MA" , 1, 5. ],
[ 6, "Buffalo", "NY" , 1, 7. ],
[ 7, "Bedford", "NY" , 1, 8. ],
[ 8, "Boston", "MA" , 1, 9. ],
[ 9, "Yonkers", "NY" , 1, 10. ],
])
# Change an amount (New York, NY, 6 -> 106), check that the right calc action gets emitted.
out_actions = self.update_record(source_tbl_name, 26, amount=106)
self.assertPartialOutActions(out_actions, {
"stored": [actions.UpdateRecord(source_tbl_name, 26, {'amount': 106})],
"calc": [actions.UpdateRecord(summary_tbl_name, 1, {'amount': 1.+106+11})]
})
# Change a groupby value so that a record moves from one summary group to another.
# Bedford, NY, 8.0 -> Bedford, MA, 8.0
out_actions = self.update_record(source_tbl_name, 28, state="MA")
self.assertPartialOutActions(out_actions, {
"stored": [actions.UpdateRecord(source_tbl_name, 28, {'state': 'MA'})],
"calc": [
actions.BulkUpdateRecord(summary_tbl_name, [5,7], {'group': [[25, 28], []]}),
actions.BulkUpdateRecord(summary_tbl_name, [5,7], {'amount': [5.0 + 8.0, 0.0]}),
actions.BulkUpdateRecord(summary_tbl_name, [5,7], {'count': [2, 0]}),
]
})
# Add a record to an existing group (Bedford, MA, 108.0)
out_actions = self.add_record(source_tbl_name, city="Bedford", state="MA", amount=108.0)
self.assertPartialOutActions(out_actions, {
"stored": [actions.AddRecord(source_tbl_name, 32,
{'city': 'Bedford', 'state': 'MA', 'amount': 108.0})],
"calc": [
actions.UpdateRecord(summary_tbl_name, 5, {'group': [25, 28, 32]}),
actions.UpdateRecord(summary_tbl_name, 5, {'amount': 5.0 + 8.0 + 108.0}),
actions.UpdateRecord(summary_tbl_name, 5, {'count': 3}),
]
})
# Remove a record (rowId=28, Bedford, MA, 8.0)
out_actions = self.remove_record(source_tbl_name, 28)
self.assertPartialOutActions(out_actions, {
"stored": [actions.RemoveRecord(source_tbl_name, 28)],
"calc": [
actions.UpdateRecord(summary_tbl_name, 5, {'group': [25, 32]}),
actions.UpdateRecord(summary_tbl_name, 5, {'amount': 5.0 + 108.0}),
actions.UpdateRecord(summary_tbl_name, 5, {'count': 2}),
]
})
# Change groupby value to create a new combination (rowId 25, Bedford, MA, 5.0 -> Salem, MA).
# A new summary record should be added.
out_actions = self.update_record(source_tbl_name, 25, city="Salem")
self.assertPartialOutActions(out_actions, {
"stored": [
actions.UpdateRecord(source_tbl_name, 25, {'city': 'Salem'}),
actions.AddRecord(summary_tbl_name, 10, {'city': 'Salem', 'state': 'MA'}),
],
"calc": [
actions.BulkUpdateRecord(summary_tbl_name, [5,10], {'group': [[32], [25]]}),
actions.BulkUpdateRecord(summary_tbl_name, [5,10], {'amount': [108.0, 5.0]}),
actions.BulkUpdateRecord(summary_tbl_name, [5,10], {'count': [1, 1]}),
]
})
# Add a record with a new combination (Amherst, MA, 17)
out_actions = self.add_record(source_tbl_name, city="Amherst", state="MA", amount=17.0)
self.assertPartialOutActions(out_actions, {
"stored": [
actions.AddRecord(source_tbl_name, 33, {'city': 'Amherst', 'state': 'MA', 'amount': 17.}),
actions.AddRecord(summary_tbl_name, 11, {'city': 'Amherst', 'state': 'MA'}),
],
"calc": [
actions.UpdateRecord(summary_tbl_name, 11, {'group': [33]}),
actions.UpdateRecord(summary_tbl_name, 11, {'amount': 17.0}),
actions.UpdateRecord(summary_tbl_name, 11, {'count': 1}),
]
})
# Verify the resulting data after all the updates.
self.assertTableData(summary_tbl_name, cols="subset", data=[
[ "id", "city", "state", "count", "amount" ],
[ 1, "New York", "NY" , 3, 1.+106+11 ],
[ 2, "Albany", "NY" , 1, 2. ],
[ 3, "Seattle", "WA" , 1, 3. ],
[ 4, "Chicago", "IL" , 1, 4. ],
[ 5, "Bedford", "MA" , 1, 108. ],
[ 6, "Buffalo", "NY" , 1, 7. ],
[ 7, "Bedford", "NY" , 0, 0. ],
[ 8, "Boston", "MA" , 1, 9. ],
[ 9, "Yonkers", "NY" , 1, 10. ],
[ 10, "Salem", "MA" , 1, 5.0 ],
[ 11, "Amherst", "MA" , 1, 17.0 ],
])
#----------------------------------------------------------------------
def test_table_rename(self):
# Verify that summary tables keep working and updating when source table is renamed.
# Load sample and create a couple of summary sections.
self.load_sample(self.sample)
self.apply_user_action(["CreateViewSection", 1, 0, "record", [11,12]])
# Check what tables we have now.
self.assertPartialData("_grist_Tables", ["id", "tableId", "summarySourceTable"], [
[1, "Address", 0],
[2, "GristSummary_7_Address", 1],
])
# Rename the table: this is what we are really testing in this test case.
self.apply_user_action(["RenameTable", "Address", "Location"])
self.assertPartialData("_grist_Tables", ["id", "tableId", "summarySourceTable"], [
[1, "Location", 0],
[2, "GristSummary_8_Location", 1],
])
# Verify that the bigger summary table respects all updates to the renamed source table.
self._do_test_updates("Location", "GristSummary_8_Location")
#----------------------------------------------------------------------
def test_table_rename_multiple(self):
# Similar to the above, verify renames, but now with two summary tables.
self.load_sample(self.sample)
self.apply_user_action(["CreateViewSection", 1, 0, "record", [11,12]])
self.apply_user_action(["CreateViewSection", 1, 0, "record", []])
self.assertPartialData("_grist_Tables", ["id", "tableId", "summarySourceTable"], [
[1, "Address", 0],
[2, "GristSummary_7_Address", 1],
[3, "GristSummary_7_Address2", 1],
])
# Verify the data in the simple totals-only summary table.
self.assertTableData('GristSummary_7_Address2', cols="subset", data=[
[ "id", "count", "amount"],
[ 1, 11, 66.0 ],
])
# Do a rename.
self.apply_user_action(["RenameTable", "Address", "Addresses"])
self.assertPartialData("_grist_Tables", ["id", "tableId", "summarySourceTable"], [
[1, "Addresses", 0],
[2, "GristSummary_9_Addresses", 1],
[3, "GristSummary_9_Addresses2", 1],
])
self.assertTableData('GristSummary_9_Addresses2', cols="subset", data=[
[ "id", "count", "amount"],
[ 1, 11, 66.0 ],
])
# Remove one of the tables so that we can use _do_test_updates to verify updates still work.
self.apply_user_action(["RemoveTable", "GristSummary_9_Addresses2"])
self.assertPartialData("_grist_Tables", ["id", "tableId", "summarySourceTable"], [
[1, "Addresses", 0],
[2, "GristSummary_9_Addresses", 1],
])
self._do_test_updates("Addresses", "GristSummary_9_Addresses")
#----------------------------------------------------------------------
def test_change_summary_formula(self):
# Verify that changing a summary formula affects all group-by variants, and adding a new
# summary table gets the changed formula.
#
# (Recall that all summaries of a single table are *conceptually* variants of a single summary
# table, sharing all formulas and differing only in the group-by columns.)
self.load_sample(self.sample)
self.apply_user_action(["CreateViewSection", 1, 0, "record", [11,12]])
self.apply_user_action(["CreateViewSection", 1, 0, "record", []])
# These are the tables and columns we automatically get.
self.assertTables([
self.starting_table,
Table(2, "GristSummary_7_Address", 0, 1, columns=[
Column(14, "city", "Text", False, "", 11),
Column(15, "state", "Text", False, "", 12),
Column(16, "group", "RefList:Address", True, "table.getSummarySourceGroup(rec)", 0),
Column(17, "count", "Int", True, "len($group)", 0),
Column(18, "amount", "Numeric", True, "SUM($group.amount)", 0),
]),
Table(3, "GristSummary_7_Address2", 0, 1, columns=[
Column(19, "group", "RefList:Address", True, "table.getSummarySourceGroup(rec)", 0),
Column(20, "count", "Int", True, "len($group)", 0),
Column(21, "amount", "Numeric", True, "SUM($group.amount)", 0),
])
])
# Now change a formula using one of the summary tables. It should trigger an equivalent
# change in the other.
self.apply_user_action(["ModifyColumn", "GristSummary_7_Address", "amount",
{"formula": "10*sum($group.amount)"}])
self.assertTableData('_grist_Tables_column', rows="subset", cols="subset", data=[
['id', 'colId', 'type', 'formula', 'widgetOptions', 'label'],
[18, 'amount', 'Numeric', '10*sum($group.amount)', 'WidgetOptions2', 'Amount'],
[21, 'amount', 'Numeric', '10*sum($group.amount)', 'WidgetOptions2', 'Amount'],
])
# Change a formula and a few other fields in the other table, and verify a change to both.
self.apply_user_action(["ModifyColumn", "GristSummary_7_Address2", "amount",
{"formula": "100*sum($group.amount)",
"type": "Text",
"widgetOptions": "hello",
"label": "AMOUNT",
"untieColIdFromLabel": True
}])
self.assertTableData('_grist_Tables_column', rows="subset", cols="subset", data=[
['id', 'colId', 'type', 'formula', 'widgetOptions', 'label'],
[18, 'amount', 'Text', '100*sum($group.amount)', 'hello', 'AMOUNT'],
[21, 'amount', 'Text', '100*sum($group.amount)', 'hello', 'AMOUNT'],
])
# Check the values in the summary tables: they should reflect the new formula.
self.assertTableData('GristSummary_7_Address', cols="subset", data=[
[ "id", "city", "state", "count", "amount" ],
[ 1, "New York", "NY" , 3, str(100*(1.+6+11))],
[ 2, "Albany", "NY" , 1, "200.0" ],
[ 3, "Seattle", "WA" , 1, "300.0" ],
[ 4, "Chicago", "IL" , 1, "400.0" ],
[ 5, "Bedford", "MA" , 1, "500.0" ],
[ 6, "Buffalo", "NY" , 1, "700.0" ],
[ 7, "Bedford", "NY" , 1, "800.0" ],
[ 8, "Boston", "MA" , 1, "900.0" ],
[ 9, "Yonkers", "NY" , 1, "1000.0" ],
])
self.assertTableData('GristSummary_7_Address2', cols="subset", data=[
[ "id", "count", "amount"],
[ 1, 11, "6600.0"],
])
# Add a new summary table, and check that it gets the new formula.
self.apply_user_action(["CreateViewSection", 1, 0, "record", [12]])
self.assertTables([
self.starting_table,
Table(2, "GristSummary_7_Address", 0, 1, columns=[
Column(14, "city", "Text", False, "", 11),
Column(15, "state", "Text", False, "", 12),
Column(16, "group", "RefList:Address", True, "table.getSummarySourceGroup(rec)", 0),
Column(17, "count", "Int", True, "len($group)", 0),
Column(18, "amount", "Text", True, "100*sum($group.amount)", 0),
]),
Table(3, "GristSummary_7_Address2", 0, 1, columns=[
Column(19, "group", "RefList:Address", True, "table.getSummarySourceGroup(rec)", 0),
Column(20, "count", "Int", True, "len($group)", 0),
Column(21, "amount", "Text", True, "100*sum($group.amount)", 0),
]),
Table(4, "GristSummary_7_Address3", 0, 1, columns=[
Column(22, "state", "Text", False, "", 12),
Column(23, "group", "RefList:Address", True, "table.getSummarySourceGroup(rec)", 0),
Column(24, "count", "Int", True, "len($group)", 0),
Column(25, "amount", "Text", True, "100*sum($group.amount)", 0),
])
])
self.assertTableData('_grist_Tables_column', rows="subset", cols="subset", data=[
['id', 'colId', 'type', 'formula', 'widgetOptions', 'label'],
[18, 'amount', 'Text', '100*sum($group.amount)', 'hello', 'AMOUNT'],
[21, 'amount', 'Text', '100*sum($group.amount)', 'hello', 'AMOUNT'],
[25, 'amount', 'Text', '100*sum($group.amount)', 'hello', 'AMOUNT'],
])
# Verify the summarized data.
self.assertTableData('GristSummary_7_Address3', cols="subset", data=[
[ "id", "state", "count", "amount" ],
[ 1, "NY", 7, str(100*(1.+2+6+7+8+10+11)) ],
[ 2, "WA", 1, "300.0" ],
[ 3, "IL", 1, "400.0" ],
[ 4, "MA", 2, str(500.+900) ],
])
#----------------------------------------------------------------------
def test_convert_source_column(self):
# Verify that we can convert the type of a column when there is a summary table using that
# column to group by. Since converting generates extra summary records, this may cause bugs.
self.apply_user_action(["AddEmptyTable"])
self.apply_user_action(["BulkAddRecord", "Table1", [None]*3, {"A": [10,20,10], "B": [1,2,3]}])
self.apply_user_action(["CreateViewSection", 1, 0, "record", [2]])
# Verify metadata and actual data initially.
self.assertTables([
Table(1, "Table1", summarySourceTable=0, primaryViewId=1, columns=[
Column(1, "manualSort", "ManualSortPos", False, "", 0),
Column(2, "A", "Numeric", False, "", 0),
Column(3, "B", "Numeric", False, "", 0),
Column(4, "C", "Any", True, "", 0),
]),
Table(2, "GristSummary_6_Table1", summarySourceTable=1, primaryViewId=0, columns=[
Column(5, "A", "Numeric", False, "", 2),
Column(6, "group", "RefList:Table1", True, "table.getSummarySourceGroup(rec)", 0),
Column(7, "count", "Int", True, "len($group)", 0),
Column(8, "B", "Numeric", True, "SUM($group.B)", 0),
])
])
self.assertTableData('Table1', data=[
[ "id", "manualSort", "A", "B", "C" ],
[ 1, 1.0, 10, 1.0, None ],
[ 2, 2.0, 20, 2.0, None ],
[ 3, 3.0, 10, 3.0, None ],
])
self.assertTableData('GristSummary_6_Table1', data=[
[ "id", "A", "group", "count", "B" ],
[ 1, 10, [1,3], 2, 4 ],
[ 2, 20, [2], 1, 2 ],
])
# Do a conversion.
self.apply_user_action(["UpdateRecord", "_grist_Tables_column", 2, {"type": "Text"}])
# Verify that the conversion's result is as expected.
self.assertTables([
Table(1, "Table1", summarySourceTable=0, primaryViewId=1, columns=[
Column(1, "manualSort", "ManualSortPos", False, "", 0),
Column(2, "A", "Text", False, "", 0),
Column(3, "B", "Numeric", False, "", 0),
Column(4, "C", "Any", True, "", 0),
]),
Table(2, "GristSummary_6_Table1", summarySourceTable=1, primaryViewId=0, columns=[
Column(5, "A", "Text", False, "", 2),
Column(6, "group", "RefList:Table1", True, "table.getSummarySourceGroup(rec)", 0),
Column(7, "count", "Int", True, "len($group)", 0),
Column(8, "B", "Numeric", True, "SUM($group.B)", 0),
])
])
self.assertTableData('Table1', data=[
[ "id", "manualSort", "A", "B", "C" ],
[ 1, 1.0, "10.0", 1.0, None ],
[ 2, 2.0, "20.0", 2.0, None ],
[ 3, 3.0, "10.0", 3.0, None ],
])
self.assertTableData('GristSummary_6_Table1', data=[
[ "id", "A", "group", "count", "B" ],
[ 1, "10.0", [1,3], 2, 4 ],
[ 2, "20.0", [2], 1, 2 ],
])
#----------------------------------------------------------------------
@test_engine.test_undo
def test_remove_source_column(self):
# Verify that we can remove a column when there is a summary table using that column to group
# by. (Bug T188.)
self.apply_user_action(["AddEmptyTable"])
self.apply_user_action(["BulkAddRecord", "Table1", [None]*3,
{"A": ['a','b','c'], "B": [1,1,2], "C": [4,5,6]}])
self.apply_user_action(["CreateViewSection", 1, 0, "record", [2,3]])
# Verify metadata and actual data initially.
self.assertTables([
Table(1, "Table1", summarySourceTable=0, primaryViewId=1, columns=[
Column(1, "manualSort", "ManualSortPos", False, "", 0),
Column(2, "A", "Text", False, "", 0),
Column(3, "B", "Numeric", False, "", 0),
Column(4, "C", "Numeric", False, "", 0),
]),
Table(2, "GristSummary_6_Table1", summarySourceTable=1, primaryViewId=0, columns=[
Column(5, "A", "Text", False, "", 2),
Column(6, "B", "Numeric", False, "", 3),
Column(7, "group", "RefList:Table1", True, "table.getSummarySourceGroup(rec)", 0),
Column(8, "count", "Int", True, "len($group)", 0),
Column(9, "C", "Numeric", True, "SUM($group.C)", 0),
])
])
self.assertTableData('Table1', data=[
[ "id", "manualSort", "A", "B", "C" ],
[ 1, 1.0, 'a', 1.0, 4 ],
[ 2, 2.0, 'b', 1.0, 5 ],
[ 3, 3.0, 'c', 2.0, 6 ],
])
self.assertTableData('GristSummary_6_Table1', data=[
[ "id", "A", "B", "group", "count", "C" ],
[ 1, 'a', 1.0, [1], 1, 4 ],
[ 2, 'b', 1.0, [2], 1, 5 ],
[ 3, 'c', 2.0, [3], 1, 6 ],
])
# Remove column A, used for group-by.
self.apply_user_action(["RemoveColumn", "Table1", "A"])
# Verify that the conversion's result is as expected.
self.assertTables([
Table(1, "Table1", summarySourceTable=0, primaryViewId=1, columns=[
Column(1, "manualSort", "ManualSortPos", False, "", 0),
Column(3, "B", "Numeric", False, "", 0),
Column(4, "C", "Numeric", False, "", 0),
]),
Table(3, "GristSummary_6_Table1_2", summarySourceTable=1, primaryViewId=0, columns=[
Column(10, "B", "Numeric", False, "", 3),
Column(11, "count", "Int", True, "len($group)", 0),
Column(12, "C", "Numeric", True, "SUM($group.C)", 0),
Column(13, "group", "RefList:Table1", True, "table.getSummarySourceGroup(rec)", 0),
])
])
self.assertTableData('Table1', data=[
[ "id", "manualSort", "B", "C" ],
[ 1, 1.0, 1.0, 4 ],
[ 2, 2.0, 1.0, 5 ],
[ 3, 3.0, 2.0, 6 ],
])
self.assertTableData('GristSummary_6_Table1_2', data=[
[ "id", "B", "group", "count", "C" ],
[ 1, 1.0, [1,2], 2, 9 ],
[ 2, 2.0, [3], 1, 6 ],
])

File diff suppressed because it is too large Load Diff

@ -0,0 +1,317 @@
import logger
import testutil
import test_engine
from test_engine import Table, Column, View, Section, Field
log = logger.Logger(__name__, logger.INFO)
class TestTableActions(test_engine.EngineTestCase):
address_table_data = [
["id", "city", "state", "amount" ],
[ 21, "New York", "NY" , 1. ],
[ 22, "Albany", "NY" , 2. ],
[ 23, "Seattle", "WA" , 3. ],
[ 24, "Chicago", "IL" , 4. ],
[ 25, "Bedford", "MA" , 5. ],
[ 26, "New York", "NY" , 6. ],
[ 27, "Buffalo", "NY" , 7. ],
[ 28, "Bedford", "NY" , 8. ],
[ 29, "Boston", "MA" , 9. ],
[ 30, "Yonkers", "NY" , 10. ],
[ 31, "New York", "NY" , 11. ],
]
people_table_data = [
["id", "name", "address" ],
[ 1, "Alice", 22 ],
[ 2, "Bob", 25 ],
[ 3, "Carol", 27 ],
]
def init_sample_data(self):
# Add a couple of tables, including references.
self.apply_user_action(["AddTable", "Address", [
{"id": "city", "type": "Text"},
{"id": "state", "type": "Text"},
{"id": "amount", "type": "Numeric"},
]])
self.apply_user_action(["AddTable", "People", [
{"id": "name", "type": "Text"},
{"id": "address", "type": "Ref:Address"},
{"id": "city", "type": "Any", "formula": "$address.city" }
]])
# Populate some data.
d = testutil.table_data_from_rows("Address", self.address_table_data[0],
self.address_table_data[1:])
self.apply_user_action(["BulkAddRecord", "Address", d.row_ids, d.columns])
d = testutil.table_data_from_rows("People", self.people_table_data[0],
self.people_table_data[1:])
self.apply_user_action(["BulkAddRecord", "People", d.row_ids, d.columns])
# Add a view with several sections, including a summary table.
self.apply_user_action(["CreateViewSection", 1, 0, 'record', None])
self.apply_user_action(["CreateViewSection", 1, 3, 'record', [3]])
self.apply_user_action(["CreateViewSection", 2, 3, 'record', None])
# Verify the new structure of tables and views.
self.assertTables([
Table(1, "Address", primaryViewId=1, summarySourceTable=0, columns=[
Column(1, "manualSort", "ManualSortPos", False, "", 0),
Column(2, "city", "Text", False, "", 0),
Column(3, "state", "Text", False, "", 0),
Column(4, "amount", "Numeric", False, "", 0),
]),
Table(2, "People", primaryViewId=2, summarySourceTable=0, columns=[
Column(5, "manualSort", "ManualSortPos", False, "", 0),
Column(6, "name", "Text", False, "", 0),
Column(7, "address", "Ref:Address", False, "", 0),
Column(8, "city", "Any", True, "$address.city", 0),
]),
Table(3, "GristSummary_7_Address", 0, 1, columns=[
Column(9, "state", "Text", False, "", summarySourceCol=3),
Column(10, "group", "RefList:Address", True, summarySourceCol=0,
formula="table.getSummarySourceGroup(rec)"),
Column(11, "count", "Int", True, summarySourceCol=0, formula="len($group)"),
Column(12, "amount", "Numeric", True, summarySourceCol=0, formula="SUM($group.amount)"),
]),
])
self.assertViews([
View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(1, colRef=2),
Field(2, colRef=3),
Field(3, colRef=4),
]),
]),
View(2, sections=[
Section(2, parentKey="record", tableRef=2, fields=[
Field(4, colRef=6),
Field(5, colRef=7),
Field(6, colRef=8),
]),
]),
View(3, sections=[
Section(3, parentKey="record", tableRef=1, fields=[
Field(7, colRef=2),
Field(8, colRef=3),
Field(9, colRef=4),
]),
Section(4, parentKey="record", tableRef=3, fields=[
Field(10, colRef=9),
Field(11, colRef=11),
Field(12, colRef=12),
]),
Section(5, parentKey="record", tableRef=2, fields=[
Field(13, colRef=6),
Field(14, colRef=7),
Field(15, colRef=8),
]),
]),
])
# Verify the data we've loaded.
self.assertTableData('Address', cols="subset", data=self.address_table_data)
self.assertTableData('People', cols="subset", data=self.people_table_data)
self.assertTableData("GristSummary_7_Address", cols="subset", data=[
[ "id", "state", "count", "amount" ],
[ 1, "NY", 7, 1.+2+6+7+8+10+11 ],
[ 2, "WA", 1, 3. ],
[ 3, "IL", 1, 4. ],
[ 4, "MA", 2, 5.+9 ],
])
#----------------------------------------------------------------------
@test_engine.test_undo
def test_table_updates(self):
# Verify table renames triggered by UpdateRecord actions, and related behavior.
# Load a sample with a few table and views.
self.init_sample_data()
# Verify that we can rename tables via UpdatRecord actions, including multiple tables.
self.apply_user_action(["BulkUpdateRecord", "_grist_Tables", [1,2],
{"tableId": ["Location", "Persons"]}])
# Check that requested tables and summary tables got renamed correctly.
self.assertTableData('_grist_Tables', cols="subset", data=[
["id", "tableId"],
[1, "Location"],
[2, "Persons"],
[3, "GristSummary_8_Location"],
])
# Check that reference columns to renamed tables get their type modified.
self.assertTableData('_grist_Tables_column', rows="subset", cols="subset", data=[
["id", "colId", "type"],
[7, "address", "Ref:Location"],
[10, "group", "RefList:Location"],
])
# Do a bulk update to rename A and B to conflicting names.
self.apply_user_action(["AddTable", "A", [{"id": "a", "type": "Text"}]])
out_actions = self.apply_user_action(["BulkUpdateRecord", "_grist_Tables", [1,2],
{"tableId": ["A", "A"]}])
# See what doc-actions get generated.
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Persons", "address", {"type": "Int"}],
["ModifyColumn", "GristSummary_8_Location", "group", {"type": "Int"}],
["RenameTable", "Location", "A2"],
["RenameTable", "GristSummary_8_Location", "GristSummary_2_A2"],
["RenameTable", "Persons", "A3"],
["BulkUpdateRecord", "_grist_Tables", [1, 3, 2],
{"tableId": ["A2", "GristSummary_2_A2", "A3"]}],
["ModifyColumn", "A3", "address", {"type": "Ref:A2"}],
["ModifyColumn", "GristSummary_2_A2", "group", {"type": "RefList:A2"}],
["BulkUpdateRecord", "_grist_Tables_column", [7, 10], {"type": ["Ref:A2", "RefList:A2"]}],
]
})
# Check that requested tables and summary tables got renamed correctly.
self.assertTableData('_grist_Tables', cols="subset", data=[
["id", "tableId"],
[1, "A2"],
[2, "A3"],
[3, "GristSummary_2_A2"],
[4, "A"],
])
# Check that reference columns to renamed tables get their type modified.
self.assertTableData('_grist_Tables_column', rows="subset", cols="subset", data=[
["id", "colId", "type"],
[7, "address", "Ref:A2"],
[10, "group", "RefList:A2"],
])
# Verify the data we've loaded.
self.assertTableData('A2', cols="subset", data=self.address_table_data)
self.assertTableData('A3', cols="subset", data=self.people_table_data)
self.assertTableData("GristSummary_2_A2", cols="subset", data=[
[ "id", "state", "count", "amount" ],
[ 1, "NY", 7, 1.+2+6+7+8+10+11 ],
[ 2, "WA", 1, 3. ],
[ 3, "IL", 1, 4. ],
[ 4, "MA", 2, 5.+9 ],
])
#----------------------------------------------------------------------
@test_engine.test_undo
def test_table_renames_summary_by_ref(self):
# Verify table renames when there is a group-by column that's a Reference.
# This tests a potential bug since a table rename needs to modify Reference types, but
# group-by columns aren't supposed to be modifiable.
self.init_sample_data()
# Add a table grouped by a reference column (the 'Ref:Address' column named 'address').
self.apply_user_action(["CreateViewSection", 2, 0, 'record', [7]])
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula" ],
[ 13, "address", "Ref:Address", False, "" ],
[ 14, "group", "RefList:People", True, "table.getSummarySourceGroup(rec)" ],
[ 15, "count", "Int", True, "len($group)" ],
], rows=lambda r: (r.parentId.id == 4))
# Now rename the table Address -> Location.
out_actions = self.apply_user_action(["RenameTable", "Address", "Location"])
# See what doc-actions get generated.
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "People", "address", {"type": "Int"}],
["ModifyColumn", "GristSummary_7_Address", "group", {"type": "Int"}],
["ModifyColumn", "GristSummary_6_People", "address", {"type": "Int"}],
["RenameTable", "Address", "Location"],
["RenameTable", "GristSummary_7_Address", "GristSummary_8_Location"],
["BulkUpdateRecord", "_grist_Tables", [1, 3],
{"tableId": ["Location", "GristSummary_8_Location"]}],
["ModifyColumn", "People", "address", {"type": "Ref:Location"}],
["ModifyColumn", "GristSummary_8_Location", "group", {"type": "RefList:Location"}],
["ModifyColumn", "GristSummary_6_People", "address", {"type": "Ref:Location"}],
["BulkUpdateRecord", "_grist_Tables_column", [7, 10, 13],
{"type": ["Ref:Location", "RefList:Location", "Ref:Location"]}],
]
})
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula" ],
[ 13, "address", "Ref:Location", False, "" ],
[ 14, "group", "RefList:People", True, "table.getSummarySourceGroup(rec)" ],
[ 15, "count", "Int", True, "len($group)" ],
], rows=lambda r: (r.parentId.id == 4))
#----------------------------------------------------------------------
@test_engine.test_undo
def test_table_removes(self):
# Verify table removals triggered by UpdateRecord actions, and related behavior.
# Same setup as previous test.
self.init_sample_data()
# Add one more table, and one more view for tables #1 and #4 (those we are about to delete).
self.apply_user_action(["AddEmptyTable"])
out_actions = self.apply_user_action(["CreateViewSection", 1, 0, 'detail', None])
self.assertEqual(out_actions.retValues[0]["viewRef"], 5)
self.apply_user_action(["CreateViewSection", 4, 5, 'detail', None])
# See what's in TableViews and TabBar tables, to verify after we remove a table.
self.assertTableData('_grist_TableViews', data=[
["id", "tableRef", "viewRef"],
[1, 1, 3],
[2, 1, 5],
])
self.assertTableData('_grist_TabBar', cols="subset", data=[
["id", "viewRef"],
[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
])
# Remove two tables, ensure certain views get removed.
self.apply_user_action(["BulkRemoveRecord", "_grist_Tables", [1, 4]])
# See that some TableViews/TabBar entries disappear, or tableRef gets unset.
self.assertTableData('_grist_TableViews', data=[
["id", "tableRef", "viewRef"],
[1, 0, 3],
])
self.assertTableData('_grist_TabBar', cols="subset", data=[
["id", "viewRef"],
[2, 2],
[3, 3],
])
# Check that reference columns to this table get removed, with associated fields.
self.assertTables([
Table(2, "People", primaryViewId=2, summarySourceTable=0, columns=[
Column(5, "manualSort", "ManualSortPos", False, "", 0),
Column(6, "name", "Text", False, "", 0),
Column(8, "city", "Any", True, "$address.city", 0),
]),
# Note that the summary table is also gone.
])
self.assertViews([
View(2, sections=[
Section(2, parentKey="record", tableRef=2, fields=[
Field(4, colRef=6),
Field(6, colRef=8),
]),
]),
View(3, sections=[
Section(5, parentKey="record", tableRef=2, fields=[
Field(13, colRef=6),
Field(15, colRef=8),
]),
]),
])

@ -0,0 +1,174 @@
import actions
import schema
import table_data_set
import testutil
import difflib
import json
import unittest
class TestTableDataSet(unittest.TestCase):
"""
Tests functionality of TableDataSet by running through all the test cases in testscript.json.
"""
@classmethod
def init_test_cases(cls):
# Create a test_* method for each case in testscript, which runs `self._run_test_body()`.
cls.samples, test_cases = testutil.parse_testscript()
for case in test_cases:
cls._create_test_case(case["TEST_CASE"], case["BODY"])
@classmethod
def _create_test_case(cls, name, body):
setattr(cls, "test_" + name, lambda self: self._run_test_body(body))
def setUp(self):
self._table_data_set = None
def load_sample(self, sample):
"""
Load _table_data_set with given sample data. The sample is a dict with keys "SCHEMA" and
"DATA", each a dictionary mapping table names to actions.TableData objects. "SCHEMA" contains
"_grist_Tables" and "_grist_Tables_column" tables.
"""
self._table_data_set = table_data_set.TableDataSet()
for a in schema.schema_create_actions():
if a.table_id not in self._table_data_set.all_tables:
self._table_data_set.apply_doc_action(a)
for a in sample["SCHEMA"].itervalues():
self._table_data_set.BulkAddRecord(*a)
# Create AddTable actions for each table described in the metadata.
meta_tables = self._table_data_set.all_tables['_grist_Tables']
meta_columns = self._table_data_set.all_tables['_grist_Tables_column']
add_tables = {} # maps the row_id of the table to the schema object for the table.
for rec in actions.transpose_bulk_action(meta_tables):
add_tables[rec.id] = actions.AddTable(rec.tableId, [])
# Go through all columns, adding them to the appropriate tables.
for rec in actions.transpose_bulk_action(meta_columns):
add_tables[rec.parentId].columns.append({
"id": rec.colId,
"type": rec.type,
"widgetOptions": rec.widgetOptions,
"isFormula": rec.isFormula,
"formula": rec.formula,
"label" : rec.label,
"parentPos": rec.parentPos,
})
# Sort the columns in the schema according to the parentPos field from the column records.
for action in add_tables.itervalues():
action.columns.sort(key=lambda r: r["parentPos"])
self._table_data_set.AddTable(*action)
for a in sample["DATA"].itervalues():
self._table_data_set.ReplaceTableData(*a)
def _run_test_body(self, body):
"""Runs the actual script defined in the JSON test-script file."""
undo_actions = []
loaded_sample = None
for line, step, data in body:
try:
if step == "LOAD_SAMPLE":
if loaded_sample:
# Pylint's type checking gives a false positive for loaded_sample.
# pylint: disable=unsubscriptable-object
self._verify_undo_all(undo_actions, loaded_sample["DATA"])
loaded_sample = self.samples[data]
self.load_sample(loaded_sample)
elif step == "APPLY":
self._apply_stored_actions(data['ACTIONS']['stored'])
if 'calc' in data['ACTIONS']:
self._apply_stored_actions(data['ACTIONS']['calc'])
undo_actions.extend(data['ACTIONS']['undo'])
elif step == "CHECK_OUTPUT":
expected_data = {}
if "USE_SAMPLE" in data:
expected_data = self.samples[data.pop("USE_SAMPLE")]["DATA"].copy()
expected_data.update({t: testutil.table_data_from_rows(t, tdata[0], tdata[1:])
for (t, tdata) in data.iteritems()})
self._verify_data(expected_data)
else:
raise ValueError("Unrecognized step %s in test script" % step)
except Exception, e:
new_args0 = "LINE %s: %s" % (line, e.args[0])
e.args = (new_args0,) + e.args[1:]
raise
self._verify_undo_all(undo_actions, loaded_sample["DATA"])
def _apply_stored_actions(self, stored_actions):
for action in stored_actions:
self._table_data_set.apply_doc_action(actions.action_from_repr(action))
def _verify_undo_all(self, undo_actions, expected_data):
"""
At the end of each test, undo all and verify we get back to the originally loaded sample.
"""
self._apply_stored_actions(reversed(undo_actions))
del undo_actions[:]
self._verify_data(expected_data, ignore_formulas=True)
def _verify_data(self, expected_data, ignore_formulas=False):
observed_data = {t: self._prep_data(*data)
for t, data in self._table_data_set.all_tables.iteritems()
if not t.startswith("_grist_")}
if ignore_formulas:
observed_data = self._strip_formulas(observed_data)
expected_data = self._strip_formulas(expected_data)
if observed_data != expected_data:
lines = []
for table in sorted(observed_data.viewkeys() | expected_data.viewkeys()):
if table not in expected_data:
lines.append("*** Table %s observed but not expected\n" % table)
elif table not in observed_data:
lines.append("*** Table %s not observed but was expected\n" % table)
else:
obs, exp = observed_data[table], expected_data[table]
if obs != exp:
o_lines = self._get_text_lines(obs)
e_lines = self._get_text_lines(exp)
lines.append("*** Table %s differs\n" % table)
lines.extend(difflib.unified_diff(e_lines, o_lines,
fromfile="expected", tofile="observed"))
self.fail("\n" + "".join(lines))
def _strip_formulas(self, all_data):
return {t: self._strip_formulas_table(*data) for t, data in all_data.iteritems()}
def _strip_formulas_table(self, table_id, row_ids, columns):
return actions.TableData(table_id, row_ids, {
col_id: col for col_id, col in columns.iteritems()
if not self._table_data_set.get_col_info(table_id, col_id)["isFormula"]
})
@classmethod
def _prep_data(cls, table_id, row_ids, columns):
def sort(col):
return [v for r, v in sorted(zip(row_ids, col))]
sorted_data = actions.TableData(table_id, sorted(row_ids),
{c: sort(col) for c, col in columns.iteritems()})
return actions.encode_objects(testutil.replace_nans(sorted_data))
@classmethod
def _get_text_lines(cls, table_data):
col_items = sorted(table_data.columns.items())
col_items.insert(0, ('id', table_data.row_ids))
table_rows = zip(*[[col_id] + values for (col_id, values) in col_items])
return [json.dumps(row) + "\n" for row in table_rows]
# Parse and create test cases on module load. This way the python unittest feature to run only
# particular test cases can apply to these cases too.
TestTableDataSet.init_test_cases()
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,85 @@
import unittest
import asttokens
import re
import textbuilder
from textbuilder import make_patch, make_regexp_patches, Patch
class TestTextBuilder(unittest.TestCase):
def test_validate_patch(self):
text = "To be or not to be"
patch = make_patch(text, 3, 8, "SEE OR")
self.assertEquals(textbuilder.validate_patch(text, patch), None)
with self.assertRaises(ValueError):
textbuilder.validate_patch('X' + text, patch)
def test_replacer(self):
value = object()
t1 = textbuilder.Text("To be or not\n to be?\n", value)
patches = make_regexp_patches(t1.get_text(), re.compile(r'be|to', re.I),
lambda m: (m.group() + m.group()).upper())
t2 = textbuilder.Replacer(t1, patches)
self.assertEquals(t2.get_text(), "TOTO BEBE or not\n TOTO BEBE?\n")
self.assertEquals(t2.map_back_patch(make_patch(t2.get_text(), 0, 4, "xxx")),
(t1.get_text(), value, Patch(0, 2, "To", "xxx")))
self.assertEquals(t2.map_back_patch(make_patch(t2.get_text(), 5, 9, "xxx")),
(t1.get_text(), value, Patch(3, 5, "be", "xxx")))
self.assertEquals(t2.map_back_patch(make_patch(t2.get_text(), 18, 23, "xxx")),
(t1.get_text(), value, Patch(14, 17, " to", "xxx")))
# Match the entire second line
self.assertEquals(t2.map_back_patch(make_patch(t2.get_text(), 17, 29, "xxx")),
(t1.get_text(), value, Patch(13, 21, " to be?", "xxx")))
def test_combiner(self):
valueA, valueB = object(), object()
t1 = textbuilder.Text("To be or not\n to be?\n", valueA)
patches = make_regexp_patches(t1.get_text(), re.compile(r'be|to', re.I),
lambda m: (m.group() + m.group()).upper())
t2 = textbuilder.Replacer(t1, patches)
t3 = textbuilder.Text("That is the question", valueB)
t4 = textbuilder.Combiner(["[", t2, t3, "]"])
self.assertEqual(t4.get_text(), "[TOTO BEBE or not\n TOTO BEBE?\nThat is the question]")
self.assertEqual(t4.map_back_patch(make_patch(t4.get_text(), 1, 5, "xxx")),
(t1.get_text(), valueA, Patch(0, 2, "To", "xxx")))
self.assertEqual(t4.map_back_patch(make_patch(t4.get_text(), 18, 30, "xxx")),
(t1.get_text(), valueA, Patch(13, 21, " to be?", "xxx")))
self.assertEqual(t4.map_back_patch(make_patch(t4.get_text(), 0, 1, "xxx")),
None)
self.assertEqual(t4.map_back_patch(make_patch(t4.get_text(), 31, 38, "xxx")),
(t3.get_text(), valueB, Patch(0, 7, "That is", "xxx")))
def test_linenumbers(self):
ln = asttokens.LineNumbers("Hello\nworld\nThis\n\nis\n\na test.\n")
self.assertEqual(ln.line_to_offset(1, 0), 0)
self.assertEqual(ln.line_to_offset(1, 5), 5)
self.assertEqual(ln.line_to_offset(2, 0), 6)
self.assertEqual(ln.line_to_offset(2, 5), 11)
self.assertEqual(ln.line_to_offset(3, 0), 12)
self.assertEqual(ln.line_to_offset(4, 0), 17)
self.assertEqual(ln.line_to_offset(5, 0), 18)
self.assertEqual(ln.line_to_offset(6, 0), 21)
self.assertEqual(ln.line_to_offset(7, 0), 22)
self.assertEqual(ln.line_to_offset(7, 7), 29)
self.assertEqual(ln.offset_to_line(0), (1, 0))
self.assertEqual(ln.offset_to_line(5), (1, 5))
self.assertEqual(ln.offset_to_line(6), (2, 0))
self.assertEqual(ln.offset_to_line(11), (2, 5))
self.assertEqual(ln.offset_to_line(12), (3, 0))
self.assertEqual(ln.offset_to_line(17), (4, 0))
self.assertEqual(ln.offset_to_line(18), (5, 0))
self.assertEqual(ln.offset_to_line(21), (6, 0))
self.assertEqual(ln.offset_to_line(22), (7, 0))
self.assertEqual(ln.offset_to_line(29), (7, 7))
# Test that out-of-bounds inputs still return something sensible.
self.assertEqual(ln.line_to_offset(6, 19), 30)
self.assertEqual(ln.line_to_offset(100, 99), 30)
self.assertEqual(ln.line_to_offset(2, -1), 6)
self.assertEqual(ln.line_to_offset(-1, 99), 0)
self.assertEqual(ln.offset_to_line(30), (8, 0))
self.assertEqual(ln.offset_to_line(100), (8, 0))
self.assertEqual(ln.offset_to_line(-100), (1, 0))
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,59 @@
from collections import namedtuple
import unittest
from treeview import fix_indents
Item = namedtuple('Item', 'id indentation')
def fix_and_check(items, changes):
# convert from strings to items with ids and indents (e.g. "A0" -> {id: "A", indent: 0} returns
# the pair (adjustments, resulting items converted to strings) for verification
all_items = [Item(i[0], int(i[1:])) for i in items]
adjustments = fix_indents(all_items, changes)
fix_map = {id: indentation for id, indentation in adjustments}
all_items = [i for i in all_items if i.id not in changes]
result = ['%s%s' % (i.id, fix_map.get(i.id, i.indentation)) for i in all_items]
return (adjustments, result)
class TestTreeView(unittest.TestCase):
def test_fix_indents(self):
self.assertEqual(fix_and_check(["A0", "B0", "C1", "D1"], {"B"}), (
[("C", 0)],
["A0", "C0", "D1"]))
self.assertEqual(fix_and_check(["A0", "B1", "C1", "D1"], {"B"}), (
[],
["A0", "C1", "D1"]))
self.assertEqual(fix_and_check(["A0", "B0", "C1", "D2", "E3", "F2", "G1", "H0"], {"B"}), (
[("C", 0), ("D", 1), ("E", 2)],
["A0", "C0", "D1", "E2", "F2", "G1", "H0"]))
self.assertEqual(fix_and_check(["A0", "B1", "C1", "D1"], {"A", "B"}), (
[("C", 0)],
["C0", "D1"]))
self.assertEqual(fix_and_check(["A0", "B0", "C1", "D1"], {"A", "B"}), (
[("C", 0)],
["C0", "D1"]))
self.assertEqual(fix_and_check(["A0", "B1", "C2", "D0"], {"A", "B"}), (
[("C", 0)],
["C0", "D0"]))
self.assertEqual(fix_and_check(["A0", "B1", "C2", "D0"], {"A", "C"}), (
[("B", 0)],
["B0", "D0"]))
self.assertEqual(fix_and_check(["A0", "B1", "C2", "D0"], {"B", "C"}), (
[],
["A0", "D0"]))
self.assertEqual(fix_and_check(["A0", "B1", "C2", "D0", "E0"], {"B", "D"}), (
[("C", 1)],
["A0", "C1", "E0"]))
self.assertEqual(fix_and_check(["A0", "B1", "C2", "D0", "E1"], {"B", "D"}), (
[("C", 1), ("E", 0)],
["A0", "C1", "E0"]))

@ -0,0 +1,154 @@
import unittest
import twowaymap
class TestTwoWayMap(unittest.TestCase):
def assertTwoWayMap(self, twmap, forward, reverse):
map_repr = (
{ k: twmap.lookup_left(k) for k in twmap.left_all() },
{ k: twmap.lookup_right(k) for k in twmap.right_all() }
)
self.assertEqual(map_repr, (forward, reverse))
def test_set_list(self):
tmap = twowaymap.TwoWayMap(left=set, right=list)
self.assertFalse(tmap)
tmap.insert(1, "a")
self.assertTrue(tmap)
self.assertTwoWayMap(tmap, {1: ["a"]}, {"a": {1}})
tmap.insert(1, "a") # should be a no-op, since this pair already exists
tmap.insert(1, "b")
tmap.insert(2, "a")
self.assertTwoWayMap(tmap, {1: ["a", "b"], 2: ["a"]}, {"a": {1,2}, "b": {1}})
tmap.insert(1, "b")
tmap.insert(2, "b")
self.assertTwoWayMap(tmap, {1: ["a", "b"], 2: ["a", "b"]}, {"a": {1,2}, "b": {1,2}})
tmap.remove(1, "b")
tmap.remove(2, "b")
self.assertTwoWayMap(tmap, {1: ["a"], 2: ["a"]}, {"a": {1,2}})
tmap.insert(1, "b")
tmap.insert(2, "b")
tmap.remove_left(1)
self.assertTwoWayMap(tmap, {2: ["a", "b"]}, {"a": {2}, "b": {2}})
tmap.insert(1, "a")
tmap.insert(2, "b")
tmap.remove_right("b")
self.assertTwoWayMap(tmap, {1: ["a"], 2: ["a"]}, {"a": {1,2}})
self.assertTrue(tmap)
tmap.clear()
self.assertTwoWayMap(tmap, {}, {})
self.assertFalse(tmap)
def test_set_single(self):
tmap = twowaymap.TwoWayMap(left=set, right="single")
self.assertFalse(tmap)
tmap.insert(1, "a")
self.assertTrue(tmap)
self.assertTwoWayMap(tmap, {1: "a"}, {"a": {1}})
tmap.insert(1, "a") # should be a no-op, since this pair already exists
tmap.insert(1, "b")
tmap.insert(2, "a")
self.assertTwoWayMap(tmap, {1: "b", 2: "a"}, {"a": {2}, "b": {1}})
tmap.insert(1, "b")
tmap.insert(2, "b")
self.assertTwoWayMap(tmap, {1: "b", 2: "b"}, {"b": {1,2}})
tmap.remove(1, "b")
self.assertTwoWayMap(tmap, {2: "b"}, {"b": {2}})
tmap.remove(2, "b")
self.assertTwoWayMap(tmap, {}, {})
tmap.insert(1, "b")
tmap.insert(2, "b")
self.assertTwoWayMap(tmap, {1: "b", 2: "b"}, {"b": {1,2}})
tmap.remove_left(1)
self.assertTwoWayMap(tmap, {2: "b"}, {"b": {2}})
tmap.insert(1, "a")
tmap.insert(2, "b")
tmap.remove_right("b")
self.assertTwoWayMap(tmap, {1: "a"}, {"a": {1}})
self.assertTrue(tmap)
tmap.clear()
self.assertTwoWayMap(tmap, {}, {})
self.assertFalse(tmap)
def test_strict_list(self):
tmap = twowaymap.TwoWayMap(left="strict", right=list)
self.assertFalse(tmap)
tmap.insert(1, "a")
self.assertTrue(tmap)
self.assertTwoWayMap(tmap, {1: ["a"]}, {"a": 1})
tmap.insert(1, "a") # should be a no-op, since this pair already exists
tmap.insert(1, "b")
with self.assertRaises(ValueError):
tmap.insert(2, "a")
self.assertTwoWayMap(tmap, {1: ["a", "b"]}, {"a": 1, "b": 1})
tmap.insert(1, "b")
with self.assertRaises(ValueError):
tmap.insert(2, "b")
tmap.insert(2, "c")
self.assertTwoWayMap(tmap, {1: ["a", "b"], 2: ["c"]}, {"a": 1, "b": 1, "c": 2})
tmap.remove(1, "b")
self.assertTwoWayMap(tmap, {1: ["a"], 2: ["c"]}, {"a": 1, "c": 2})
tmap.remove(2, "b")
self.assertTwoWayMap(tmap, {1: ["a"], 2: ["c"]}, {"a": 1, "c": 2})
tmap.insert(1, "b")
with self.assertRaises(ValueError):
tmap.insert(2, "b")
self.assertTwoWayMap(tmap, {1: ["a", "b"], 2: ["c"]}, {"a": 1, "b": 1, "c": 2})
tmap.remove_left(1)
self.assertTwoWayMap(tmap, {2: ["c"]}, {"c": 2})
tmap.insert(1, "a")
tmap.insert(2, "b")
tmap.remove_right("b")
self.assertTwoWayMap(tmap, {1: ["a"], 2: ["c"]}, {"a": 1, "c": 2})
self.assertTrue(tmap)
tmap.clear()
self.assertTwoWayMap(tmap, {}, {})
self.assertFalse(tmap)
def test_strict_single(self):
tmap = twowaymap.TwoWayMap(left="strict", right="single")
tmap.insert(1, "a")
tmap.insert(2, "b")
tmap.insert(2, "c")
self.assertTwoWayMap(tmap, {1: "a", 2: "c"}, {"a": 1, "c": 2})
with self.assertRaises(ValueError):
tmap.insert(2, "a")
tmap.insert(2, "c") # This pair already exists, so not an error.
self.assertTwoWayMap(tmap, {1: "a", 2: "c"}, {"a": 1, "c": 2})
def test_nonhashable(self):
# Test that we don't get into an inconsistent state if we attempt to use a non-hashable value.
tmap = twowaymap.TwoWayMap(left=list, right=list)
tmap.insert(1, "a")
self.assertTwoWayMap(tmap, {1: ["a"]}, {"a": [1]})
with self.assertRaises(TypeError):
tmap.insert(1, {})
with self.assertRaises(TypeError):
tmap.insert({}, "a")
self.assertTwoWayMap(tmap, {1: ["a"]}, {"a": [1]})
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,595 @@
# -*- coding: utf-8 -*-
# pylint: disable=line-too-long
import logger
import testutil
import test_engine
log = logger.Logger(__name__, logger.INFO)
class TestTypes(test_engine.EngineTestCase):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Types", [
[21, "text", "Text", False, "", "", ""],
[22, "numeric", "Numeric", False, "", "", ""],
[23, "int", "Int", False, "", "", ""],
[24, "bool", "Bool", False, "", "", ""],
[25, "date", "Date", False, "", "", ""]
]],
[2, "Formulas", [
[30, "division", "Any", True, "Types.lookupOne(id=18).numeric / 2", "", ""]
]]
],
"DATA": {
"Types": [
["id", "text", "numeric", "int", "bool", "date"],
[11, "New York", "New York", "New York", "New York", "New York"],
[12, "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö"],
[13, False, False, False, False, False],
[14, True, True, True, True, True],
[15, 1509556595, 1509556595, 1509556595, 1509556595, 1509556595],
[16, 8.153, 8.153, 8.153, 8.153, 8.153],
[17, 0, 0, 0, 0, 0],
[18, 1, 1, 1, 1, 1],
[19, "", "", "", "", ""],
[20, None, None, None, None, None]],
"Formulas": [
["id"],
[1]]
},
})
all_row_ids = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
def test_update_typed_cells(self):
"""
Tests that updated typed values are set as expected in the sandbox. Types should follow
the rules:
- After updating a cell with a value of a type compatible to the column type,
the cell value should have the column's standard type
- Otherwise, the cell value should have the type AltText
"""
self.load_sample(self.sample)
out_actions = self.apply_user_action(["BulkUpdateRecord", "Types", self.all_row_ids, {
"text": [None, "", 1, 0, 8.153, 1509556595, True, False, u"Chîcágö", "New York"],
"numeric": [None, "", 1, 0, 8.153, 1509556595, True, False, u"Chîcágö", "New York"],
"int": [None, "", 1, 0, 8.153, 1509556595, True, False, u"Chîcágö", "New York"],
"bool": [None, "", 1, 0, 8.153, 1509556595, True, False, u"Chîcágö", "New York"],
"date": [None, "", 1, 0, 8.153, 1509556595, True, False, u"2019-01-22 00:47:39", "New York"]
}])
self.assertPartialOutActions(out_actions, {
"stored": [["BulkUpdateRecord", "Types", self.all_row_ids, {
"text": [None,"","1","0","8.153","1509556595","True","False","Chîcágö","New York"],
"numeric": [None, None, 1.0, 0.0, 8.153, 1509556595.0, 1.0, 0.0, "Chîcágö", "New York"],
"int": [None, None, 1, 0, 8, 1509556595, 1, 0, "Chîcágö", "New York"],
"bool": [False, False, True, False, True, True, True, False, "Chîcágö", "New York"],
"date": [None, None, 1.0, 0.0, 8.153, 1509556595.0, 1.0, 0.0, 1548115200.0, "New York"]
}]],
"undo": [["BulkUpdateRecord", "Types", self.all_row_ids, {
"text": ["New York", "Chîcágö", False, True, 1509556595, 8.153, 0, 1, "", None],
"numeric": ["New York", "Chîcágö", False, True, 1509556595, 8.153, 0, 1, "", None],
"int": ["New York", "Chîcágö", False, True, 1509556595, 8.153, 0, 1, "", None],
"bool": ["New York", "Chîcágö", False, True, 1509556595, 8.153, False, True, "", None],
"date": ["New York", "Chîcágö", False, True, 1509556595, 8.153, 0, 1, "", None]
}]]
})
self.assertTableData("Types", data=[
["id", "text", "numeric", "int", "bool", "date"],
[11, None, None, None, False, None],
[12, "", None, None, False, None],
[13, "1", 1.0, 1, True, 1.0],
[14, "0", 0.0, 0, False, 0.0],
[15, "8.153", 8.153, 8, True, 8.153],
[16, "1509556595", 1509556595, 1509556595, True, 1509556595.0],
[17, "True", 1.0, 1, True, 1.0],
[18, "False", 0.0, 0, False, 0.0],
[19, "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö", 1548115200.0],
[20, "New York", "New York", "New York", "New York", "New York"]
])
def test_text_conversions(self):
"""
Tests that column type changes occur as expected in the sandbox:
- Resulting cell values should all be Text
- Only non-compatible values should appear in the resulting BulkUpdateRecord
"""
self.load_sample(self.sample)
# Test Text -> Text conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "text", { "type" : "Text" }])
self.assertPartialOutActions(out_actions, {
"stored": [],
"undo": []
})
# Test Numeric -> Text conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "numeric", { "type" : "Text" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "numeric", {"type": "Text"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18],
{"numeric": ["False", "True", "1509556595.0", "8.153", "0.0", "1.0"]}],
["UpdateRecord", "_grist_Tables_column", 22, {"type": "Text"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18],
{"numeric": [False, True, 1509556595, 8.153, 0, 1]}],
["ModifyColumn", "Types", "numeric", {"type": "Numeric"}],
["UpdateRecord", "_grist_Tables_column", 22, {"type": "Numeric"}],
]
})
# Test Int -> Text conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "int", { "type" : "Text" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "int", {"type": "Text"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18],
{"int": ["False", "True", "1509556595", "8.153", "0", "1"]}],
["UpdateRecord", "_grist_Tables_column", 23, {"type": "Text"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18],
{"int": [False, True, 1509556595, 8.153, 0, 1]}],
["ModifyColumn", "Types", "int", {"type": "Int"}],
["UpdateRecord", "_grist_Tables_column", 23, {"type": "Int"}],
]
})
# Test Bool -> Text
out_actions = self.apply_user_action(["ModifyColumn", "Types", "bool", { "type" : "Text" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "bool", {"type": "Text"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18],
{"bool": ["False", "True", "1509556595", "8.153", "False", "True"]}],
["UpdateRecord", "_grist_Tables_column", 24, {"type": "Text"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18],
{"bool": [False, True, 1509556595, 8.153, False, True]}],
["ModifyColumn", "Types", "bool", {"type": "Bool"}],
["UpdateRecord", "_grist_Tables_column", 24, {"type": "Bool"}],
]
})
# Test Date -> Text
out_actions = self.apply_user_action(["ModifyColumn", "Types", "date", { "type" : "Text" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "date", {"type": "Text"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18],
{"date": ["False", "True", "1509556595", "8.153", "0", "1"]}],
["UpdateRecord", "_grist_Tables_column", 25, {"type": "Text"}]
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18],
{"date": [False, True, 1509556595, 8.153, 0, 1]}],
["ModifyColumn", "Types", "date", {"type": "Date"}],
["UpdateRecord", "_grist_Tables_column", 25, {"type": "Date"}]
]
})
# Assert that the final table is as expected
self.assertTableData("Types", data=[
["id", "text", "numeric", "int", "bool", "date"],
[11, "New York", "New York", "New York", "New York", "New York"],
[12, "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö"],
[13, False, "False", "False", "False", "False"],
[14, True, "True", "True", "True", "True"],
[15, 1509556595, "1509556595.0","1509556595","1509556595","1509556595"],
[16, 8.153, "8.153", "8.153", "8.153", "8.153"],
[17, 0, "0.0", "0", "False", "0"],
[18, 1, "1.0", "1", "True", "1"],
[19, "", "", "", "", ""],
[20, None, None, None, None, None]
])
def test_numeric_conversions(self):
"""
Tests that column type changes occur as expected in the sandbox:
- Resulting cell values should all be of type Numeric or AltText
- Only non-compatible values should appear in the resulting BulkUpdateRecord
"""
self.load_sample(self.sample)
# Test Text -> Numeric conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "text", { "type" : "Numeric" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "text", {"type": "Numeric"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"text": [0.0, 1.0, 1509556595.0, 0.0, 1.0, None]}],
["UpdateRecord", "_grist_Tables_column", 21, {"type": "Numeric"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"text": [False, True, 1509556595, 0, 1, ""]}],
["ModifyColumn", "Types", "text", {"type": "Text"}],
["UpdateRecord", "_grist_Tables_column", 21, {"type": "Text"}],
]
})
# Test Numeric -> Numeric conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "numeric", {"type": "Numeric"}])
self.assertPartialOutActions(out_actions, {
"stored": [],
"undo": []
})
# Test Int -> Numeric conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "int", { "type" : "Numeric" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "int", {"type": "Numeric"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"int": [0.0, 1.0, 1509556595.0, 0.0, 1.0, None]}],
["UpdateRecord", "_grist_Tables_column", 23, {"type": "Numeric"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"int": [False, True, 1509556595, 0, 1, ""]}],
["ModifyColumn", "Types", "int", {"type": "Int"}],
["UpdateRecord", "_grist_Tables_column", 23, {"type": "Int"}],
]
})
# Test Bool -> Numeric conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "bool", { "type" : "Numeric" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "bool", {"type": "Numeric"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"bool": [0.0, 1.0, 1509556595.0, 0.0, 1.0, None]}],
["UpdateRecord", "_grist_Tables_column", 24, {"type": "Numeric"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"bool": [False, True, 1509556595, False, True, ""]}],
["ModifyColumn", "Types", "bool", {"type": "Bool"}],
["UpdateRecord", "_grist_Tables_column", 24, {"type": "Bool"}],
]
})
# Test Date -> Numeric conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "date", { "type" : "Numeric" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "date", {"type": "Numeric"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"date": [0.0, 1.0, 1509556595.0, 0.0, 1.0, None]}],
["UpdateRecord", "_grist_Tables_column", 25, {"type": "Numeric"}]
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"date": [False, True, 1509556595, 0, 1, ""]}],
["ModifyColumn", "Types", "date", {"type": "Date"}],
["UpdateRecord", "_grist_Tables_column", 25, {"type": "Date"}]
]
})
# Assert that the final table is as expected
self.assertTableData("Types", data=[
["id", "text", "numeric", "int", "bool", "date"],
[11, "New York", "New York", "New York", "New York", "New York"],
[12, "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö"],
[13, 0.0, False, 0.0, 0.0, 0.0],
[14, 1.0, True, 1.0, 1.0, 1.0],
[15, 1509556595, 1509556595, 1509556595, 1509556595, 1509556595],
[16, 8.153, 8.153, 8.153, 8.153, 8.153],
[17, 0.0, 0.0, 0.0, 0.0, 0.0],
[18, 1.0, 1.0, 1.0, 1.0, 1.0],
[19, None, "", None, None, None],
[20, None, None, None, None, None],
])
def test_int_conversions(self):
"""
Tests that column type changes occur as expected in the sandbox:
- Resulting cell values should all be of type Int or AltText
- Only non-compatible values should appear in the resulting BulkUpdateRecord
"""
self.load_sample(self.sample)
# Test Text -> Int conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "text", { "type" : "Int" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "text", {"type": "Int"}],
["BulkUpdateRecord", "Types", [13, 14, 16, 19], {"text": [0, 1, 8, None]}],
["UpdateRecord", "_grist_Tables_column", 21, {"type": "Int"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 16, 19],
{"text": [False, True, 8.153, ""]}],
["ModifyColumn", "Types", "text", {"type": "Text"}],
["UpdateRecord", "_grist_Tables_column", 21, {"type": "Text"}],
]
})
# Test Numeric -> Int conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "numeric", { "type" : "Int" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "numeric", {"type": "Int"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18, 19],
{"numeric": [0, 1, 1509556595, 8, 0, 1, None]}],
["UpdateRecord", "_grist_Tables_column", 22, {"type": "Int"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 16, 17, 18, 19],
{"numeric": [False, True, 1509556595.0, 8.153, 0.0, 1.0, ""]}],
["ModifyColumn", "Types", "numeric", {"type": "Numeric"}],
["UpdateRecord", "_grist_Tables_column", 22, {"type": "Numeric"}],
]
})
# Test Int -> Int conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "int", { "type" : "Int" }])
self.assertPartialOutActions(out_actions, {
"stored": [],
"undo": []
})
# Test Bool -> Int conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "bool", { "type" : "Int" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "bool", {"type": "Int"}],
["BulkUpdateRecord", "Types", [13, 14, 16, 17, 18, 19],
{"bool": [0, 1, 8, 0, 1, None]}],
["UpdateRecord", "_grist_Tables_column", 24, {"type": "Int"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 16, 17, 18, 19],
{"bool": [False, True, 8.153, False, True, ""]}],
["ModifyColumn", "Types", "bool", {"type": "Bool"}],
["UpdateRecord", "_grist_Tables_column", 24, {"type": "Bool"}],
]
})
# Test Date -> Int conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "date", { "type" : "Int" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "date", {"type": "Int"}],
["BulkUpdateRecord", "Types", [13, 14, 16, 19], {"date": [0, 1, 8, None]}],
["UpdateRecord", "_grist_Tables_column", 25, {"type": "Int"}]
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 16, 19],
{"date": [False, True, 8.153, ""]}],
["ModifyColumn", "Types", "date", {"type": "Date"}],
["UpdateRecord", "_grist_Tables_column", 25, {"type": "Date"}]
]
})
# Assert that the final table is as expected
self.assertTableData("Types", data=[
["id", "text", "numeric", "int", "bool", "date"],
[11, "New York", "New York", "New York", "New York", "New York"],
[12, "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö"],
[13, 0, 0, False, 0, 0],
[14, 1, 1, True, 1, 1],
[15, 1509556595, 1509556595, 1509556595, 1509556595, 1509556595],
[16, 8, 8, 8.153, 8, 8],
[17, 0, 0, 0, 0, 0],
[18, 1, 1, 1, 1, 1],
[19, None, None, "", None, None],
[20, None, None, None, None, None]
])
def test_bool_conversions(self):
"""
Tests that column type changes occur as expected in the sandbox:
- Resulting cell values should all be of type Bool or AltText
- Only non-compatible values should appear in the resulting BulkUpdateRecord
"""
self.load_sample(self.sample)
# Test Text -> Bool conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "text", { "type" : "Bool" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "text", {"type": "Bool"}],
["BulkUpdateRecord", "Types", [15, 16, 17, 18, 19, 20],
{"text": [True, True, False, True, False, False]}],
["UpdateRecord", "_grist_Tables_column", 21, {"type": "Bool"}],
],
"undo": [
["BulkUpdateRecord", "Types", [15, 16, 17, 18, 19, 20],
{"text": [1509556595, 8.153, 0, 1, "", None]}],
["ModifyColumn", "Types", "text", {"type": "Text"}],
["UpdateRecord", "_grist_Tables_column", 21, {"type": "Text"}],
]
})
# Test Numeric -> Bool conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "numeric", { "type" : "Bool" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "numeric", {"type": "Bool"}],
["BulkUpdateRecord", "Types", [15, 16, 17, 18, 19, 20],
{"numeric": [True, True, False, True, False, False]}],
["UpdateRecord", "_grist_Tables_column", 22, {"type": "Bool"}],
],
"undo": [
["BulkUpdateRecord", "Types", [15, 16, 17, 18, 19, 20],
{"numeric": [1509556595.0, 8.153, 0.0, 1.0, "", None]}],
["ModifyColumn", "Types", "numeric", {"type": "Numeric"}],
["UpdateRecord", "_grist_Tables_column", 22, {"type": "Numeric"}],
]
})
# Test Int -> Bool conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "int", { "type" : "Bool" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "int", {"type": "Bool"}],
["BulkUpdateRecord", "Types", [15, 16, 17, 18, 19, 20],
{"int": [True, True, False, True, False, False]}],
["UpdateRecord", "_grist_Tables_column", 23, {"type": "Bool"}],
],
"undo": [
["BulkUpdateRecord", "Types", [15, 16, 17, 18, 19, 20],
{"int": [1509556595, 8.153, 0, 1, "", None]}],
["ModifyColumn", "Types", "int", {"type": "Int"}],
["UpdateRecord", "_grist_Tables_column", 23, {"type": "Int"}],
]
})
# Test Bool -> Bool conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "bool", { "type" : "Bool" }])
self.assertPartialOutActions(out_actions, {
"stored": [],
"undo": []
})
# Test Date -> Bool conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "date", { "type" : "Bool" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "date", {"type": "Bool"}],
["BulkUpdateRecord", "Types", [15, 16, 17, 18, 19, 20],
{"date": [True, True, False, True, False, False]}],
["UpdateRecord", "_grist_Tables_column", 25, {"type": "Bool"}]
],
"undo": [
["BulkUpdateRecord", "Types", [15, 16, 17, 18, 19, 20],
{"date": [1509556595, 8.153, 0, 1, "", None]}],
["ModifyColumn", "Types", "date", {"type": "Date"}],
["UpdateRecord", "_grist_Tables_column", 25, {"type": "Date"}]
]
})
# Assert that the final table is as expected
self.assertTableData("Types", data=[
["id", "text", "numeric", "int", "bool", "date"],
[11, "New York", "New York", "New York", "New York", "New York"],
[12, "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö"],
[13, False, False, False, False, False],
[14, True, True, True, True, True],
[15, True, True, True, 1509556595, True],
[16, True, True, True, 8.153, True],
[17, False, False, False, 0, False],
[18, True, True, True, 1, True],
[19, False, False, False, "", False],
[20, False, False, False, None, False]
])
def test_date_conversions(self):
"""
Tests that column type changes occur as expected in the sandbox:
- Resulting cell values should all be of type Date or AltText
- Only non-compatible values should appear in the resulting BulkUpdateRecord
"""
self.load_sample(self.sample)
# Test Text -> Date conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "text", { "type" : "Date" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "text", {"type": "Date"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"text": [0.0, 1.0, 1509556595.0, 0.0, 1.0, None]}],
["UpdateRecord", "_grist_Tables_column", 21, {"type": "Date"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"text": [False, True, 1509556595, 0, 1, ""]}],
["ModifyColumn", "Types", "text", {"type": "Text"}],
["UpdateRecord", "_grist_Tables_column", 21, {"type": "Text"}],
]
})
# Test Numeric -> Date conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "numeric", { "type" : "Date" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "numeric", {"type": "Date"}],
["BulkUpdateRecord", "Types", [13, 14, 19],
{"numeric": [0.0, 1.0, None]}],
["UpdateRecord", "_grist_Tables_column", 22, {"type": "Date"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 19],
{"numeric": [False, True, ""]}],
["ModifyColumn", "Types", "numeric", {"type": "Numeric"}],
["UpdateRecord", "_grist_Tables_column", 22, {"type": "Numeric"}],
]
})
# Test Int -> Date conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "int", { "type" : "Date" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "int", {"type": "Date"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"int": [0.0, 1.0, 1509556595.0, 0.0, 1.0, None]}],
["UpdateRecord", "_grist_Tables_column", 23, {"type": "Date"}],
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"int": [False, True, 1509556595, 0, 1, ""]}],
["ModifyColumn", "Types", "int", {"type": "Int"}],
["UpdateRecord", "_grist_Tables_column", 23, {"type": "Int"}],
]
})
# Test Bool -> Date conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "bool", { "type" : "Date" }])
self.assertPartialOutActions(out_actions, {
"stored": [
["ModifyColumn", "Types", "bool", {"type": "Date"}],
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"bool": [0.0, 1.0, 1509556595.0, 0.0, 1.0, None]}],
["UpdateRecord", "_grist_Tables_column", 24, {"type": "Date"}]
],
"undo": [
["BulkUpdateRecord", "Types", [13, 14, 15, 17, 18, 19],
{"bool": [False, True, 1509556595, False, True, ""]}],
["ModifyColumn", "Types", "bool", {"type": "Bool"}],
["UpdateRecord", "_grist_Tables_column", 24, {"type": "Bool"}]
]
})
# Test Date -> Date conversion
out_actions = self.apply_user_action(["ModifyColumn", "Types", "date", { "type" : "Date" }])
self.assertPartialOutActions(out_actions, {
"stored": [],
"undo": []
})
# Assert that the final table is as expected
self.assertTableData("Types", data=[
["id", "text", "numeric", "int", "bool", "date"],
[11, "New York", "New York", "New York", "New York", "New York"],
[12, "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö", "Chîcágö"],
[13, 0.0, 0.0, 0.0, 0.0, False],
[14, 1.0, 1.0, 1.0, 1.0, True],
[15, 1509556595, 1509556595, 1509556595, 1509556595, 1509556595],
[16, 8.153, 8.153, 8.153, 8.153, 8.153],
[17, 0.0, 0.0, 0.0, 0.0, 0],
[18, 1.0, 1.0, 1.0, 1.0, 1],
[19, None, None, None, None, ""],
[20, None, None, None, None, None]
])
def test_numerics_are_floats(self):
"""
Tests that in formulas, numeric values are floats, not integers.
Important to avoid truncation.
"""
self.load_sample(self.sample)
self.assertTableData('Formulas', data=[
['id', 'division'],
[ 1, 0.5],
])

@ -0,0 +1,914 @@
import types
import logger
import useractions
import testutil
import test_engine
from test_engine import Table, Column, View, Section, Field
log = logger.Logger(__name__, logger.INFO)
class TestUserActions(test_engine.EngineTestCase):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Address", [
[21, "city", "Text", False, "", "", ""],
]]
],
"DATA": {
"Address": [
["id", "city" ],
[11, "New York" ],
[12, "Colombia" ],
[13, "New Haven" ],
[14, "West Haven" ]],
}
})
starting_table = Table(1, "Address", primaryViewId=0, summarySourceTable=0, columns=[
Column(21, "city", "Text", isFormula=False, formula="", summarySourceCol=0)
])
#----------------------------------------------------------------------
def test_conversions(self):
# Test the sequence of user actions as used for transform-based conversions. This is actually
# not exactly what the client emits, but more like what the client should ideally emit.
# Our sample has a Schools.city text column; we'll convert it to Ref:Address.
self.load_sample(self.sample)
# Add a new table for Schools so that we get the associated views and fields.
self.apply_user_action(['AddTable', 'Schools', [{'id': 'city', 'type': 'Text'}]])
self.apply_user_action(['BulkAddRecord', 'Schools', [1,2,3,4], {
'city': ['New York', 'Colombia', 'New York', '']
}])
self.assertPartialData("_grist_Tables", ["id", "tableId"], [
[1, "Address"],
[2, "Schools"],
])
self.assertPartialData("_grist_Tables_column",
["id", "colId", "parentId", "parentPos", "widgetOptions"], [
[21, "city", 1, 1.0, ""],
[22, "manualSort", 2, 2.0, ""],
[23, "city", 2, 3.0, ""],
])
self.assertPartialData("_grist_Views_section_field", ["id", "colRef", "widgetOptions"], [
[1, 23, ""]
])
self.assertPartialData("Schools", ["id", "city"], [
[1, "New York" ],
[2, "Colombia" ],
[3, "New York" ],
[4, "" ],
])
# Our sample has a text column city.
out_actions = self.add_column('Schools', 'grist_Transform',
isFormula=True, formula='return $city', type='Text')
self.assertPartialOutActions(out_actions, { "stored": [
['AddColumn', 'Schools', 'grist_Transform', {
'type': 'Text', 'isFormula': True, 'formula': 'return $city',
}],
['AddRecord', '_grist_Tables_column', 24, {
'widgetOptions': '', 'parentPos': 4.0, 'isFormula': True, 'parentId': 2, 'colId':
'grist_Transform', 'formula': 'return $city', 'label': 'grist_Transform',
'type': 'Text'
}],
["AddRecord", "_grist_Views_section_field", 2, {
"colRef": 24, "parentId": 1, "parentPos": 2.0
}],
]})
out_actions = self.update_record('_grist_Tables_column', 24,
type='Ref:Address',
formula='return Address.lookupOne(city=$city).id')
self.assertPartialOutActions(out_actions, { "stored": [
['ModifyColumn', 'Schools', 'grist_Transform', {
'formula': 'return Address.lookupOne(city=$city).id', 'type': 'Ref:Address'}],
['UpdateRecord', '_grist_Tables_column', 24, {
'formula': 'return Address.lookupOne(city=$city).id', 'type': 'Ref:Address'}],
]})
# It seems best if TypeTransform sets widgetOptions on grist_Transform column, so that they
# can be copied in CopyFromColumn; rather than updating them after the copy is done.
self.update_record('_grist_Views_section_field', 1, widgetOptions="hello")
self.update_record('_grist_Tables_column', 24, widgetOptions="world")
out_actions = self.apply_user_action(
['CopyFromColumn', 'Schools', 'grist_Transform', 'city', None])
self.assertPartialOutActions(out_actions, { "stored": [
['ModifyColumn', 'Schools', 'city', {'type': 'Ref:Address'}],
['UpdateRecord', 'Schools', 4, {'city': 0}],
['UpdateRecord', '_grist_Views_section_field', 1, {'widgetOptions': ''}],
['UpdateRecord', '_grist_Tables_column', 23, {
'type': 'Ref:Address', 'widgetOptions': 'world'
}],
['BulkUpdateRecord', 'Schools', [1, 2, 3], {'city': [11, 12, 11]}],
]})
out_actions = self.update_record('_grist_Tables_column', 23,
widgetOptions='{"widget":"Reference","visibleCol":"city"}')
self.assertPartialOutActions(out_actions, { "stored": [
['UpdateRecord', '_grist_Tables_column', 23, {
'widgetOptions': '{"widget":"Reference","visibleCol":"city"}'}],
]})
out_actions = self.remove_column('Schools', 'grist_Transform')
self.assertPartialOutActions(out_actions, { "stored": [
['RemoveRecord', '_grist_Views_section_field', 2],
['RemoveRecord', '_grist_Tables_column', 24],
['RemoveColumn', 'Schools', 'grist_Transform'],
]})
#----------------------------------------------------------------------
def test_create_section_existing_view(self):
# Test that CreateViewSection works for an existing view.
self.load_sample(self.sample)
self.assertTables([self.starting_table])
# Create a view + section for the initial table.
self.apply_user_action(["CreateViewSection", 1, 0, "record", None])
# Verify that we got a new view, with one section, and three fields.
self.assertViews([View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(1, colRef=21),
])
]) ])
# Create a new section for the same view, check that only a section is added.
self.apply_user_action(["CreateViewSection", 1, 1, "record", None])
self.assertTables([self.starting_table])
self.assertViews([View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(1, colRef=21),
]),
Section(2, parentKey="record", tableRef=1, fields=[
Field(2, colRef=21),
])
]) ])
# Create another section for the same view, this time summarized.
self.apply_user_action(["CreateViewSection", 1, 1, "record", [21]])
summary_table = Table(2, "GristSummary_7_Address", 0, summarySourceTable=1, columns=[
Column(22, "city", "Text", isFormula=False, formula="", summarySourceCol=21),
Column(23, "group", "RefList:Address", isFormula=True,
formula="table.getSummarySourceGroup(rec)", summarySourceCol=0),
Column(24, "count", "Int", isFormula=True, formula="len($group)", summarySourceCol=0),
])
self.assertTables([self.starting_table, summary_table])
# Check that we still have one view, with sections for different tables.
view = View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(1, colRef=21),
]),
Section(2, parentKey="record", tableRef=1, fields=[
Field(2, colRef=21),
]),
Section(3, parentKey="record", tableRef=2, fields=[
Field(3, colRef=22),
Field(4, colRef=24),
]),
])
self.assertTables([self.starting_table, summary_table])
self.assertViews([view])
# Try to create a summary table for an invalid column, and check that it fails.
with self.assertRaises(ValueError):
self.apply_user_action(["CreateViewSection", 1, 1, "record", [23]])
self.assertTables([self.starting_table, summary_table])
self.assertViews([view])
#----------------------------------------------------------------------
def test_creates_section_new_table(self):
# Test that CreateViewSection works for adding a new table.
self.load_sample(self.sample)
self.assertTables([self.starting_table])
self.assertViews([])
# When we create a section/view for new table, we get both a primary view, and the new view we
# are creating.
self.apply_user_action(["CreateViewSection", 0, 0, "record", None])
new_table = Table(2, "Table1", primaryViewId=1, summarySourceTable=0, columns=[
Column(22, "manualSort", "ManualSortPos", isFormula=False, formula="", summarySourceCol=0),
Column(23, "A", "Any", isFormula=True, formula="", summarySourceCol=0),
Column(24, "B", "Any", isFormula=True, formula="", summarySourceCol=0),
Column(25, "C", "Any", isFormula=True, formula="", summarySourceCol=0),
])
primary_view = View(1, sections=[
Section(1, parentKey="record", tableRef=2, fields=[
Field(1, colRef=23),
Field(2, colRef=24),
Field(3, colRef=25),
])
])
new_view = View(2, sections=[
Section(2, parentKey="record", tableRef=2, fields=[
Field(4, colRef=23),
Field(5, colRef=24),
Field(6, colRef=25),
])
])
self.assertTables([self.starting_table, new_table])
self.assertViews([primary_view, new_view])
# Create another section in an existing view for a new table.
self.apply_user_action(["CreateViewSection", 0, 2, "record", None])
new_table2 = Table(3, "Table2", primaryViewId=3, summarySourceTable=0, columns=[
Column(26, "manualSort", "ManualSortPos", isFormula=False, formula="", summarySourceCol=0),
Column(27, "A", "Any", isFormula=True, formula="", summarySourceCol=0),
Column(28, "B", "Any", isFormula=True, formula="", summarySourceCol=0),
Column(29, "C", "Any", isFormula=True, formula="", summarySourceCol=0),
])
primary_view2 = View(3, sections=[
Section(3, parentKey="record", tableRef=3, fields=[
Field(7, colRef=27),
Field(8, colRef=28),
Field(9, colRef=29),
])
])
new_view.sections.append(
Section(4, parentKey="record", tableRef=3, fields=[
Field(10, colRef=27),
Field(11, colRef=28),
Field(12, colRef=29),
])
)
# Check that we have a new table, only the primary view as new view; and a new section.
self.assertTables([self.starting_table, new_table, new_table2])
self.assertViews([primary_view, new_view, primary_view2])
# Check that we can't create a summary of a table grouped by a column that doesn't exist yet.
with self.assertRaises(ValueError):
self.apply_user_action(["CreateViewSection", 0, 2, "record", [31]])
self.assertTables([self.starting_table, new_table, new_table2])
self.assertViews([primary_view, new_view, primary_view2])
# But creating a new table and showing totals for it is possible though dumb.
self.apply_user_action(["CreateViewSection", 0, 2, "record", []])
# We expect a new table.
new_table3 = Table(4, "Table3", primaryViewId=4, summarySourceTable=0, columns=[
Column(30, "manualSort", "ManualSortPos", isFormula=False, formula="", summarySourceCol=0),
Column(31, "A", "Any", isFormula=True, formula="", summarySourceCol=0),
Column(32, "B", "Any", isFormula=True, formula="", summarySourceCol=0),
Column(33, "C", "Any", isFormula=True, formula="", summarySourceCol=0),
])
# A summary of it.
summary_table = Table(5, "GristSummary_6_Table3", 0, summarySourceTable=4, columns=[
Column(34, "group", "RefList:Table3", isFormula=True,
formula="table.getSummarySourceGroup(rec)", summarySourceCol=0),
Column(35, "count", "Int", isFormula=True, formula="len($group)", summarySourceCol=0),
])
# The primary view of the new table.
primary_view3 = View(4, sections=[
Section(5, parentKey="record", tableRef=4, fields=[
Field(13, colRef=31),
Field(14, colRef=32),
Field(15, colRef=33),
])
])
# And a new view section for the summary.
new_view.sections.append(Section(6, parentKey="record", tableRef=5, fields=[
Field(16, colRef=35)
]))
self.assertTables([self.starting_table, new_table, new_table2, new_table3, summary_table])
self.assertViews([primary_view, new_view, primary_view2, primary_view3])
#----------------------------------------------------------------------
def init_views_sample(self):
# Add a new table and a view, to get some Views/Sections/Fields, and TableView/TabBar items.
self.apply_user_action(['AddTable', 'Schools', [
{'id': 'city', 'type': 'Text'},
{'id': 'state', 'type': 'Text'},
{'id': 'size', 'type': 'Numeric'},
]])
self.apply_user_action(['BulkAddRecord', 'Schools', [1,2,3,4], {
'city': ['New York', 'Colombia', 'New York', ''],
'state': ['NY', 'NY', 'NY', ''],
'size': [1000, 2000, 3000, 4000],
}])
# Add a new view; a second section (summary) to it; and a third view.
self.apply_user_action(['CreateViewSection', 1, 0, 'detail', None])
self.apply_user_action(['CreateViewSection', 1, 2, 'record', [3]])
self.apply_user_action(['CreateViewSection', 1, 0, 'chart', None])
self.apply_user_action(['CreateViewSection', 0, 2, 'record', None])
# Verify the new structure of tables and views.
self.assertTables([
Table(1, "Schools", 1, 0, columns=[
Column(1, "manualSort", "ManualSortPos", False, "", 0),
Column(2, "city", "Text", False, "", 0),
Column(3, "state", "Text", False, "", 0),
Column(4, "size", "Numeric", False, "", 0),
]),
Table(2, "GristSummary_7_Schools", 0, 1, columns=[
Column(5, "state", "Text", False, "", 3),
Column(6, "group", "RefList:Schools", True, "table.getSummarySourceGroup(rec)", 0),
Column(7, "count", "Int", True, "len($group)", 0),
Column(8, "size", "Numeric", True, "SUM($group.size)", 0),
]),
Table(3, 'Table1', 4, 0, columns=[
Column(9, "manualSort", "ManualSortPos", False, "", 0),
Column(10, "A", "Any", True, "", 0),
Column(11, "B", "Any", True, "", 0),
Column(12, "C", "Any", True, "", 0),
]),
])
self.assertViews([
View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(1, colRef=2),
Field(2, colRef=3),
Field(3, colRef=4),
]),
]),
View(2, sections=[
Section(2, parentKey="detail", tableRef=1, fields=[
Field(4, colRef=2),
Field(5, colRef=3),
Field(6, colRef=4),
]),
Section(3, parentKey="record", tableRef=2, fields=[
Field(7, colRef=5),
Field(8, colRef=7),
Field(9, colRef=8),
]),
Section(6, parentKey='record', tableRef=3, fields=[
Field(15, colRef=10),
Field(16, colRef=11),
Field(17, colRef=12),
]),
]),
View(3, sections=[
Section(4, parentKey="chart", tableRef=1, fields=[
Field(10, colRef=2),
Field(11, colRef=3),
]),
]),
View(4, sections=[
Section(5, parentKey='record', tableRef=3, fields=[
Field(12, colRef=10),
Field(13, colRef=11),
Field(14, colRef=12),
]),
]),
])
self.assertTableData('_grist_TableViews', data=[
["id", "tableRef", "viewRef"],
[1, 1, 2],
[2, 1, 3],
])
self.assertTableData('_grist_TabBar', cols="subset", data=[
["id", "viewRef"],
[1, 1],
[2, 2],
[3, 3],
[4, 4],
])
self.assertTableData('_grist_Pages', cols="subset", data=[
["id", "viewRef"],
[1, 1],
[2, 2],
[3, 3],
[4, 4]
])
#----------------------------------------------------------------------
def test_view_remove(self):
# Add a couple of tables and views, to trigger creation of some related items.
self.init_views_sample()
# Remove a view. Ensure related items, sections, fields get removed.
self.apply_user_action(["BulkRemoveRecord", "_grist_Views", [2,3]])
# Verify the new structure of tables and views.
self.assertTables([
Table(1, "Schools", 1, 0, columns=[
Column(1, "manualSort", "ManualSortPos", False, "", 0),
Column(2, "city", "Text", False, "", 0),
Column(3, "state", "Text", False, "", 0),
Column(4, "size", "Numeric", False, "", 0),
]),
# Note that the summary table is gone.
Table(3, 'Table1', 4, 0, columns=[
Column(9, "manualSort", "ManualSortPos", False, "", 0),
Column(10, "A", "Any", True, "", 0),
Column(11, "B", "Any", True, "", 0),
Column(12, "C", "Any", True, "", 0),
]),
])
self.assertViews([
View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(1, colRef=2),
Field(2, colRef=3),
Field(3, colRef=4),
]),
]),
View(4, sections=[
Section(5, parentKey='record', tableRef=3, fields=[
Field(12, colRef=10),
Field(13, colRef=11),
Field(14, colRef=12),
]),
]),
])
self.assertTableData('_grist_TableViews', data=[
["id", "tableRef", "viewRef"],
])
self.assertTableData('_grist_TabBar', cols="subset", data=[
["id", "viewRef"],
[1, 1],
[4, 4],
])
self.assertTableData('_grist_Pages', cols="subset", data=[
["id", "viewRef"],
[1, 1],
[4, 4],
])
#----------------------------------------------------------------------
def test_view_rename(self):
# Add a couple of tables and views, to trigger creation of some related items.
self.init_views_sample()
# Verify the new structure of tables and views.
self.assertTableData('_grist_Tables', cols="subset", data=[
[ 'id', 'tableId', 'primaryViewId' ],
[ 1, 'Schools', 1],
[ 2, 'GristSummary_7_Schools', 0],
[ 3, 'Table1', 4],
])
self.assertTableData('_grist_Views', cols="subset", data=[
[ 'id', 'name', 'primaryViewTable' ],
[ 1, 'Schools', 1],
[ 2, 'New page', 0],
[ 3, 'New page', 0],
[ 4, 'Table1', 3],
])
# Update the names in a few views, and ensure that primary ones cause tables to get renamed.
self.apply_user_action(['BulkUpdateRecord', '_grist_Views', [2,3,4],
{'name': ['A', 'B', 'C']}])
self.assertTableData('_grist_Tables', cols="subset", data=[
[ 'id', 'tableId', 'primaryViewId' ],
[ 1, 'Schools', 1],
[ 2, 'GristSummary_7_Schools', 0],
[ 3, 'C', 4],
])
self.assertTableData('_grist_Views', cols="subset", data=[
[ 'id', 'name', 'primaryViewTable' ],
[ 1, 'Schools', 1],
[ 2, 'A', 0],
[ 3, 'B', 0],
[ 4, 'C', 3]
])
#----------------------------------------------------------------------
def test_section_removes(self):
# Add a couple of tables and views, to trigger creation of some related items.
self.init_views_sample()
# Remove a couple of sections. Ensure their fields get removed.
self.apply_user_action(['BulkRemoveRecord', '_grist_Views_section', [3,6]])
self.assertViews([
View(1, sections=[
Section(1, parentKey="record", tableRef=1, fields=[
Field(1, colRef=2),
Field(2, colRef=3),
Field(3, colRef=4),
]),
]),
View(2, sections=[
Section(2, parentKey="detail", tableRef=1, fields=[
Field(4, colRef=2),
Field(5, colRef=3),
Field(6, colRef=4),
]),
]),
View(3, sections=[
Section(4, parentKey="chart", tableRef=1, fields=[
Field(10, colRef=2),
Field(11, colRef=3),
]),
]),
View(4, sections=[
Section(5, parentKey='record', tableRef=3, fields=[
Field(12, colRef=10),
Field(13, colRef=11),
Field(14, colRef=12),
]),
]),
])
#----------------------------------------------------------------------
def test_schema_consistency_check(self):
# Verify that schema consistency check actually runs, but only when schema is affected.
self.init_views_sample()
# Replace the engine's assert_schema_consistent() method with a mocked version.
orig_method = self.engine.assert_schema_consistent
count_calls = [0]
def override(self): # pylint: disable=unused-argument
count_calls[0] += 1
# pylint: disable=not-callable
orig_method()
self.engine.assert_schema_consistent = types.MethodType(override, self.engine)
# Do a non-sschema action to ensure it doesn't get called.
self.apply_user_action(['UpdateRecord', '_grist_Views', 2, {'name': 'A'}])
self.assertEqual(count_calls[0], 0)
# Do a schema action to ensure it gets called: this causes a table rename.
self.apply_user_action(['UpdateRecord', '_grist_Views', 4, {'name': 'C'}])
self.assertEqual(count_calls[0], 1)
self.assertTableData('_grist_Tables', cols="subset", data=[
[ 'id', 'tableId', 'primaryViewId' ],
[ 1, 'Schools', 1],
[ 2, 'GristSummary_7_Schools', 0],
[ 3, 'C', 4],
])
# Do another schema and non-schema action.
self.apply_user_action(['UpdateRecord', 'Schools', 1, {'city': 'Seattle'}])
self.assertEqual(count_calls[0], 1)
self.apply_user_action(['UpdateRecord', '_grist_Tables_column', 2, {'colId': 'city2'}])
self.assertEqual(count_calls[0], 2)
#----------------------------------------------------------------------
def test_new_column_conversions(self):
self.init_views_sample()
self.apply_user_action(['AddColumn', 'Schools', None, {}])
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[1, "manualSort", "ManualSortPos",False, ""],
[2, "city", "Text", False, ""],
[3, "state", "Text", False, ""],
[4, "size", "Numeric", False, ""],
[13, "A", "Any", True, ""],
], rows=lambda r: r.parentId.id == 1)
self.assertTableData('Schools', cols="subset", data=[
["id", "city", "A"],
[1, "New York", None],
[2, "Colombia", None],
[3, "New York", None],
[4, "", None],
])
# Check that typing in text into the column produces a text column.
out_actions = self.apply_user_action(['UpdateRecord', 'Schools', 3, {"A": "foo"}])
self.assertTableData('_grist_Tables_column', cols="subset", rows="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[13, "A", "Text", False, ""],
])
self.assertTableData('Schools', cols="subset", data=[
["id", "city", "A" ],
[1, "New York", "" ],
[2, "Colombia", "" ],
[3, "New York", "foo" ],
[4, "", "" ],
])
# Undo, and check that typing in a number produces a numeric column.
self.apply_undo_actions(out_actions.undo)
out_actions = self.apply_user_action(['UpdateRecord', 'Schools', 3, {"A": " -17.6"}])
self.assertTableData('_grist_Tables_column', cols="subset", rows="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[13, "A", "Numeric", False, ""],
])
self.assertTableData('Schools', cols="subset", data=[
["id", "city", "A" ],
[1, "New York", 0.0 ],
[2, "Colombia", 0.0 ],
[3, "New York", -17.6 ],
[4, "", 0.0 ],
])
# Undo, and set a formula for the new column instead.
self.apply_undo_actions(out_actions.undo)
self.apply_user_action(['UpdateRecord', '_grist_Tables_column', 13, {'formula': 'len($city)'}])
self.assertTableData('_grist_Tables_column', cols="subset", rows="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[13, "A", "Any", True, "len($city)"],
])
self.assertTableData('Schools', cols="subset", data=[
["id", "city", "A" ],
[1, "New York", 8 ],
[2, "Colombia", 8 ],
[3, "New York", 8 ],
[4, "", 0 ],
])
# Convert the formula column to non-formula.
self.apply_user_action(['UpdateRecord', '_grist_Tables_column', 13, {'isFormula': False}])
self.assertTableData('_grist_Tables_column', cols="subset", rows="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[13, "A", "Numeric", False, "len($city)"],
])
self.assertTableData('Schools', cols="subset", data=[
["id", "city", "A" ],
[1, "New York", 8 ],
[2, "Colombia", 8 ],
[3, "New York", 8 ],
[4, "", 0 ],
])
# Add some more formula columns of type 'Any'.
self.apply_user_action(['AddColumn', 'Schools', None, {"formula": "1"}])
self.apply_user_action(['AddColumn', 'Schools', None, {"formula": "'x'"}])
self.apply_user_action(['AddColumn', 'Schools', None, {"formula": "$city == 'New York'"}])
self.apply_user_action(['AddColumn', 'Schools', None, {"formula": "$city=='New York' or '-'"}])
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[1, "manualSort", "ManualSortPos",False, ""],
[2, "city", "Text", False, ""],
[3, "state", "Text", False, ""],
[4, "size", "Numeric", False, ""],
[13, "A", "Numeric", False, "len($city)"],
[14, "B", "Any", True, "1"],
[15, "C", "Any", True, "'x'"],
[16, "D", "Any", True, "$city == 'New York'"],
[17, "E", "Any", True, "$city=='New York' or '-'"],
], rows=lambda r: r.parentId.id == 1)
self.assertTableData('Schools', cols="subset", data=[
["id", "city", "A", "B", "C", "D", "E"],
[1, "New York", 8, 1, "x", True, True],
[2, "Colombia", 8, 1, "x", False, '-' ],
[3, "New York", 8, 1, "x", True, True],
[4, "", 0, 1, "x", False, '-' ],
])
# Convert all these formulas to non-formulas, and see that their types get guessed OK.
# TODO: We should also guess Int, Bool, Reference, ReferenceList, Date, and DateTime.
# TODO: It is possibly better if B became Int, and D became Bool.
self.apply_user_action(['BulkUpdateRecord', '_grist_Tables_column', [14,15,16,17],
{'isFormula': [False, False, False, False]}])
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[1, "manualSort", "ManualSortPos",False, ""],
[2, "city", "Text", False, ""],
[3, "state", "Text", False, ""],
[4, "size", "Numeric", False, ""],
[13, "A", "Numeric", False, "len($city)"],
[14, "B", "Numeric", False, "1"],
[15, "C", "Text", False, "'x'"],
[16, "D", "Text", False, "$city == 'New York'"],
[17, "E", "Text", False, "$city=='New York' or '-'"],
], rows=lambda r: r.parentId.id == 1)
self.assertTableData('Schools', cols="subset", data=[
["id", "city", "A", "B", "C", "D", "E"],
[1, "New York", 8, 1.0, "x", "True", 'True'],
[2, "Colombia", 8, 1.0, "x", "False", '-' ],
[3, "New York", 8, 1.0, "x", "True", 'True'],
[4, "", 0, 1.0, "x", "False", '-' ],
])
#----------------------------------------------------------------------
def test_useraction_failures(self):
# Verify that when a useraction fails, we revert any changes already applied.
self.load_sample(self.sample)
# Simple failure: bad action (last argument should be a dict). It shouldn't cause any actions
# in the first place, just raise an exception about the argument being an int.
with self.assertRaisesRegexp(AttributeError, r"'int'"):
self.apply_user_action(['AddColumn', 'Address', "A", 17])
# Do some successful actions, just to make sure we know what they look like.
self.engine.apply_user_actions([useractions.from_repr(ua) for ua in (
['AddColumn', 'Address', "B", {"isFormula": True}],
['UpdateRecord', 'Address', 11, {"city": "New York2"}],
)])
# More complicated: here some actions should succeed, but get reverted when a later one fails.
with self.assertRaisesRegexp(AttributeError, r"'int'"):
self.engine.apply_user_actions([useractions.from_repr(ua) for ua in (
['UpdateRecord', 'Address', 11, {"city": "New York3"}],
['AddColumn', 'Address', "C", {"isFormula": True}],
['AddColumn', 'Address', "D", 17]
)])
with self.assertRaisesRegexp(Exception, r"non-existent record #77"):
self.engine.apply_user_actions([useractions.from_repr(ua) for ua in (
['UpdateRecord', 'Address', 11, {"city": "New York4"}],
['UpdateRecord', 'Address', 77, {"city": "Chicago"}],
)])
# Make sure that no columns got added except the intentionally successful one.
self.assertTableData('_grist_Tables_column', cols="subset", data=[
["id", "colId", "type", "isFormula", "formula"],
[21, "city", "Text", False, ""],
[22, "B", "Any", True, ""],
], rows=lambda r: r.parentId.id == 1)
# Make sure that no columns got added here either, and the only change to "New York" is the
# one in the successful user-action.
self.assertTableData('Address', cols="all", data=[
["id", "city" , "B" ],
[11, "New York2" , None ],
[12, "Colombia" , None ],
[13, "New Haven" , None ],
[14, "West Haven", None ],
])
#----------------------------------------------------------------------
def test_acl_principal_actions(self):
# Test the AddUser, RemoveUser, AddInstance and RemoveInstance actions.
self.load_sample(self.sample)
# Add two users
out_actions = self.apply_user_action(['AddUser', 'jake@grist.com', 'Jake', ['i001', 'i002']])
self.assertPartialOutActions(out_actions, { "stored": [
["AddRecord", "_grist_ACLPrincipals", 1, {
"type": "user",
"userEmail": "jake@grist.com",
"userName": "Jake"
}],
["BulkAddRecord", "_grist_ACLPrincipals", [2, 3], {
"instanceId": ["i001", "i002"],
"type": ["instance", "instance"]
}],
["BulkAddRecord", "_grist_ACLMemberships", [1, 2], {
"child": [2, 3],
"parent": [1, 1]
}]
]})
out_actions = self.apply_user_action(['AddUser', 'steve@grist.com', 'Steve', ['i003']])
self.assertPartialOutActions(out_actions, { "stored": [
["AddRecord", "_grist_ACLPrincipals", 4, {
"type": "user",
"userEmail": "steve@grist.com",
"userName": "Steve"
}],
["AddRecord", "_grist_ACLPrincipals", 5, {
"instanceId": "i003",
"type": "instance"
}],
["AddRecord", "_grist_ACLMemberships", 3, {
"child": 5,
"parent": 4
}]
]})
self.assertTableData('_grist_ACLPrincipals', cols="subset", data=[
["id", "type", "userEmail", "userName", "groupName", "instanceId"],
[1, "user", "jake@grist.com", "Jake", "", ""],
[2, "instance", "", "", "", "i001"],
[3, "instance", "", "", "", "i002"],
[4, "user", "steve@grist.com", "Steve", "", ""],
[5, "instance", "", "", "", "i003"],
])
self.assertTableData('_grist_ACLMemberships', cols="subset", data=[
["id", "parent", "child"],
[1, 1, 2],
[2, 1, 3],
[3, 4, 5]
])
# Add an instance to a non-existent user
with self.assertRaisesRegexp(ValueError, "Cannot find existing user with email null@grist.com"):
self.apply_user_action(['AddInstance', 'null@grist.com', 'i003'])
# Add an instance to an existing user
out_actions = self.apply_user_action(['AddInstance', 'jake@grist.com', 'i004'])
self.assertPartialOutActions(out_actions, { "stored": [
["AddRecord", "_grist_ACLPrincipals", 6, {
"instanceId": "i004",
"type": "instance"
}],
["AddRecord", "_grist_ACLMemberships", 4, {
"child": 6,
"parent": 1
}]
]})
self.assertTableData('_grist_ACLPrincipals', cols="subset", data=[
["id", "type", "userEmail", "userName", "groupName", "instanceId"],
[1, "user", "jake@grist.com", "Jake", "", ""],
[2, "instance", "", "", "", "i001"],
[3, "instance", "", "", "", "i002"],
[4, "user", "steve@grist.com", "Steve", "", ""],
[5, "instance", "", "", "", "i003"],
[6, "instance", "", "", "", "i004"],
])
self.assertTableData('_grist_ACLMemberships', cols="subset", data=[
["id", "parent", "child"],
[1, 1, 2],
[2, 1, 3],
[3, 4, 5],
[4, 1, 6]
])
# Remove a non-existent instance from a user
with self.assertRaisesRegexp(ValueError, "Cannot find existing instance id i000"):
self.apply_user_action(['RemoveInstance', 'i000'])
# Remove an instance from a user
out_actions = self.apply_user_action(['RemoveInstance', 'i002'])
self.assertPartialOutActions(out_actions, { "stored": [
["RemoveRecord", "_grist_ACLMemberships", 2],
["RemoveRecord", "_grist_ACLPrincipals", 3]
]})
self.assertTableData('_grist_ACLPrincipals', cols="subset", data=[
["id", "type", "userEmail", "userName", "groupName", "instanceId"],
[1, "user", "jake@grist.com", "Jake", "", ""],
[2, "instance", "", "", "", "i001"],
[4, "user", "steve@grist.com", "Steve", "", ""],
[5, "instance", "", "", "", "i003"],
[6, "instance", "", "", "", "i004"],
])
self.assertTableData('_grist_ACLMemberships', cols="subset", data=[
["id", "parent", "child"],
[1, 1, 2],
[3, 4, 5],
[4, 1, 6]
])
# Remove a non-existent user
with self.assertRaisesRegexp(ValueError, "Cannot find existing user with email null@grist.com"):
self.apply_user_action(['RemoveUser', 'null@grist.com'])
# Remove an existing user
out_actions = self.apply_user_action(['RemoveUser', 'jake@grist.com'])
self.assertPartialOutActions(out_actions, { "stored": [
["BulkRemoveRecord", "_grist_ACLMemberships", [1, 4]],
["BulkRemoveRecord", "_grist_ACLPrincipals", [2, 6, 1]]
]})
self.assertTableData('_grist_ACLPrincipals', cols="subset", data=[
["id", "type", "userEmail", "userName", "groupName", "instanceId"],
[4, "user", "steve@grist.com", "Steve", "", ""],
[5, "instance", "", "", "", "i003"],
])
self.assertTableData('_grist_ACLMemberships', cols="subset", data=[
["id", "parent", "child"],
[3, 4, 5]
])
# Remove the only instance of an existing user, removing that user
out_actions = self.apply_user_action(['RemoveInstance', 'i003'])
self.assertPartialOutActions(out_actions, { "stored": [
["RemoveRecord", "_grist_ACLMemberships", 3],
["BulkRemoveRecord", "_grist_ACLPrincipals", [4, 5]]
]})
self.assertTableData('_grist_ACLPrincipals', cols="subset", data=[
["id", "type", "userEmail", "userName", "groupName", "instanceId"]
])
self.assertTableData('_grist_ACLMemberships', cols="subset", data=[
["id", "parent", "child"]
])
#----------------------------------------------------------------------
def test_pages_remove(self):
# Test that orphan pages get fixed after removing a page
self.init_views_sample()
# Moves page 2 to children of page 1.
self.apply_user_action(['BulkUpdateRecord', '_grist_Pages', [2], {'indentation': [1]}])
self.assertTableData('_grist_Pages', cols='subset', data=[
['id', 'indentation'],
[ 1, 0],
[ 2, 1],
[ 3, 0],
[ 4, 0],
])
# Verify that removing page 1 fixes page 2 indentation.
self.apply_user_action(['RemoveRecord', '_grist_Pages', 1])
self.assertTableData('_grist_Pages', cols='subset', data=[
['id', 'indentation'],
[ 2, 0],
[ 3, 0],
[ 4, 0],
])
# Removing last page should not fail
# Verify that removing page 1 fixes page 2 indentation.
self.apply_user_action(['RemoveRecord', '_grist_Pages', 4])
self.assertTableData('_grist_Pages', cols='subset', data=[
['id', 'indentation'],
[ 2, 0],
[ 3, 0],
])
# Removing a page that has no children should do nothing
self.apply_user_action(['RemoveRecord', '_grist_Pages', 2])
self.assertTableData('_grist_Pages', cols='subset', data=[
['id', 'indentation'],
[ 3, 0],
])

@ -0,0 +1,43 @@
import testutil
# pylint: disable=line-too-long
sample_students = testutil.parse_test_sample({
"SCHEMA": [
[1, "Students", [
[1, "firstName", "Text", False, "", "", ""],
[2, "lastName", "Text", False, "", "", ""],
[4, "schoolName", "Text", False, "", "", ""],
[5, "schoolIds", "Text", True, "':'.join(str(id) for id in Schools.lookupRecords(name=$schoolName).id)", "", ""],
[6, "schoolCities","Text", True, "':'.join(r.address.city for r in Schools.lookupRecords(name=$schoolName))", "", ""],
]],
[2, "Schools", [
[10, "name", "Text", False, "", "", ""],
[12, "address", "Ref:Address",False, "", "", ""]
]],
[3, "Address", [
[21, "city", "Text", False, "", "", ""],
]]
],
"DATA": {
"Students": [
["id","firstName","lastName", "schoolName" ],
[1, "Barack", "Obama", "Columbia" ],
[2, "George W", "Bush", "Yale" ],
[3, "Bill", "Clinton", "Columbia" ],
[4, "George H", "Bush", "Yale" ],
[5, "Ronald", "Reagan", "Eureka" ],
[6, "Gerald", "Ford", "Yale" ]],
"Schools": [
["id", "name", "address"],
[1, "Columbia", 11],
[2, "Columbia", 12],
[3, "Yale", 13],
[4, "Yale", 14]],
"Address": [
["id", "city" ],
[11, "New York" ],
[12, "Colombia" ],
[13, "New Haven" ],
[14, "West Haven" ]],
}
})

File diff suppressed because it is too large Load Diff

@ -0,0 +1,149 @@
import json
import math
import os
import re
import actions
import logger
def limit_log_stderr(min_level):
"""
Returns a log handler suitable for logger.set_handler(), which logs using log_stderr but only
messages at the given level or higher.
"""
def handler(level, name, msg):
if level >= min_level:
logger.log_stderr(level, name, msg)
return handler
def table_data_from_rows(table_id, col_names, rows):
"""
Returns a TableData object built from a table_id, a list of column names, and corresponding
row-oriented data.
"""
column_values = {}
for i, col in enumerate(col_names):
# Strip leading @ from column headers
column_values[col.lstrip('@')] = [row[i] for row in rows]
return actions.TableData(table_id, column_values.pop('id'), column_values)
def parse_testscript(script_path=None):
"""
Parses JSON spec for test cases, and returns a tuple of (samples, test_cases). Lines starting
with '//' are comments and are skipped.
Samples are objects with keys "SCHEMA" and "DATA", each a dictionary mapping table name to
actions.TableData object. "SCHEMA" contains "_grist_Tables" and "_grist_Tables_column" tables.
Test cases are a list of objects with "TEST_CASE" and "BODY", and the body is a list of steps of
the form [line_number, step_name, data], with line_number being an addition by this parser (or
None if not available).
"""
if not script_path:
script_path = os.path.join(os.path.dirname(__file__), "testscript.json")
comment_re = re.compile(r'^\s*//')
add_line_no_re = re.compile(r'"(APPLY|CHECK_OUTPUT|LOAD_SAMPLE)"\s*,')
all_lines = []
with open(script_path, "r") as testfile:
for i, line in enumerate(testfile):
if comment_re.match(line):
all_lines.append("\n")
else:
line = add_line_no_re.sub(r'"\1@%s",' % (i + 1), line)
all_lines.append(line)
full_text = "".join(all_lines)
script = byteify(json.loads(full_text))
samples = {}
test_cases = []
for obj in script:
if "TEST_CASE" in obj:
body = []
for step, data in obj["BODY"]:
step_line = step.split('@', 1)
step = step_line[0]
line = step_line[1] if len(step_line) > 1 else None
body.append([line, step, data])
obj["BODY"] = body
test_cases.append(obj)
elif "SAMPLE_NAME" in obj:
samples[obj["SAMPLE_NAME"]] = parse_test_sample(obj, samples=samples)
else:
raise ValueError("Unrecognized object in test script: %s" % obj)
return (samples, test_cases)
def parse_test_sample(obj, samples={}):
"""
Parses human-readable sample data (with "SCHEMA" or "SCHEMA_FROM", and "DATA" dictionaries; see
testscript.json for an example) into a sample containing "SCHEMA" and "DATA" keys, each a
dictionary mapping table name to TableData object.
"""
if "SCHEMA_FROM" in obj:
schema = samples[obj["SCHEMA_FROM"]]["SCHEMA"].copy()
else:
raw_schema = obj["SCHEMA"]
# Convert the meta tables to appropriate table representations for loading.
schema = {
'_grist_Tables': table_data_from_rows(
'_grist_Tables',
("id", "tableId"),
[(table_row_id, table_id) for (table_row_id, table_id, _) in raw_schema]),
'_grist_Tables_column': table_data_from_rows(
'_grist_Tables_column',
("parentId", "parentPos", "id", "colId", "type", "isFormula",
"formula", "label", "widgetOptions"),
[[table_row_id, i+1] + e for (table_row_id, _, entries) in raw_schema
for (i, e) in enumerate(entries)])
}
data = {t: table_data_from_rows(t, data[0], data[1:])
for t, data in obj["DATA"].iteritems()}
return {"SCHEMA": schema, "DATA": data}
def byteify(data):
"""
Convert all unicode strings in a parsed JSON object into utf8-encoded strings. We deal with
utf8-encoded strings throughout the test.
"""
if isinstance(data, unicode):
return data.encode('utf-8')
return actions.convert_recursive_helper(byteify, data)
def replace_nans(data):
"""
Convert all NaNs and Infinities in the data to descriptive strings, since they cannot be
serialized to JS-compliant JSON. (But we can serialize them using marshalling, so this
workaround is just for the testscript-based tests.)
"""
if isinstance(data, float) and (math.isnan(data) or math.isinf(data)):
return "@+Infinity" if data > 0 else "@-Infinity" if data < 0 else "@NaN"
return actions.convert_recursive_in_action(replace_nans, data)
def repeat_until_passes(count):
"""
Use as a decorator on test cases to repeat a failing test case up to count times, until it
passes. The resulting test cases will fail only if every repetition failed. This is suitable for
flaky timing test when unexpected load spikes could cause spurious failures.
"""
def decorator(f):
def wrapped(*args):
for i in range(0, count):
try:
f(*args)
return
except AssertionError as e:
pass
# Raises the last caught exception, even outside try/except (see
# https://stackoverflow.com/questions/25632147/raise-at-the-end-of-a-python-function-outside-try-or-except-block)
raise # pylint: disable=misplaced-bare-raise
return wrapped
return decorator

@ -0,0 +1,179 @@
"""
This module allows building text with transformations. It is used specifically for transforming
code, such as replacing "$foo" with "rec.foo" in formulas, and composing formulas into a full
usercode module.
The importance of this module is in allowing to map back replacements (or patches) to output code,
such as those generated to rename column references, into patches to the original inputs. It
allows us to deal with the complete valid usercode module text when searching for renames.
"""
import bisect
import re
from collections import namedtuple
Patch = namedtuple('Patch', ('start', 'end', 'old_text', 'new_text'))
line_start_re = re.compile(r'^', re.M)
def make_patch(full_text, start, end, new_text):
"""
Returns a patch to `full_text` to replace `full_text[start:end]` with `new_text`.
"""
return Patch(start, end, full_text[start:end], new_text)
def make_regexp_patches(full_text, regexp, repl):
"""
Returns a list of patches to `full_text` to replace each occurrence of `regexp` with `repl`. If
repl is a function, will replace with `repl(match_object)`. If repl is a string, it is used
verbatim, without interpreting any special characters.
"""
repl_func = repl if callable(repl) else (lambda m: repl)
return [make_patch(full_text, m.start(0), m.end(0), repl_func(m))
for m in regexp.finditer(full_text)]
def validate_patch(text, patch):
"""
Ensures that the given patch fits the given text, raising ValueError if not.
"""
found = text[patch.start : patch.end]
if found != patch.old_text:
before = text[patch.start - 10 : patch.start]
after = text[patch.end : patch.end + 10]
raise ValueError("Invalid patch to '%s[%s]%s' at %s; expected '%s'" % (
before, found, after, patch.start, patch.old_text))
class Builder(object):
"""
The base for classes that produce text and can map back a text patch to some useful value. A
series of Builders transforms text, and when we know what to change in the result, we use
map_back_patch() to get the source of the original `Text` object.
"""
def map_back_patch(self, patch):
"""
See Text.map_back_patch.
"""
raise NotImplementedError()
def get_text(self):
"""
Returns the output text of this Builder.
"""
raise NotImplementedError()
class Text(Builder):
"""
The lowest Builder that holds a simple string with an optional associated arbitrary value (e.g.
which column a formula came from). When we map back a patch of transformed text, we get a tuple
(text, value, patch) with text and value from the constructor, and patch that applies to text.
"""
def __init__(self, text, value=None):
self._text = text
self._value = value
def map_back_patch(self, patch):
"""
Returns the tuple (text, value, patch) with text and value from the constructor, and patch
that applies to text.
"""
assert self._text[patch.start:patch.end] == patch.old_text
return (self._text, self._value, patch)
def get_text(self):
return self._text
class Replacer(Builder):
"""
Builder that transforms an input Builder with some patches to produce output. It remembers
positions of replacements, so it can map patches of its output back to its input.
"""
def __init__(self, in_builder, patches):
self._in_builder = in_builder
# Two parallel lists of input and output offsets, with corresponding offsets at the same index
# in the two lists. Each list is ordered by offset.
self._input_offsets = [0]
self._output_offsets = [0]
out_parts = []
in_pos = 0
out_pos = 0
text = self._in_builder.get_text()
# Note that we have to go through patches in sorted order.
for in_patch in sorted(patches):
validate_patch(text, in_patch)
out_parts.append(text[in_pos:in_patch.start])
out_parts.append(in_patch.new_text)
out_pos += (in_patch.start - in_pos) + len(in_patch.new_text)
in_pos = in_patch.end
# If the replacement text is shorter or longer than the original, insert a new pair of
# offsets corresponding to the patch's end position in the input and output text.
if len(in_patch.new_text) != in_patch.end - in_patch.start:
self._input_offsets.append(in_pos)
self._output_offsets.append(out_pos)
out_parts.append(text[in_pos:])
self._output_text = ''.join(out_parts)
def get_text(self):
return self._output_text
def map_back_patch(self, patch):
validate_patch(self._output_text, patch)
in_start = self.get_input_pos(patch.start)
in_end = self.get_input_pos(patch.end)
in_patch = make_patch(self._in_builder.get_text(), in_start, in_end, patch.new_text)
return self._in_builder.map_back_patch(in_patch)
def get_input_pos(self, out_pos):
"""Returns the position in the input text corresponding to the given position in output."""
index = bisect.bisect_right(self._output_offsets, out_pos) - 1
offset = out_pos - self._output_offsets[index]
return self._input_offsets[index] + offset
def map_back_offset(self, out_pos):
"""
Returns the position corresponding to out_pos in the original input, in case it was
processed by a series of Replacers.
"""
input_pos = self.get_input_pos(out_pos)
if isinstance(self._in_builder, Replacer):
return self._in_builder.map_back_offset(input_pos)
return input_pos
class Combiner(Builder):
"""
Combiner allows building output text from a sequence of other Builders. When a patch is mapped
back, it gets passed to the Builder it came from, and must not span more than one input Builder.
"""
def __init__(self, parts):
self._parts = parts
self._offsets = []
text_parts = [(p if isinstance(p, basestring) else p.get_text()) for p in self._parts]
self._text = ''.join(text_parts)
offset = 0
self._offsets = []
for t in text_parts:
self._offsets.append(offset)
offset += len(t)
def get_text(self):
return self._text
def map_back_patch(self, patch):
validate_patch(self._text, patch)
start_index = bisect.bisect_right(self._offsets, patch.start)
end_index = bisect.bisect_right(self._offsets, patch.end - 1)
if start_index <= 0 or end_index <= 0 or start_index != end_index:
raise ValueError("Invalid patch to Combiner: %s" % (patch,))
offset = self._offsets[start_index - 1]
part = self._parts[start_index - 1]
in_patch = Patch(patch.start - offset, patch.end - offset, patch.old_text, patch.new_text)
return None if isinstance(part, basestring) else part.map_back_patch(in_patch)

@ -0,0 +1,32 @@
"""
Grist supports organizing a list of records as a tree view which allows for grouping records as
children of some other record.
On the client, the .indentation is used to measure the distance between the left margin of the
container and where we want the record to be. The variation of .indentation gives the parent-child
relationship between consecutive records. For instance in ["A0", "B1", "C1"] (where "A0" stands for
the record {'id': "A", 'indentation': 0}), "B" and "C" are children of "A". In ["A0", "B1", "C2"],
"C" is a child of "B", which is a child of "A".
The order for the records is typically handled using a field of type "PositionNumber", ie: .pagePos
in _grist_Pages table.
Because user can remove records that invalidate the tree, the module exposes fix_indents. For
example if user removes "C" from ["A0", "B1", "C0", "D1"] the resulting table holds ["A0", "B1",
"D1"] and "D" became child of "A", which is unfortunate because we'd rather have "C" become a
sibling of "A" instead. Using fix_indents helps with keeping the tree consistent by returning [("D",
0)] which indicate that the indentation of row "D" needs to be set to 0.
"""
# Items is an array of items with .id and .indentation properties. Returns a list of (item_id,
# new_indent) pairs.
def fix_indents(items, deleted_ids):
max_next_indent = 0
adjustments = []
for item in items:
indent = min(max_next_indent, item.indentation)
is_deleted = item.id in deleted_ids
if indent != item.indentation and not is_deleted:
adjustments.append((item.id, indent))
max_next_indent = indent if is_deleted else indent + 1
return adjustments

@ -0,0 +1,252 @@
"""
TwoWayMap implements mapping from keys to values, and values back to keys. Since keys and values
are not really different here, they are referred to throughout as 'left' and 'right' values.
TwoWayMap supports different types of containers when one value maps to multiple. You may add
support for additional container types using register_container() module function.
It's implemented using Python dictionaries, so both 'left' and 'right' values must be hashable.
For example, to create a dictionary-like structure mapping one key to one value, and which allows
to quickly tell the set of keys that map to a given value, we can use m=TwoWayMap(left=set,
right="single"). Then m.insert(key, value) sets the given key to the given value (overwriting the
value previously set, since the "right" dataset is "single" values), m.lookup_left(key) returns
that value, and m.lookup_right(value) returns a `set` of keys that map to the value.
"""
# Special sentinel value which can never be legitimately stored in TwoWayMap, to easily tell the
# difference between a present and absent value.
_NIL = object()
class TwoWayMap(object):
def __init__(self, left=set, right=set):
"""
Create a new TwoWayMap. The `left` and `right` parameters determine the type of bin for
storing multiple values on the respective side of the map. E.g. if right=set, then
lookup_left() will return a set (what's on the right side). Supported values are:
set: a set of values.
list: a list of values, with new items added at the end of the list.
"single": a single value, new items overwrite previous ones.
"strict": a single value, new items must not overwrite previous ones.
To add support for another bin type, use twowaymap.register_container().
E.g. for TwoWayMap(left="single", right="strict"),
after insert(1, "a"), insert(1, "b") will succeed, but insert(2, "a") will fail.
E.g. for TwoWayMap(left=list, right="single"),
after insert(1, "a"), insert(1, "b"), insert(2, "a"),
lookup_left(1) will return ["a", "b"], and lookup_right("a") will return 2.
"""
self._left_bin = _mapper_types[left]
self._right_bin = _mapper_types[right]
self._fwd = {}
self._bwd = {}
def __nonzero__(self):
return bool(self._fwd)
def lookup_left(self, left, default=None):
""" Returns the value(s) on the right corresponding to the given value on the left. """
return self._fwd.get(left, default)
def lookup_right(self, right, default=None):
""" Returns the value(s) on the left corresponding to the given value on the right. """
return self._bwd.get(right, default)
def count_left(self):
""" Returns the count of unique values on the left."""
return len(self._fwd)
def count_right(self):
""" Returns the count of unique values on the right."""
return len(self._bwd)
def left_all(self):
""" Returns an iterable over all values on the left."""
return self._fwd.iterkeys()
def right_all(self):
""" Returns an iterable over all values on the right."""
return self._bwd.iterkeys()
def insert(self, left, right):
""" Insert the (left, right) value pair. """
# The tricky thing here is to keep the two maps consistent if an update to the second one
# raises an exception. To handle it, add_item must return what got added and removed, so that
# we can restore things after an exception. An exception could be caused by a "strict" bin
# type, or by using an un-hashable key (on either left or right side), or by using a custom
# container that can throw.
right_removed, right_added = self._right_bin.add_item(self._fwd, left, right)
try:
left_removed, _ = self._left_bin.add_item(self._bwd, right, left)
except:
# _left_bin is responsible to stay unchanged if there was an exception. Now we need to bring
# _right_bin back in sync with _left_bin.
if right_added is not _NIL:
self._right_bin.remove_item(self._fwd, left, right_added)
if right_removed is not _NIL:
self._right_bin.add_item(self._fwd, left, right_removed)
raise
# It's possible for add_item to overwrite elements, in which case we need to remove the
# other side of the mapping for the removed element.
if right_removed is not _NIL:
self._left_bin.remove_item(self._bwd, right_removed, left)
if left_removed is not _NIL:
self._right_bin.remove_item(self._fwd, left_removed, right)
def remove(self, left, right):
""" Remove the (left, right) value pair. """
self._right_bin.remove_item(self._fwd, left, right)
self._left_bin.remove_item(self._bwd, right, left)
def remove_left(self, left):
""" Remove all values on the right corresponding to the given value on the left. """
right_removed = self._right_bin.remove_key(self._fwd, left)
for x in right_removed:
self._left_bin.remove_item(self._bwd, x, left)
def remove_right(self, right):
""" Remove all values on the left corresponding to the given value on the right. """
left_removed = self._left_bin.remove_key(self._bwd, right)
for x in left_removed:
self._right_bin.remove_item(self._fwd, x, right)
def clear(self):
""" Clear the entire map. """
self._fwd.clear()
self._bwd.clear()
#----------------------------------------------------------------------
# The private classes below implement the different container types.
class _BaseBinType(object):
""" Base class for other BinTypes. """
def add_item(self, mapping, key, value):
pass
def remove_item(self, mapping, key, value):
pass
def remove_key(self, mapping, key):
pass
class _SingleValueBin(_BaseBinType):
""" Bin that contains a single value, with new values overwriting previous ones."""
def add_item(self, mapping, key, value):
stored = mapping.get(key, _NIL)
mapping[key] = value
if stored is _NIL:
return _NIL, value
elif stored == value:
return _NIL, _NIL
else:
return stored, value
def remove_item(self, mapping, key, value):
stored = mapping.get(key, _NIL)
if stored == value:
del mapping[key]
def remove_key(self, mapping, key):
stored = mapping.pop(key, _NIL)
return () if stored is _NIL else (stored,)
class _SingleValueStrictBin(_SingleValueBin):
""" Bin that contains a single value, overwriting which raises ValueError."""
def add_item(self, mapping, key, value):
stored = mapping.get(key, _NIL)
if stored is _NIL:
mapping[key] = value
return _NIL, value
elif stored == value:
return _NIL, _NIL
else:
raise ValueError("twowaymap: one-to-one map violation for key %s" % key)
class _ContainerBin(_BaseBinType):
"""
Bin that contains a container of values managed by the passed-in functions. See
register_container() for documentation of the arguments.
"""
def __init__(self, make_func, add_func, remove_func):
self.make = make_func
self.add = add_func
self.remove = remove_func
def add_item(self, mapping, key, value):
stored = mapping.get(key, _NIL)
if stored is _NIL:
mapping[key] = self.make(value)
return _NIL, value
else:
return _NIL, (value if self.add(stored, value) else _NIL)
def remove_item(self, mapping, key, value):
stored = mapping.get(key, _NIL)
if stored is not _NIL:
self.remove(stored, value)
if not stored:
del mapping[key]
def remove_key(self, mapping, key):
return mapping.pop(key, ())
#----------------------------------------------------------------------
_mapper_types = {
'single': _SingleValueBin(),
'strict': _SingleValueStrictBin(),
}
def register_container(cls, make_func, add_func, remove_func):
"""
Register another container type. The first argument can be the container's class object, but
really can be any hashable value, which you can then give as an argument to left= or right=
arguments when constructing a TwoWayMap. The other arguments are:
make_func(value) - must return a new instance of the container with a single value.
This container must support iteration through values, and in boolean context must
evaluate to whether it's non-empty.
add_func(container, value) - must add value to container, only if it's not already there,
and return True if the value was added, False if it was already there.
remove_func(container, value) - must remove value from container if present.
This must never raise an exception, since that could leave the map in inconsistent state.
"""
_mapper_types[cls] = _ContainerBin(make_func, add_func, remove_func)
# Allow `set` to be used as a bin type.
def _set_make(value):
return {value}
def _set_add(container, value):
if value not in container:
container.add(value)
return True
return False
def _set_remove(container, value):
container.discard(value)
register_container(set, _set_make, _set_add, _set_remove)
# Allow `list` to be used as a bin type.
def _list_make(value):
return [value]
def _list_add(container, value):
if value not in container:
container.append(value)
return True
return False
def _list_remove(container, value):
try:
container.remove(value)
except ValueError:
pass
register_container(list, _list_make, _list_add, _list_remove)

Binary file not shown.

File diff suppressed because it is too large Load Diff

@ -0,0 +1,68 @@
"""
usercode.py isn't a real module, but an example of a module produced by gencode.py from the
user-defined document schema.
It is the same code that's produced from the test schema in test_gencode.py. In fact, it is used
as part of that test.
User-defined Tables (i.e. classes that derive from grist.Table) automatically get some additional
members:
Record - a class derived from grist.Record, with a property for each table column.
RecordSet - a class derived from grist.Record, with a property for each table column.
RecordSet.Record - a reference to the Record class above
======================================================================
import grist
from functions import * # global uppercase functions
import datetime, math, re # modules commonly needed in formulas
@grist.UserTable
class Students:
firstName = grist.Text()
lastName = grist.Text()
school = grist.Reference('Schools')
def fullName(rec, table):
return rec.firstName + ' ' + rec.lastName
def fullNameLen(rec, table):
return len(rec.fullName)
def schoolShort(rec, table):
return rec.school.name.split(' ')[0]
def schoolRegion(rec, table):
addr = rec.school.address
return addr.state if addr.country == 'US' else addr.region
@grist.formulaType(grist.Reference('Schools'))
def school2(rec, table):
return Schools.lookupFirst(name=rec.school.name)
@grist.UserTable
class Schools:
name = grist.Text()
address = grist.Reference('Address')
@grist.UserTable
class Address:
city = grist.Text()
state = grist.Text()
def _default_country(rec, table):
return 'US'
country = grist.Text()
def region(rec, table):
return {'US': 'North America', 'UK': 'Europe'}.get(rec.country, 'N/A')
def badSyntax(rec, table):
# for a in b
# 10
raise SyntaxError('invalid syntax on line 1 col 11')
======================================================================
"""

@ -0,0 +1,461 @@
"""
The basic types in Grist include Numeric, Text, Reference, Date, and others. Each type needs a
representation in storage (database), in communication messages, and in the memory of JS and
Python interpreters. Each type also needs a convenient Python representation when used in
formulas. Any typed column may also contain values of a wrong type, and those also need a
representation. Finally, every type defines a default value, used when the column is first
created, and for new records.
For values of type int or bool, It's possible to save some memory by using JS typed arrays or
Python's array.array. However, at least on the Python side, it means that we need an additional
data structure for values of the wrong type, and the memory savings aren't that great to be worth
the extra complexity.
"""
import datetime
import six
import objtypes
import moment
import logger
from records import Record, RecordSet
log = logger.Logger(__name__, logger.INFO)
NoneType = type(None)
def strict_equal(a, b):
"""Checks the equality of the types of the values as well as the values."""
# pylint: disable=unidiomatic-typecheck
return type(a) == type(b) and a == b
# Note that this matches the defaults in app/common/gristTypes.js
_type_defaults = {
'Any': None,
'Attachments': None,
'Blob': None,
'Bool': False,
'Choice': '',
'Date': None,
'DateTime': None,
'Id': 0,
'Int': 0,
'ManualSortPos': float('inf'),
'Numeric': 0.0,
'PositionNumber': float('inf'),
'Ref': 0,
'RefList': None,
'Text': '',
}
def get_type_default(col_type):
col_type = col_type.split(':', 1)[0] # Strip suffix for Ref:, DateTime:, etc.
return _type_defaults.get(col_type, None)
def formulaType(grist_type):
"""
formulaType(gristType) is a decorator which saves the type as the 'grist_type' attribute
on the decorated formula function. It allows the formula columns to be typed.
"""
def wrapper(method):
method.grist_type = grist_type
return method
return wrapper
class AltText(object):
"""
Represents a text value in a non-text column. The separate class allows formulas to access
wrong-type values. We use a wrapper rather than expose text directly to formulas, because with
text there is a risk that e.g. a formula that's supposed to add numbers would add two strings
with unexpected result.
"""
def __init__(self, text, typename=None):
self._text = text
self._typename = typename
def __str__(self):
return self._text
def __int__(self):
# This ensures that AltText values that look like ints may be cast back to int.
# Convert to float first, since python does not allow casting strings with decimals to int.
return int(float(self._text))
def __float__(self):
# This ensures that AltText values that look like floats may be cast back to float.
return float(self._text)
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, self._text)
# Allow comparing to AltText("something")
def __eq__(self, other):
return isinstance(other, self.__class__) and self._text == other._text
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash((self.__class__, self._text))
def __getattr__(self, name):
# On attempt to do $foo.Bar on an AltText value such as "hello", raise an exception that will
# show up as e.g. "Invalid Ref: hello" or "Invalid Date: hello".
raise objtypes.InvalidTypedValue(self._typename, self._text)
def ifError(value, value_if_error):
"""
Return `value` if it is valid, or `value_if_error` otherwise. Similar to Excel's IFERROR.
"""
# TODO: this should ideally handle exception values and values of wrong type returned by
# formulas, but it's unclear how to make that work.
return value_if_error if isinstance(value, AltText) else value
# Unique sentinel object to tell BaseColumnType constructor to use get_type_default().
_use_type_default = object()
class BaseColumnType(object):
"""
Base class for all column types.
"""
_global_creation_order = 0
def __init__(self, default=_use_type_default):
self.default = get_type_default(self.typename()) if default is _use_type_default else default
self.default_func = None
# Slightly silly, but it allows us to extract the order in which fields are listed in the
# model definition, without looking back at the schema.
self._creation_order = BaseColumnType._global_creation_order
BaseColumnType._global_creation_order += 1
@classmethod
def typename(cls):
"""
Returns the name of the type, e.g. "Int", "Ref", or "RefList".
"""
return cls.__name__
@classmethod
def is_right_type(cls, _value):
"""
Returns whether the given value belongs to this type. A cell may contain a wrong-type value
(e.g. alttext, error), but formulas will only see right-type values, defaulting to the
column's default.
If is_right_type returns true, it must be possible to store the value (so with typed arrays,
it must fit the type's restrictions).
"""
return True
@classmethod
def do_convert(cls, value):
"""
Converts a value of any type to one of our type (for which is_right_type is true) and returns
it, or throws an exception. This is the method that should be overridden by subclasses.
"""
return value
def convert(self, value_to_convert):
"""
Converts a value of any type to this type, returning either a value of the right type, or
alttext, or error. It never throws, and should not be overridden by subclasses (override
do_convert instead).
"""
# Don't try to convert errors, although some day we may want to attempt it (e.g. if an error
# contains original text, we may want to try to convert the original text).
if isinstance(value_to_convert, objtypes.RaisedException):
return value_to_convert
try:
return self.do_convert(value_to_convert)
except Exception as e:
# If conversion failed, return a string to serve as alttext.
if isinstance(value_to_convert, six.text_type):
# str() will fail for a non-ascii unicode object, which needs an explicit encoding.
return value_to_convert.encode('utf8')
return str(value_to_convert)
# This is a user-facing method, hence the camel-case naming, as for `lookupRecords` and such.
@classmethod
def typeConvert(cls, value):
"""
Convert a value from a different type to something that this type can accept, as when
explicitly converting a column type. Note that usual conversion (such as converting numbers to
strings or vice versa) will still apply to the returned value.
"""
return value
class Text(BaseColumnType):
"""
Text is the type for a field holding string (text) data.
"""
@classmethod
def do_convert(cls, value):
return str(value) if value is not None else None
@classmethod
def is_right_type(cls, value):
return isinstance(value, (basestring, NoneType))
@classmethod
def typeConvert(cls, value):
# When converting NULLs (that typically show up as a plain empty cell for Numeric or Date
# columns) to Text, it makes more sense to end up with a plain blank text cell.
return '' if value is None else value
class Blob(BaseColumnType):
"""
Blob hold binary data.
"""
@classmethod
def do_convert(cls, value):
return str(value) if value is not None else None
@classmethod
def is_right_type(cls, value):
return isinstance(value, (basestring, NoneType))
class Any(BaseColumnType):
"""
Any is the type that can hold any kind of value. It's used to hold computed values.
"""
@classmethod
def do_convert(cls, value):
# Convert AltText values to plain text when assigning to type Any.
return str(value) if isinstance(value, AltText) else value
class Bool(BaseColumnType):
"""
Bool is the type for a field holding boolean data.
"""
@classmethod
def do_convert(cls, value):
# We'll convert any falsy value to False, non-zero numbers to True, and only strings we
# recognize. Everything else will result in alttext.
if not value:
return False
if isinstance(value, (float, int, long)):
return True
if isinstance(value, AltText):
value = str(value)
if isinstance(value, basestring):
if value.lower() in ("false", "no", "0"):
return False
if value.lower() in ("true", "yes", "1"):
return True
raise objtypes.ConversionError("Bool")
@classmethod
def is_right_type(cls, value):
return isinstance(value, (bool, NoneType))
class Int(BaseColumnType):
"""
Int is the type for a field holding integer data.
"""
@classmethod
def do_convert(cls, value):
if value in ("", None):
return None
# Convert to float first, since python does not allow casting strings with decimals to int
ret = int(float(value))
if not objtypes.is_int_short(ret):
raise OverflowError("Integer value too large")
return ret
@classmethod
def is_right_type(cls, value):
return value is None or (isinstance(value, (int, long)) and not isinstance(value, bool) and
objtypes.is_int_short(value))
class Numeric(BaseColumnType):
"""
Numeric is the type for a field holding numerical data.
"""
@classmethod
def do_convert(cls, value):
return float(value) if value not in ("", None) else None
@classmethod
def is_right_type(cls, value):
# TODO: Python distinguishes ints from floats, while JS only has floats. A value that can be
# interpreted as an int will upon being entered have type 'float', but after database reload
# will have type 'int'.
return isinstance(value, (float, int, long, NoneType)) and not isinstance(value, bool)
class Date(Numeric):
"""
Date is the type for a field holding date data (no timezone).
"""
@classmethod
def do_convert(cls, value):
if value in ("", None):
return None
elif isinstance(value, datetime.datetime):
return moment.dt_to_ts(value)
elif isinstance(value, datetime.date):
return moment.date_to_ts(value)
elif isinstance(value, (float, int, long)):
return float(value)
elif isinstance(value, basestring):
# We also accept a date in ISO format (YYYY-MM-DD), the time portion is optional and ignored
return moment.parse_iso_date(value)
else:
raise objtypes.ConversionError('Date')
@classmethod
def is_right_type(cls, value):
return isinstance(value, (float, int, long, NoneType))
@classmethod
def typeConvert(cls, value, date_format, timezone='UTC'): # pylint: disable=arguments-differ
# Note: the timezone argument is used in DateTime conversions, allows sharing this method.
try:
return moment.parse(value, date_format, timezone)
except Exception:
return value
class DateTime(Date):
"""
DateTime is the type for a field holding date and time data.
"""
def __init__(self, timezone="America/New_York", default=_use_type_default):
super(DateTime, self).__init__(default)
try:
self.timezone = moment.Zone(timezone)
except KeyError:
self.timezone = moment.Zone('UTC')
def do_convert(self, value):
if value in ("", None):
return None
elif isinstance(value, datetime.datetime):
return moment.dt_to_ts(value, self.timezone)
elif isinstance(value, datetime.date):
return moment.date_to_ts(value, self.timezone)
elif isinstance(value, (float, int, long)):
return float(value)
elif isinstance(value, basestring):
# We also accept a datetime in ISO format (YYYY-MM-DD[T]HH:mm:ss)
return moment.parse_iso(value, self.timezone)
else:
raise objtypes.ConversionError('DateTime')
class Choice(Text):
"""
Choice is the type for a field holding one of a set of acceptable string (text) values.
TODO: Type should possibly be aware of the allowed choices, and be considered invalid
when its value isn't one of them
"""
pass
class PositionNumber(BaseColumnType):
"""
PositionNumber is the type for a position field used to order records in record lists.
"""
# The 'inf' default is used by prepare_new_values() in column.py, which always changes it to
# finite numbers, but relies on it to keep newly-added records below existing ones by default.
@classmethod
def do_convert(cls, value):
return float(value) if value not in ("", None) else float('inf')
@classmethod
def is_right_type(cls, value):
# Same as Numeric, but does not support None.
return isinstance(value, (float, int, long)) and not isinstance(value, bool)
class ManualSortPos(PositionNumber):
pass
class Id(BaseColumnType):
"""
Id is the type for the record ID field, present automatically in each table.
The default of 0 points to the always-present empty record. Real records start at index 1.
"""
@classmethod
def do_convert(cls, value):
# Just like Int.do_convert, but skips conversion via float. This also makes it work for Record
# types, which override int() conversion to yield the row ID. Arbitrary values should not be
# cast to ints as it results in false hits when converting numerical values to reference ids.
if not value:
return 0
if not isinstance(value, (int, Record)):
raise TypeError("Cannot convert to Id type")
ret = int(value)
if not objtypes.is_int_short(ret):
raise OverflowError("Integer value too large")
return ret
@classmethod
def is_right_type(cls, value):
return (isinstance(value, (int, long)) and not isinstance(value, bool) and
objtypes.is_int_short(value))
class Reference(Id):
"""
Reference is the type for a field holding a reference into another table.
Note that if `foo` is a Reference('Foo'), then `rec.foo` is of type `Foo.Record`. The ID of that
record is available as `rec.foo._row_id`. It is equivalent to `rec.foo.id`, except that
accessing `id`, as other public properties, involves a lookup in `Foo` table.
"""
def __init__(self, table_id):
super(Reference, self).__init__()
self.table_id = table_id
@classmethod
def typename(cls):
return "Ref"
@classmethod
def typeConvert(cls, value, ref_table, visible_col=None): # pylint: disable=arguments-differ
if ref_table and visible_col:
return ref_table.lookupOne(**{visible_col: value}) or str(value)
else:
return value
class ReferenceList(BaseColumnType):
"""
ReferenceList stores a list of references into another table.
"""
def __init__(self, table_id):
super(ReferenceList, self).__init__()
self.table_id = table_id
@classmethod
def typename(cls):
return "RefList"
def do_convert(self, value):
if isinstance(value, RecordSet):
assert value._table.table_id == self.table_id
return objtypes.RecordList(value._row_ids, group_by=value._group_by, sort_by=value._sort_by)
elif not value:
return []
return [Reference.do_convert(val) for val in value]
@classmethod
def is_right_type(cls, value):
return value is None or (isinstance(value, list) and
all(Reference.is_right_type(val) for val in value))
class Attachments(ReferenceList):
"""
Currently attachment type is the field for holding data for attachments.
"""
def __init__(self):
super(Attachments, self).__init__('_grist_Attachments')

@ -0,0 +1,31 @@
const path = require('path');
require('app-module-path').addPath(path.dirname(__dirname));
require('ts-node').register();
/**
* This script converts the timezone data from moment-timezone to marshalled format, for fast
* loading by Python.
*/
const marshal = require('app/common/marshal');
const fse = require('fs-extra');
const moment = require('moment-timezone');
const DEST_FILE = 'sandbox/grist/tzdata.data';
function main() {
const zones = moment.tz.names().map((name) => {
const z = moment.tz.zone(name);
return marshal.wrap('TUPLE', [z.name, z.abbrs, z.offsets, z.untils]);
});
const marshaller = new marshal.Marshaller({version: 2});
marshaller.marshal(zones);
const contents = marshaller.dumpAsBuffer();
return fse.writeFile(DEST_FILE, contents);
}
if (require.main === module) {
main().catch((e) => {
console.log("ERROR", e.message);
process.exit(1);
});
}

@ -0,0 +1,17 @@
astroid==1.4.9
asttokens==1.1.4
chardet==2.3.0
html5lib==0.999999999
iso8601==0.1.12
json_table_schema==0.2.1
lazy_object_proxy==1.2.2
messytables==0.15.2
python_dateutil==2.6.0
python_magic==0.4.12
roman==2.0.0
six==1.10.0
sortedcontainers==1.5.7
webencodings==0.5
wrapt==1.10.8
xlrd==1.2.0
unittest-xml-reporting==2.0.0
Loading…
Cancel
Save