(core) Allow assistant to evaluate current formula

Summary:
Replaces https://phab.getgrist.com/D3940, particularly to avoid doing potentially unwanted things automatically.

Adds optional fields `evaluateCurrentFormula?: boolean; rowId?: number` to `FormulaAssistanceContext` (part of `AssistanceRequest`). When `evaluateCurrentFormula` is `true`, calls a new function `evaluate_formula` in the sandbox which computes the existing formula in the column (regardless of anything the AI may have suggested) and uses that to generate an additional system message which is added before the user's message. In theory this could be used in an interface where users ask why a formula doesn't work, including possibly a formula suggested by the AI. For now, it's only used in `runCompletion_impl.ts` for experimenting.

Also cleaned up a bit, removing `_chatMode` which is always `true` now, and uses of `regenerate` which is always `false`.

Test Plan: Updated `runCompletion_impl` to optionally use the new feature, in which case it now scores 51/68 instead of 49/68.

Reviewers: paulfitz

Reviewed By: paulfitz

Differential Revision: https://phab.getgrist.com/D3970
This commit is contained in:
Alex Hall 2023-07-24 20:56:38 +02:00
parent 14b14f116e
commit 391c8ee087
8 changed files with 257 additions and 141 deletions

View File

@ -27,6 +27,8 @@ export interface FormulaAssistanceContext {
type: 'formula';
tableId: string;
colId: string;
evaluateCurrentFormula?: boolean;
rowId?: number;
}
export type AssistanceContext = FormulaAssistanceContext;
@ -39,10 +41,8 @@ export interface AssistanceRequest {
context: AssistanceContext;
state?: AssistanceState;
text: string;
regenerate?: boolean; // Set if there was a previous request
// and response that should be omitted
// from history, or (if available) an
// alternative response generated.
// TODO this is no longer used and should be removed
regenerate?: boolean;
}
/**

View File

@ -85,6 +85,7 @@ import {ParseFileResult, ParseOptions} from 'app/plugin/FileParserAPI';
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
import {compileAclFormula} from 'app/server/lib/ACLFormula';
import {AssistanceSchemaPromptV1Context} from 'app/server/lib/Assistance';
import {AssistanceContext} from 'app/common/AssistancePrompts';
import {Authorizer} from 'app/server/lib/Authorizer';
import {checksumFile} from 'app/server/lib/checksumFile';
import {Client} from 'app/server/lib/Client';
@ -1289,6 +1290,16 @@ export class ActiveDoc extends EventEmitter {
return this._pyCall('convert_formula_completion', txt);
}
// Callback to compute an existing formula and return the result along with recorded values
// of (possibly nested) attributes of `rec`.
// Used by AI assistance to fix an incorrect formula.
public assistanceEvaluateFormula(options: AssistanceContext) {
if (!options.evaluateCurrentFormula) {
throw new Error('evaluateCurrentFormula must be true');
}
return this._pyCall('evaluate_formula', options.tableId, options.colId, options.rowId);
}
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
}

View File

@ -2,7 +2,12 @@
* Module with functions used for AI formula assistance.
*/
import {AssistanceMessage, AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {
AssistanceContext,
AssistanceMessage,
AssistanceRequest,
AssistanceResponse
} from 'app/common/AssistancePrompts';
import {delay} from 'app/common/delay';
import {DocAction} from 'app/common/DocActions';
import {ActiveDoc} from 'app/server/lib/ActiveDoc';
@ -37,10 +42,29 @@ interface AssistanceDoc extends ActiveDoc {
* be great to try variants.
*/
assistanceSchemaPromptV1(session: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string>;
/**
* Some tweaks to a formula after it has been generated.
*/
assistanceFormulaTweak(txt: string): Promise<string>;
/**
* Compute the existing formula and return the result along with recorded values
* of (possibly nested) attributes of `rec`.
* Used by AI assistance to fix an incorrect formula.
*/
assistanceEvaluateFormula(options: AssistanceContext): Promise<AssistanceFormulaEvaluationResult>;
}
export interface AssistanceFormulaEvaluationResult {
error: boolean; // true if an exception was raised
result: string; // repr of the return value OR exception message
// Recorded attributes of `rec` at the time of evaluation.
// Keys may be e.g. "rec.foo.bar" for nested attributes.
attributes: Record<string, string>;
formula: string; // the code that was evaluated, without special grist syntax
}
export interface AssistanceSchemaPromptV1Context {
@ -101,7 +125,6 @@ export class OpenAIAssistant implements Assistant {
public static LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613";
private _apiKey: string;
private _chatMode: boolean;
private _endpoint: string;
public constructor() {
@ -110,19 +133,13 @@ export class OpenAIAssistant implements Assistant {
throw new Error('OPENAI_API_KEY not set');
}
this._apiKey = apiKey;
this._chatMode = true;
if (!this._chatMode) {
throw new Error('Only turbo models are currently supported');
}
this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`;
this._endpoint = `https://api.openai.com/v1/chat/completions`;
}
public async apply(
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
const messages = request.state?.messages || [];
const newMessages = [];
const chatMode = this._chatMode;
if (chatMode) {
if (messages.length === 0) {
newMessages.push({
role: 'system',
@ -145,25 +162,23 @@ export class OpenAIAssistant implements Assistant {
await makeSchemaPromptV1(optSession, doc, request) +
'\n```',
});
newMessages.push({
role: 'user', content: request.text,
});
} else {
if (request.regenerate) {
if (messages[messages.length - 1].role !== 'user') {
messages.pop();
}
if (request.context.evaluateCurrentFormula) {
const result = await doc.assistanceEvaluateFormula(request.context);
let message = "Evaluating this code:\n\n```python\n" + result.formula + "\n```\n\n";
if (Object.keys(result.attributes).length > 0) {
const attributes = Object.entries(result.attributes).map(([k, v]) => `${k} = ${v}`).join('\n');
message += `where:\n\n${attributes}\n\n`;
}
message += `${result.error ? 'raises an exception' : 'returns'}: ${result.result}`;
newMessages.push({
role: 'system',
content: message,
});
}
newMessages.push({
role: 'user', content: request.text,
});
}
} else {
messages.length = 0;
newMessages.push({
role: 'user', content: await makeSchemaPromptV1(optSession, doc, request),
});
}
messages.push(...newMessages);
const newMessagesStartIndex = messages.length - newMessages.length;
@ -184,9 +199,7 @@ export class OpenAIAssistant implements Assistant {
const userIdHash = getUserHash(optSession);
const completion: string = await this._getCompletion(messages, userIdHash);
const response = await completionToResponse(doc, request, completion, completion);
if (chatMode) {
response.state = {messages};
}
doc.logTelemetryEvent(optSession, 'assistantReceive', {
full: {
conversationId: request.conversationId,
@ -211,12 +224,9 @@ export class OpenAIAssistant implements Assistant {
"Content-Type": "application/json",
},
body: JSON.stringify({
...(!this._chatMode ? {
prompt: messages[messages.length - 1].content,
} : {messages}),
messages,
temperature: 0,
model: longerContext ? OpenAIAssistant.LONGER_CONTEXT_MODEL : OpenAIAssistant.DEFAULT_MODEL,
stop: this._chatMode ? undefined : ["\n\n"],
user: userIdHash,
}),
},
@ -267,11 +277,9 @@ export class OpenAIAssistant implements Assistant {
private async _getCompletion(messages: AssistanceMessage[], userIdHash: string) {
const result = await this._fetchCompletionWithRetries(messages, userIdHash, false);
const completion: string = String(this._chatMode ? result.choices[0].message.content : result.choices[0].text);
if (this._chatMode) {
messages.push(result.choices[0].message);
}
return completion;
const {message} = result.choices[0];
messages.push(message);
return message.content;
}
}
@ -404,6 +412,9 @@ export async function sendForCompletion(
doc: AssistanceDoc,
request: AssistanceRequest,
): Promise<AssistanceResponse> {
if (request.regenerate) {
throw new Error('regenerate no longer supported');
}
const assistant = getAssistant();
return await assistant.apply(optSession, doc, request);
}

View File

@ -0,0 +1,59 @@
from six.moves import reprlib
import records
class AttributeRecorder(object):
"""
Wrapper around a Record that records attribute accesses.
Used to generate a prompt for the AI with basic 'debugging' info.
"""
def __init__(self, inner, name, attributes):
assert isinstance(inner, records.Record)
self._inner = inner
self._name = name
self._attributes = attributes
def __getattr__(self, name):
"""
Record attribute access.
If the result is a Record or RecordSet, wrap that with AttributeRecorder
to also record nested attribute values.
"""
result = getattr(self._inner, name)
full_name = "{}.{}".format(self._name, name)
if isinstance(result, records.Record):
result = AttributeRecorder(result, full_name, self._attributes)
elif isinstance(result, records.RecordSet):
# Use a tuple to imply immutability so that the AI doesn't try appending.
# Don't try recording attributes of all contained records, just record the first access.
# Pretend that the attribute is always accessed from the first record for simplicity.
result = tuple(AttributeRecorder(r, full_name + "[0]", self._attributes) for r in result)
self._attributes.setdefault(full_name, safe_repr(result))
return result
def __repr__(self):
# The usual Record repr looks like Table1[2] which may surprise the AI.
return "{}(id={})".format(self._inner._table.table_id, self._inner._row_id)
arepr = reprlib.Repr()
arepr.maxlevel = 3
arepr.maxtuple = 3
arepr.maxlist = 3
arepr.maxarray = 3
arepr.maxdict = 4
arepr.maxset = 3
arepr.maxfrozenset = 3
arepr.maxdeque = 3
arepr.maxstring = 40
arepr.maxlong = 20
arepr.maxother = 60
def safe_repr(x):
try:
return arepr.repr(x)
except Exception:
# Copied from Repr.repr_instance in Python 3.
return '<%s instance at %#x>' % (x.__class__.__name__, id(x))

View File

@ -20,6 +20,7 @@ from sortedcontainers import SortedSet
import acl
import actions
import action_obj
from attribute_recorder import AttributeRecorder
from autocomplete_context import AutocompleteContext, lookup_autocomplete_options, eval_suggestion
from codebuilder import DOLLAR_REGEX
import depend
@ -694,13 +695,9 @@ class Engine(object):
not recomputing the whole column and dependent columns as well. So it recomputes the formula
for this cell and returns error message with details.
"""
result = self.get_formula_value(table_id, col_id, row_id)
table = self.tables[table_id]
col = table.get_column(col_id)
checkpoint = self._get_undo_checkpoint()
# Makes calls to REQUEST synchronous, since raising a RequestingError can't work here.
self._sync_request = True
try:
result = self._recompute_one_cell(table, col, row_id)
# If the error is gone for a trigger formula
if col.has_formula() and not col.is_formula():
if not isinstance(result, objtypes.RaisedException):
@ -710,6 +707,15 @@ class Engine(object):
assert isinstance(error_in_cell, objtypes.RaisedException)
return error_in_cell.no_traceback()
return result
def get_formula_value(self, table_id, col_id, row_id, record_attributes=None):
table = self.tables[table_id]
col = table.get_column(col_id)
checkpoint = self._get_undo_checkpoint()
# Makes calls to REQUEST synchronous, since raising a RequestingError can't work here.
self._sync_request = True
try:
return self._recompute_one_cell(table, col, row_id, record_attributes=record_attributes)
finally:
# It is possible for formula evaluation to have side-effects that produce DocActions (e.g.
# lookupOrAddDerived() creates those). In case of get_formula_error(), these aren't fully
@ -920,7 +926,7 @@ class Engine(object):
raise RequestingError()
def _recompute_one_cell(self, table, col, row_id, cycle=False, node=None):
def _recompute_one_cell(self, table, col, row_id, cycle=False, node=None, record_attributes=None):
"""
Recomputes an one formula cell and returns a value.
The value can be:
@ -939,6 +945,11 @@ class Engine(object):
checkpoint = self._get_undo_checkpoint()
record = table.Record(row_id, table._identity_relation)
if record_attributes is not None:
assert isinstance(record_attributes, dict)
assert col.is_formula()
assert not cycle
record = AttributeRecorder(record, "rec", record_attributes)
value = None
try:
if cycle:

View File

@ -7,6 +7,9 @@ import asttokens
import asttokens.util
import six
import attribute_recorder
import objtypes
from codebuilder import make_formula_body
from column import is_visible_column, BaseReferenceColumn
from objtypes import RaisedException
import records
@ -255,3 +258,25 @@ def convert_completion(completion):
result = asttokens.util.replace(result, replacements)
return result.strip()
def evaluate_formula(engine, table_id, col_id, row_id):
grist_formula = engine.docmodel.get_column_rec(table_id, col_id).formula
assert grist_formula
plain_formula = make_formula_body(grist_formula, default_value=None).get_text()
attributes = {}
result = engine.get_formula_value(table_id, col_id, row_id, record_attributes=attributes)
if isinstance(result, objtypes.RaisedException):
name, message = result.encode_args()[:2]
result = "%s: %s" % (name, message)
error = True
else:
result = attribute_recorder.safe_repr(result)
error = False
return dict(
error=error,
formula=plain_formula,
result=result,
attributes=attributes,
)

View File

@ -153,6 +153,10 @@ def run(sandbox):
def convert_formula_completion(completion):
return formula_prompt.convert_completion(completion)
@export
def evaluate_formula(table_id, col_id, row_id):
return formula_prompt.evaluate_formula(eng, table_id, col_id, row_id)
export(parse_acl_formula)
export(eng.load_empty)
export(eng.load_done)

View File

@ -64,8 +64,8 @@ const _stats = {
callCount: 0,
};
const SEEMS_CHATTY = (process.env.COMPLETION_MODEL || '').includes('turbo');
const SIMULATE_CONVERSATION = SEEMS_CHATTY;
const SIMULATE_CONVERSATION = true;
const FOLLOWUP_EVALUATE = false;
export async function runCompletion() {
ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600;
@ -132,13 +132,10 @@ export async function runCompletion() {
let success: boolean = false;
let suggestedActions: AssistanceResponse['suggestedActions'] | undefined;
let newValues: CellValue[] | undefined;
let expected: CellValue[] | undefined;
let formula: string | undefined;
let history: AssistanceState = {messages: []};
let lastFollowUp: string | undefined;
try {
async function sendMessage(followUp?: string) {
// load new document
if (!activeDoc || activeDoc.docName !== rec.doc_id) {
const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist');
@ -146,26 +143,37 @@ export async function runCompletion() {
await activeDoc.waitForInitialization();
}
if (!activeDoc) { throw new Error("No doc"); }
// get values
await activeDoc.docData!.fetchTable(rec.table_id);
expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
const expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
async function sendMessage(followUp?: string, rowId?: number) {
if (!activeDoc) {
throw new Error("No doc");
}
// send prompt
const tableId = rec.table_id;
const colId = rec.col_id;
const description = rec.Description;
const colInfo = await activeDoc.docStorage.get(`
select * from _grist_Tables_column as c
left join _grist_Tables as t on t.id = c.parentId
where c.colId = ? and t.tableId = ?
`, rec.col_id, rec.table_id);
select *
from _grist_Tables_column as c
left join _grist_Tables as t on t.id = c.parentId
where c.colId = ?
and t.tableId = ?
`, rec.col_id, rec.table_id);
formula = colInfo?.formula;
const result = await sendForCompletion(session, activeDoc, {
conversationId: 'conversationId',
context: {type: 'formula', tableId, colId},
context: {
type: 'formula',
tableId,
colId,
evaluateCurrentFormula: Boolean(followUp) && FOLLOWUP_EVALUATE,
rowId,
},
state: history,
text: followUp || description,
});
@ -174,40 +182,24 @@ where c.colId = ? and t.tableId = ?
}
if (rec.no_formula == "1") {
success = result.suggestedActions.length === 0;
return null;
return;
}
suggestedActions = result.suggestedActions;
if (!suggestedActions.length) {
success = false;
return;
}
// apply modification
const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions);
// get new values
newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
// revert modification
const [bundle] = await activeDoc.getActions([actionNum]);
await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true);
// compare values
success = isEqual(expected, newValues);
if (!success) {
const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds();
const result = await activeDoc.getFormulaError({client: null, mode: 'system'} as any, rec.table_id,
rec.col_id, rowIds[0]);
if (Array.isArray(result) && result[0] === 'E') {
result.shift();
const txt = `I got a \`${result.shift()}\` error:\n` +
'```\n' +
result.shift() + '\n' +
result.shift() + '\n' +
'```\n' +
'Please answer with the code block you (the assistant) just gave, ' +
'revised based on this error. Your answer must include a code block. ' +
'If you have to explain anything, do it after. ' +
'It is perfectly acceptable (and may be necessary) to do ' +
'imports from within a method body.\n';
return { followUp: txt };
} else {
if (!success && SIMULATE_CONVERSATION) {
for (let i = 0; i < expected.length; i++) {
const e = expected[i];
const v = newValues[i];
@ -216,21 +208,24 @@ where c.colId = ? and t.tableId = ?
'Please answer with the code block you (the assistant) just gave, ' +
'revised based on this information. Your answer must include a code ' +
'block. If you have to explain anything, do it after.\n';
return { followUp: txt };
const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds();
const rowId = rowIds[i];
if (followUp) {
lastFollowUp = txt;
} else {
await sendMessage(txt, rowId);
}
break;
}
}
}
// revert modification
const [bundle] = await activeDoc.getActions([actionNum]);
await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true);
}
return null;
}
const result = await sendMessage();
if (result?.followUp && SIMULATE_CONVERSATION) {
// Allow one follow up message, based on error or differences.
const result2 = await sendMessage(result.followUp);
if (result2?.followUp) {
lastFollowUp = result2.followUp;
}
}
try {
await sendMessage();
} catch (e) {
console.error(e);
}