(core) Add AI Assistant retry with shorter prompt

Summary:
If the longer OpenAI model exceeds the OpenAPI context length, we now perform another retry with a
shorter variant of the formula prompt. The shorter prompt excludes non-referenced tables and lookup
method definitions, which should help reduce token usage in documents with larger schemas.

Test Plan: Server test.

Reviewers: JakubSerafin

Reviewed By: JakubSerafin

Subscribers: JakubSerafin

Differential Revision: https://phab.getgrist.com/D4184
This commit is contained in:
George Gevoian 2024-02-11 22:11:06 -05:00
parent 43a76235c7
commit 94eec5e906
5 changed files with 175 additions and 65 deletions

View File

@ -1298,12 +1298,22 @@ export class ActiveDoc extends EventEmitter {
// Callback to generate a prompt containing schema info for assistance. // Callback to generate a prompt containing schema info for assistance.
public async assistanceSchemaPromptV1( public async assistanceSchemaPromptV1(
docSession: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string> { docSession: OptDocSession,
context: AssistanceSchemaPromptV1Context
): Promise<string> {
// Making a prompt leaks names of tables and columns etc. // Making a prompt leaks names of tables and columns etc.
if (!await this._granularAccess.canScanData(docSession)) { if (!await this._granularAccess.canScanData(docSession)) {
throw new Error("Permission denied"); throw new Error("Permission denied");
} }
return await this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
return await this._pyCall(
'get_formula_prompt',
context.tableId,
context.colId,
context.docString,
context.includeAllTables ?? true,
context.includeLookups ?? true
);
} }
// Callback to make a data-engine formula tweak for assistance. // Callback to make a data-engine formula tweak for assistance.

View File

@ -67,19 +67,26 @@ export interface AssistanceFormulaEvaluationResult {
formula: string; // the code that was evaluated, without special grist syntax formula: string; // the code that was evaluated, without special grist syntax
} }
export interface AssistanceSchemaPromptV1Context { export interface AssistanceSchemaPromptV1Options {
tableId: string, includeAllTables?: boolean;
colId: string, includeLookups?: boolean;
docString: string,
} }
class SwitchToLongerContext extends Error { export interface AssistanceSchemaPromptV1Context extends AssistanceSchemaPromptV1Options {
tableId: string;
colId: string;
docString: string;
} }
type AssistanceSchemaPromptGenerator = (options?: AssistanceSchemaPromptV1Options) => Promise<AssistanceMessage>;
class NonRetryableError extends Error { class NonRetryableError extends Error {
} }
class TokensExceededFirstMessage extends NonRetryableError { class TokensExceededError extends NonRetryableError {
}
class TokensExceededFirstMessageError extends TokensExceededError {
constructor() { constructor() {
super( super(
"Sorry, there's too much information for the AI to process. " + "Sorry, there's too much information for the AI to process. " +
@ -88,7 +95,7 @@ class TokensExceededFirstMessage extends NonRetryableError {
} }
} }
class TokensExceededLaterMessage extends NonRetryableError { class TokensExceededLaterMessageError extends TokensExceededError {
constructor() { constructor() {
super( super(
"Sorry, there's too much information for the AI to process. " + "Sorry, there's too much information for the AI to process. " +
@ -168,29 +175,13 @@ export class OpenAIAssistant implements Assistant {
} }
public async apply( public async apply(
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> { optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest
): Promise<AssistanceResponse> {
const generatePrompt = this._buildSchemaPromptGenerator(optSession, doc, request);
const messages = request.state?.messages || []; const messages = request.state?.messages || [];
const newMessages = []; const newMessages: AssistanceMessage[] = [];
if (messages.length === 0) { if (messages.length === 0) {
newMessages.push({ newMessages.push(await generatePrompt());
role: 'system',
content: 'You are a helpful assistant for a user of software called Grist. ' +
"Below are one or more fake Python classes representing the structure of the user's data. " +
'The function at the end needs completing. ' +
"The user will probably give a description of what they want the function (a 'formula') to return. " +
'If so, your response should include the function BODY as Python code in a markdown block. ' +
"Your response will be automatically concatenated to the code below, so you mustn't repeat any of it. " +
'You cannot change the function signature or define additional functions or classes. ' +
'It should be a pure function that performs some computation and returns a result. ' +
'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' +
'It CANNOT interact with files/databases/networks/etc. ' +
'It CANNOT display images/charts/graphs/maps/etc. ' +
'If the user asks for these things, tell them that you cannot help. ' +
"\n\n" +
'```python\n' +
await makeSchemaPromptV1(optSession, doc, request) +
'\n```',
});
} }
if (request.context.evaluateCurrentFormula) { if (request.context.evaluateCurrentFormula) {
const result = await doc.assistanceEvaluateFormula(request.context); const result = await doc.assistanceEvaluateFormula(request.context);
@ -225,8 +216,11 @@ export class OpenAIAssistant implements Assistant {
}); });
} }
const userIdHash = getUserHash(optSession); const completion = await this._getCompletion(messages, {
const completion: string = await this._getCompletion(messages, userIdHash); generatePrompt,
user: getUserHash(optSession),
});
messages.push({role: 'assistant', content: completion});
// It's nice to have this ready to uncomment for debugging. // It's nice to have this ready to uncomment for debugging.
// console.log(completion); // console.log(completion);
@ -254,8 +248,8 @@ export class OpenAIAssistant implements Assistant {
return response; return response;
} }
private async _fetchCompletion(messages: AssistanceMessage[], userIdHash: string, longerContext: boolean) { private async _fetchCompletion(messages: AssistanceMessage[], params: {user: string, model?: string}) {
const model = longerContext ? this._longerContextModel : this._model; const {user, model} = params;
const apiResponse = await DEPS.fetch( const apiResponse = await DEPS.fetch(
this._endpoint, this._endpoint,
{ {
@ -270,7 +264,7 @@ export class OpenAIAssistant implements Assistant {
messages, messages,
temperature: 0, temperature: 0,
...(model ? { model } : undefined), ...(model ? { model } : undefined),
user: userIdHash, user,
...(this._maxTokens ? { ...(this._maxTokens ? {
max_tokens: this._maxTokens, max_tokens: this._maxTokens,
} : undefined), } : undefined),
@ -280,14 +274,13 @@ export class OpenAIAssistant implements Assistant {
const resultText = await apiResponse.text(); const resultText = await apiResponse.text();
const result = JSON.parse(resultText); const result = JSON.parse(resultText);
const errorCode = result.error?.code; const errorCode = result.error?.code;
const errorMessage = result.error?.message;
if (errorCode === "context_length_exceeded" || result.choices?.[0].finish_reason === "length") { if (errorCode === "context_length_exceeded" || result.choices?.[0].finish_reason === "length") {
if (!longerContext && this._longerContextModel) { log.warn("OpenAI context length exceeded: ", errorMessage);
log.info("Switching to longer context model..."); if (messages.length <= 2) {
throw new SwitchToLongerContext(); throw new TokensExceededFirstMessageError();
} else if (messages.length <= 2) {
throw new TokensExceededFirstMessage();
} else { } else {
throw new TokensExceededLaterMessage(); throw new TokensExceededLaterMessageError();
} }
} }
if (errorCode === "insufficient_quota") { if (errorCode === "insufficient_quota") {
@ -297,35 +290,99 @@ export class OpenAIAssistant implements Assistant {
if (apiResponse.status !== 200) { if (apiResponse.status !== 200) {
throw new Error(`OpenAI API returned status ${apiResponse.status}: ${resultText}`); throw new Error(`OpenAI API returned status ${apiResponse.status}: ${resultText}`);
} }
return result; return result.choices[0].message.content;
} }
private async _fetchCompletionWithRetries( private async _fetchCompletionWithRetries(messages: AssistanceMessage[], params: {
messages: AssistanceMessage[], userIdHash: string, longerContext: boolean user: string,
): Promise<any> { model?: string,
}): Promise<any> {
let attempts = 0;
const maxAttempts = 3; const maxAttempts = 3;
for (let attempt = 1; ; attempt++) { while (attempts < maxAttempts) {
try { try {
return await this._fetchCompletion(messages, userIdHash, longerContext); return await this._fetchCompletion(messages, params);
} catch (e) { } catch (e) {
if (e instanceof SwitchToLongerContext) { if (e instanceof NonRetryableError) {
return await this._fetchCompletionWithRetries(messages, userIdHash, true);
} else if (e instanceof NonRetryableError) {
throw e; throw e;
} else if (attempt === maxAttempts) { }
attempts += 1;
if (attempts === maxAttempts) {
throw new RetryableError(e.toString()); throw new RetryableError(e.toString());
} }
log.warn(`Waiting and then retrying after error: ${e}`); log.warn(`Waiting and then retrying after error: ${e}`);
await delay(DEPS.delayTime); await delay(DEPS.delayTime);
} }
} }
} }
private async _getCompletion(messages: AssistanceMessage[], userIdHash: string) { private async _getCompletion(
const result = await this._fetchCompletionWithRetries(messages, userIdHash, false); messages: AssistanceMessage[],
const {message} = result.choices[0]; params: {
messages.push(message); generatePrompt: AssistanceSchemaPromptGenerator,
return message.content; user: string,
}
): Promise<string> {
const {generatePrompt, user} = params;
// First try fetching the completion with the default model.
try {
return await this._fetchCompletionWithRetries(messages, {user, model: this._model});
} catch (e) {
if (!(e instanceof TokensExceededError)) {
throw e;
}
}
// If we hit the token limit and a model with a longer context length is
// available, try it.
if (this._longerContextModel) {
try {
return await this._fetchCompletionWithRetries(messages, {
user,
model: this._longerContextModel,
});
} catch (e) {
if (!(e instanceof TokensExceededError)) {
throw e;
}
}
}
// If we (still) hit the token limit, try a shorter schema prompt as a last resort.
const prompt = await generatePrompt({includeAllTables: false, includeLookups: false});
return await this._fetchCompletionWithRetries([prompt, ...messages.slice(1)], {
user,
model: this._longerContextModel || this._model,
});
}
private _buildSchemaPromptGenerator(
optSession: OptDocSession,
doc: AssistanceDoc,
request: AssistanceRequest
): AssistanceSchemaPromptGenerator {
return async (options) => ({
role: 'system',
content: 'You are a helpful assistant for a user of software called Grist. ' +
"Below are one or more fake Python classes representing the structure of the user's data. " +
'The function at the end needs completing. ' +
"The user will probably give a description of what they want the function (a 'formula') to return. " +
'If so, your response should include the function BODY as Python code in a markdown block. ' +
"Your response will be automatically concatenated to the code below, so you mustn't repeat any of it. " +
'You cannot change the function signature or define additional functions or classes. ' +
'It should be a pure function that performs some computation and returns a result. ' +
'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' +
'It CANNOT interact with files/databases/networks/etc. ' +
'It CANNOT display images/charts/graphs/maps/etc. ' +
'If the user asks for these things, tell them that you cannot help. ' +
"\n\n" +
'```python\n' +
await makeSchemaPromptV1(optSession, doc, request, options) +
'\n```',
});
} }
} }
@ -457,14 +514,21 @@ export function replaceMarkdownCode(markdown: string, replaceValue: string) {
return markdown.replace(/```\w*\n(.*)```/s, '```python\n' + replaceValue + '\n```'); return markdown.replace(/```\w*\n(.*)```/s, '```python\n' + replaceValue + '\n```');
} }
async function makeSchemaPromptV1(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest) { async function makeSchemaPromptV1(
session: OptDocSession,
doc: AssistanceDoc,
request: AssistanceRequest,
options: AssistanceSchemaPromptV1Options = {}
) {
if (request.context.type !== 'formula') { if (request.context.type !== 'formula') {
throw new Error('makeSchemaPromptV1 only works for formulas'); throw new Error('makeSchemaPromptV1 only works for formulas');
} }
return doc.assistanceSchemaPromptV1(session, { return doc.assistanceSchemaPromptV1(session, {
tableId: request.context.tableId, tableId: request.context.tableId,
colId: request.context.colId, colId: request.context.colId,
docString: request.text, docString: request.text,
...options,
}); });
} }

View File

@ -160,7 +160,7 @@ def get_formula_prompt(engine, table_id, col_id, _description,
other_tables = (all_other_tables(engine, table_id) other_tables = (all_other_tables(engine, table_id)
if include_all_tables else referenced_tables(engine, table_id)) if include_all_tables else referenced_tables(engine, table_id))
for other_table_id in sorted(other_tables): for other_table_id in sorted(other_tables):
result += class_schema(engine, other_table_id, lookups) result += class_schema(engine, other_table_id, None, lookups)
result += class_schema(engine, table_id, col_id, lookups) result += class_schema(engine, table_id, col_id, lookups)

View File

@ -146,8 +146,9 @@ def run(sandbox):
return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id)) return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id))
@export @export
def get_formula_prompt(table_id, col_id, description): def get_formula_prompt(table_id, col_id, description, include_all_tables=True, lookups=True):
return formula_prompt.get_formula_prompt(eng, table_id, col_id, description) return formula_prompt.get_formula_prompt(eng, table_id, col_id, description,
include_all_tables, lookups)
@export @export
def convert_formula_completion(completion): def convert_formula_completion(completion):

View File

@ -18,7 +18,8 @@ describe('Assistance', function () {
this.timeout(10000); this.timeout(10000);
const docTools = createDocTools({persistAcrossCases: true}); const docTools = createDocTools({persistAcrossCases: true});
const tableId = "Table1"; const table1Id = "Table1";
const table2Id = "Table2";
let session: DocSession; let session: DocSession;
let doc: ActiveDoc; let doc: ActiveDoc;
before(async () => { before(async () => {
@ -26,7 +27,8 @@ describe('Assistance', function () {
session = docTools.createFakeSession(); session = docTools.createFakeSession();
doc = await docTools.createDoc('test.grist'); doc = await docTools.createDoc('test.grist');
await doc.applyUserActions(session, [ await doc.applyUserActions(session, [
["AddTable", tableId, [{id: "A"}, {id: "B"}, {id: "C"}]], ["AddTable", table1Id, [{id: "A"}, {id: "B"}, {id: "C"}]],
["AddTable", table2Id, [{id: "A"}, {id: "B"}, {id: "C"}]],
]); ]);
}); });
@ -36,7 +38,7 @@ describe('Assistance', function () {
function checkSendForCompletion(state?: AssistanceState) { function checkSendForCompletion(state?: AssistanceState) {
return sendForCompletion(session, doc, { return sendForCompletion(session, doc, {
conversationId: 'conversationId', conversationId: 'conversationId',
context: {type: 'formula', tableId, colId}, context: {type: 'formula', tableId: table1Id, colId},
state, state,
text: userMessageContent, text: userMessageContent,
}); });
@ -108,7 +110,7 @@ describe('Assistance', function () {
+ "\n\nLet me know if there's anything else I can help with."; + "\n\nLet me know if there's anything else I can help with.";
assert.deepEqual(result, { assert.deepEqual(result, {
suggestedActions: [ suggestedActions: [
["ModifyColumn", tableId, colId, {formula: suggestedFormula}] ["ModifyColumn", table1Id, colId, {formula: suggestedFormula}]
], ],
suggestedFormula, suggestedFormula,
reply: replyWithSuggestedFormula, reply: replyWithSuggestedFormula,
@ -203,9 +205,40 @@ describe('Assistance', function () {
checkModels([ checkModels([
OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_MODEL,
OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL,
OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL,
]); ]);
}); });
it('switches to a shorter prompt if the longer model exceeds its token limit', async function () {
fakeResponse = () => ({
error: {
code: "context_length_exceeded",
},
status: 400,
});
await assert.isRejected(
checkSendForCompletion(),
/You'll need to either shorten your message or delete some columns/
);
fakeFetch.getCalls().map((callInfo, i) => {
const [, request] = callInfo.args;
const {messages} = JSON.parse(request.body);
const systemMessageContent = messages[0].content;
const shortCallIndex = 2;
if (i === shortCallIndex) {
assert.match(systemMessageContent, /class Table1/);
assert.notMatch(systemMessageContent, /class Table2/);
assert.notMatch(systemMessageContent, /def lookupOne/);
assert.lengthOf(systemMessageContent, 1001);
} else {
assert.match(systemMessageContent, /class Table1/);
assert.match(systemMessageContent, /class Table2/);
assert.match(systemMessageContent, /def lookupOne/);
assert.lengthOf(systemMessageContent, 1982);
}
});
});
it('switches to a longer model with no retries if the model runs out of tokens while responding', async function () { it('switches to a longer model with no retries if the model runs out of tokens while responding', async function () {
fakeResponse = () => ({ fakeResponse = () => ({
"choices": [{ "choices": [{
@ -222,6 +255,7 @@ describe('Assistance', function () {
checkModels([ checkModels([
OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_MODEL,
OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL,
OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL,
]); ]);
}); });
@ -245,6 +279,7 @@ describe('Assistance', function () {
checkModels([ checkModels([
OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_MODEL,
OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL,
OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL,
]); ]);
}); });
@ -279,7 +314,7 @@ describe('Assistance', function () {
OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL,
]); ]);
assert.deepEqual(result.suggestedActions, [ assert.deepEqual(result.suggestedActions, [
["ModifyColumn", tableId, colId, {formula: "123"}] ["ModifyColumn", table1Id, colId, {formula: "123"}]
]); ]);
}); });
}); });