(core) Refactor Table.Record[Set] classes

Summary:
Dealing with some things that bothered and sometimes confused me:

Make Table.Record[Set] provide the table argument automatically
Remove the classes from UserTable because they're not used anywhere and the Table/UserTable distinction is already confusing. They're not documented for users and they don't show up in autocomplete.
Remove RecordSet.Record because it was confusing me where that attribute was being set, but also this means .Record will work properly for users with columns named 'Record'.

Test Plan: existing tests

Reviewers: dsagal

Reviewed By: dsagal

Differential Revision: https://phab.getgrist.com/D2913
This commit is contained in:
Alex Hall 2021-07-16 20:15:04 +02:00
parent 5b2666a88a
commit a9d5b4d5af
10 changed files with 48 additions and 26 deletions

View File

@ -42,6 +42,8 @@ class AutocompleteContext(object):
# Add in the important UserTable methods, with custom friendlier descriptions.
self._functions['.lookupOne'] = Completion('.lookupOne', '(colName=<value>, ...)', True)
self._functions['.lookupRecords'] = Completion('.lookupRecords', '(colName=<value>, ...)', True)
self._functions['.Record'] = Completion('.Record', '', True)
self._functions['.RecordSet'] = Completion('.RecordSet', '', True)
# Remember the original name for each lowercase one.
self._lowercase = {}

View File

@ -421,7 +421,7 @@ class ReferenceColumn(BaseReferenceColumn):
return typed_value
# For a Reference, values must either refer to an existing record, or be 0. In all tables,
# the 0 index will contain the all-defaults record.
return self._target_table.Record(self._target_table, typed_value, self._relation)
return self._target_table.Record(typed_value, self._relation)
def _update_references(self, row_id, old_value, new_value):
if old_value:
@ -452,7 +452,7 @@ class ReferenceListColumn(BaseReferenceColumn):
# If we refer to an invalid table, return integers rather than fail completely.
if not self._target_table:
return typed_value
return self._target_table.RecordSet(self._target_table, typed_value, self._relation)
return self._target_table.RecordSet(typed_value, self._relation)
def _raw_get_without(self, row_id, target_row_ids):
"""

View File

@ -253,7 +253,7 @@ class DocModel(object):
table_obj = record_set_or_table.table
row_ids = self._engine.user_actions.BulkAddRecord(table_obj.table_id, [None] * count, values)
return [table_obj.Record(table_obj, r, None) for r in row_ids]
return [table_obj.Record(r, None) for r in row_ids]
def insert(self, record_set, position, **col_values):
"""

View File

@ -860,7 +860,7 @@ class Engine(object):
usercode_reference = self.gencode.usercode
checkpoint = self._get_undo_checkpoint()
record = table.Record(table, row_id, table._identity_relation)
record = table.Record(row_id, table._identity_relation)
try:
if cycle:
raise depend.CircularRefError("Circular Reference")

View File

@ -61,6 +61,12 @@ class Record(object):
table - Table object, in which this record lives.
row_id - The ID of the record within table.
relation - Relation object for how this record was obtained; used in dependency tracking.
In general you shouldn't call this constructor directly, but rather:
table.Record(row_id, relation)
which provides the table argument automatically.
"""
self._table = table
self._row_id = row_id
@ -103,7 +109,7 @@ class Record(object):
return "%s[%s]" % (self._table.table_id, self._row_id)
def _clone_with_relation(self, src_relation):
return self.__class__(self._table, self._row_id,
return self._table.Record(self._row_id,
relation=src_relation.compose(self._source_relation))
@ -157,7 +163,7 @@ class RecordSet(object):
def __iter__(self):
for row_id in self._row_ids:
yield self.Record(self._table, row_id, self._source_relation)
yield self._table.Record(row_id, self._source_relation)
def __contains__(self, item):
"""item may be a Record or its row_id."""
@ -169,7 +175,7 @@ class RecordSet(object):
def get_one(self):
row_id = min(self._row_ids) if self._row_ids else 0
return self.Record(self._table, row_id, self._source_relation)
return self._table.Record(row_id, self._source_relation)
def _get_col(self, col_id):
return self._table._get_col_subset(col_id, self._row_ids, self._source_relation)
@ -180,7 +186,7 @@ class RecordSet(object):
return self._table._attribute_error(name, self._source_relation)
def _clone_with_relation(self, src_relation):
return self.__class__(self._table, self._row_ids,
return self._table.RecordSet(self._row_ids,
relation=src_relation.compose(self._source_relation),
group_by=self._group_by,
sort_by=self._sort_by)

View File

@ -61,14 +61,19 @@ class UserTable(object):
self.Model = model_class
column_ids = {col for col in model_class.__dict__ if not col.startswith("_")}
column_ids.add('id')
self.Record = type('Record', (records.Record,), {})
self.RecordSet = type('RecordSet', (records.RecordSet,), {})
self.RecordSet.Record = self.Record
self.table = None
def _set_table_impl(self, table_impl):
self.table = table_impl
@property
def Record(self):
return self.table.Record
@property
def RecordSet(self):
return self.table.RecordSet
# Note these methods are named camelCase since they are a public interface exposed to formulas,
# and we decided camelCase was a more user-friendly choice for user-facing functions.
def lookupRecords(self, **field_value_pairs):
@ -200,6 +205,18 @@ class Table(object):
# For a summary table, the name of the special helper column auto-added to the source table.
self._summary_helper_col_id = None
# Add Record and RecordSet subclasses which fill in this table as the first argument
class Record(records.Record):
def __init__(inner_self, *args, **kwargs): # pylint: disable=no-self-argument
super(Record, inner_self).__init__(self, *args, **kwargs)
class RecordSet(records.RecordSet):
def __init__(inner_self, *args, **kwargs): # pylint: disable=no-self-argument
super(RecordSet, inner_self).__init__(self, *args, **kwargs)
self.Record = Record
self.RecordSet = RecordSet
def _rebuild_model(self, user_table):
"""
Sets class-wide properties from a new Model class for the table (inner class within the table
@ -207,8 +224,6 @@ class Table(object):
"""
self.user_table = user_table
self.Model = user_table.Model
self.Record = user_table.Record
self.RecordSet = user_table.RecordSet
new_cols = collections.OrderedDict()
new_cols['id'] = self._id_column
@ -337,7 +352,7 @@ class Table(object):
row_ids = sorted(row_id_set, key=lambda r: self._get_col_value(sort_by, r, rel))
else:
row_ids = sorted(row_id_set)
return self.RecordSet(self, row_ids, rel, group_by=kwargs, sort_by=sort_by)
return self.RecordSet(row_ids, rel, group_by=kwargs, sort_by=sort_by)
def lookup_one_record(self, **kwargs):
return self.lookup_records(**kwargs).get_one()
@ -416,7 +431,7 @@ class Table(object):
# user-actions caused by formula side-effects (e.g. as trigged by lookupOrAddDerived())
if row_id not in self.row_ids:
raise KeyError("'get_record' found no matching record")
return self.Record(self, row_id, None)
return self.Record(row_id, None)
def filter_records(self, **kwargs):
"""
@ -427,7 +442,7 @@ class Table(object):
# See note in get_record() about using this call from formulas.
for row_id in self.filter(**kwargs):
yield self.Record(self, row_id, None)
yield self.Record(row_id, None)
# TODO: document everything here.

View File

@ -213,6 +213,8 @@ class TestCompletion(test_engine.EngineTestCase):
def test_suggest_lookup_methods(self):
# Should suggest lookup formulas for tables.
self.assertEqual(self.engine.autocomplete("Address.", "Students", "firstName", self.user), [
('Address.Record', '', True),
('Address.RecordSet', '', True),
'Address.all',
('Address.lookupOne', '(colName=<value>, ...)', True),
('Address.lookupRecords', '(colName=<value>, ...)', True),

View File

@ -82,9 +82,6 @@ class TestGenCode(unittest.TestCase):
gcode.make_module(self.schema)
module = gcode.usercode
self.assertTrue(isinstance(module.Students, table.UserTable))
self.assertTrue(issubclass(module.Students.Record, records.Record))
self.assertTrue(issubclass(module.Students.RecordSet, records.RecordSet))
self.assertIs(module.Students.RecordSet.Record, module.Students.Record)
def test_pick_col_ident(self):
self.assertEqual(identifiers.pick_col_ident("asdf"), "asdf")

View File

@ -633,7 +633,7 @@ return ",".join(str(r.id) for r in Students.lookupRecords(firstName=fn, lastName
# A helper for comparing Record objects below.
schools_table = self.engine.tables['Schools']
def SchoolsRec(row_id):
return schools_table.Record(schools_table, row_id, None)
return schools_table.Record(row_id, None)
# We'll play with schools "Columbia" and "Eureka", which are rows 1,3,5 in the Students table.
self.assertTableData("Students", cols="subset", rows="subset", data=[
@ -680,7 +680,7 @@ return ",".join(str(r.id) for r in Students.lookupRecords(firstName=fn, lastName
# A helper for comparing Record objects below.
schools_table = self.engine.tables['Schools']
def SchoolsRec(row_id):
return schools_table.Record(schools_table, row_id, None)
return schools_table.Record(row_id, None)
# We'll play with schools "Columbia" and "Eureka", which are rows 1,3,5 in the Students table.
self.assertTableData("Students", cols="subset", rows="all", data=[

View File

@ -380,7 +380,7 @@ class TestRenames(test_engine.EngineTestCase):
from datetime import date
# A helper for comparing Record objects below.
people_table = self.engine.tables['People']
people_rec = lambda row_id: people_table.Record(people_table, row_id, None)
people_rec = lambda row_id: people_table.Record(row_id, None)
# Verify the data and calculations are correct.
self.assertTableData("Address", cols="all", data=[