gristlabs_grist-core/sandbox/grist/formula_prompt.py
Alex Hall a1d31e41ad (core) Prevent the AI assistant response from including class definitions
Summary:
Sometimes the model repeats the classes given in the prompt which would mess up extracting the actual formula. This diff solves this by:

1. Changes the generated Python schema so that (a) the thing that needs completing is a plain top level function instead of a property/method inside the class and (2) the classes are fully valid syntax, which makes it easier to
2. Remove classes from the parsed Python code when converting the completion to a formula.
3. Tweak the prompt wording to discourage including classes in general, especially because sometimes the model tries to solve the problem by defining extra methods/attributes/classes.

While I was at it, I changed type hints to use builtins (e.g. `list` instead of `List`) to prevent `from typing import List` which was happening sometimes and would look weird in a formula. Similarly I removed `@dataclass` since that also implies an import, and this also fits with the tweaked wording that the classes are fake.

Test Plan:
Added a new test case to the formula dataset which triggers the unwanted behaviour. The factors that seem to trigger the problem are (1) a small schema so the classes are easier to repeat and (2) the need to import modules, which the model wants to place before all other code. The case failed before this diff and succeeded after. The tweaked wording reduces the chances of repeating the classes but didn't eliminate it, so forcibly removing the classes in Python was needed.

There were also a couple of other existing cases where repeating the classes was observed before but not after.

Overall the score increased from 49 to 51 out of 69 (including the new case). At one point the score was 53, but changes in whitespace were enough to make it drop again.

Reviewers: georgegevoian

Reviewed By: georgegevoian

Differential Revision: https://phab.getgrist.com/D4000
2023-08-18 12:50:09 +02:00

283 lines
8.8 KiB
Python

import ast
import json
import re
import textwrap
import asttokens
import asttokens.util
import six
import attribute_recorder
import objtypes
from codebuilder import make_formula_body
from column import is_visible_column, BaseReferenceColumn
from objtypes import RaisedException
import records
def column_type(engine, table_id, col_id):
col_rec = engine.docmodel.get_column_rec(table_id, col_id)
typ = col_rec.type
parts = typ.split(":")
if parts[0] == "Ref":
return parts[1]
elif parts[0] == "RefList":
return "list[{}]".format(parts[1])
elif typ == "Choice":
return choices(col_rec)
elif typ == "ChoiceList":
return "tuple[{}, ...]".format(choices(col_rec))
elif typ == "Any":
table = engine.tables[table_id]
col = table.get_column(col_id)
values = [col.raw_get(row_id) for row_id in table.row_ids]
return values_type(values)
else:
return dict(
Text="str",
Numeric="float",
Int="int",
Bool="bool",
Date="datetime.date",
DateTime="datetime.datetime",
Any="Any",
Attachments="Any",
)[parts[0]]
def choices(col_rec):
try:
widget_options = json.loads(col_rec.widgetOptions)
return "Literal{}".format(widget_options["choices"])
except (ValueError, KeyError):
return 'str'
def values_type(values):
types = set(type(v) for v in values) - {RaisedException}
optional = type(None) in types # pylint: disable=unidiomatic-typecheck
types.discard(type(None))
if types == {int, float}:
types = {float}
if len(types) != 1:
return "Any"
[typ] = types
val = next(v for v in values if isinstance(v, typ))
if isinstance(val, records.Record):
type_name = val._table.table_id
elif isinstance(val, records.RecordSet):
type_name = "list[{}]".format(val._table.table_id)
elif isinstance(val, list):
type_name = "list[{}]".format(values_type(val))
elif isinstance(val, set):
type_name = "set[{}]".format(values_type(val))
elif isinstance(val, tuple):
type_name = "tuple[{}, ...]".format(values_type(val))
elif isinstance(val, dict):
type_name = "dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
else:
type_name = typ.__name__
if optional:
type_name = "Optional[{}]".format(type_name)
return type_name
def referenced_tables(engine, table_id):
result = set()
queue = [table_id]
while queue:
cur_table_id = queue.pop()
if cur_table_id in result:
continue
result.add(cur_table_id)
for col_id, col in visible_columns(engine, cur_table_id):
if isinstance(col, BaseReferenceColumn):
target_table_id = col._target_table.table_id
if not target_table_id.startswith("_"):
queue.append(target_table_id)
return result - {table_id}
def all_other_tables(engine, table_id):
result = set(t for t in engine.tables.keys() if not t.startswith('_grist'))
return result - {table_id} - {'GristDocTour'}
def visible_columns(engine, table_id):
return [
(col_id, col)
for col_id, col in engine.tables[table_id].all_columns.items()
if is_visible_column(col_id)
]
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
result = "class {}:\n".format(table_id)
if lookups:
# Build a lookupRecords and lookupOne method for each table, providing some arguments hints
# for the columns that are visible.
lookupRecords_args = []
lookupOne_args = []
for col_id, col in visible_columns(engine, table_id):
if col_id != exclude_col_id:
lookupOne_args.append(col_id + '=None')
lookupRecords_args.append('%s=%s' % (col_id, col_id))
lookupOne_args.append('sort_by=None')
lookupRecords_args.append('sort_by=sort_by')
lookupOne_args_line = ', '.join(lookupOne_args)
lookupRecords_args_line = ', '.join(lookupRecords_args)
result += " def __len__(self):\n"
result += " return len(%s.lookupRecords())\n" % table_id
result += " @staticmethod\n"
result += " def lookupRecords(%s) -> list[%s]:\n" % (lookupOne_args_line, table_id)
result += " ...\n"
result += " @staticmethod\n"
result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
result += " '''\n"
result += " Filter for one result matching the keys provided.\n"
result += " To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n"
result += " '''\n"
result += " return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line)
result += "\n"
for col_id, col in visible_columns(engine, table_id):
if col_id != exclude_col_id:
result += " {}: {}\n".format(col_id, column_type(engine, table_id, col_id))
result += "\n"
return result
def get_formula_prompt(engine, table_id, col_id, _description,
include_all_tables=True,
lookups=True):
result = ""
other_tables = (all_other_tables(engine, table_id)
if include_all_tables else referenced_tables(engine, table_id))
for other_table_id in sorted(other_tables):
result += class_schema(engine, other_table_id, lookups)
result += class_schema(engine, table_id, col_id, lookups)
return_type = column_type(engine, table_id, col_id)
result += "def {}(rec: {}) -> {}:\n".format(col_id, table_id, return_type)
return result
def indent(text, prefix, predicate=None):
"""
Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility.
"""
if six.PY3:
return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member
if predicate is None:
def predicate(line):
return line.strip()
def prefixed_lines():
for line in text.splitlines(True):
yield (prefix + line if predicate(line) else line)
return ''.join(prefixed_lines())
def convert_completion(completion):
# Extract code from a markdown code block if needed.
match = re.search(r"```\w*\n(.*)```", completion, re.DOTALL)
if match:
completion = match.group(1)
result = textwrap.dedent(completion)
try:
atok = asttokens.ASTTokens(result, parse=True)
except SyntaxError:
# If we don't have valid Python code, don't suggest a formula at all
return ""
stmts = atok.tree.body
# If the code starts with imports, save them for later.
# In particular, the model may return something like:
# from datetime import date
# def my_column():
# ...
# We want to return just the function body, but we need to keep the import,
# i.e. move it 'inside the function'.
imports = ""
while stmts and isinstance(stmts[0], (ast.Import, ast.ImportFrom)):
imports += atok.get_text(stmts.pop(0)) + "\n"
# Sometimes the model repeats the provided classes, remove them.
stmts = [stmt for stmt in stmts if not isinstance(stmt, ast.ClassDef)]
# If the remaining code consists only of a function definition, extract the body.
if len(stmts) == 1 and isinstance(stmts[0], ast.FunctionDef):
func_body_stmts = stmts[0].body
if (
len(func_body_stmts) > 1 and
isinstance(func_body_stmts[0], ast.Expr) and
isinstance(func_body_stmts[0].value, ast.Str)
):
# Skip the docstring.
first_stmt = func_body_stmts[1]
else:
first_stmt = func_body_stmts[0]
result_lines = result.splitlines()[first_stmt.lineno - 1:]
result = "\n".join(result_lines)
result = textwrap.dedent(result)
if imports:
result = imports + "\n" + result
# Now convert `rec.` to `$` and remove redundant `return ` at the end.
try:
atok = asttokens.ASTTokens(result, parse=True)
except SyntaxError:
# In case the above extraction somehow messed things up
return ""
replacements = []
for node in ast.walk(atok.tree):
if isinstance(node, ast.Attribute):
start, end = atok.get_text_range(node.value)
end += 1
if result[start:end] == "rec.":
replacements.append((start, end, "$"))
last_stmt = atok.tree.body[-1]
if isinstance(last_stmt, ast.Return):
start, _ = atok.get_text_range(last_stmt)
expected = "return "
end = start + len(expected)
if result[start:end] == expected:
replacements.append((start, end, ""))
result = asttokens.util.replace(result, replacements)
return result.strip()
def evaluate_formula(engine, table_id, col_id, row_id):
grist_formula = engine.docmodel.get_column_rec(table_id, col_id).formula
assert grist_formula
plain_formula = make_formula_body(grist_formula, default_value=None).get_text()
attributes = {}
result = engine.get_formula_value(table_id, col_id, row_id, record_attributes=attributes)
if isinstance(result, objtypes.RaisedException):
name, message = result.encode_args()[:2]
result = "%s: %s" % (name, message)
error = True
else:
result = attribute_recorder.safe_repr(result)
error = False
return dict(
error=error,
formula=plain_formula,
result=result,
attributes=attributes,
)