From 391c8ee087cfe71c186e25fd4bcdd2660fbf3465 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Mon, 24 Jul 2023 20:56:38 +0200 Subject: [PATCH] (core) Allow assistant to evaluate current formula Summary: Replaces https://phab.getgrist.com/D3940, particularly to avoid doing potentially unwanted things automatically. Adds optional fields `evaluateCurrentFormula?: boolean; rowId?: number` to `FormulaAssistanceContext` (part of `AssistanceRequest`). When `evaluateCurrentFormula` is `true`, calls a new function `evaluate_formula` in the sandbox which computes the existing formula in the column (regardless of anything the AI may have suggested) and uses that to generate an additional system message which is added before the user's message. In theory this could be used in an interface where users ask why a formula doesn't work, including possibly a formula suggested by the AI. For now, it's only used in `runCompletion_impl.ts` for experimenting. Also cleaned up a bit, removing `_chatMode` which is always `true` now, and uses of `regenerate` which is always `false`. Test Plan: Updated `runCompletion_impl` to optionally use the new feature, in which case it now scores 51/68 instead of 49/68. Reviewers: paulfitz Reviewed By: paulfitz Differential Revision: https://phab.getgrist.com/D3970 --- app/common/AssistancePrompts.ts | 8 +- app/server/lib/ActiveDoc.ts | 11 ++ app/server/lib/Assistance.ts | 129 +++++++++++---------- sandbox/grist/attribute_recorder.py | 59 ++++++++++ sandbox/grist/engine.py | 33 ++++-- sandbox/grist/formula_prompt.py | 25 ++++ sandbox/grist/main.py | 4 + test/formula-dataset/runCompletion_impl.ts | 129 ++++++++++----------- 8 files changed, 257 insertions(+), 141 deletions(-) create mode 100644 sandbox/grist/attribute_recorder.py diff --git a/app/common/AssistancePrompts.ts b/app/common/AssistancePrompts.ts index fce742f8..abe4e915 100644 --- a/app/common/AssistancePrompts.ts +++ b/app/common/AssistancePrompts.ts @@ -27,6 +27,8 @@ export interface FormulaAssistanceContext { type: 'formula'; tableId: string; colId: string; + evaluateCurrentFormula?: boolean; + rowId?: number; } export type AssistanceContext = FormulaAssistanceContext; @@ -39,10 +41,8 @@ export interface AssistanceRequest { context: AssistanceContext; state?: AssistanceState; text: string; - regenerate?: boolean; // Set if there was a previous request - // and response that should be omitted - // from history, or (if available) an - // alternative response generated. + // TODO this is no longer used and should be removed + regenerate?: boolean; } /** diff --git a/app/server/lib/ActiveDoc.ts b/app/server/lib/ActiveDoc.ts index 4111a43a..83eb128c 100644 --- a/app/server/lib/ActiveDoc.ts +++ b/app/server/lib/ActiveDoc.ts @@ -85,6 +85,7 @@ import {ParseFileResult, ParseOptions} from 'app/plugin/FileParserAPI'; import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI'; import {compileAclFormula} from 'app/server/lib/ACLFormula'; import {AssistanceSchemaPromptV1Context} from 'app/server/lib/Assistance'; +import {AssistanceContext} from 'app/common/AssistancePrompts'; import {Authorizer} from 'app/server/lib/Authorizer'; import {checksumFile} from 'app/server/lib/checksumFile'; import {Client} from 'app/server/lib/Client'; @@ -1289,6 +1290,16 @@ export class ActiveDoc extends EventEmitter { return this._pyCall('convert_formula_completion', txt); } + // Callback to compute an existing formula and return the result along with recorded values + // of (possibly nested) attributes of `rec`. + // Used by AI assistance to fix an incorrect formula. + public assistanceEvaluateFormula(options: AssistanceContext) { + if (!options.evaluateCurrentFormula) { + throw new Error('evaluateCurrentFormula must be true'); + } + return this._pyCall('evaluate_formula', options.tableId, options.colId, options.rowId); + } + public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise { return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options); } diff --git a/app/server/lib/Assistance.ts b/app/server/lib/Assistance.ts index 6a97d0fd..6eb94bb9 100644 --- a/app/server/lib/Assistance.ts +++ b/app/server/lib/Assistance.ts @@ -2,7 +2,12 @@ * Module with functions used for AI formula assistance. */ -import {AssistanceMessage, AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts'; +import { + AssistanceContext, + AssistanceMessage, + AssistanceRequest, + AssistanceResponse +} from 'app/common/AssistancePrompts'; import {delay} from 'app/common/delay'; import {DocAction} from 'app/common/DocActions'; import {ActiveDoc} from 'app/server/lib/ActiveDoc'; @@ -37,10 +42,29 @@ interface AssistanceDoc extends ActiveDoc { * be great to try variants. */ assistanceSchemaPromptV1(session: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise; + /** * Some tweaks to a formula after it has been generated. */ assistanceFormulaTweak(txt: string): Promise; + + /** + * Compute the existing formula and return the result along with recorded values + * of (possibly nested) attributes of `rec`. + * Used by AI assistance to fix an incorrect formula. + */ + assistanceEvaluateFormula(options: AssistanceContext): Promise; +} + +export interface AssistanceFormulaEvaluationResult { + error: boolean; // true if an exception was raised + result: string; // repr of the return value OR exception message + + // Recorded attributes of `rec` at the time of evaluation. + // Keys may be e.g. "rec.foo.bar" for nested attributes. + attributes: Record; + + formula: string; // the code that was evaluated, without special grist syntax } export interface AssistanceSchemaPromptV1Context { @@ -101,7 +125,6 @@ export class OpenAIAssistant implements Assistant { public static LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613"; private _apiKey: string; - private _chatMode: boolean; private _endpoint: string; public constructor() { @@ -110,60 +133,52 @@ export class OpenAIAssistant implements Assistant { throw new Error('OPENAI_API_KEY not set'); } this._apiKey = apiKey; - this._chatMode = true; - if (!this._chatMode) { - throw new Error('Only turbo models are currently supported'); - } - this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`; + this._endpoint = `https://api.openai.com/v1/chat/completions`; } public async apply( optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise { const messages = request.state?.messages || []; const newMessages = []; - const chatMode = this._chatMode; - if (chatMode) { - if (messages.length === 0) { - newMessages.push({ - role: 'system', - content: 'You are a helpful assistant for a user of software called Grist. ' + - 'Below are one or more Python classes. ' + - 'The last method needs completing. ' + - "The user will probably give a description of what they want the method (a 'formula') to return. " + - 'If so, your response should include the method body as Python code in a markdown block. ' + - 'Do not include the class or method signature, just the method body. ' + - 'If your code starts with `class`, `@dataclass`, or `def` it will fail. Only give the method body. ' + - 'You can import modules inside the method body if needed. ' + - 'You cannot define additional functions or methods. ' + - 'The method should be a pure function that performs some computation and returns a result. ' + - 'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' + - 'It CANNOT interact with files/databases/networks/etc. ' + - 'It CANNOT display images/charts/graphs/maps/etc. ' + - 'If the user asks for these things, tell them that you cannot help. ' + - 'The method uses `rec` instead of `self` as the first parameter.\n\n' + - '```python\n' + - await makeSchemaPromptV1(optSession, doc, request) + - '\n```', - }); - newMessages.push({ - role: 'user', content: request.text, - }); - } else { - if (request.regenerate) { - if (messages[messages.length - 1].role !== 'user') { - messages.pop(); - } - } - newMessages.push({ - role: 'user', content: request.text, - }); - } - } else { - messages.length = 0; + if (messages.length === 0) { newMessages.push({ - role: 'user', content: await makeSchemaPromptV1(optSession, doc, request), + role: 'system', + content: 'You are a helpful assistant for a user of software called Grist. ' + + 'Below are one or more Python classes. ' + + 'The last method needs completing. ' + + "The user will probably give a description of what they want the method (a 'formula') to return. " + + 'If so, your response should include the method body as Python code in a markdown block. ' + + 'Do not include the class or method signature, just the method body. ' + + 'If your code starts with `class`, `@dataclass`, or `def` it will fail. Only give the method body. ' + + 'You can import modules inside the method body if needed. ' + + 'You cannot define additional functions or methods. ' + + 'The method should be a pure function that performs some computation and returns a result. ' + + 'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' + + 'It CANNOT interact with files/databases/networks/etc. ' + + 'It CANNOT display images/charts/graphs/maps/etc. ' + + 'If the user asks for these things, tell them that you cannot help. ' + + 'The method uses `rec` instead of `self` as the first parameter.\n\n' + + '```python\n' + + await makeSchemaPromptV1(optSession, doc, request) + + '\n```', }); } + if (request.context.evaluateCurrentFormula) { + const result = await doc.assistanceEvaluateFormula(request.context); + let message = "Evaluating this code:\n\n```python\n" + result.formula + "\n```\n\n"; + if (Object.keys(result.attributes).length > 0) { + const attributes = Object.entries(result.attributes).map(([k, v]) => `${k} = ${v}`).join('\n'); + message += `where:\n\n${attributes}\n\n`; + } + message += `${result.error ? 'raises an exception' : 'returns'}: ${result.result}`; + newMessages.push({ + role: 'system', + content: message, + }); + } + newMessages.push({ + role: 'user', content: request.text, + }); messages.push(...newMessages); const newMessagesStartIndex = messages.length - newMessages.length; @@ -184,9 +199,7 @@ export class OpenAIAssistant implements Assistant { const userIdHash = getUserHash(optSession); const completion: string = await this._getCompletion(messages, userIdHash); const response = await completionToResponse(doc, request, completion, completion); - if (chatMode) { - response.state = {messages}; - } + response.state = {messages}; doc.logTelemetryEvent(optSession, 'assistantReceive', { full: { conversationId: request.conversationId, @@ -211,12 +224,9 @@ export class OpenAIAssistant implements Assistant { "Content-Type": "application/json", }, body: JSON.stringify({ - ...(!this._chatMode ? { - prompt: messages[messages.length - 1].content, - } : {messages}), + messages, temperature: 0, model: longerContext ? OpenAIAssistant.LONGER_CONTEXT_MODEL : OpenAIAssistant.DEFAULT_MODEL, - stop: this._chatMode ? undefined : ["\n\n"], user: userIdHash, }), }, @@ -267,11 +277,9 @@ export class OpenAIAssistant implements Assistant { private async _getCompletion(messages: AssistanceMessage[], userIdHash: string) { const result = await this._fetchCompletionWithRetries(messages, userIdHash, false); - const completion: string = String(this._chatMode ? result.choices[0].message.content : result.choices[0].text); - if (this._chatMode) { - messages.push(result.choices[0].message); - } - return completion; + const {message} = result.choices[0]; + messages.push(message); + return message.content; } } @@ -404,6 +412,9 @@ export async function sendForCompletion( doc: AssistanceDoc, request: AssistanceRequest, ): Promise { + if (request.regenerate) { + throw new Error('regenerate no longer supported'); + } const assistant = getAssistant(); return await assistant.apply(optSession, doc, request); } diff --git a/sandbox/grist/attribute_recorder.py b/sandbox/grist/attribute_recorder.py new file mode 100644 index 00000000..6aa82171 --- /dev/null +++ b/sandbox/grist/attribute_recorder.py @@ -0,0 +1,59 @@ +from six.moves import reprlib + +import records + + +class AttributeRecorder(object): + """ + Wrapper around a Record that records attribute accesses. + Used to generate a prompt for the AI with basic 'debugging' info. + """ + def __init__(self, inner, name, attributes): + assert isinstance(inner, records.Record) + self._inner = inner + self._name = name + self._attributes = attributes + + def __getattr__(self, name): + """ + Record attribute access. + If the result is a Record or RecordSet, wrap that with AttributeRecorder + to also record nested attribute values. + """ + result = getattr(self._inner, name) + full_name = "{}.{}".format(self._name, name) + if isinstance(result, records.Record): + result = AttributeRecorder(result, full_name, self._attributes) + elif isinstance(result, records.RecordSet): + # Use a tuple to imply immutability so that the AI doesn't try appending. + # Don't try recording attributes of all contained records, just record the first access. + # Pretend that the attribute is always accessed from the first record for simplicity. + result = tuple(AttributeRecorder(r, full_name + "[0]", self._attributes) for r in result) + self._attributes.setdefault(full_name, safe_repr(result)) + return result + + def __repr__(self): + # The usual Record repr looks like Table1[2] which may surprise the AI. + return "{}(id={})".format(self._inner._table.table_id, self._inner._row_id) + + +arepr = reprlib.Repr() +arepr.maxlevel = 3 +arepr.maxtuple = 3 +arepr.maxlist = 3 +arepr.maxarray = 3 +arepr.maxdict = 4 +arepr.maxset = 3 +arepr.maxfrozenset = 3 +arepr.maxdeque = 3 +arepr.maxstring = 40 +arepr.maxlong = 20 +arepr.maxother = 60 + + +def safe_repr(x): + try: + return arepr.repr(x) + except Exception: + # Copied from Repr.repr_instance in Python 3. + return '<%s instance at %#x>' % (x.__class__.__name__, id(x)) diff --git a/sandbox/grist/engine.py b/sandbox/grist/engine.py index 15b70de8..1e2dae2e 100644 --- a/sandbox/grist/engine.py +++ b/sandbox/grist/engine.py @@ -20,6 +20,7 @@ from sortedcontainers import SortedSet import acl import actions import action_obj +from attribute_recorder import AttributeRecorder from autocomplete_context import AutocompleteContext, lookup_autocomplete_options, eval_suggestion from codebuilder import DOLLAR_REGEX import depend @@ -694,22 +695,27 @@ class Engine(object): not recomputing the whole column and dependent columns as well. So it recomputes the formula for this cell and returns error message with details. """ + result = self.get_formula_value(table_id, col_id, row_id) + table = self.tables[table_id] + col = table.get_column(col_id) + # If the error is gone for a trigger formula + if col.has_formula() and not col.is_formula(): + if not isinstance(result, objtypes.RaisedException): + # Get the error stored in the cell + # and change it to show to the user that no traceback is available + error_in_cell = objtypes.decode_object(col.raw_get(row_id)) + assert isinstance(error_in_cell, objtypes.RaisedException) + return error_in_cell.no_traceback() + return result + + def get_formula_value(self, table_id, col_id, row_id, record_attributes=None): table = self.tables[table_id] col = table.get_column(col_id) checkpoint = self._get_undo_checkpoint() # Makes calls to REQUEST synchronous, since raising a RequestingError can't work here. self._sync_request = True try: - result = self._recompute_one_cell(table, col, row_id) - # If the error is gone for a trigger formula - if col.has_formula() and not col.is_formula(): - if not isinstance(result, objtypes.RaisedException): - # Get the error stored in the cell - # and change it to show to the user that no traceback is available - error_in_cell = objtypes.decode_object(col.raw_get(row_id)) - assert isinstance(error_in_cell, objtypes.RaisedException) - return error_in_cell.no_traceback() - return result + return self._recompute_one_cell(table, col, row_id, record_attributes=record_attributes) finally: # It is possible for formula evaluation to have side-effects that produce DocActions (e.g. # lookupOrAddDerived() creates those). In case of get_formula_error(), these aren't fully @@ -920,7 +926,7 @@ class Engine(object): raise RequestingError() - def _recompute_one_cell(self, table, col, row_id, cycle=False, node=None): + def _recompute_one_cell(self, table, col, row_id, cycle=False, node=None, record_attributes=None): """ Recomputes an one formula cell and returns a value. The value can be: @@ -939,6 +945,11 @@ class Engine(object): checkpoint = self._get_undo_checkpoint() record = table.Record(row_id, table._identity_relation) + if record_attributes is not None: + assert isinstance(record_attributes, dict) + assert col.is_formula() + assert not cycle + record = AttributeRecorder(record, "rec", record_attributes) value = None try: if cycle: diff --git a/sandbox/grist/formula_prompt.py b/sandbox/grist/formula_prompt.py index c3125135..63308126 100644 --- a/sandbox/grist/formula_prompt.py +++ b/sandbox/grist/formula_prompt.py @@ -7,6 +7,9 @@ import asttokens import asttokens.util import six +import attribute_recorder +import objtypes +from codebuilder import make_formula_body from column import is_visible_column, BaseReferenceColumn from objtypes import RaisedException import records @@ -255,3 +258,25 @@ def convert_completion(completion): result = asttokens.util.replace(result, replacements) return result.strip() + + +def evaluate_formula(engine, table_id, col_id, row_id): + grist_formula = engine.docmodel.get_column_rec(table_id, col_id).formula + assert grist_formula + plain_formula = make_formula_body(grist_formula, default_value=None).get_text() + + attributes = {} + result = engine.get_formula_value(table_id, col_id, row_id, record_attributes=attributes) + if isinstance(result, objtypes.RaisedException): + name, message = result.encode_args()[:2] + result = "%s: %s" % (name, message) + error = True + else: + result = attribute_recorder.safe_repr(result) + error = False + return dict( + error=error, + formula=plain_formula, + result=result, + attributes=attributes, + ) diff --git a/sandbox/grist/main.py b/sandbox/grist/main.py index 03968120..df9c6443 100644 --- a/sandbox/grist/main.py +++ b/sandbox/grist/main.py @@ -153,6 +153,10 @@ def run(sandbox): def convert_formula_completion(completion): return formula_prompt.convert_completion(completion) + @export + def evaluate_formula(table_id, col_id, row_id): + return formula_prompt.evaluate_formula(eng, table_id, col_id, row_id) + export(parse_acl_formula) export(eng.load_empty) export(eng.load_done) diff --git a/test/formula-dataset/runCompletion_impl.ts b/test/formula-dataset/runCompletion_impl.ts index 930ae6e5..2eb92594 100644 --- a/test/formula-dataset/runCompletion_impl.ts +++ b/test/formula-dataset/runCompletion_impl.ts @@ -64,8 +64,8 @@ const _stats = { callCount: 0, }; -const SEEMS_CHATTY = (process.env.COMPLETION_MODEL || '').includes('turbo'); -const SIMULATE_CONVERSATION = SEEMS_CHATTY; +const SIMULATE_CONVERSATION = true; +const FOLLOWUP_EVALUATE = false; export async function runCompletion() { ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600; @@ -132,40 +132,48 @@ export async function runCompletion() { let success: boolean = false; let suggestedActions: AssistanceResponse['suggestedActions'] | undefined; let newValues: CellValue[] | undefined; - let expected: CellValue[] | undefined; let formula: string | undefined; let history: AssistanceState = {messages: []}; let lastFollowUp: string | undefined; - try { - async function sendMessage(followUp?: string) { - // load new document - if (!activeDoc || activeDoc.docName !== rec.doc_id) { - const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist'); - activeDoc = await docTools.loadLocalDoc(docPath); - await activeDoc.waitForInitialization(); - } + // load new document + if (!activeDoc || activeDoc.docName !== rec.doc_id) { + const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist'); + activeDoc = await docTools.loadLocalDoc(docPath); + await activeDoc.waitForInitialization(); + } - if (!activeDoc) { throw new Error("No doc"); } + // get values + await activeDoc.docData!.fetchTable(rec.table_id); + const expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice(); - // get values - await activeDoc.docData!.fetchTable(rec.table_id); - expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice(); + async function sendMessage(followUp?: string, rowId?: number) { + if (!activeDoc) { + throw new Error("No doc"); + } - // send prompt - const tableId = rec.table_id; - const colId = rec.col_id; - const description = rec.Description; - const colInfo = await activeDoc.docStorage.get(` -select * from _grist_Tables_column as c -left join _grist_Tables as t on t.id = c.parentId -where c.colId = ? and t.tableId = ? -`, rec.col_id, rec.table_id); + // send prompt + const tableId = rec.table_id; + const colId = rec.col_id; + const description = rec.Description; + const colInfo = await activeDoc.docStorage.get(` + select * + from _grist_Tables_column as c + left join _grist_Tables as t on t.id = c.parentId + where c.colId = ? + and t.tableId = ? + `, rec.col_id, rec.table_id); formula = colInfo?.formula; const result = await sendForCompletion(session, activeDoc, { conversationId: 'conversationId', - context: {type: 'formula', tableId, colId}, + context: { + type: 'formula', + tableId, + colId, + evaluateCurrentFormula: Boolean(followUp) && FOLLOWUP_EVALUATE, + rowId, + }, state: history, text: followUp || description, }); @@ -174,63 +182,50 @@ where c.colId = ? and t.tableId = ? } if (rec.no_formula == "1") { success = result.suggestedActions.length === 0; - return null; + return; } suggestedActions = result.suggestedActions; + if (!suggestedActions.length) { + success = false; + return; + } + // apply modification const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions); // get new values newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice(); - // revert modification - const [bundle] = await activeDoc.getActions([actionNum]); - await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true); - // compare values success = isEqual(expected, newValues); - if (!success) { - const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds(); - const result = await activeDoc.getFormulaError({client: null, mode: 'system'} as any, rec.table_id, - rec.col_id, rowIds[0]); - if (Array.isArray(result) && result[0] === 'E') { - result.shift(); - const txt = `I got a \`${result.shift()}\` error:\n` + - '```\n' + - result.shift() + '\n' + - result.shift() + '\n' + - '```\n' + - 'Please answer with the code block you (the assistant) just gave, ' + - 'revised based on this error. Your answer must include a code block. ' + - 'If you have to explain anything, do it after. ' + - 'It is perfectly acceptable (and may be necessary) to do ' + - 'imports from within a method body.\n'; - return { followUp: txt }; - } else { - for (let i = 0; i < expected.length; i++) { - const e = expected[i]; - const v = newValues[i]; - if (String(e) !== String(v)) { - const txt = `I got \`${v}\` where I expected \`${e}\`\n` + - 'Please answer with the code block you (the assistant) just gave, ' + - 'revised based on this information. Your answer must include a code ' + - 'block. If you have to explain anything, do it after.\n'; - return { followUp: txt }; + if (!success && SIMULATE_CONVERSATION) { + for (let i = 0; i < expected.length; i++) { + const e = expected[i]; + const v = newValues[i]; + if (String(e) !== String(v)) { + const txt = `I got \`${v}\` where I expected \`${e}\`\n` + + 'Please answer with the code block you (the assistant) just gave, ' + + 'revised based on this information. Your answer must include a code ' + + 'block. If you have to explain anything, do it after.\n'; + const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds(); + const rowId = rowIds[i]; + if (followUp) { + lastFollowUp = txt; + } else { + await sendMessage(txt, rowId); } + break; } } } - return null; - } - const result = await sendMessage(); - if (result?.followUp && SIMULATE_CONVERSATION) { - // Allow one follow up message, based on error or differences. - const result2 = await sendMessage(result.followUp); - if (result2?.followUp) { - lastFollowUp = result2.followUp; - } - } + // revert modification + const [bundle] = await activeDoc.getActions([actionNum]); + await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true); + } + + try { + await sendMessage(); } catch (e) { console.error(e); }