(core) Python optimizations to speed up data engine

Summary:
- A bunch of optimizations guided by python profiling (esp. py-spy)
- Big one is optimizing Record/RecordSet attribute access
- Adds tracemalloc printout when running test_replay with PYTHONTRACEMALLOC=1 (on PY3)
  (but memory size is barely affected by these changes)

- Testing with RECORD_SANDBOX_BUFFERS_DIR, loading and calculating a particular
  very large doc (CRM), time taken improved from 73.9s to 54.8s (26% faster)

Test Plan: No behavior changes intended; relying on existing tests to verify that.

Reviewers: georgegevoian

Reviewed By: georgegevoian

Differential Revision: https://phab.getgrist.com/D3781
This commit is contained in:
Dmitry S 2023-02-04 11:20:13 -05:00
parent 7c448d746f
commit 9d4eeda480
8 changed files with 144 additions and 84 deletions

View File

@ -60,6 +60,7 @@ class BaseColumn(object):
""" """
def __init__(self, table, col_id, col_info): def __init__(self, table, col_id, col_info):
self.type_obj = col_info.type_obj self.type_obj = col_info.type_obj
self._is_right_type = self.type_obj.is_right_type
self._data = [] self._data = []
self.col_id = col_id self.col_id = col_id
self.table_id = table.table_id self.table_id = table.table_id
@ -154,10 +155,14 @@ class BaseColumn(object):
else: else:
raise objtypes.CellError(self.table_id, self.col_id, row_id, raw.error) raise objtypes.CellError(self.table_id, self.col_id, row_id, raw.error)
return self._convert_raw_value(raw) # Inline _convert_raw_value here because this is particularly hot code, called on every access
# of any data field in a formula.
if self._is_right_type(raw):
return self._make_rich_value(raw)
return self._alt_text(raw)
def _convert_raw_value(self, raw): def _convert_raw_value(self, raw):
if self.type_obj.is_right_type(raw): if self._is_right_type(raw):
return self._make_rich_value(raw) return self._make_rich_value(raw)
return self._alt_text(raw) return self._alt_text(raw)

View File

@ -161,9 +161,9 @@ class SimpleLookupMapColumn(BaseLookupMapColumn):
def _recalc_rec_method(self, rec, table): def _recalc_rec_method(self, rec, table):
old_key = self._row_key_map.lookup_left(rec._row_id) old_key = self._row_key_map.lookup_left(rec._row_id)
# Note that rec._get_col(_col_id) is what creates the correct dependency, as well as ensures # Note that getattr(rec, _col_id) is what creates the correct dependency, as well as ensures
# that the columns used to index by are brought up-to-date (in case they are formula columns). # that the columns used to index by are brought up-to-date (in case they are formula columns).
new_key = tuple(_extract(rec._get_col(_col_id)) for _col_id in self._col_ids_tuple) new_key = tuple(_extract(getattr(rec, _col_id)) for _col_id in self._col_ids_tuple)
try: try:
self._row_key_map.insert(rec._row_id, new_key) self._row_key_map.insert(rec._row_id, new_key)
@ -188,9 +188,9 @@ class ContainsLookupMapColumn(BaseLookupMapColumn):
# looked up with CONTAINS() # looked up with CONTAINS()
new_keys_groups = [] new_keys_groups = []
for col_id in self._col_ids_tuple: for col_id in self._col_ids_tuple:
# Note that _get_col is what creates the correct dependency, as well as ensures # Note that getattr() is what creates the correct dependency, as well as ensures
# that the columns used to index by are brought up-to-date (in case they are formula columns). # that the columns used to index by are brought up-to-date (in case they are formula columns).
group = rec._get_col(extract_column_id(col_id)) group = getattr(rec, extract_column_id(col_id))
if isinstance(col_id, _Contains): if isinstance(col_id, _Contains):
# Check that the cell targeted by CONTAINS() has an appropriate type. # Check that the cell targeted by CONTAINS() has an appropriate type.

View File

@ -121,10 +121,8 @@ class CensoredValue(object):
_censored_sentinel = CensoredValue() _censored_sentinel = CensoredValue()
_max_js_int = 1<<31
def is_int_short(value): def is_int_short(value):
return -_max_js_int <= value < _max_js_int return -(1<<31) <= value < (1<<31)
def safe_shift(arg, default=None): def safe_shift(arg, default=None):
value = arg.pop(0) if arg else None value = arg.pop(0) if arg else None

View File

@ -39,8 +39,8 @@ class Record(object):
Usage: __$group__ Usage: __$group__
In a [summary table](summary-tables.md), `$group` is a special field In a [summary table](summary-tables.md), `$group` is a special field
containing the list of Records that are summarized by the current summary line. E.g. the formula containing the list of Records that are summarized by the current summary line. E.g. the
`len($group)` counts the number of those records being summarized in each row. formula `len($group)` counts the number of those records being summarized in each row.
See [RecordSet](#recordset) for useful properties offered by the returned object. See [RecordSet](#recordset) for useful properties offered by the returned object.
@ -54,9 +54,15 @@ class Record(object):
""" """
) )
# Slots are an optimization to avoid the need for a per-object __dict__.
__slots__ = ('_row_id', '_source_relation')
# Per-table derived classes override this and set it to the appropriate Table object.
_table = None
# Record is always a thin class, containing essentially a reference to a row in the table. The # Record is always a thin class, containing essentially a reference to a row in the table. The
# properties to access individual fields of a row are provided in per-table derived classes. # properties to access individual fields of a row are provided in per-table derived classes.
def __init__(self, table, row_id, relation=None): def __init__(self, row_id, relation=None):
""" """
Creates a Record object. Creates a Record object.
table - Table object, in which this record lives. table - Table object, in which this record lives.
@ -69,20 +75,12 @@ class Record(object):
which provides the table argument automatically. which provides the table argument automatically.
""" """
self._table = table
self._row_id = row_id self._row_id = row_id
self._source_relation = relation or table._identity_relation self._source_relation = relation or self._table._identity_relation
def _get_col(self, col_id): # Existing fields are added as @property methods in table.py. When no field is found, raise a
return self._table._get_col_value(col_id, self._row_id, self._source_relation) # more informative AttributeError.
# Look up a property of the record. Internal properties are simple.
# For columns, we explicitly check that we have them before attempting to access.
# Otherwise AttributeError is ambiguous - it could be because we don't have the
# column, or because the column threw an AttributeError when evaluated.
def __getattr__(self, name): def __getattr__(self, name):
if name in self._table.all_columns:
return self._get_col(name)
return self._table._attribute_error(name, self._source_relation) return self._table._attribute_error(name, self._source_relation)
def __hash__(self): def __hash__(self):
@ -134,17 +132,23 @@ class RecordSet(object):
You can get the number of records in a RecordSet using `len`, e.g. `len($group)`. You can get the number of records in a RecordSet using `len`, e.g. `len($group)`.
""" """
# Slots are an optimization to avoid the need for a per-object __dict__.
__slots__ = ('_row_ids', '_source_relation', '_group_by', '_sort_by')
# Per-table derived classes override this and set it to the appropriate Table object.
_table = None
# Methods should be named with a leading underscore to avoid interfering with access to # Methods should be named with a leading underscore to avoid interfering with access to
# user-defined fields. # user-defined fields.
def __init__(self, table, row_ids, relation=None, group_by=None, sort_by=None): def __init__(self, row_ids, relation=None, group_by=None, sort_by=None):
""" """
group_by may be a dictionary mapping column names to values that are all the same for the given group_by may be a dictionary mapping column names to values that are all the same for the given
RecordSet. sort_by may be the column name used for sorting this record set. Both are set by RecordSet. sort_by may be the column name used for sorting this record set. Both are set by
lookupRecords, and used when using RecordSet to insert new records. lookupRecords, and used when using RecordSet to insert new records.
""" """
self._table = table
self._row_ids = row_ids self._row_ids = row_ids
self._source_relation = relation or table._identity_relation self._source_relation = relation or self._table._identity_relation
# If row_ids is itself a RecordList, default to its _group_by and _sort_by properties. # If row_ids is itself a RecordList, default to its _group_by and _sort_by properties.
self._group_by = group_by or getattr(row_ids, '_group_by', None) self._group_by = group_by or getattr(row_ids, '_group_by', None)
self._sort_by = sort_by or getattr(row_ids, '_sort_by', None) self._sort_by = sort_by or getattr(row_ids, '_sort_by', None)
@ -188,12 +192,7 @@ class RecordSet(object):
row_id = min(self._row_ids) row_id = min(self._row_ids)
return self._table.Record(row_id, self._source_relation) return self._table.Record(row_id, self._source_relation)
def _get_col(self, col_id):
return self._table._get_col_subset(col_id, self._row_ids, self._source_relation)
def __getattr__(self, name): def __getattr__(self, name):
if name in self._table.all_columns:
return self._get_col(name)
return self._table._attribute_error(name, self._source_relation) return self._table._attribute_error(name, self._source_relation)
def __repr__(self): def __repr__(self):

View File

@ -11,7 +11,7 @@ import docmodel
import functions import functions
import logger import logger
import lookup import lookup
import records from records import adjust_record, Record as BaseRecord, RecordSet as BaseRecordSet
import relation as relation_module # "relation" is used too much as a variable name below. import relation as relation_module # "relation" is used too much as a variable name below.
import usertypes import usertypes
@ -64,7 +64,8 @@ class UserTable(object):
""" """
Name: lookupRecords Name: lookupRecords
Usage: UserTable.__lookupRecords__(Field_In_Lookup_Table=value, ...) Usage: UserTable.__lookupRecords__(Field_In_Lookup_Table=value, ...)
Returns a [RecordSet](#recordset) matching the given field=value arguments. The value may be any expression, Returns a [RecordSet](#recordset) matching the given field=value arguments. The value may be
any expression,
most commonly a field in the current row (e.g. `$SomeField`) or a constant (e.g. a quoted string most commonly a field in the current row (e.g. `$SomeField`) or a constant (e.g. a quoted string
like `"Some Value"`) (examples below). like `"Some Value"`) (examples below).
If `sort_by=field` is given, sort the results by that field. If `sort_by=field` is given, sort the results by that field.
@ -88,7 +89,8 @@ class UserTable(object):
""" """
Name: lookupOne Name: lookupOne
Usage: UserTable.__lookupOne__(Field_In_Lookup_Table=value, ...) Usage: UserTable.__lookupOne__(Field_In_Lookup_Table=value, ...)
Returns a [Record](#record) matching the given field=value arguments. The value may be any expression, Returns a [Record](#record) matching the given field=value arguments. The value may be any
expression,
most commonly a field in the current row (e.g. `$SomeField`) or a constant (e.g. a quoted string most commonly a field in the current row (e.g. `$SomeField`) or a constant (e.g. a quoted string
like `"Some Value"`). If multiple records match, returns one of them. If none match, returns the like `"Some Value"`). If multiple records match, returns one of them. If none match, returns the
special empty record. special empty record.
@ -222,23 +224,24 @@ class Table(object):
# which are 'flattened' so source records may appear in multiple groups # which are 'flattened' so source records may appear in multiple groups
self._summary_simple = None self._summary_simple = None
# Add Record and RecordSet subclasses with correct `_table` attribute, which will also hold a
# field attribute for each column.
class Record(BaseRecord):
__slots__ = ()
_table = self
class RecordSet(BaseRecordSet):
__slots__ = ()
_table = self
self.Record = Record
self.RecordSet = RecordSet
# For use in _num_rows. The attribute isn't strictly needed, # For use in _num_rows. The attribute isn't strictly needed,
# but it makes _num_rows slightly faster, and only creating the lookup map when _num_rows # but it makes _num_rows slightly faster, and only creating the lookup map when _num_rows
# is called seems to be too late, at least for unit tests. # is called seems to be too late, at least for unit tests.
self._empty_lookup_column = self._get_lookup_map(()) self._empty_lookup_column = self._get_lookup_map(())
# Add Record and RecordSet subclasses which fill in this table as the first argument
class Record(records.Record):
def __init__(inner_self, *args, **kwargs): # pylint: disable=no-self-argument
super(Record, inner_self).__init__(self, *args, **kwargs)
class RecordSet(records.RecordSet):
def __init__(inner_self, *args, **kwargs): # pylint: disable=no-self-argument
super(RecordSet, inner_self).__init__(self, *args, **kwargs)
self.Record = Record
self.RecordSet = RecordSet
def _num_rows(self): def _num_rows(self):
""" """
Similar to `len(self.lookup_records())` but faster and doesn't create dependencies. Similar to `len(self.lookup_records())` but faster and doesn't create dependencies.
@ -301,6 +304,8 @@ class Table(object):
# column changes should stay the same. These get removed when unneeded using other means. # column changes should stay the same. These get removed when unneeded using other means.
new_cols.update(sorted(six.iteritems(self._special_cols))) new_cols.update(sorted(six.iteritems(self._special_cols)))
self._update_record_classes(self.all_columns, new_cols)
# Set the new columns. # Set the new columns.
self.all_columns = new_cols self.all_columns = new_cols
@ -411,8 +416,7 @@ class Table(object):
if type(self.get_column(col_id).type_obj) != type(_updateSummary.grist_type): if type(self.get_column(col_id).type_obj) != type(_updateSummary.grist_type):
self.delete_column(self.get_column(col_id)) self.delete_column(self.get_column(col_id))
col_obj = self._create_or_update_col(col_id, _updateSummary) col_obj = self._create_or_update_col(col_id, _updateSummary)
self._special_cols[col_id] = col_obj self._add_special_col(col_obj)
self.all_columns[col_id] = col_obj
def get_helper_columns(self): def get_helper_columns(self):
""" """
@ -509,9 +513,10 @@ class Table(object):
raise TypeError("sort_by must be a column ID (string)") raise TypeError("sort_by must be a column ID (string)")
reverse = sort_by.startswith("-") reverse = sort_by.startswith("-")
sort_col = sort_by.lstrip("-") sort_col = sort_by.lstrip("-")
sort_col_obj = self.all_columns[sort_col]
row_ids = sorted( row_ids = sorted(
row_id_set, row_id_set,
key=lambda r: column.SafeSortKey(self._get_col_value(sort_col, r, rel)), key=lambda r: column.SafeSortKey(self._get_col_obj_value(sort_col_obj, r, rel)),
reverse=reverse, reverse=reverse,
) )
else: else:
@ -542,14 +547,20 @@ class Table(object):
else: else:
column_class = lookup.SimpleLookupMapColumn column_class = lookup.SimpleLookupMapColumn
lmap = column_class(self, lookup_col_id, col_ids_tuple) lmap = column_class(self, lookup_col_id, col_ids_tuple)
self._special_cols[lookup_col_id] = lmap self._add_special_col(lmap)
self.all_columns[lookup_col_id] = lmap
return lmap return lmap
def delete_column(self, col_obj): def delete_column(self, col_obj):
assert col_obj.table_id == self.table_id assert col_obj.table_id == self.table_id
self._special_cols.pop(col_obj.col_id, None) self._special_cols.pop(col_obj.col_id, None)
self.all_columns.pop(col_obj.col_id, None) self.all_columns.pop(col_obj.col_id, None)
self._remove_field_from_record_classes(col_obj.col_id)
def _add_special_col(self, col_obj):
assert col_obj.table_id == self.table_id
self._special_cols[col_obj.col_id] = col_obj
self.all_columns[col_obj.col_id] = col_obj
self._add_field_to_record_classes(col_obj)
def lookupOrAddDerived(self, **kwargs): def lookupOrAddDerived(self, **kwargs):
record = self.lookup_one_record(**kwargs) record = self.lookup_one_record(**kwargs)
@ -629,39 +640,73 @@ class Table(object):
# TODO: document everything here. # TODO: document everything here.
# Called when record.foo is accessed # Equivalent to accessing record.foo, but only used in very limited cases now (field accessor is
def _get_col_value(self, col_id, row_id, relation): # more optimized).
[value] = self._get_col_subset_raw(col_id, [row_id], relation) def _get_col_obj_value(self, col_obj, row_id, relation):
return records.adjust_record(relation, value) # creates a dependency and brings formula columns up-to-date.
self._engine._use_node(col_obj.node, relation, (row_id,))
value = col_obj.get_cell_value(row_id)
return adjust_record(relation, value)
def _attribute_error(self, col_id, relation): def _attribute_error(self, col_id, relation):
self._engine._use_node(self._new_columns_node, relation) self._engine._use_node(self._new_columns_node, relation)
raise AttributeError("Table '%s' has no column '%s'" % (self.table_id, col_id)) raise AttributeError("Table '%s' has no column '%s'" % (self.table_id, col_id))
# Called when record_set.foo is accessed # Called when record_set.foo is accessed
def _get_col_subset(self, col_id, row_ids, relation): def _get_col_obj_subset(self, col_obj, row_ids, relation):
values = self._get_col_subset_raw(col_id, row_ids, relation) self._engine._use_node(col_obj.node, relation, row_ids)
values = [col_obj.get_cell_value(row_id) for row_id in row_ids]
# When all the values are the same type of Record (i.e. all references to the same table) # When all the values are the same type of Record (i.e. all references to the same table)
# combine them into a single RecordSet for that table instead of a list # combine them into a single RecordSet for that table instead of a list
# so that more attribute accesses can be chained, # so that more attribute accesses can be chained,
# e.g. record_set.foo.bar where `foo` is a Reference column. # e.g. record_set.foo.bar where `foo` is a Reference column.
value_types = list(set(map(type, values))) value_types = list(set(map(type, values)))
if len(value_types) == 1 and issubclass(value_types[0], records.Record): if len(value_types) == 1 and issubclass(value_types[0], BaseRecord):
return records.RecordSet( return values[0]._table.RecordSet(
values[0]._table,
# This is different from row_ids: these are the row IDs referenced by these Records, # This is different from row_ids: these are the row IDs referenced by these Records,
# whereas row_ids are where the values were being stored. # whereas row_ids are where the values were being stored.
[val._row_id for val in values], [val._row_id for val in values],
relation.compose(values[0]._source_relation), relation.compose(values[0]._source_relation),
) )
else: else:
return [records.adjust_record(relation, value) for value in values] return [adjust_record(relation, value) for value in values]
# Internal helper to optimise _get_col_value #----------------------------------------
# so that it doesn't make a singleton RecordSet just to immediately unpack it
def _get_col_subset_raw(self, col_id, row_ids, relation): def _update_record_classes(self, old_columns, new_columns):
col = self.all_columns[col_id] for col_id in old_columns:
# creates a dependency and brings formula columns up-to-date. if col_id not in new_columns:
self._engine._use_node(col.node, relation, row_ids) self._remove_field_from_record_classes(col_id)
return [col.get_cell_value(row_id) for row_id in row_ids]
for col_id, col_obj in six.iteritems(new_columns):
if col_obj != old_columns.get(col_id):
self._add_field_to_record_classes(col_obj)
def _add_field_to_record_classes(self, col_obj):
node = col_obj.node
use_node = self._engine._use_node
@property
def record_field(rec):
# This is equivalent to _get_col_obj_value(), but is extra-optimized with _get_col_obj_value()
# and adjust_record() inlined, since this is particularly hot code, called on every access of
# any data field in a formula.
use_node(node, rec._source_relation, (rec._row_id,))
value = col_obj.get_cell_value(rec._row_id)
if isinstance(value, (BaseRecord, BaseRecordSet)):
return value._clone_with_relation(rec._source_relation)
return value
@property
def recordset_field(recset):
return self._get_col_obj_subset(col_obj, recset._row_ids, recset._source_relation)
setattr(self.Record, col_obj.col_id, record_field)
setattr(self.RecordSet, col_obj.col_id, recordset_field)
def _remove_field_from_record_classes(self, col_id):
if hasattr(self.Record, col_id):
delattr(self.Record, col_id)
if hasattr(self.RecordSet, col_id):
delattr(self.RecordSet, col_id)

View File

@ -26,8 +26,9 @@ View = namedtuple('View', 'id sections')
Section = namedtuple('Section', 'id parentKey tableRef fields') Section = namedtuple('Section', 'id parentKey tableRef fields')
Field = namedtuple('Field', 'id colRef') Field = namedtuple('Field', 'id colRef')
unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp if six.PY2:
unittest.TestCase.assertRegex = unittest.TestCase.assertRegexpMatches unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
unittest.TestCase.assertRegex = unittest.TestCase.assertRegexpMatches
class EngineTestCase(unittest.TestCase): class EngineTestCase(unittest.TestCase):
""" """

View File

@ -44,7 +44,7 @@ import unittest
from main import run from main import run
from sandbox import Sandbox from sandbox import Sandbox
import six
def marshal_load_all(path): def marshal_load_all(path):
result = [] result = []
@ -65,6 +65,7 @@ class TestReplay(unittest.TestCase):
root = os.environ.get("RECORD_SANDBOX_BUFFERS_DIR") root = os.environ.get("RECORD_SANDBOX_BUFFERS_DIR")
if not root: if not root:
self.skipTest("RECORD_SANDBOX_BUFFERS_DIR not set") self.skipTest("RECORD_SANDBOX_BUFFERS_DIR not set")
for dirpath, dirnames, filenames in os.walk(root): for dirpath, dirnames, filenames in os.walk(root):
if "input" not in filenames: if "input" not in filenames:
continue continue
@ -76,9 +77,18 @@ class TestReplay(unittest.TestCase):
new_output_path = os.path.join(dirpath, "new_output") new_output_path = os.path.join(dirpath, "new_output")
with open(input_path, "rb") as external_input: with open(input_path, "rb") as external_input:
with open(new_output_path, "wb") as external_output: with open(new_output_path, "wb") as external_output:
if six.PY3:
import tracemalloc # pylint: disable=import-error
tracemalloc.reset_peak()
sandbox = Sandbox(external_input, external_output) sandbox = Sandbox(external_input, external_output)
run(sandbox) run(sandbox)
# Run with env PYTHONTRACEMALLOC=1 to trace and print peak memory (runs much slower).
if six.PY3 and tracemalloc.is_tracing():
mem_size, mem_peak = tracemalloc.get_traced_memory()
print("mem_size {}, mem_peak {}".format(mem_size, mem_peak))
original_output = marshal_load_all(output_path) original_output = marshal_load_all(output_path)
# _send_to_js does two layers of marshalling, # _send_to_js does two layers of marshalling,

View File

@ -11,14 +11,16 @@ Python's array.array. However, at least on the Python side, it means that we nee
data structure for values of the wrong type, and the memory savings aren't that great to be worth data structure for values of the wrong type, and the memory savings aren't that great to be worth
the extra complexity. the extra complexity.
""" """
# pylint: disable=unidiomatic-typecheck
import csv import csv
import datetime import datetime
import json import json
import math import math
import six import six
from six import integer_types
import objtypes import objtypes
from objtypes import AltText from objtypes import AltText, is_int_short
import moment import moment
import logger import logger
from records import Record, RecordSet from records import Record, RecordSet
@ -69,6 +71,8 @@ def ifError(value, value_if_error):
# formulas, but it's unclear how to make that work. # formulas, but it's unclear how to make that work.
return value_if_error if isinstance(value, AltText) else value return value_if_error if isinstance(value, AltText) else value
_numeric_types = (float,) + six.integer_types
_numeric_or_none = (float, NoneType) + six.integer_types
# Unique sentinel object to tell BaseColumnType constructor to use get_type_default(). # Unique sentinel object to tell BaseColumnType constructor to use get_type_default().
_use_type_default = object() _use_type_default = object()
@ -203,7 +207,7 @@ class Bool(BaseColumnType):
# recognize. Everything else will result in alttext. # recognize. Everything else will result in alttext.
if not value: if not value:
return False return False
if isinstance(value, (float, six.integer_types)): if isinstance(value, _numeric_types):
return True return True
if isinstance(value, AltText): if isinstance(value, AltText):
value = six.text_type(value) value = six.text_type(value)
@ -229,14 +233,13 @@ class Int(BaseColumnType):
return None return None
# Convert to float first, since python does not allow casting strings with decimals to int # Convert to float first, since python does not allow casting strings with decimals to int
ret = int(float(value)) ret = int(float(value))
if not objtypes.is_int_short(ret): if not is_int_short(ret):
raise OverflowError("Integer value too large") raise OverflowError("Integer value too large")
return ret return ret
@classmethod @classmethod
def is_right_type(cls, value): def is_right_type(cls, value):
return value is None or (isinstance(value, six.integer_types) and not isinstance(value, bool) and return value is None or (type(value) in integer_types and is_int_short(value))
objtypes.is_int_short(value))
class Numeric(BaseColumnType): class Numeric(BaseColumnType):
@ -252,7 +255,7 @@ class Numeric(BaseColumnType):
# TODO: Python distinguishes ints from floats, while JS only has floats. A value that can be # TODO: Python distinguishes ints from floats, while JS only has floats. A value that can be
# interpreted as an int will upon being entered have type 'float', but after database reload # interpreted as an int will upon being entered have type 'float', but after database reload
# will have type 'int'. # will have type 'int'.
return isinstance(value, (float, six.integer_types, NoneType)) and not isinstance(value, bool) return type(value) in _numeric_or_none
class Date(Numeric): class Date(Numeric):
@ -267,7 +270,7 @@ class Date(Numeric):
return moment.date_to_ts(value.date()) return moment.date_to_ts(value.date())
elif isinstance(value, datetime.date): elif isinstance(value, datetime.date):
return moment.date_to_ts(value) return moment.date_to_ts(value)
elif isinstance(value, (float, six.integer_types)): elif isinstance(value, _numeric_types):
return float(value) return float(value)
elif isinstance(value, six.string_types): elif isinstance(value, six.string_types):
# We also accept a date in ISO format (YYYY-MM-DD), the time portion is optional and ignored # We also accept a date in ISO format (YYYY-MM-DD), the time portion is optional and ignored
@ -277,7 +280,7 @@ class Date(Numeric):
@classmethod @classmethod
def is_right_type(cls, value): def is_right_type(cls, value):
return isinstance(value, (float, six.integer_types, NoneType)) return isinstance(value, _numeric_or_none)
class DateTime(Date): class DateTime(Date):
@ -299,7 +302,7 @@ class DateTime(Date):
return moment.dt_to_ts(value, self.timezone) return moment.dt_to_ts(value, self.timezone)
elif isinstance(value, datetime.date): elif isinstance(value, datetime.date):
return moment.date_to_ts(value, self.timezone) return moment.date_to_ts(value, self.timezone)
elif isinstance(value, (float, six.integer_types)): elif isinstance(value, _numeric_types):
return float(value) return float(value)
elif isinstance(value, six.string_types): elif isinstance(value, six.string_types):
# We also accept a datetime in ISO format (YYYY-MM-DD[T]HH:mm:ss) # We also accept a datetime in ISO format (YYYY-MM-DD[T]HH:mm:ss)
@ -365,7 +368,7 @@ class PositionNumber(BaseColumnType):
@classmethod @classmethod
def is_right_type(cls, value): def is_right_type(cls, value):
# Same as Numeric, but does not support None. # Same as Numeric, but does not support None.
return isinstance(value, (float, six.integer_types)) and not isinstance(value, bool) return type(value) in _numeric_types
class ManualSortPos(PositionNumber): class ManualSortPos(PositionNumber):
@ -387,14 +390,13 @@ class Id(BaseColumnType):
if not isinstance(value, (int, Record)): if not isinstance(value, (int, Record)):
raise TypeError("Cannot convert to Id type") raise TypeError("Cannot convert to Id type")
ret = int(value) ret = int(value)
if not objtypes.is_int_short(ret): if not is_int_short(ret):
raise OverflowError("Integer value too large") raise OverflowError("Integer value too large")
return ret return ret
@classmethod @classmethod
def is_right_type(cls, value): def is_right_type(cls, value):
return (isinstance(value, six.integer_types) and not isinstance(value, bool) and return (type(value) in integer_types and is_int_short(value))
objtypes.is_int_short(value))
class Reference(Id): class Reference(Id):