mirror of
https://github.com/gristlabs/grist-core.git
synced 2026-03-02 04:09:24 +00:00
(core) Porting back AI formula backend
Summary: This is a backend part for the formula AI. Test Plan: New tests Reviewers: paulfitz Reviewed By: paulfitz Subscribers: cyprien Differential Revision: https://phab.getgrist.com/D3786
This commit is contained in:
186
sandbox/grist/formula_prompt.py
Normal file
186
sandbox/grist/formula_prompt.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
import six
|
||||
|
||||
from column import is_visible_column, BaseReferenceColumn
|
||||
from objtypes import RaisedException
|
||||
import records
|
||||
|
||||
|
||||
def column_type(engine, table_id, col_id):
|
||||
col_rec = engine.docmodel.get_column_rec(table_id, col_id)
|
||||
typ = col_rec.type
|
||||
parts = typ.split(":")
|
||||
if parts[0] == "Ref":
|
||||
return parts[1]
|
||||
elif parts[0] == "RefList":
|
||||
return "List[{}]".format(parts[1])
|
||||
elif typ == "Choice":
|
||||
return choices(col_rec)
|
||||
elif typ == "ChoiceList":
|
||||
return "Tuple[{}, ...]".format(choices(col_rec))
|
||||
elif typ == "Any":
|
||||
table = engine.tables[table_id]
|
||||
col = table.get_column(col_id)
|
||||
values = [col.raw_get(row_id) for row_id in table.row_ids]
|
||||
return values_type(values)
|
||||
else:
|
||||
return dict(
|
||||
Text="str",
|
||||
Numeric="float",
|
||||
Int="int",
|
||||
Bool="bool",
|
||||
Date="datetime.date",
|
||||
DateTime="datetime.datetime",
|
||||
Any="Any",
|
||||
Attachments="Any",
|
||||
)[parts[0]]
|
||||
|
||||
|
||||
def choices(col_rec):
|
||||
try:
|
||||
widget_options = json.loads(col_rec.widgetOptions)
|
||||
return "Literal{}".format(widget_options["choices"])
|
||||
except (ValueError, KeyError):
|
||||
return 'str'
|
||||
|
||||
|
||||
def values_type(values):
|
||||
types = set(type(v) for v in values) - {RaisedException}
|
||||
optional = type(None) in types # pylint: disable=unidiomatic-typecheck
|
||||
types.discard(type(None))
|
||||
|
||||
if types == {int, float}:
|
||||
types = {float}
|
||||
|
||||
if len(types) != 1:
|
||||
return "Any"
|
||||
|
||||
[typ] = types
|
||||
val = next(v for v in values if isinstance(v, typ))
|
||||
|
||||
if isinstance(val, records.Record):
|
||||
type_name = val._table.table_id
|
||||
elif isinstance(val, records.RecordSet):
|
||||
type_name = "List[{}]".format(val._table.table_id)
|
||||
elif isinstance(val, list):
|
||||
type_name = "List[{}]".format(values_type(val))
|
||||
elif isinstance(val, set):
|
||||
type_name = "Set[{}]".format(values_type(val))
|
||||
elif isinstance(val, tuple):
|
||||
type_name = "Tuple[{}, ...]".format(values_type(val))
|
||||
elif isinstance(val, dict):
|
||||
type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
|
||||
else:
|
||||
type_name = typ.__name__
|
||||
|
||||
if optional:
|
||||
type_name = "Optional[{}]".format(type_name)
|
||||
|
||||
return type_name
|
||||
|
||||
|
||||
def referenced_tables(engine, table_id):
|
||||
result = set()
|
||||
queue = [table_id]
|
||||
while queue:
|
||||
cur_table_id = queue.pop()
|
||||
if cur_table_id in result:
|
||||
continue
|
||||
result.add(cur_table_id)
|
||||
for col_id, col in visible_columns(engine, cur_table_id):
|
||||
if isinstance(col, BaseReferenceColumn):
|
||||
target_table_id = col._target_table.table_id
|
||||
if not target_table_id.startswith("_"):
|
||||
queue.append(target_table_id)
|
||||
return result - {table_id}
|
||||
|
||||
def all_other_tables(engine, table_id):
|
||||
result = set(t for t in engine.tables.keys() if not t.startswith('_grist'))
|
||||
return result - {table_id} - {'GristDocTour'}
|
||||
|
||||
def visible_columns(engine, table_id):
|
||||
return [
|
||||
(col_id, col)
|
||||
for col_id, col in engine.tables[table_id].all_columns.items()
|
||||
if is_visible_column(col_id)
|
||||
]
|
||||
|
||||
|
||||
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
|
||||
result = "@dataclass\nclass {}:\n".format(table_id)
|
||||
|
||||
if lookups:
|
||||
|
||||
# Build a lookupRecords and lookupOne method for each table, providing some arguments hints
|
||||
# for the columns that are visible.
|
||||
lookupRecords_args = []
|
||||
lookupOne_args = []
|
||||
for col_id, col in visible_columns(engine, table_id):
|
||||
if col_id != exclude_col_id:
|
||||
lookupOne_args.append(col_id + '=None')
|
||||
lookupRecords_args.append('%s=%s' % (col_id, col_id))
|
||||
lookupOne_args.append('sort_by=None')
|
||||
lookupRecords_args.append('sort_by=sort_by')
|
||||
lookupOne_args_line = ', '.join(lookupOne_args)
|
||||
lookupRecords_args_line = ', '.join(lookupRecords_args)
|
||||
|
||||
result += " def __len__(self):\n"
|
||||
result += " return len(%s.lookupRecords())\n" % table_id
|
||||
result += " @staticmethod\n"
|
||||
result += " def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id)
|
||||
result += " # ...\n"
|
||||
result += " @staticmethod\n"
|
||||
result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
|
||||
result += " '''\n"
|
||||
result += " Filter for one result matching the keys provided.\n"
|
||||
result += " To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n"
|
||||
result += " '''\n"
|
||||
result += " return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line)
|
||||
result += "\n"
|
||||
|
||||
for col_id, col in visible_columns(engine, table_id):
|
||||
if col_id != exclude_col_id:
|
||||
result += " {}: {}\n".format(col_id, column_type(engine, table_id, col_id))
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
|
||||
def get_formula_prompt(engine, table_id, col_id, description,
|
||||
include_all_tables=True,
|
||||
lookups=True):
|
||||
result = ""
|
||||
other_tables = (all_other_tables(engine, table_id)
|
||||
if include_all_tables else referenced_tables(engine, table_id))
|
||||
for other_table_id in sorted(other_tables):
|
||||
result += class_schema(engine, other_table_id, lookups)
|
||||
|
||||
result += class_schema(engine, table_id, col_id, lookups)
|
||||
|
||||
return_type = column_type(engine, table_id, col_id)
|
||||
result += " @property\n"
|
||||
result += " # rec is alias for self\n"
|
||||
result += " def {}(rec) -> {}:\n".format(col_id, return_type)
|
||||
result += ' """\n'
|
||||
result += '{}\n'.format(indent(description, " "))
|
||||
result += ' """\n'
|
||||
return result
|
||||
|
||||
def indent(text, prefix, predicate=None):
|
||||
"""
|
||||
Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility.
|
||||
"""
|
||||
if six.PY3:
|
||||
return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member
|
||||
if predicate is None:
|
||||
def predicate(line):
|
||||
return line.strip()
|
||||
def prefixed_lines():
|
||||
for line in text.splitlines(True):
|
||||
yield (prefix + line if predicate(line) else line)
|
||||
return ''.join(prefixed_lines())
|
||||
|
||||
def convert_completion(completion):
|
||||
result = textwrap.dedent(completion)
|
||||
return result
|
||||
@@ -16,6 +16,7 @@ import six
|
||||
|
||||
import actions
|
||||
import engine
|
||||
import formula_prompt
|
||||
import migrations
|
||||
import schema
|
||||
import useractions
|
||||
@@ -135,6 +136,14 @@ def run(sandbox):
|
||||
def get_formula_error(table_id, col_id, row_id):
|
||||
return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id))
|
||||
|
||||
@export
|
||||
def get_formula_prompt(table_id, col_id, description):
|
||||
return formula_prompt.get_formula_prompt(eng, table_id, col_id, description)
|
||||
|
||||
@export
|
||||
def convert_formula_completion(completion):
|
||||
return formula_prompt.convert_completion(completion)
|
||||
|
||||
export(parse_acl_formula)
|
||||
export(eng.load_empty)
|
||||
export(eng.load_done)
|
||||
|
||||
217
sandbox/grist/test_formula_prompt.py
Normal file
217
sandbox/grist/test_formula_prompt.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import unittest
|
||||
import six
|
||||
|
||||
import test_engine
|
||||
import testutil
|
||||
|
||||
from formula_prompt import (
|
||||
values_type, column_type, referenced_tables, get_formula_prompt,
|
||||
)
|
||||
from objtypes import RaisedException
|
||||
from records import Record, RecordSet
|
||||
|
||||
|
||||
class FakeTable(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
table_id = "Table1"
|
||||
_identity_relation = None
|
||||
|
||||
|
||||
@unittest.skipUnless(six.PY3, "Python 3 only")
|
||||
class TestFormulaPrompt(test_engine.EngineTestCase):
|
||||
def test_values_type(self):
|
||||
self.assertEqual(values_type([1, 2, 3]), "int")
|
||||
self.assertEqual(values_type([1.0, 2.0, 3.0]), "float")
|
||||
self.assertEqual(values_type([1, 2, 3.0]), "float")
|
||||
|
||||
self.assertEqual(values_type([1, 2, None]), "Optional[int]")
|
||||
self.assertEqual(values_type([1, 2, 3.0, None]), "Optional[float]")
|
||||
|
||||
self.assertEqual(values_type([1, RaisedException(None), 3]), "int")
|
||||
self.assertEqual(values_type([1, RaisedException(None), None]), "Optional[int]")
|
||||
|
||||
self.assertEqual(values_type(["1", "2", "3"]), "str")
|
||||
self.assertEqual(values_type([1, 2, "3"]), "Any")
|
||||
self.assertEqual(values_type([1, 2, "3", None]), "Any")
|
||||
|
||||
self.assertEqual(values_type([
|
||||
Record(FakeTable(), None),
|
||||
Record(FakeTable(), None),
|
||||
]), "Table1")
|
||||
self.assertEqual(values_type([
|
||||
Record(FakeTable(), None),
|
||||
Record(FakeTable(), None),
|
||||
None,
|
||||
]), "Optional[Table1]")
|
||||
|
||||
self.assertEqual(values_type([
|
||||
RecordSet(FakeTable(), None),
|
||||
RecordSet(FakeTable(), None),
|
||||
]), "List[Table1]")
|
||||
self.assertEqual(values_type([
|
||||
RecordSet(FakeTable(), None),
|
||||
RecordSet(FakeTable(), None),
|
||||
None,
|
||||
]), "Optional[List[Table1]]")
|
||||
|
||||
self.assertEqual(values_type([[1, 2, 3]]), "List[int]")
|
||||
self.assertEqual(values_type([[1, 2, 3], None]), "Optional[List[int]]")
|
||||
self.assertEqual(values_type([[1, 2, None]]), "List[Optional[int]]")
|
||||
self.assertEqual(values_type([[1, 2, None], None]), "Optional[List[Optional[int]]]")
|
||||
self.assertEqual(values_type([[1, 2, "3"]]), "List[Any]")
|
||||
|
||||
self.assertEqual(values_type([{1, 2, 3}]), "Set[int]")
|
||||
self.assertEqual(values_type([(1, 2, 3)]), "Tuple[int, ...]")
|
||||
self.assertEqual(values_type([{1: ["2"]}]), "Dict[int, List[str]]")
|
||||
|
||||
def assert_column_type(self, col_id, expected_type):
|
||||
self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type)
|
||||
|
||||
def assert_prompt(self, table_name, col_id, expected_prompt):
|
||||
prompt = get_formula_prompt(self.engine, table_name, col_id, "description here",
|
||||
include_all_tables=False, lookups=False)
|
||||
# print(prompt)
|
||||
self.assertEqual(prompt, expected_prompt)
|
||||
|
||||
def test_column_type(self):
|
||||
sample = testutil.parse_test_sample({
|
||||
"SCHEMA": [
|
||||
[1, "Table2", [
|
||||
[1, "text", "Text", False, "", "", ""],
|
||||
[2, "numeric", "Numeric", False, "", "", ""],
|
||||
[3, "int", "Int", False, "", "", ""],
|
||||
[4, "bool", "Bool", False, "", "", ""],
|
||||
[5, "date", "Date", False, "", "", ""],
|
||||
[6, "datetime", "DateTime", False, "", "", ""],
|
||||
[7, "attachments", "Attachments", False, "", "", ""],
|
||||
[8, "ref", "Ref:Table2", False, "", "", ""],
|
||||
[9, "reflist", "RefList:Table2", False, "", "", ""],
|
||||
[10, "choice", "Choice", False, "", "", '{"choices": ["a", "b", "c"]}'],
|
||||
[11, "choicelist", "ChoiceList", False, "", "", '{"choices": ["x", "y", "z"]}'],
|
||||
[12, "ref_formula", "Any", True, "$ref or None", "", ""],
|
||||
[13, "numeric_formula", "Any", True, "1 / $numeric", "", ""],
|
||||
[14, "new_formula", "Numeric", True, "'to be generated...'", "", ""],
|
||||
]],
|
||||
],
|
||||
"DATA": {
|
||||
"Table2": [
|
||||
["id", "numeric", "ref"],
|
||||
[1, 0, 0],
|
||||
[2, 1, 1],
|
||||
],
|
||||
},
|
||||
})
|
||||
self.load_sample(sample)
|
||||
|
||||
self.assert_column_type("text", "str")
|
||||
self.assert_column_type("numeric", "float")
|
||||
self.assert_column_type("int", "int")
|
||||
self.assert_column_type("bool", "bool")
|
||||
self.assert_column_type("date", "datetime.date")
|
||||
self.assert_column_type("datetime", "datetime.datetime")
|
||||
self.assert_column_type("attachments", "Any")
|
||||
self.assert_column_type("ref", "Table2")
|
||||
self.assert_column_type("reflist", "List[Table2]")
|
||||
self.assert_column_type("choice", "Literal['a', 'b', 'c']")
|
||||
self.assert_column_type("choicelist", "Tuple[Literal['x', 'y', 'z'], ...]")
|
||||
self.assert_column_type("ref_formula", "Optional[Table2]")
|
||||
self.assert_column_type("numeric_formula", "float")
|
||||
|
||||
self.assertEqual(referenced_tables(self.engine, "Table2"), set())
|
||||
|
||||
self.assert_prompt("Table2", "new_formula",
|
||||
'''\
|
||||
@dataclass
|
||||
class Table2:
|
||||
text: str
|
||||
numeric: float
|
||||
int: int
|
||||
bool: bool
|
||||
date: datetime.date
|
||||
datetime: datetime.datetime
|
||||
attachments: Any
|
||||
ref: Table2
|
||||
reflist: List[Table2]
|
||||
choice: Literal['a', 'b', 'c']
|
||||
choicelist: Tuple[Literal['x', 'y', 'z'], ...]
|
||||
ref_formula: Optional[Table2]
|
||||
numeric_formula: float
|
||||
|
||||
@property
|
||||
# rec is alias for self
|
||||
def new_formula(rec) -> float:
|
||||
"""
|
||||
description here
|
||||
"""
|
||||
''')
|
||||
|
||||
def test_get_formula_prompt(self):
|
||||
sample = testutil.parse_test_sample({
|
||||
"SCHEMA": [
|
||||
[1, "Table1", [
|
||||
[1, "text", "Text", False, "", "", ""],
|
||||
]],
|
||||
[2, "Table2", [
|
||||
[2, "ref", "Ref:Table1", False, "", "", ""],
|
||||
]],
|
||||
[3, "Table3", [
|
||||
[3, "reflist", "RefList:Table2", False, "", "", ""],
|
||||
]],
|
||||
],
|
||||
"DATA": {},
|
||||
})
|
||||
self.load_sample(sample)
|
||||
self.assertEqual(referenced_tables(self.engine, "Table3"), {"Table1", "Table2"})
|
||||
self.assertEqual(referenced_tables(self.engine, "Table2"), {"Table1"})
|
||||
self.assertEqual(referenced_tables(self.engine, "Table1"), set())
|
||||
|
||||
self.assert_prompt("Table1", "text", '''\
|
||||
@dataclass
|
||||
class Table1:
|
||||
|
||||
@property
|
||||
# rec is alias for self
|
||||
def text(rec) -> str:
|
||||
"""
|
||||
description here
|
||||
"""
|
||||
''')
|
||||
|
||||
self.assert_prompt("Table2", "ref", '''\
|
||||
@dataclass
|
||||
class Table1:
|
||||
text: str
|
||||
|
||||
@dataclass
|
||||
class Table2:
|
||||
|
||||
@property
|
||||
# rec is alias for self
|
||||
def ref(rec) -> Table1:
|
||||
"""
|
||||
description here
|
||||
"""
|
||||
''')
|
||||
|
||||
self.assert_prompt("Table3", "reflist", '''\
|
||||
@dataclass
|
||||
class Table1:
|
||||
text: str
|
||||
|
||||
@dataclass
|
||||
class Table2:
|
||||
ref: Table1
|
||||
|
||||
@dataclass
|
||||
class Table3:
|
||||
|
||||
@property
|
||||
# rec is alias for self
|
||||
def reflist(rec) -> List[Table2]:
|
||||
"""
|
||||
description here
|
||||
"""
|
||||
''')
|
||||
Reference in New Issue
Block a user