gristlabs_grist-core/sandbox/grist/formula_prompt.py
George Gevoian 94eec5e906 (core) Add AI Assistant retry with shorter prompt
Summary:
If the longer OpenAI model exceeds the OpenAPI context length, we now perform another retry with a
shorter variant of the formula prompt. The shorter prompt excludes non-referenced tables and lookup
method definitions, which should help reduce token usage in documents with larger schemas.

Test Plan: Server test.

Reviewers: JakubSerafin

Reviewed By: JakubSerafin

Subscribers: JakubSerafin

Differential Revision: https://phab.getgrist.com/D4184
2024-02-12 11:06:52 -05:00

285 lines
9.0 KiB
Python

import ast
import json
import re
import textwrap
import asttokens
import asttokens.util
import six
import attribute_recorder
import objtypes
from codebuilder import make_formula_body
from column import is_visible_column, BaseReferenceColumn
from objtypes import RaisedException
import records
def column_type(engine, table_id, col_id):
col_rec = engine.docmodel.get_column_rec(table_id, col_id)
typ = col_rec.type
parts = typ.split(":")
if parts[0] == "Ref":
return parts[1]
elif parts[0] == "RefList":
return "list[{}]".format(parts[1])
elif typ == "Choice":
return choices(col_rec)
elif typ == "ChoiceList":
return "tuple[{}, ...]".format(choices(col_rec))
elif typ == "Any":
table = engine.tables[table_id]
col = table.get_column(col_id)
values = [col.raw_get(row_id) for row_id in table.row_ids]
return values_type(values)
else:
return dict(
Text="str",
Numeric="float",
Int="int",
Bool="bool",
Date="datetime.date",
DateTime="datetime.datetime",
Any="Any",
Attachments="Any",
)[parts[0]]
def choices(col_rec):
try:
widget_options = json.loads(col_rec.widgetOptions)
return "Literal{}".format(widget_options["choices"])
except (ValueError, KeyError):
return 'str'
def values_type(values):
types = set(type(v) for v in values) - {RaisedException}
optional = type(None) in types # pylint: disable=unidiomatic-typecheck
types.discard(type(None))
if types == {int, float}:
types = {float}
if len(types) != 1:
return "Any"
[typ] = types
val = next(v for v in values if isinstance(v, typ))
if isinstance(val, records.Record):
type_name = val._table.table_id
elif isinstance(val, records.RecordSet):
type_name = "list[{}]".format(val._table.table_id)
elif isinstance(val, list):
type_name = "list[{}]".format(values_type(val))
elif isinstance(val, set):
type_name = "set[{}]".format(values_type(val))
elif isinstance(val, tuple):
type_name = "tuple[{}, ...]".format(values_type(val))
elif isinstance(val, dict):
type_name = "dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
else:
type_name = typ.__name__
if optional:
type_name = "Optional[{}]".format(type_name)
return type_name
def referenced_tables(engine, table_id):
result = set()
queue = [table_id]
while queue:
cur_table_id = queue.pop()
if cur_table_id in result:
continue
result.add(cur_table_id)
for col_id, col in visible_columns(engine, cur_table_id):
if isinstance(col, BaseReferenceColumn):
target_table_id = col._target_table.table_id
if not target_table_id.startswith("_"):
queue.append(target_table_id)
return result - {table_id}
def all_other_tables(engine, table_id):
result = set(t for t in engine.tables.keys() if not t.startswith('_grist'))
return result - {table_id} - {'GristDocTour'}
def visible_columns(engine, table_id):
return [
(col_id, col)
for col_id, col in engine.tables[table_id].all_columns.items()
if is_visible_column(col_id)
]
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
result = "class {}:\n".format(table_id)
if lookups:
# Build a lookupRecords and lookupOne method for each table, providing some arguments hints
# for the columns that are visible.
lookupRecords_args = []
lookupOne_args = []
for col_id, col in visible_columns(engine, table_id):
if col_id != exclude_col_id:
lookupOne_args.append(col_id + '=None')
lookupRecords_args.append('%s=%s' % (col_id, col_id))
lookupOne_args.append('sort_by=None')
lookupRecords_args.append('sort_by=sort_by')
lookupOne_args_line = ', '.join(lookupOne_args)
lookupRecords_args_line = ', '.join(lookupRecords_args)
result += " def __len__(self):\n"
result += " return len(%s.lookupRecords())\n" % table_id
result += " @staticmethod\n"
result += " def lookupRecords(%s) -> list[%s]:\n" % (lookupOne_args_line, table_id)
result += " ...\n"
result += " @staticmethod\n"
result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
result += " '''\n"
result += " Filter for one result matching the keys provided.\n"
result += " To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n"
result += " '''\n"
result += " return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line)
result += "\n"
for col_id, col in visible_columns(engine, table_id):
if col_id != exclude_col_id:
result += " {}: {}\n".format(col_id, column_type(engine, table_id, col_id))
result += "\n"
return result
def get_formula_prompt(engine, table_id, col_id, _description,
include_all_tables=True,
lookups=True):
result = ""
other_tables = (all_other_tables(engine, table_id)
if include_all_tables else referenced_tables(engine, table_id))
for other_table_id in sorted(other_tables):
result += class_schema(engine, other_table_id, None, lookups)
result += class_schema(engine, table_id, col_id, lookups)
return_type = column_type(engine, table_id, col_id)
result += "def {}(rec: {}) -> {}:\n".format(col_id, table_id, return_type)
return result
def indent(text, prefix, predicate=None):
"""
Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility.
"""
if six.PY3:
return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member
if predicate is None:
def predicate(line):
return line.strip()
def prefixed_lines():
for line in text.splitlines(True):
yield (prefix + line if predicate(line) else line)
return ''.join(prefixed_lines())
def convert_completion(completion):
# Extract code from a markdown code block if needed.
match = re.search(r"```\w*\n(.*)```", completion, re.DOTALL)
if match:
completion = match.group(1)
result = textwrap.dedent(completion)
atok = asttokens.ASTText(result)
try:
# Constructing ASTText doesn't parse the code, but the .tree property does.
stmts = atok.tree.body
except SyntaxError:
# If we don't have valid Python code, don't suggest a formula at all
return ""
# If the code starts with imports, save them for later.
# In particular, the model may return something like:
# from datetime import date
# def my_column():
# ...
# We want to return just the function body, but we need to keep the import,
# i.e. move it 'inside the function'.
imports = ""
while stmts and isinstance(stmts[0], (ast.Import, ast.ImportFrom)):
imports += atok.get_text(stmts.pop(0)) + "\n"
# Sometimes the model repeats the provided classes, remove them.
stmts = [stmt for stmt in stmts if not isinstance(stmt, ast.ClassDef)]
# If the remaining code consists only of a function definition, extract the body.
if len(stmts) == 1 and isinstance(stmts[0], ast.FunctionDef):
func_body_stmts = stmts[0].body
if (
len(func_body_stmts) > 1 and
isinstance(func_body_stmts[0], ast.Expr) and
isinstance(func_body_stmts[0].value, ast.Str)
):
# Skip the docstring.
first_stmt = func_body_stmts[1]
else:
first_stmt = func_body_stmts[0]
result_lines = result.splitlines()[first_stmt.lineno - 1:]
result = "\n".join(result_lines)
result = textwrap.dedent(result)
if imports:
result = imports + "\n" + result
# Now convert `rec.` to `$` and remove redundant `return ` at the end.
atok = asttokens.ASTText(result)
try:
# Constructing ASTText doesn't parse the code, but the .tree property does.
tree = atok.tree
except SyntaxError:
# In case the above extraction somehow messed things up
return ""
replacements = []
for node in ast.walk(tree):
if isinstance(node, ast.Attribute):
start, end = atok.get_text_range(node.value)
end += 1
if result[start:end] == "rec.":
replacements.append((start, end, "$"))
last_stmt = tree.body[-1]
if isinstance(last_stmt, ast.Return):
start, _ = atok.get_text_range(last_stmt)
expected = "return "
end = start + len(expected)
if result[start:end] == expected:
replacements.append((start, end, ""))
result = asttokens.util.replace(result, replacements)
return result.strip()
def evaluate_formula(engine, table_id, col_id, row_id):
grist_formula = engine.docmodel.get_column_rec(table_id, col_id).formula
assert grist_formula
plain_formula = make_formula_body(grist_formula, default_value=None).get_text()
attributes = {}
result = engine.get_formula_value(table_id, col_id, row_id, record_attributes=attributes)
if isinstance(result, objtypes.RaisedException):
name, message = result.encode_args()[:2]
result = "%s: %s" % (name, message)
error = True
else:
result = attribute_recorder.safe_repr(result)
error = False
return dict(
error=error,
formula=plain_formula,
result=result,
attributes=attributes,
)