(core) Porting back AI formula backend

Summary: This is a backend part for the formula AI.

Test Plan: New tests

Reviewers: paulfitz

Reviewed By: paulfitz

Subscribers: cyprien

Differential Revision: https://phab.getgrist.com/D3786
This commit is contained in:
Jarosław Sadziński
2023-02-08 16:46:34 +01:00
parent ef0a55ced1
commit 6e3f0f2b35
10 changed files with 595 additions and 0 deletions

View File

@@ -0,0 +1,186 @@
import json
import textwrap
import six
from column import is_visible_column, BaseReferenceColumn
from objtypes import RaisedException
import records
def column_type(engine, table_id, col_id):
col_rec = engine.docmodel.get_column_rec(table_id, col_id)
typ = col_rec.type
parts = typ.split(":")
if parts[0] == "Ref":
return parts[1]
elif parts[0] == "RefList":
return "List[{}]".format(parts[1])
elif typ == "Choice":
return choices(col_rec)
elif typ == "ChoiceList":
return "Tuple[{}, ...]".format(choices(col_rec))
elif typ == "Any":
table = engine.tables[table_id]
col = table.get_column(col_id)
values = [col.raw_get(row_id) for row_id in table.row_ids]
return values_type(values)
else:
return dict(
Text="str",
Numeric="float",
Int="int",
Bool="bool",
Date="datetime.date",
DateTime="datetime.datetime",
Any="Any",
Attachments="Any",
)[parts[0]]
def choices(col_rec):
try:
widget_options = json.loads(col_rec.widgetOptions)
return "Literal{}".format(widget_options["choices"])
except (ValueError, KeyError):
return 'str'
def values_type(values):
types = set(type(v) for v in values) - {RaisedException}
optional = type(None) in types # pylint: disable=unidiomatic-typecheck
types.discard(type(None))
if types == {int, float}:
types = {float}
if len(types) != 1:
return "Any"
[typ] = types
val = next(v for v in values if isinstance(v, typ))
if isinstance(val, records.Record):
type_name = val._table.table_id
elif isinstance(val, records.RecordSet):
type_name = "List[{}]".format(val._table.table_id)
elif isinstance(val, list):
type_name = "List[{}]".format(values_type(val))
elif isinstance(val, set):
type_name = "Set[{}]".format(values_type(val))
elif isinstance(val, tuple):
type_name = "Tuple[{}, ...]".format(values_type(val))
elif isinstance(val, dict):
type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
else:
type_name = typ.__name__
if optional:
type_name = "Optional[{}]".format(type_name)
return type_name
def referenced_tables(engine, table_id):
result = set()
queue = [table_id]
while queue:
cur_table_id = queue.pop()
if cur_table_id in result:
continue
result.add(cur_table_id)
for col_id, col in visible_columns(engine, cur_table_id):
if isinstance(col, BaseReferenceColumn):
target_table_id = col._target_table.table_id
if not target_table_id.startswith("_"):
queue.append(target_table_id)
return result - {table_id}
def all_other_tables(engine, table_id):
result = set(t for t in engine.tables.keys() if not t.startswith('_grist'))
return result - {table_id} - {'GristDocTour'}
def visible_columns(engine, table_id):
return [
(col_id, col)
for col_id, col in engine.tables[table_id].all_columns.items()
if is_visible_column(col_id)
]
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
result = "@dataclass\nclass {}:\n".format(table_id)
if lookups:
# Build a lookupRecords and lookupOne method for each table, providing some arguments hints
# for the columns that are visible.
lookupRecords_args = []
lookupOne_args = []
for col_id, col in visible_columns(engine, table_id):
if col_id != exclude_col_id:
lookupOne_args.append(col_id + '=None')
lookupRecords_args.append('%s=%s' % (col_id, col_id))
lookupOne_args.append('sort_by=None')
lookupRecords_args.append('sort_by=sort_by')
lookupOne_args_line = ', '.join(lookupOne_args)
lookupRecords_args_line = ', '.join(lookupRecords_args)
result += " def __len__(self):\n"
result += " return len(%s.lookupRecords())\n" % table_id
result += " @staticmethod\n"
result += " def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id)
result += " # ...\n"
result += " @staticmethod\n"
result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
result += " '''\n"
result += " Filter for one result matching the keys provided.\n"
result += " To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n"
result += " '''\n"
result += " return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line)
result += "\n"
for col_id, col in visible_columns(engine, table_id):
if col_id != exclude_col_id:
result += " {}: {}\n".format(col_id, column_type(engine, table_id, col_id))
result += "\n"
return result
def get_formula_prompt(engine, table_id, col_id, description,
include_all_tables=True,
lookups=True):
result = ""
other_tables = (all_other_tables(engine, table_id)
if include_all_tables else referenced_tables(engine, table_id))
for other_table_id in sorted(other_tables):
result += class_schema(engine, other_table_id, lookups)
result += class_schema(engine, table_id, col_id, lookups)
return_type = column_type(engine, table_id, col_id)
result += " @property\n"
result += " # rec is alias for self\n"
result += " def {}(rec) -> {}:\n".format(col_id, return_type)
result += ' """\n'
result += '{}\n'.format(indent(description, " "))
result += ' """\n'
return result
def indent(text, prefix, predicate=None):
"""
Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility.
"""
if six.PY3:
return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member
if predicate is None:
def predicate(line):
return line.strip()
def prefixed_lines():
for line in text.splitlines(True):
yield (prefix + line if predicate(line) else line)
return ''.join(prefixed_lines())
def convert_completion(completion):
result = textwrap.dedent(completion)
return result

View File

@@ -16,6 +16,7 @@ import six
import actions
import engine
import formula_prompt
import migrations
import schema
import useractions
@@ -135,6 +136,14 @@ def run(sandbox):
def get_formula_error(table_id, col_id, row_id):
return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id))
@export
def get_formula_prompt(table_id, col_id, description):
return formula_prompt.get_formula_prompt(eng, table_id, col_id, description)
@export
def convert_formula_completion(completion):
return formula_prompt.convert_completion(completion)
export(parse_acl_formula)
export(eng.load_empty)
export(eng.load_done)

View File

@@ -0,0 +1,217 @@
import unittest
import six
import test_engine
import testutil
from formula_prompt import (
values_type, column_type, referenced_tables, get_formula_prompt,
)
from objtypes import RaisedException
from records import Record, RecordSet
class FakeTable(object):
def __init__(self):
pass
table_id = "Table1"
_identity_relation = None
@unittest.skipUnless(six.PY3, "Python 3 only")
class TestFormulaPrompt(test_engine.EngineTestCase):
def test_values_type(self):
self.assertEqual(values_type([1, 2, 3]), "int")
self.assertEqual(values_type([1.0, 2.0, 3.0]), "float")
self.assertEqual(values_type([1, 2, 3.0]), "float")
self.assertEqual(values_type([1, 2, None]), "Optional[int]")
self.assertEqual(values_type([1, 2, 3.0, None]), "Optional[float]")
self.assertEqual(values_type([1, RaisedException(None), 3]), "int")
self.assertEqual(values_type([1, RaisedException(None), None]), "Optional[int]")
self.assertEqual(values_type(["1", "2", "3"]), "str")
self.assertEqual(values_type([1, 2, "3"]), "Any")
self.assertEqual(values_type([1, 2, "3", None]), "Any")
self.assertEqual(values_type([
Record(FakeTable(), None),
Record(FakeTable(), None),
]), "Table1")
self.assertEqual(values_type([
Record(FakeTable(), None),
Record(FakeTable(), None),
None,
]), "Optional[Table1]")
self.assertEqual(values_type([
RecordSet(FakeTable(), None),
RecordSet(FakeTable(), None),
]), "List[Table1]")
self.assertEqual(values_type([
RecordSet(FakeTable(), None),
RecordSet(FakeTable(), None),
None,
]), "Optional[List[Table1]]")
self.assertEqual(values_type([[1, 2, 3]]), "List[int]")
self.assertEqual(values_type([[1, 2, 3], None]), "Optional[List[int]]")
self.assertEqual(values_type([[1, 2, None]]), "List[Optional[int]]")
self.assertEqual(values_type([[1, 2, None], None]), "Optional[List[Optional[int]]]")
self.assertEqual(values_type([[1, 2, "3"]]), "List[Any]")
self.assertEqual(values_type([{1, 2, 3}]), "Set[int]")
self.assertEqual(values_type([(1, 2, 3)]), "Tuple[int, ...]")
self.assertEqual(values_type([{1: ["2"]}]), "Dict[int, List[str]]")
def assert_column_type(self, col_id, expected_type):
self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type)
def assert_prompt(self, table_name, col_id, expected_prompt):
prompt = get_formula_prompt(self.engine, table_name, col_id, "description here",
include_all_tables=False, lookups=False)
# print(prompt)
self.assertEqual(prompt, expected_prompt)
def test_column_type(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table2", [
[1, "text", "Text", False, "", "", ""],
[2, "numeric", "Numeric", False, "", "", ""],
[3, "int", "Int", False, "", "", ""],
[4, "bool", "Bool", False, "", "", ""],
[5, "date", "Date", False, "", "", ""],
[6, "datetime", "DateTime", False, "", "", ""],
[7, "attachments", "Attachments", False, "", "", ""],
[8, "ref", "Ref:Table2", False, "", "", ""],
[9, "reflist", "RefList:Table2", False, "", "", ""],
[10, "choice", "Choice", False, "", "", '{"choices": ["a", "b", "c"]}'],
[11, "choicelist", "ChoiceList", False, "", "", '{"choices": ["x", "y", "z"]}'],
[12, "ref_formula", "Any", True, "$ref or None", "", ""],
[13, "numeric_formula", "Any", True, "1 / $numeric", "", ""],
[14, "new_formula", "Numeric", True, "'to be generated...'", "", ""],
]],
],
"DATA": {
"Table2": [
["id", "numeric", "ref"],
[1, 0, 0],
[2, 1, 1],
],
},
})
self.load_sample(sample)
self.assert_column_type("text", "str")
self.assert_column_type("numeric", "float")
self.assert_column_type("int", "int")
self.assert_column_type("bool", "bool")
self.assert_column_type("date", "datetime.date")
self.assert_column_type("datetime", "datetime.datetime")
self.assert_column_type("attachments", "Any")
self.assert_column_type("ref", "Table2")
self.assert_column_type("reflist", "List[Table2]")
self.assert_column_type("choice", "Literal['a', 'b', 'c']")
self.assert_column_type("choicelist", "Tuple[Literal['x', 'y', 'z'], ...]")
self.assert_column_type("ref_formula", "Optional[Table2]")
self.assert_column_type("numeric_formula", "float")
self.assertEqual(referenced_tables(self.engine, "Table2"), set())
self.assert_prompt("Table2", "new_formula",
'''\
@dataclass
class Table2:
text: str
numeric: float
int: int
bool: bool
date: datetime.date
datetime: datetime.datetime
attachments: Any
ref: Table2
reflist: List[Table2]
choice: Literal['a', 'b', 'c']
choicelist: Tuple[Literal['x', 'y', 'z'], ...]
ref_formula: Optional[Table2]
numeric_formula: float
@property
# rec is alias for self
def new_formula(rec) -> float:
"""
description here
"""
''')
def test_get_formula_prompt(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[1, "text", "Text", False, "", "", ""],
]],
[2, "Table2", [
[2, "ref", "Ref:Table1", False, "", "", ""],
]],
[3, "Table3", [
[3, "reflist", "RefList:Table2", False, "", "", ""],
]],
],
"DATA": {},
})
self.load_sample(sample)
self.assertEqual(referenced_tables(self.engine, "Table3"), {"Table1", "Table2"})
self.assertEqual(referenced_tables(self.engine, "Table2"), {"Table1"})
self.assertEqual(referenced_tables(self.engine, "Table1"), set())
self.assert_prompt("Table1", "text", '''\
@dataclass
class Table1:
@property
# rec is alias for self
def text(rec) -> str:
"""
description here
"""
''')
self.assert_prompt("Table2", "ref", '''\
@dataclass
class Table1:
text: str
@dataclass
class Table2:
@property
# rec is alias for self
def ref(rec) -> Table1:
"""
description here
"""
''')
self.assert_prompt("Table3", "reflist", '''\
@dataclass
class Table1:
text: str
@dataclass
class Table2:
ref: Table1
@dataclass
class Table3:
@property
# rec is alias for self
def reflist(rec) -> List[Table2]:
"""
description here
"""
''')