(core) Allow assistant to evaluate current formula

Summary:
Replaces https://phab.getgrist.com/D3940, particularly to avoid doing potentially unwanted things automatically.

Adds optional fields `evaluateCurrentFormula?: boolean; rowId?: number` to `FormulaAssistanceContext` (part of `AssistanceRequest`). When `evaluateCurrentFormula` is `true`, calls a new function `evaluate_formula` in the sandbox which computes the existing formula in the column (regardless of anything the AI may have suggested) and uses that to generate an additional system message which is added before the user's message. In theory this could be used in an interface where users ask why a formula doesn't work, including possibly a formula suggested by the AI. For now, it's only used in `runCompletion_impl.ts` for experimenting.

Also cleaned up a bit, removing `_chatMode` which is always `true` now, and uses of `regenerate` which is always `false`.

Test Plan: Updated `runCompletion_impl` to optionally use the new feature, in which case it now scores 51/68 instead of 49/68.

Reviewers: paulfitz

Reviewed By: paulfitz

Differential Revision: https://phab.getgrist.com/D3970
pull/590/head
Alex Hall 10 months ago
parent 14b14f116e
commit 391c8ee087

@ -27,6 +27,8 @@ export interface FormulaAssistanceContext {
type: 'formula';
tableId: string;
colId: string;
evaluateCurrentFormula?: boolean;
rowId?: number;
}
export type AssistanceContext = FormulaAssistanceContext;
@ -39,10 +41,8 @@ export interface AssistanceRequest {
context: AssistanceContext;
state?: AssistanceState;
text: string;
regenerate?: boolean; // Set if there was a previous request
// and response that should be omitted
// from history, or (if available) an
// alternative response generated.
// TODO this is no longer used and should be removed
regenerate?: boolean;
}
/**

@ -85,6 +85,7 @@ import {ParseFileResult, ParseOptions} from 'app/plugin/FileParserAPI';
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
import {compileAclFormula} from 'app/server/lib/ACLFormula';
import {AssistanceSchemaPromptV1Context} from 'app/server/lib/Assistance';
import {AssistanceContext} from 'app/common/AssistancePrompts';
import {Authorizer} from 'app/server/lib/Authorizer';
import {checksumFile} from 'app/server/lib/checksumFile';
import {Client} from 'app/server/lib/Client';
@ -1289,6 +1290,16 @@ export class ActiveDoc extends EventEmitter {
return this._pyCall('convert_formula_completion', txt);
}
// Callback to compute an existing formula and return the result along with recorded values
// of (possibly nested) attributes of `rec`.
// Used by AI assistance to fix an incorrect formula.
public assistanceEvaluateFormula(options: AssistanceContext) {
if (!options.evaluateCurrentFormula) {
throw new Error('evaluateCurrentFormula must be true');
}
return this._pyCall('evaluate_formula', options.tableId, options.colId, options.rowId);
}
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
}

@ -2,7 +2,12 @@
* Module with functions used for AI formula assistance.
*/
import {AssistanceMessage, AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {
AssistanceContext,
AssistanceMessage,
AssistanceRequest,
AssistanceResponse
} from 'app/common/AssistancePrompts';
import {delay} from 'app/common/delay';
import {DocAction} from 'app/common/DocActions';
import {ActiveDoc} from 'app/server/lib/ActiveDoc';
@ -37,10 +42,29 @@ interface AssistanceDoc extends ActiveDoc {
* be great to try variants.
*/
assistanceSchemaPromptV1(session: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string>;
/**
* Some tweaks to a formula after it has been generated.
*/
assistanceFormulaTweak(txt: string): Promise<string>;
/**
* Compute the existing formula and return the result along with recorded values
* of (possibly nested) attributes of `rec`.
* Used by AI assistance to fix an incorrect formula.
*/
assistanceEvaluateFormula(options: AssistanceContext): Promise<AssistanceFormulaEvaluationResult>;
}
export interface AssistanceFormulaEvaluationResult {
error: boolean; // true if an exception was raised
result: string; // repr of the return value OR exception message
// Recorded attributes of `rec` at the time of evaluation.
// Keys may be e.g. "rec.foo.bar" for nested attributes.
attributes: Record<string, string>;
formula: string; // the code that was evaluated, without special grist syntax
}
export interface AssistanceSchemaPromptV1Context {
@ -101,7 +125,6 @@ export class OpenAIAssistant implements Assistant {
public static LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613";
private _apiKey: string;
private _chatMode: boolean;
private _endpoint: string;
public constructor() {
@ -110,60 +133,52 @@ export class OpenAIAssistant implements Assistant {
throw new Error('OPENAI_API_KEY not set');
}
this._apiKey = apiKey;
this._chatMode = true;
if (!this._chatMode) {
throw new Error('Only turbo models are currently supported');
}
this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`;
this._endpoint = `https://api.openai.com/v1/chat/completions`;
}
public async apply(
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
const messages = request.state?.messages || [];
const newMessages = [];
const chatMode = this._chatMode;
if (chatMode) {
if (messages.length === 0) {
newMessages.push({
role: 'system',
content: 'You are a helpful assistant for a user of software called Grist. ' +
'Below are one or more Python classes. ' +
'The last method needs completing. ' +
"The user will probably give a description of what they want the method (a 'formula') to return. " +
'If so, your response should include the method body as Python code in a markdown block. ' +
'Do not include the class or method signature, just the method body. ' +
'If your code starts with `class`, `@dataclass`, or `def` it will fail. Only give the method body. ' +
'You can import modules inside the method body if needed. ' +
'You cannot define additional functions or methods. ' +
'The method should be a pure function that performs some computation and returns a result. ' +
'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' +
'It CANNOT interact with files/databases/networks/etc. ' +
'It CANNOT display images/charts/graphs/maps/etc. ' +
'If the user asks for these things, tell them that you cannot help. ' +
'The method uses `rec` instead of `self` as the first parameter.\n\n' +
'```python\n' +
await makeSchemaPromptV1(optSession, doc, request) +
'\n```',
});
newMessages.push({
role: 'user', content: request.text,
});
} else {
if (request.regenerate) {
if (messages[messages.length - 1].role !== 'user') {
messages.pop();
}
}
newMessages.push({
role: 'user', content: request.text,
});
if (messages.length === 0) {
newMessages.push({
role: 'system',
content: 'You are a helpful assistant for a user of software called Grist. ' +
'Below are one or more Python classes. ' +
'The last method needs completing. ' +
"The user will probably give a description of what they want the method (a 'formula') to return. " +
'If so, your response should include the method body as Python code in a markdown block. ' +
'Do not include the class or method signature, just the method body. ' +
'If your code starts with `class`, `@dataclass`, or `def` it will fail. Only give the method body. ' +
'You can import modules inside the method body if needed. ' +
'You cannot define additional functions or methods. ' +
'The method should be a pure function that performs some computation and returns a result. ' +
'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' +
'It CANNOT interact with files/databases/networks/etc. ' +
'It CANNOT display images/charts/graphs/maps/etc. ' +
'If the user asks for these things, tell them that you cannot help. ' +
'The method uses `rec` instead of `self` as the first parameter.\n\n' +
'```python\n' +
await makeSchemaPromptV1(optSession, doc, request) +
'\n```',
});
}
if (request.context.evaluateCurrentFormula) {
const result = await doc.assistanceEvaluateFormula(request.context);
let message = "Evaluating this code:\n\n```python\n" + result.formula + "\n```\n\n";
if (Object.keys(result.attributes).length > 0) {
const attributes = Object.entries(result.attributes).map(([k, v]) => `${k} = ${v}`).join('\n');
message += `where:\n\n${attributes}\n\n`;
}
} else {
messages.length = 0;
message += `${result.error ? 'raises an exception' : 'returns'}: ${result.result}`;
newMessages.push({
role: 'user', content: await makeSchemaPromptV1(optSession, doc, request),
role: 'system',
content: message,
});
}
newMessages.push({
role: 'user', content: request.text,
});
messages.push(...newMessages);
const newMessagesStartIndex = messages.length - newMessages.length;
@ -184,9 +199,7 @@ export class OpenAIAssistant implements Assistant {
const userIdHash = getUserHash(optSession);
const completion: string = await this._getCompletion(messages, userIdHash);
const response = await completionToResponse(doc, request, completion, completion);
if (chatMode) {
response.state = {messages};
}
response.state = {messages};
doc.logTelemetryEvent(optSession, 'assistantReceive', {
full: {
conversationId: request.conversationId,
@ -211,12 +224,9 @@ export class OpenAIAssistant implements Assistant {
"Content-Type": "application/json",
},
body: JSON.stringify({
...(!this._chatMode ? {
prompt: messages[messages.length - 1].content,
} : {messages}),
messages,
temperature: 0,
model: longerContext ? OpenAIAssistant.LONGER_CONTEXT_MODEL : OpenAIAssistant.DEFAULT_MODEL,
stop: this._chatMode ? undefined : ["\n\n"],
user: userIdHash,
}),
},
@ -267,11 +277,9 @@ export class OpenAIAssistant implements Assistant {
private async _getCompletion(messages: AssistanceMessage[], userIdHash: string) {
const result = await this._fetchCompletionWithRetries(messages, userIdHash, false);
const completion: string = String(this._chatMode ? result.choices[0].message.content : result.choices[0].text);
if (this._chatMode) {
messages.push(result.choices[0].message);
}
return completion;
const {message} = result.choices[0];
messages.push(message);
return message.content;
}
}
@ -404,6 +412,9 @@ export async function sendForCompletion(
doc: AssistanceDoc,
request: AssistanceRequest,
): Promise<AssistanceResponse> {
if (request.regenerate) {
throw new Error('regenerate no longer supported');
}
const assistant = getAssistant();
return await assistant.apply(optSession, doc, request);
}

@ -0,0 +1,59 @@
from six.moves import reprlib
import records
class AttributeRecorder(object):
"""
Wrapper around a Record that records attribute accesses.
Used to generate a prompt for the AI with basic 'debugging' info.
"""
def __init__(self, inner, name, attributes):
assert isinstance(inner, records.Record)
self._inner = inner
self._name = name
self._attributes = attributes
def __getattr__(self, name):
"""
Record attribute access.
If the result is a Record or RecordSet, wrap that with AttributeRecorder
to also record nested attribute values.
"""
result = getattr(self._inner, name)
full_name = "{}.{}".format(self._name, name)
if isinstance(result, records.Record):
result = AttributeRecorder(result, full_name, self._attributes)
elif isinstance(result, records.RecordSet):
# Use a tuple to imply immutability so that the AI doesn't try appending.
# Don't try recording attributes of all contained records, just record the first access.
# Pretend that the attribute is always accessed from the first record for simplicity.
result = tuple(AttributeRecorder(r, full_name + "[0]", self._attributes) for r in result)
self._attributes.setdefault(full_name, safe_repr(result))
return result
def __repr__(self):
# The usual Record repr looks like Table1[2] which may surprise the AI.
return "{}(id={})".format(self._inner._table.table_id, self._inner._row_id)
arepr = reprlib.Repr()
arepr.maxlevel = 3
arepr.maxtuple = 3
arepr.maxlist = 3
arepr.maxarray = 3
arepr.maxdict = 4
arepr.maxset = 3
arepr.maxfrozenset = 3
arepr.maxdeque = 3
arepr.maxstring = 40
arepr.maxlong = 20
arepr.maxother = 60
def safe_repr(x):
try:
return arepr.repr(x)
except Exception:
# Copied from Repr.repr_instance in Python 3.
return '<%s instance at %#x>' % (x.__class__.__name__, id(x))

@ -20,6 +20,7 @@ from sortedcontainers import SortedSet
import acl
import actions
import action_obj
from attribute_recorder import AttributeRecorder
from autocomplete_context import AutocompleteContext, lookup_autocomplete_options, eval_suggestion
from codebuilder import DOLLAR_REGEX
import depend
@ -694,22 +695,27 @@ class Engine(object):
not recomputing the whole column and dependent columns as well. So it recomputes the formula
for this cell and returns error message with details.
"""
result = self.get_formula_value(table_id, col_id, row_id)
table = self.tables[table_id]
col = table.get_column(col_id)
# If the error is gone for a trigger formula
if col.has_formula() and not col.is_formula():
if not isinstance(result, objtypes.RaisedException):
# Get the error stored in the cell
# and change it to show to the user that no traceback is available
error_in_cell = objtypes.decode_object(col.raw_get(row_id))
assert isinstance(error_in_cell, objtypes.RaisedException)
return error_in_cell.no_traceback()
return result
def get_formula_value(self, table_id, col_id, row_id, record_attributes=None):
table = self.tables[table_id]
col = table.get_column(col_id)
checkpoint = self._get_undo_checkpoint()
# Makes calls to REQUEST synchronous, since raising a RequestingError can't work here.
self._sync_request = True
try:
result = self._recompute_one_cell(table, col, row_id)
# If the error is gone for a trigger formula
if col.has_formula() and not col.is_formula():
if not isinstance(result, objtypes.RaisedException):
# Get the error stored in the cell
# and change it to show to the user that no traceback is available
error_in_cell = objtypes.decode_object(col.raw_get(row_id))
assert isinstance(error_in_cell, objtypes.RaisedException)
return error_in_cell.no_traceback()
return result
return self._recompute_one_cell(table, col, row_id, record_attributes=record_attributes)
finally:
# It is possible for formula evaluation to have side-effects that produce DocActions (e.g.
# lookupOrAddDerived() creates those). In case of get_formula_error(), these aren't fully
@ -920,7 +926,7 @@ class Engine(object):
raise RequestingError()
def _recompute_one_cell(self, table, col, row_id, cycle=False, node=None):
def _recompute_one_cell(self, table, col, row_id, cycle=False, node=None, record_attributes=None):
"""
Recomputes an one formula cell and returns a value.
The value can be:
@ -939,6 +945,11 @@ class Engine(object):
checkpoint = self._get_undo_checkpoint()
record = table.Record(row_id, table._identity_relation)
if record_attributes is not None:
assert isinstance(record_attributes, dict)
assert col.is_formula()
assert not cycle
record = AttributeRecorder(record, "rec", record_attributes)
value = None
try:
if cycle:

@ -7,6 +7,9 @@ import asttokens
import asttokens.util
import six
import attribute_recorder
import objtypes
from codebuilder import make_formula_body
from column import is_visible_column, BaseReferenceColumn
from objtypes import RaisedException
import records
@ -255,3 +258,25 @@ def convert_completion(completion):
result = asttokens.util.replace(result, replacements)
return result.strip()
def evaluate_formula(engine, table_id, col_id, row_id):
grist_formula = engine.docmodel.get_column_rec(table_id, col_id).formula
assert grist_formula
plain_formula = make_formula_body(grist_formula, default_value=None).get_text()
attributes = {}
result = engine.get_formula_value(table_id, col_id, row_id, record_attributes=attributes)
if isinstance(result, objtypes.RaisedException):
name, message = result.encode_args()[:2]
result = "%s: %s" % (name, message)
error = True
else:
result = attribute_recorder.safe_repr(result)
error = False
return dict(
error=error,
formula=plain_formula,
result=result,
attributes=attributes,
)

@ -153,6 +153,10 @@ def run(sandbox):
def convert_formula_completion(completion):
return formula_prompt.convert_completion(completion)
@export
def evaluate_formula(table_id, col_id, row_id):
return formula_prompt.evaluate_formula(eng, table_id, col_id, row_id)
export(parse_acl_formula)
export(eng.load_empty)
export(eng.load_done)

@ -64,8 +64,8 @@ const _stats = {
callCount: 0,
};
const SEEMS_CHATTY = (process.env.COMPLETION_MODEL || '').includes('turbo');
const SIMULATE_CONVERSATION = SEEMS_CHATTY;
const SIMULATE_CONVERSATION = true;
const FOLLOWUP_EVALUATE = false;
export async function runCompletion() {
ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600;
@ -132,40 +132,48 @@ export async function runCompletion() {
let success: boolean = false;
let suggestedActions: AssistanceResponse['suggestedActions'] | undefined;
let newValues: CellValue[] | undefined;
let expected: CellValue[] | undefined;
let formula: string | undefined;
let history: AssistanceState = {messages: []};
let lastFollowUp: string | undefined;
try {
async function sendMessage(followUp?: string) {
// load new document
if (!activeDoc || activeDoc.docName !== rec.doc_id) {
const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist');
activeDoc = await docTools.loadLocalDoc(docPath);
await activeDoc.waitForInitialization();
}
// load new document
if (!activeDoc || activeDoc.docName !== rec.doc_id) {
const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist');
activeDoc = await docTools.loadLocalDoc(docPath);
await activeDoc.waitForInitialization();
}
// get values
await activeDoc.docData!.fetchTable(rec.table_id);
const expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
if (!activeDoc) { throw new Error("No doc"); }
// get values
await activeDoc.docData!.fetchTable(rec.table_id);
expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
// send prompt
const tableId = rec.table_id;
const colId = rec.col_id;
const description = rec.Description;
const colInfo = await activeDoc.docStorage.get(`
select * from _grist_Tables_column as c
left join _grist_Tables as t on t.id = c.parentId
where c.colId = ? and t.tableId = ?
`, rec.col_id, rec.table_id);
async function sendMessage(followUp?: string, rowId?: number) {
if (!activeDoc) {
throw new Error("No doc");
}
// send prompt
const tableId = rec.table_id;
const colId = rec.col_id;
const description = rec.Description;
const colInfo = await activeDoc.docStorage.get(`
select *
from _grist_Tables_column as c
left join _grist_Tables as t on t.id = c.parentId
where c.colId = ?
and t.tableId = ?
`, rec.col_id, rec.table_id);
formula = colInfo?.formula;
const result = await sendForCompletion(session, activeDoc, {
conversationId: 'conversationId',
context: {type: 'formula', tableId, colId},
context: {
type: 'formula',
tableId,
colId,
evaluateCurrentFormula: Boolean(followUp) && FOLLOWUP_EVALUATE,
rowId,
},
state: history,
text: followUp || description,
});
@ -174,63 +182,50 @@ where c.colId = ? and t.tableId = ?
}
if (rec.no_formula == "1") {
success = result.suggestedActions.length === 0;
return null;
return;
}
suggestedActions = result.suggestedActions;
if (!suggestedActions.length) {
success = false;
return;
}
// apply modification
const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions);
// get new values
newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
// revert modification
const [bundle] = await activeDoc.getActions([actionNum]);
await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true);
// compare values
success = isEqual(expected, newValues);
if (!success) {
const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds();
const result = await activeDoc.getFormulaError({client: null, mode: 'system'} as any, rec.table_id,
rec.col_id, rowIds[0]);
if (Array.isArray(result) && result[0] === 'E') {
result.shift();
const txt = `I got a \`${result.shift()}\` error:\n` +
'```\n' +
result.shift() + '\n' +
result.shift() + '\n' +
'```\n' +
'Please answer with the code block you (the assistant) just gave, ' +
'revised based on this error. Your answer must include a code block. ' +
'If you have to explain anything, do it after. ' +
'It is perfectly acceptable (and may be necessary) to do ' +
'imports from within a method body.\n';
return { followUp: txt };
} else {
for (let i = 0; i < expected.length; i++) {
const e = expected[i];
const v = newValues[i];
if (String(e) !== String(v)) {
const txt = `I got \`${v}\` where I expected \`${e}\`\n` +
'Please answer with the code block you (the assistant) just gave, ' +
'revised based on this information. Your answer must include a code ' +
'block. If you have to explain anything, do it after.\n';
return { followUp: txt };
if (!success && SIMULATE_CONVERSATION) {
for (let i = 0; i < expected.length; i++) {
const e = expected[i];
const v = newValues[i];
if (String(e) !== String(v)) {
const txt = `I got \`${v}\` where I expected \`${e}\`\n` +
'Please answer with the code block you (the assistant) just gave, ' +
'revised based on this information. Your answer must include a code ' +
'block. If you have to explain anything, do it after.\n';
const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds();
const rowId = rowIds[i];
if (followUp) {
lastFollowUp = txt;
} else {
await sendMessage(txt, rowId);
}
break;
}
}
}
return null;
}
const result = await sendMessage();
if (result?.followUp && SIMULATE_CONVERSATION) {
// Allow one follow up message, based on error or differences.
const result2 = await sendMessage(result.followUp);
if (result2?.followUp) {
lastFollowUp = result2.followUp;
}
}
// revert modification
const [bundle] = await activeDoc.getActions([actionNum]);
await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true);
}
try {
await sendMessage();
} catch (e) {
console.error(e);
}

Loading…
Cancel
Save