(core) Refactor Table.Record[Set] classes

Summary:
Dealing with some things that bothered and sometimes confused me:

Make Table.Record[Set] provide the table argument automatically
Remove the classes from UserTable because they're not used anywhere and the Table/UserTable distinction is already confusing. They're not documented for users and they don't show up in autocomplete.
Remove RecordSet.Record because it was confusing me where that attribute was being set, but also this means .Record will work properly for users with columns named 'Record'.

Test Plan: existing tests

Reviewers: dsagal

Reviewed By: dsagal

Differential Revision: https://phab.getgrist.com/D2913
This commit is contained in:
Alex Hall 2021-07-16 20:15:04 +02:00
parent 5b2666a88a
commit a9d5b4d5af
10 changed files with 48 additions and 26 deletions

View File

@ -42,6 +42,8 @@ class AutocompleteContext(object):
# Add in the important UserTable methods, with custom friendlier descriptions. # Add in the important UserTable methods, with custom friendlier descriptions.
self._functions['.lookupOne'] = Completion('.lookupOne', '(colName=<value>, ...)', True) self._functions['.lookupOne'] = Completion('.lookupOne', '(colName=<value>, ...)', True)
self._functions['.lookupRecords'] = Completion('.lookupRecords', '(colName=<value>, ...)', True) self._functions['.lookupRecords'] = Completion('.lookupRecords', '(colName=<value>, ...)', True)
self._functions['.Record'] = Completion('.Record', '', True)
self._functions['.RecordSet'] = Completion('.RecordSet', '', True)
# Remember the original name for each lowercase one. # Remember the original name for each lowercase one.
self._lowercase = {} self._lowercase = {}

View File

@ -421,7 +421,7 @@ class ReferenceColumn(BaseReferenceColumn):
return typed_value return typed_value
# For a Reference, values must either refer to an existing record, or be 0. In all tables, # For a Reference, values must either refer to an existing record, or be 0. In all tables,
# the 0 index will contain the all-defaults record. # the 0 index will contain the all-defaults record.
return self._target_table.Record(self._target_table, typed_value, self._relation) return self._target_table.Record(typed_value, self._relation)
def _update_references(self, row_id, old_value, new_value): def _update_references(self, row_id, old_value, new_value):
if old_value: if old_value:
@ -452,7 +452,7 @@ class ReferenceListColumn(BaseReferenceColumn):
# If we refer to an invalid table, return integers rather than fail completely. # If we refer to an invalid table, return integers rather than fail completely.
if not self._target_table: if not self._target_table:
return typed_value return typed_value
return self._target_table.RecordSet(self._target_table, typed_value, self._relation) return self._target_table.RecordSet(typed_value, self._relation)
def _raw_get_without(self, row_id, target_row_ids): def _raw_get_without(self, row_id, target_row_ids):
""" """

View File

@ -253,7 +253,7 @@ class DocModel(object):
table_obj = record_set_or_table.table table_obj = record_set_or_table.table
row_ids = self._engine.user_actions.BulkAddRecord(table_obj.table_id, [None] * count, values) row_ids = self._engine.user_actions.BulkAddRecord(table_obj.table_id, [None] * count, values)
return [table_obj.Record(table_obj, r, None) for r in row_ids] return [table_obj.Record(r, None) for r in row_ids]
def insert(self, record_set, position, **col_values): def insert(self, record_set, position, **col_values):
""" """

View File

@ -860,7 +860,7 @@ class Engine(object):
usercode_reference = self.gencode.usercode usercode_reference = self.gencode.usercode
checkpoint = self._get_undo_checkpoint() checkpoint = self._get_undo_checkpoint()
record = table.Record(table, row_id, table._identity_relation) record = table.Record(row_id, table._identity_relation)
try: try:
if cycle: if cycle:
raise depend.CircularRefError("Circular Reference") raise depend.CircularRefError("Circular Reference")

View File

@ -61,6 +61,12 @@ class Record(object):
table - Table object, in which this record lives. table - Table object, in which this record lives.
row_id - The ID of the record within table. row_id - The ID of the record within table.
relation - Relation object for how this record was obtained; used in dependency tracking. relation - Relation object for how this record was obtained; used in dependency tracking.
In general you shouldn't call this constructor directly, but rather:
table.Record(row_id, relation)
which provides the table argument automatically.
""" """
self._table = table self._table = table
self._row_id = row_id self._row_id = row_id
@ -103,8 +109,8 @@ class Record(object):
return "%s[%s]" % (self._table.table_id, self._row_id) return "%s[%s]" % (self._table.table_id, self._row_id)
def _clone_with_relation(self, src_relation): def _clone_with_relation(self, src_relation):
return self.__class__(self._table, self._row_id, return self._table.Record(self._row_id,
relation=src_relation.compose(self._source_relation)) relation=src_relation.compose(self._source_relation))
class RecordSet(object): class RecordSet(object):
@ -157,7 +163,7 @@ class RecordSet(object):
def __iter__(self): def __iter__(self):
for row_id in self._row_ids: for row_id in self._row_ids:
yield self.Record(self._table, row_id, self._source_relation) yield self._table.Record(row_id, self._source_relation)
def __contains__(self, item): def __contains__(self, item):
"""item may be a Record or its row_id.""" """item may be a Record or its row_id."""
@ -169,7 +175,7 @@ class RecordSet(object):
def get_one(self): def get_one(self):
row_id = min(self._row_ids) if self._row_ids else 0 row_id = min(self._row_ids) if self._row_ids else 0
return self.Record(self._table, row_id, self._source_relation) return self._table.Record(row_id, self._source_relation)
def _get_col(self, col_id): def _get_col(self, col_id):
return self._table._get_col_subset(col_id, self._row_ids, self._source_relation) return self._table._get_col_subset(col_id, self._row_ids, self._source_relation)
@ -180,10 +186,10 @@ class RecordSet(object):
return self._table._attribute_error(name, self._source_relation) return self._table._attribute_error(name, self._source_relation)
def _clone_with_relation(self, src_relation): def _clone_with_relation(self, src_relation):
return self.__class__(self._table, self._row_ids, return self._table.RecordSet(self._row_ids,
relation=src_relation.compose(self._source_relation), relation=src_relation.compose(self._source_relation),
group_by=self._group_by, group_by=self._group_by,
sort_by=self._sort_by) sort_by=self._sort_by)
class ColumnView(object): class ColumnView(object):

View File

@ -61,14 +61,19 @@ class UserTable(object):
self.Model = model_class self.Model = model_class
column_ids = {col for col in model_class.__dict__ if not col.startswith("_")} column_ids = {col for col in model_class.__dict__ if not col.startswith("_")}
column_ids.add('id') column_ids.add('id')
self.Record = type('Record', (records.Record,), {})
self.RecordSet = type('RecordSet', (records.RecordSet,), {})
self.RecordSet.Record = self.Record
self.table = None self.table = None
def _set_table_impl(self, table_impl): def _set_table_impl(self, table_impl):
self.table = table_impl self.table = table_impl
@property
def Record(self):
return self.table.Record
@property
def RecordSet(self):
return self.table.RecordSet
# Note these methods are named camelCase since they are a public interface exposed to formulas, # Note these methods are named camelCase since they are a public interface exposed to formulas,
# and we decided camelCase was a more user-friendly choice for user-facing functions. # and we decided camelCase was a more user-friendly choice for user-facing functions.
def lookupRecords(self, **field_value_pairs): def lookupRecords(self, **field_value_pairs):
@ -200,6 +205,18 @@ class Table(object):
# For a summary table, the name of the special helper column auto-added to the source table. # For a summary table, the name of the special helper column auto-added to the source table.
self._summary_helper_col_id = None self._summary_helper_col_id = None
# Add Record and RecordSet subclasses which fill in this table as the first argument
class Record(records.Record):
def __init__(inner_self, *args, **kwargs): # pylint: disable=no-self-argument
super(Record, inner_self).__init__(self, *args, **kwargs)
class RecordSet(records.RecordSet):
def __init__(inner_self, *args, **kwargs): # pylint: disable=no-self-argument
super(RecordSet, inner_self).__init__(self, *args, **kwargs)
self.Record = Record
self.RecordSet = RecordSet
def _rebuild_model(self, user_table): def _rebuild_model(self, user_table):
""" """
Sets class-wide properties from a new Model class for the table (inner class within the table Sets class-wide properties from a new Model class for the table (inner class within the table
@ -207,8 +224,6 @@ class Table(object):
""" """
self.user_table = user_table self.user_table = user_table
self.Model = user_table.Model self.Model = user_table.Model
self.Record = user_table.Record
self.RecordSet = user_table.RecordSet
new_cols = collections.OrderedDict() new_cols = collections.OrderedDict()
new_cols['id'] = self._id_column new_cols['id'] = self._id_column
@ -337,7 +352,7 @@ class Table(object):
row_ids = sorted(row_id_set, key=lambda r: self._get_col_value(sort_by, r, rel)) row_ids = sorted(row_id_set, key=lambda r: self._get_col_value(sort_by, r, rel))
else: else:
row_ids = sorted(row_id_set) row_ids = sorted(row_id_set)
return self.RecordSet(self, row_ids, rel, group_by=kwargs, sort_by=sort_by) return self.RecordSet(row_ids, rel, group_by=kwargs, sort_by=sort_by)
def lookup_one_record(self, **kwargs): def lookup_one_record(self, **kwargs):
return self.lookup_records(**kwargs).get_one() return self.lookup_records(**kwargs).get_one()
@ -416,7 +431,7 @@ class Table(object):
# user-actions caused by formula side-effects (e.g. as trigged by lookupOrAddDerived()) # user-actions caused by formula side-effects (e.g. as trigged by lookupOrAddDerived())
if row_id not in self.row_ids: if row_id not in self.row_ids:
raise KeyError("'get_record' found no matching record") raise KeyError("'get_record' found no matching record")
return self.Record(self, row_id, None) return self.Record(row_id, None)
def filter_records(self, **kwargs): def filter_records(self, **kwargs):
""" """
@ -427,7 +442,7 @@ class Table(object):
# See note in get_record() about using this call from formulas. # See note in get_record() about using this call from formulas.
for row_id in self.filter(**kwargs): for row_id in self.filter(**kwargs):
yield self.Record(self, row_id, None) yield self.Record(row_id, None)
# TODO: document everything here. # TODO: document everything here.

View File

@ -213,6 +213,8 @@ class TestCompletion(test_engine.EngineTestCase):
def test_suggest_lookup_methods(self): def test_suggest_lookup_methods(self):
# Should suggest lookup formulas for tables. # Should suggest lookup formulas for tables.
self.assertEqual(self.engine.autocomplete("Address.", "Students", "firstName", self.user), [ self.assertEqual(self.engine.autocomplete("Address.", "Students", "firstName", self.user), [
('Address.Record', '', True),
('Address.RecordSet', '', True),
'Address.all', 'Address.all',
('Address.lookupOne', '(colName=<value>, ...)', True), ('Address.lookupOne', '(colName=<value>, ...)', True),
('Address.lookupRecords', '(colName=<value>, ...)', True), ('Address.lookupRecords', '(colName=<value>, ...)', True),

View File

@ -82,9 +82,6 @@ class TestGenCode(unittest.TestCase):
gcode.make_module(self.schema) gcode.make_module(self.schema)
module = gcode.usercode module = gcode.usercode
self.assertTrue(isinstance(module.Students, table.UserTable)) self.assertTrue(isinstance(module.Students, table.UserTable))
self.assertTrue(issubclass(module.Students.Record, records.Record))
self.assertTrue(issubclass(module.Students.RecordSet, records.RecordSet))
self.assertIs(module.Students.RecordSet.Record, module.Students.Record)
def test_pick_col_ident(self): def test_pick_col_ident(self):
self.assertEqual(identifiers.pick_col_ident("asdf"), "asdf") self.assertEqual(identifiers.pick_col_ident("asdf"), "asdf")

View File

@ -633,7 +633,7 @@ return ",".join(str(r.id) for r in Students.lookupRecords(firstName=fn, lastName
# A helper for comparing Record objects below. # A helper for comparing Record objects below.
schools_table = self.engine.tables['Schools'] schools_table = self.engine.tables['Schools']
def SchoolsRec(row_id): def SchoolsRec(row_id):
return schools_table.Record(schools_table, row_id, None) return schools_table.Record(row_id, None)
# We'll play with schools "Columbia" and "Eureka", which are rows 1,3,5 in the Students table. # We'll play with schools "Columbia" and "Eureka", which are rows 1,3,5 in the Students table.
self.assertTableData("Students", cols="subset", rows="subset", data=[ self.assertTableData("Students", cols="subset", rows="subset", data=[
@ -680,7 +680,7 @@ return ",".join(str(r.id) for r in Students.lookupRecords(firstName=fn, lastName
# A helper for comparing Record objects below. # A helper for comparing Record objects below.
schools_table = self.engine.tables['Schools'] schools_table = self.engine.tables['Schools']
def SchoolsRec(row_id): def SchoolsRec(row_id):
return schools_table.Record(schools_table, row_id, None) return schools_table.Record(row_id, None)
# We'll play with schools "Columbia" and "Eureka", which are rows 1,3,5 in the Students table. # We'll play with schools "Columbia" and "Eureka", which are rows 1,3,5 in the Students table.
self.assertTableData("Students", cols="subset", rows="all", data=[ self.assertTableData("Students", cols="subset", rows="all", data=[

View File

@ -380,7 +380,7 @@ class TestRenames(test_engine.EngineTestCase):
from datetime import date from datetime import date
# A helper for comparing Record objects below. # A helper for comparing Record objects below.
people_table = self.engine.tables['People'] people_table = self.engine.tables['People']
people_rec = lambda row_id: people_table.Record(people_table, row_id, None) people_rec = lambda row_id: people_table.Record(row_id, None)
# Verify the data and calculations are correct. # Verify the data and calculations are correct.
self.assertTableData("Address", cols="all", data=[ self.assertTableData("Address", cols="all", data=[