(core) Prevent the AI assistant response from including class definitions

Summary:
Sometimes the model repeats the classes given in the prompt which would mess up extracting the actual formula. This diff solves this by:

1. Changes the generated Python schema so that (a) the thing that needs completing is a plain top level function instead of a property/method inside the class and (2) the classes are fully valid syntax, which makes it easier to
2. Remove classes from the parsed Python code when converting the completion to a formula.
3. Tweak the prompt wording to discourage including classes in general, especially because sometimes the model tries to solve the problem by defining extra methods/attributes/classes.

While I was at it, I changed type hints to use builtins (e.g. `list` instead of `List`) to prevent `from typing import List` which was happening sometimes and would look weird in a formula. Similarly I removed `@dataclass` since that also implies an import, and this also fits with the tweaked wording that the classes are fake.

Test Plan:
Added a new test case to the formula dataset which triggers the unwanted behaviour. The factors that seem to trigger the problem are (1) a small schema so the classes are easier to repeat and (2) the need to import modules, which the model wants to place before all other code. The case failed before this diff and succeeded after. The tweaked wording reduces the chances of repeating the classes but didn't eliminate it, so forcibly removing the classes in Python was needed.

There were also a couple of other existing cases where repeating the classes was observed before but not after.

Overall the score increased from 49 to 51 out of 69 (including the new case). At one point the score was 53, but changes in whitespace were enough to make it drop again.

Reviewers: georgegevoian

Reviewed By: georgegevoian

Differential Revision: https://phab.getgrist.com/D4000
This commit is contained in:
Alex Hall
2023-08-18 12:48:47 +02:00
parent 999d723d14
commit a1d31e41ad
4 changed files with 74 additions and 65 deletions

View File

@@ -22,11 +22,11 @@ def column_type(engine, table_id, col_id):
if parts[0] == "Ref":
return parts[1]
elif parts[0] == "RefList":
return "List[{}]".format(parts[1])
return "list[{}]".format(parts[1])
elif typ == "Choice":
return choices(col_rec)
elif typ == "ChoiceList":
return "Tuple[{}, ...]".format(choices(col_rec))
return "tuple[{}, ...]".format(choices(col_rec))
elif typ == "Any":
table = engine.tables[table_id]
col = table.get_column(col_id)
@@ -69,15 +69,15 @@ def values_type(values):
if isinstance(val, records.Record):
type_name = val._table.table_id
elif isinstance(val, records.RecordSet):
type_name = "List[{}]".format(val._table.table_id)
type_name = "list[{}]".format(val._table.table_id)
elif isinstance(val, list):
type_name = "List[{}]".format(values_type(val))
type_name = "list[{}]".format(values_type(val))
elif isinstance(val, set):
type_name = "Set[{}]".format(values_type(val))
type_name = "set[{}]".format(values_type(val))
elif isinstance(val, tuple):
type_name = "Tuple[{}, ...]".format(values_type(val))
type_name = "tuple[{}, ...]".format(values_type(val))
elif isinstance(val, dict):
type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
type_name = "dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
else:
type_name = typ.__name__
@@ -115,7 +115,7 @@ def visible_columns(engine, table_id):
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
result = "@dataclass\nclass {}:\n".format(table_id)
result = "class {}:\n".format(table_id)
if lookups:
@@ -132,11 +132,11 @@ def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
lookupOne_args_line = ', '.join(lookupOne_args)
lookupRecords_args_line = ', '.join(lookupRecords_args)
result += " def __len__(self):\n"
result += " def __len__(self):\n"
result += " return len(%s.lookupRecords())\n" % table_id
result += " @staticmethod\n"
result += " def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id)
result += " # ...\n"
result += " def lookupRecords(%s) -> list[%s]:\n" % (lookupOne_args_line, table_id)
result += " ...\n"
result += " @staticmethod\n"
result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
result += " '''\n"
@@ -165,10 +165,7 @@ def get_formula_prompt(engine, table_id, col_id, _description,
result += class_schema(engine, table_id, col_id, lookups)
return_type = column_type(engine, table_id, col_id)
result += " @property\n"
result += " # rec is alias for self\n"
result += " def {}(rec) -> {}:\n".format(col_id, return_type)
result += " # Please fill in code only after this line, not the `def`\n"
result += "def {}(rec: {}) -> {}:\n".format(col_id, table_id, return_type)
return result
def indent(text, prefix, predicate=None):
@@ -213,7 +210,10 @@ def convert_completion(completion):
while stmts and isinstance(stmts[0], (ast.Import, ast.ImportFrom)):
imports += atok.get_text(stmts.pop(0)) + "\n"
# If the non-import code consists only of a function definition, extract the body.
# Sometimes the model repeats the provided classes, remove them.
stmts = [stmt for stmt in stmts if not isinstance(stmt, ast.ClassDef)]
# If the remaining code consists only of a function definition, extract the body.
if len(stmts) == 1 and isinstance(stmts[0], ast.FunctionDef):
func_body_stmts = stmts[0].body
if (

View File

@@ -58,29 +58,29 @@ class TestFormulaPrompt(test_engine.EngineTestCase):
self.assertEqual(values_type([
fake_table.RecordSet(None),
fake_table.RecordSet(None),
]), "List[Table1]")
]), "list[Table1]")
self.assertEqual(values_type([
fake_table.RecordSet(None),
fake_table.RecordSet(None),
None,
]), "Optional[List[Table1]]")
]), "Optional[list[Table1]]")
self.assertEqual(values_type([[1, 2, 3]]), "List[int]")
self.assertEqual(values_type([[1, 2, 3], None]), "Optional[List[int]]")
self.assertEqual(values_type([[1, 2, None]]), "List[Optional[int]]")
self.assertEqual(values_type([[1, 2, None], None]), "Optional[List[Optional[int]]]")
self.assertEqual(values_type([[1, 2, "3"]]), "List[Any]")
self.assertEqual(values_type([[1, 2, 3]]), "list[int]")
self.assertEqual(values_type([[1, 2, 3], None]), "Optional[list[int]]")
self.assertEqual(values_type([[1, 2, None]]), "list[Optional[int]]")
self.assertEqual(values_type([[1, 2, None], None]), "Optional[list[Optional[int]]]")
self.assertEqual(values_type([[1, 2, "3"]]), "list[Any]")
self.assertEqual(values_type([{1, 2, 3}]), "Set[int]")
self.assertEqual(values_type([(1, 2, 3)]), "Tuple[int, ...]")
self.assertEqual(values_type([{1: ["2"]}]), "Dict[int, List[str]]")
self.assertEqual(values_type([{1, 2, 3}]), "set[int]")
self.assertEqual(values_type([(1, 2, 3)]), "tuple[int, ...]")
self.assertEqual(values_type([{1: ["2"]}]), "dict[int, list[str]]")
def assert_column_type(self, col_id, expected_type):
self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type)
def assert_prompt(self, table_name, col_id, expected_prompt):
def assert_prompt(self, table_name, col_id, expected_prompt, lookups=False):
prompt = get_formula_prompt(self.engine, table_name, col_id, "description here",
include_all_tables=False, lookups=False)
include_all_tables=False, lookups=lookups)
# print(prompt)
self.assertEqual(prompt, expected_prompt)
@@ -122,9 +122,9 @@ class TestFormulaPrompt(test_engine.EngineTestCase):
self.assert_column_type("datetime", "datetime.datetime")
self.assert_column_type("attachments", "Any")
self.assert_column_type("ref", "Table2")
self.assert_column_type("reflist", "List[Table2]")
self.assert_column_type("reflist", "list[Table2]")
self.assert_column_type("choice", "Literal['a', 'b', 'c']")
self.assert_column_type("choicelist", "Tuple[Literal['x', 'y', 'z'], ...]")
self.assert_column_type("choicelist", "tuple[Literal['x', 'y', 'z'], ...]")
self.assert_column_type("ref_formula", "Optional[Table2]")
self.assert_column_type("numeric_formula", "float")
@@ -132,7 +132,6 @@ class TestFormulaPrompt(test_engine.EngineTestCase):
self.assert_prompt("Table2", "new_formula",
'''\
@dataclass
class Table2:
text: str
numeric: float
@@ -142,16 +141,13 @@ class Table2:
datetime: datetime.datetime
attachments: Any
ref: Table2
reflist: List[Table2]
reflist: list[Table2]
choice: Literal['a', 'b', 'c']
choicelist: Tuple[Literal['x', 'y', 'z'], ...]
choicelist: tuple[Literal['x', 'y', 'z'], ...]
ref_formula: Optional[Table2]
numeric_formula: float
@property
# rec is alias for self
def new_formula(rec) -> float:
# Please fill in code only after this line, not the `def`
def new_formula(rec: Table2) -> float:
''')
def test_get_formula_prompt(self):
@@ -175,45 +171,52 @@ class Table2:
self.assertEqual(referenced_tables(self.engine, "Table1"), set())
self.assert_prompt("Table1", "text", '''\
@dataclass
class Table1:
@property
# rec is alias for self
def text(rec) -> str:
# Please fill in code only after this line, not the `def`
def text(rec: Table1) -> str:
''')
# Test the same thing but include the lookup methods as in a real case,
# just to show that the table class would never actually be empty
# (which would be invalid Python and might confuse the model).
self.assert_prompt("Table1", "text", """\
class Table1:
def __len__(self):
return len(Table1.lookupRecords())
@staticmethod
def lookupRecords(sort_by=None) -> list[Table1]:
...
@staticmethod
def lookupOne(sort_by=None) -> Table1:
'''
Filter for one result matching the keys provided.
To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.
'''
return Table1.lookupRecords(sort_by=sort_by)[0]
def text(rec: Table1) -> str:
""", lookups=True)
self.assert_prompt("Table2", "ref", '''\
@dataclass
class Table1:
text: str
@dataclass
class Table2:
@property
# rec is alias for self
def ref(rec) -> Table1:
# Please fill in code only after this line, not the `def`
def ref(rec: Table2) -> Table1:
''')
self.assert_prompt("Table3", "reflist", '''\
@dataclass
class Table1:
text: str
@dataclass
class Table2:
ref: Table1
@dataclass
class Table3:
@property
# rec is alias for self
def reflist(rec) -> List[Table2]:
# Please fill in code only after this line, not the `def`
def reflist(rec: Table3) -> list[Table2]:
''')
def test_convert_completion(self):
@@ -227,6 +230,9 @@ from x import (
z,
)
class Foo:
bar: Bar
@property
def foo(rec):
'''This is a docstring'''