mirror of
https://github.com/gristlabs/grist-core.git
synced 2024-10-27 20:44:07 +00:00
bb7cf6ba20
Summary: This tweaks the prompting so that the user's message is given on its own instead of as a docstring within Python. This is so that the prompt makes sense when: - the user asks a question such as "Can you write me a formula which does ...?" rather than describing their formula as a docstring would, or - the user sends a message that doesn't ask for a formula at all (https://grist.slack.com/archives/C0234CPPXPA/p1687699944315069?thread_ts=1687698078.832209&cid=C0234CPPXPA) Also added wording for the model to refuse when the user asks for something that the model cannot do. Because the code (and maybe in some cases the model) for non-ChatGPT models relies on the prompt consisting entirely of Python code produced by the data engine (which no longer contains the user's message) those code paths have been disabled for now. Updating them now seems like undesirable drag, I think it'd be better to revisit this when iteration/experimentation has slowed down and stabilised. Test Plan: Added entries to the formula dataset where the response shouldn't contain a formula, indicated by the value `1` for the new column `no_formula`. This is somewhat successful, as the model does refuse to help in some of the new test cases, but not all. Performance on existing entries also seems a bit worse, but it's hard to distinguish this from random noise. Hopefully this can be remedied in the future with more work, e.g. automatic followup messages containing example inputs and outputs. Reviewers: paulfitz Reviewed By: paulfitz Subscribers: dsagal Differential Revision: https://phab.getgrist.com/D3936
258 lines
8.1 KiB
Python
258 lines
8.1 KiB
Python
import ast
|
|
import json
|
|
import re
|
|
import textwrap
|
|
|
|
import asttokens
|
|
import asttokens.util
|
|
import six
|
|
|
|
from column import is_visible_column, BaseReferenceColumn
|
|
from objtypes import RaisedException
|
|
import records
|
|
|
|
|
|
def column_type(engine, table_id, col_id):
|
|
col_rec = engine.docmodel.get_column_rec(table_id, col_id)
|
|
typ = col_rec.type
|
|
parts = typ.split(":")
|
|
if parts[0] == "Ref":
|
|
return parts[1]
|
|
elif parts[0] == "RefList":
|
|
return "List[{}]".format(parts[1])
|
|
elif typ == "Choice":
|
|
return choices(col_rec)
|
|
elif typ == "ChoiceList":
|
|
return "Tuple[{}, ...]".format(choices(col_rec))
|
|
elif typ == "Any":
|
|
table = engine.tables[table_id]
|
|
col = table.get_column(col_id)
|
|
values = [col.raw_get(row_id) for row_id in table.row_ids]
|
|
return values_type(values)
|
|
else:
|
|
return dict(
|
|
Text="str",
|
|
Numeric="float",
|
|
Int="int",
|
|
Bool="bool",
|
|
Date="datetime.date",
|
|
DateTime="datetime.datetime",
|
|
Any="Any",
|
|
Attachments="Any",
|
|
)[parts[0]]
|
|
|
|
def choices(col_rec):
|
|
try:
|
|
widget_options = json.loads(col_rec.widgetOptions)
|
|
return "Literal{}".format(widget_options["choices"])
|
|
except (ValueError, KeyError):
|
|
return 'str'
|
|
|
|
|
|
def values_type(values):
|
|
types = set(type(v) for v in values) - {RaisedException}
|
|
optional = type(None) in types # pylint: disable=unidiomatic-typecheck
|
|
types.discard(type(None))
|
|
|
|
if types == {int, float}:
|
|
types = {float}
|
|
|
|
if len(types) != 1:
|
|
return "Any"
|
|
|
|
[typ] = types
|
|
val = next(v for v in values if isinstance(v, typ))
|
|
|
|
if isinstance(val, records.Record):
|
|
type_name = val._table.table_id
|
|
elif isinstance(val, records.RecordSet):
|
|
type_name = "List[{}]".format(val._table.table_id)
|
|
elif isinstance(val, list):
|
|
type_name = "List[{}]".format(values_type(val))
|
|
elif isinstance(val, set):
|
|
type_name = "Set[{}]".format(values_type(val))
|
|
elif isinstance(val, tuple):
|
|
type_name = "Tuple[{}, ...]".format(values_type(val))
|
|
elif isinstance(val, dict):
|
|
type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
|
|
else:
|
|
type_name = typ.__name__
|
|
|
|
if optional:
|
|
type_name = "Optional[{}]".format(type_name)
|
|
|
|
return type_name
|
|
|
|
|
|
def referenced_tables(engine, table_id):
|
|
result = set()
|
|
queue = [table_id]
|
|
while queue:
|
|
cur_table_id = queue.pop()
|
|
if cur_table_id in result:
|
|
continue
|
|
result.add(cur_table_id)
|
|
for col_id, col in visible_columns(engine, cur_table_id):
|
|
if isinstance(col, BaseReferenceColumn):
|
|
target_table_id = col._target_table.table_id
|
|
if not target_table_id.startswith("_"):
|
|
queue.append(target_table_id)
|
|
return result - {table_id}
|
|
|
|
def all_other_tables(engine, table_id):
|
|
result = set(t for t in engine.tables.keys() if not t.startswith('_grist'))
|
|
return result - {table_id} - {'GristDocTour'}
|
|
|
|
def visible_columns(engine, table_id):
|
|
return [
|
|
(col_id, col)
|
|
for col_id, col in engine.tables[table_id].all_columns.items()
|
|
if is_visible_column(col_id)
|
|
]
|
|
|
|
|
|
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
|
|
result = "@dataclass\nclass {}:\n".format(table_id)
|
|
|
|
if lookups:
|
|
|
|
# Build a lookupRecords and lookupOne method for each table, providing some arguments hints
|
|
# for the columns that are visible.
|
|
lookupRecords_args = []
|
|
lookupOne_args = []
|
|
for col_id, col in visible_columns(engine, table_id):
|
|
if col_id != exclude_col_id:
|
|
lookupOne_args.append(col_id + '=None')
|
|
lookupRecords_args.append('%s=%s' % (col_id, col_id))
|
|
lookupOne_args.append('sort_by=None')
|
|
lookupRecords_args.append('sort_by=sort_by')
|
|
lookupOne_args_line = ', '.join(lookupOne_args)
|
|
lookupRecords_args_line = ', '.join(lookupRecords_args)
|
|
|
|
result += " def __len__(self):\n"
|
|
result += " return len(%s.lookupRecords())\n" % table_id
|
|
result += " @staticmethod\n"
|
|
result += " def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id)
|
|
result += " # ...\n"
|
|
result += " @staticmethod\n"
|
|
result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
|
|
result += " '''\n"
|
|
result += " Filter for one result matching the keys provided.\n"
|
|
result += " To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n"
|
|
result += " '''\n"
|
|
result += " return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line)
|
|
result += "\n"
|
|
|
|
for col_id, col in visible_columns(engine, table_id):
|
|
if col_id != exclude_col_id:
|
|
result += " {}: {}\n".format(col_id, column_type(engine, table_id, col_id))
|
|
result += "\n"
|
|
return result
|
|
|
|
|
|
def get_formula_prompt(engine, table_id, col_id, _description,
|
|
include_all_tables=True,
|
|
lookups=True):
|
|
result = ""
|
|
other_tables = (all_other_tables(engine, table_id)
|
|
if include_all_tables else referenced_tables(engine, table_id))
|
|
for other_table_id in sorted(other_tables):
|
|
result += class_schema(engine, other_table_id, lookups)
|
|
|
|
result += class_schema(engine, table_id, col_id, lookups)
|
|
|
|
return_type = column_type(engine, table_id, col_id)
|
|
result += " @property\n"
|
|
result += " # rec is alias for self\n"
|
|
result += " def {}(rec) -> {}:\n".format(col_id, return_type)
|
|
result += " # Please fill in code only after this line, not the `def`\n"
|
|
return result
|
|
|
|
def indent(text, prefix, predicate=None):
|
|
"""
|
|
Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility.
|
|
"""
|
|
if six.PY3:
|
|
return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member
|
|
if predicate is None:
|
|
def predicate(line):
|
|
return line.strip()
|
|
def prefixed_lines():
|
|
for line in text.splitlines(True):
|
|
yield (prefix + line if predicate(line) else line)
|
|
return ''.join(prefixed_lines())
|
|
|
|
|
|
def convert_completion(completion):
|
|
# Extract code from a markdown code block if needed.
|
|
match = re.search(r"```\w*\n(.*)```", completion, re.DOTALL)
|
|
if match:
|
|
completion = match.group(1)
|
|
|
|
result = textwrap.dedent(completion)
|
|
|
|
try:
|
|
atok = asttokens.ASTTokens(result, parse=True)
|
|
except SyntaxError:
|
|
# If we don't have valid Python code, don't suggest a formula at all
|
|
return ""
|
|
|
|
stmts = atok.tree.body
|
|
|
|
# If the code starts with imports, save them for later.
|
|
# In particular, the model may return something like:
|
|
# from datetime import date
|
|
# def my_column():
|
|
# ...
|
|
# We want to return just the function body, but we need to keep the import,
|
|
# i.e. move it 'inside the function'.
|
|
imports = ""
|
|
while stmts and isinstance(stmts[0], (ast.Import, ast.ImportFrom)):
|
|
imports += atok.get_text(stmts.pop(0)) + "\n"
|
|
|
|
# If the non-import code consists only of a function definition, extract the body.
|
|
if len(stmts) == 1 and isinstance(stmts[0], ast.FunctionDef):
|
|
func_body_stmts = stmts[0].body
|
|
if (
|
|
len(func_body_stmts) > 1 and
|
|
isinstance(func_body_stmts[0], ast.Expr) and
|
|
isinstance(func_body_stmts[0].value, ast.Str)
|
|
):
|
|
# Skip the docstring.
|
|
first_stmt = func_body_stmts[1]
|
|
else:
|
|
first_stmt = func_body_stmts[0]
|
|
result_lines = result.splitlines()[first_stmt.lineno - 1:]
|
|
result = "\n".join(result_lines)
|
|
result = textwrap.dedent(result)
|
|
|
|
if imports:
|
|
result = imports + "\n" + result
|
|
|
|
# Now convert `rec.` to `$` and remove redundant `return ` at the end.
|
|
try:
|
|
atok = asttokens.ASTTokens(result, parse=True)
|
|
except SyntaxError:
|
|
# In case the above extraction somehow messed things up
|
|
return ""
|
|
|
|
replacements = []
|
|
for node in ast.walk(atok.tree):
|
|
if isinstance(node, ast.Attribute):
|
|
start, end = atok.get_text_range(node.value)
|
|
end += 1
|
|
if result[start:end] == "rec.":
|
|
replacements.append((start, end, "$"))
|
|
|
|
last_stmt = atok.tree.body[-1]
|
|
if isinstance(last_stmt, ast.Return):
|
|
start, _ = atok.get_text_range(last_stmt)
|
|
expected = "return "
|
|
end = start + len(expected)
|
|
if result[start:end] == expected:
|
|
replacements.append((start, end, ""))
|
|
|
|
result = asttokens.util.replace(result, replacements)
|
|
|
|
return result.strip()
|