diff --git a/app/server/lib/ActiveDoc.ts b/app/server/lib/ActiveDoc.ts index 44341992..a7eb4c30 100644 --- a/app/server/lib/ActiveDoc.ts +++ b/app/server/lib/ActiveDoc.ts @@ -1298,12 +1298,22 @@ export class ActiveDoc extends EventEmitter { // Callback to generate a prompt containing schema info for assistance. public async assistanceSchemaPromptV1( - docSession: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise { + docSession: OptDocSession, + context: AssistanceSchemaPromptV1Context + ): Promise { // Making a prompt leaks names of tables and columns etc. if (!await this._granularAccess.canScanData(docSession)) { throw new Error("Permission denied"); } - return await this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString); + + return await this._pyCall( + 'get_formula_prompt', + context.tableId, + context.colId, + context.docString, + context.includeAllTables ?? true, + context.includeLookups ?? true + ); } // Callback to make a data-engine formula tweak for assistance. diff --git a/app/server/lib/Assistance.ts b/app/server/lib/Assistance.ts index 38f06f5d..875752ad 100644 --- a/app/server/lib/Assistance.ts +++ b/app/server/lib/Assistance.ts @@ -67,19 +67,26 @@ export interface AssistanceFormulaEvaluationResult { formula: string; // the code that was evaluated, without special grist syntax } -export interface AssistanceSchemaPromptV1Context { - tableId: string, - colId: string, - docString: string, +export interface AssistanceSchemaPromptV1Options { + includeAllTables?: boolean; + includeLookups?: boolean; } -class SwitchToLongerContext extends Error { +export interface AssistanceSchemaPromptV1Context extends AssistanceSchemaPromptV1Options { + tableId: string; + colId: string; + docString: string; } +type AssistanceSchemaPromptGenerator = (options?: AssistanceSchemaPromptV1Options) => Promise; + class NonRetryableError extends Error { } -class TokensExceededFirstMessage extends NonRetryableError { +class TokensExceededError extends NonRetryableError { +} + +class TokensExceededFirstMessageError extends TokensExceededError { constructor() { super( "Sorry, there's too much information for the AI to process. " + @@ -88,7 +95,7 @@ class TokensExceededFirstMessage extends NonRetryableError { } } -class TokensExceededLaterMessage extends NonRetryableError { +class TokensExceededLaterMessageError extends TokensExceededError { constructor() { super( "Sorry, there's too much information for the AI to process. " + @@ -168,29 +175,13 @@ export class OpenAIAssistant implements Assistant { } public async apply( - optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise { + optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest + ): Promise { + const generatePrompt = this._buildSchemaPromptGenerator(optSession, doc, request); const messages = request.state?.messages || []; - const newMessages = []; + const newMessages: AssistanceMessage[] = []; if (messages.length === 0) { - newMessages.push({ - role: 'system', - content: 'You are a helpful assistant for a user of software called Grist. ' + - "Below are one or more fake Python classes representing the structure of the user's data. " + - 'The function at the end needs completing. ' + - "The user will probably give a description of what they want the function (a 'formula') to return. " + - 'If so, your response should include the function BODY as Python code in a markdown block. ' + - "Your response will be automatically concatenated to the code below, so you mustn't repeat any of it. " + - 'You cannot change the function signature or define additional functions or classes. ' + - 'It should be a pure function that performs some computation and returns a result. ' + - 'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' + - 'It CANNOT interact with files/databases/networks/etc. ' + - 'It CANNOT display images/charts/graphs/maps/etc. ' + - 'If the user asks for these things, tell them that you cannot help. ' + - "\n\n" + - '```python\n' + - await makeSchemaPromptV1(optSession, doc, request) + - '\n```', - }); + newMessages.push(await generatePrompt()); } if (request.context.evaluateCurrentFormula) { const result = await doc.assistanceEvaluateFormula(request.context); @@ -225,8 +216,11 @@ export class OpenAIAssistant implements Assistant { }); } - const userIdHash = getUserHash(optSession); - const completion: string = await this._getCompletion(messages, userIdHash); + const completion = await this._getCompletion(messages, { + generatePrompt, + user: getUserHash(optSession), + }); + messages.push({role: 'assistant', content: completion}); // It's nice to have this ready to uncomment for debugging. // console.log(completion); @@ -254,8 +248,8 @@ export class OpenAIAssistant implements Assistant { return response; } - private async _fetchCompletion(messages: AssistanceMessage[], userIdHash: string, longerContext: boolean) { - const model = longerContext ? this._longerContextModel : this._model; + private async _fetchCompletion(messages: AssistanceMessage[], params: {user: string, model?: string}) { + const {user, model} = params; const apiResponse = await DEPS.fetch( this._endpoint, { @@ -270,7 +264,7 @@ export class OpenAIAssistant implements Assistant { messages, temperature: 0, ...(model ? { model } : undefined), - user: userIdHash, + user, ...(this._maxTokens ? { max_tokens: this._maxTokens, } : undefined), @@ -280,14 +274,13 @@ export class OpenAIAssistant implements Assistant { const resultText = await apiResponse.text(); const result = JSON.parse(resultText); const errorCode = result.error?.code; + const errorMessage = result.error?.message; if (errorCode === "context_length_exceeded" || result.choices?.[0].finish_reason === "length") { - if (!longerContext && this._longerContextModel) { - log.info("Switching to longer context model..."); - throw new SwitchToLongerContext(); - } else if (messages.length <= 2) { - throw new TokensExceededFirstMessage(); + log.warn("OpenAI context length exceeded: ", errorMessage); + if (messages.length <= 2) { + throw new TokensExceededFirstMessageError(); } else { - throw new TokensExceededLaterMessage(); + throw new TokensExceededLaterMessageError(); } } if (errorCode === "insufficient_quota") { @@ -297,35 +290,99 @@ export class OpenAIAssistant implements Assistant { if (apiResponse.status !== 200) { throw new Error(`OpenAI API returned status ${apiResponse.status}: ${resultText}`); } - return result; + return result.choices[0].message.content; } - private async _fetchCompletionWithRetries( - messages: AssistanceMessage[], userIdHash: string, longerContext: boolean - ): Promise { + private async _fetchCompletionWithRetries(messages: AssistanceMessage[], params: { + user: string, + model?: string, + }): Promise { + let attempts = 0; const maxAttempts = 3; - for (let attempt = 1; ; attempt++) { + while (attempts < maxAttempts) { try { - return await this._fetchCompletion(messages, userIdHash, longerContext); + return await this._fetchCompletion(messages, params); } catch (e) { - if (e instanceof SwitchToLongerContext) { - return await this._fetchCompletionWithRetries(messages, userIdHash, true); - } else if (e instanceof NonRetryableError) { + if (e instanceof NonRetryableError) { throw e; - } else if (attempt === maxAttempts) { + } + + attempts += 1; + if (attempts === maxAttempts) { throw new RetryableError(e.toString()); } + log.warn(`Waiting and then retrying after error: ${e}`); await delay(DEPS.delayTime); } } } - private async _getCompletion(messages: AssistanceMessage[], userIdHash: string) { - const result = await this._fetchCompletionWithRetries(messages, userIdHash, false); - const {message} = result.choices[0]; - messages.push(message); - return message.content; + private async _getCompletion( + messages: AssistanceMessage[], + params: { + generatePrompt: AssistanceSchemaPromptGenerator, + user: string, + } + ): Promise { + const {generatePrompt, user} = params; + + // First try fetching the completion with the default model. + try { + return await this._fetchCompletionWithRetries(messages, {user, model: this._model}); + } catch (e) { + if (!(e instanceof TokensExceededError)) { + throw e; + } + } + + // If we hit the token limit and a model with a longer context length is + // available, try it. + if (this._longerContextModel) { + try { + return await this._fetchCompletionWithRetries(messages, { + user, + model: this._longerContextModel, + }); + } catch (e) { + if (!(e instanceof TokensExceededError)) { + throw e; + } + } + } + + // If we (still) hit the token limit, try a shorter schema prompt as a last resort. + const prompt = await generatePrompt({includeAllTables: false, includeLookups: false}); + return await this._fetchCompletionWithRetries([prompt, ...messages.slice(1)], { + user, + model: this._longerContextModel || this._model, + }); + } + + private _buildSchemaPromptGenerator( + optSession: OptDocSession, + doc: AssistanceDoc, + request: AssistanceRequest + ): AssistanceSchemaPromptGenerator { + return async (options) => ({ + role: 'system', + content: 'You are a helpful assistant for a user of software called Grist. ' + + "Below are one or more fake Python classes representing the structure of the user's data. " + + 'The function at the end needs completing. ' + + "The user will probably give a description of what they want the function (a 'formula') to return. " + + 'If so, your response should include the function BODY as Python code in a markdown block. ' + + "Your response will be automatically concatenated to the code below, so you mustn't repeat any of it. " + + 'You cannot change the function signature or define additional functions or classes. ' + + 'It should be a pure function that performs some computation and returns a result. ' + + 'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' + + 'It CANNOT interact with files/databases/networks/etc. ' + + 'It CANNOT display images/charts/graphs/maps/etc. ' + + 'If the user asks for these things, tell them that you cannot help. ' + + "\n\n" + + '```python\n' + + await makeSchemaPromptV1(optSession, doc, request, options) + + '\n```', + }); } } @@ -457,14 +514,21 @@ export function replaceMarkdownCode(markdown: string, replaceValue: string) { return markdown.replace(/```\w*\n(.*)```/s, '```python\n' + replaceValue + '\n```'); } -async function makeSchemaPromptV1(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest) { +async function makeSchemaPromptV1( + session: OptDocSession, + doc: AssistanceDoc, + request: AssistanceRequest, + options: AssistanceSchemaPromptV1Options = {} +) { if (request.context.type !== 'formula') { throw new Error('makeSchemaPromptV1 only works for formulas'); } + return doc.assistanceSchemaPromptV1(session, { tableId: request.context.tableId, colId: request.context.colId, docString: request.text, + ...options, }); } diff --git a/sandbox/grist/formula_prompt.py b/sandbox/grist/formula_prompt.py index 6a72d8b5..6e1087a3 100644 --- a/sandbox/grist/formula_prompt.py +++ b/sandbox/grist/formula_prompt.py @@ -160,7 +160,7 @@ def get_formula_prompt(engine, table_id, col_id, _description, other_tables = (all_other_tables(engine, table_id) if include_all_tables else referenced_tables(engine, table_id)) for other_table_id in sorted(other_tables): - result += class_schema(engine, other_table_id, lookups) + result += class_schema(engine, other_table_id, None, lookups) result += class_schema(engine, table_id, col_id, lookups) diff --git a/sandbox/grist/main.py b/sandbox/grist/main.py index df9c6443..ab87952b 100644 --- a/sandbox/grist/main.py +++ b/sandbox/grist/main.py @@ -146,8 +146,9 @@ def run(sandbox): return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id)) @export - def get_formula_prompt(table_id, col_id, description): - return formula_prompt.get_formula_prompt(eng, table_id, col_id, description) + def get_formula_prompt(table_id, col_id, description, include_all_tables=True, lookups=True): + return formula_prompt.get_formula_prompt(eng, table_id, col_id, description, + include_all_tables, lookups) @export def convert_formula_completion(completion): diff --git a/test/server/lib/Assistance.ts b/test/server/lib/Assistance.ts index 9a60acb3..6ad8c6cb 100644 --- a/test/server/lib/Assistance.ts +++ b/test/server/lib/Assistance.ts @@ -18,7 +18,8 @@ describe('Assistance', function () { this.timeout(10000); const docTools = createDocTools({persistAcrossCases: true}); - const tableId = "Table1"; + const table1Id = "Table1"; + const table2Id = "Table2"; let session: DocSession; let doc: ActiveDoc; before(async () => { @@ -26,7 +27,8 @@ describe('Assistance', function () { session = docTools.createFakeSession(); doc = await docTools.createDoc('test.grist'); await doc.applyUserActions(session, [ - ["AddTable", tableId, [{id: "A"}, {id: "B"}, {id: "C"}]], + ["AddTable", table1Id, [{id: "A"}, {id: "B"}, {id: "C"}]], + ["AddTable", table2Id, [{id: "A"}, {id: "B"}, {id: "C"}]], ]); }); @@ -36,7 +38,7 @@ describe('Assistance', function () { function checkSendForCompletion(state?: AssistanceState) { return sendForCompletion(session, doc, { conversationId: 'conversationId', - context: {type: 'formula', tableId, colId}, + context: {type: 'formula', tableId: table1Id, colId}, state, text: userMessageContent, }); @@ -108,7 +110,7 @@ describe('Assistance', function () { + "\n\nLet me know if there's anything else I can help with."; assert.deepEqual(result, { suggestedActions: [ - ["ModifyColumn", tableId, colId, {formula: suggestedFormula}] + ["ModifyColumn", table1Id, colId, {formula: suggestedFormula}] ], suggestedFormula, reply: replyWithSuggestedFormula, @@ -203,9 +205,40 @@ describe('Assistance', function () { checkModels([ OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, + OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, ]); }); + it('switches to a shorter prompt if the longer model exceeds its token limit', async function () { + fakeResponse = () => ({ + error: { + code: "context_length_exceeded", + }, + status: 400, + }); + await assert.isRejected( + checkSendForCompletion(), + /You'll need to either shorten your message or delete some columns/ + ); + fakeFetch.getCalls().map((callInfo, i) => { + const [, request] = callInfo.args; + const {messages} = JSON.parse(request.body); + const systemMessageContent = messages[0].content; + const shortCallIndex = 2; + if (i === shortCallIndex) { + assert.match(systemMessageContent, /class Table1/); + assert.notMatch(systemMessageContent, /class Table2/); + assert.notMatch(systemMessageContent, /def lookupOne/); + assert.lengthOf(systemMessageContent, 1001); + } else { + assert.match(systemMessageContent, /class Table1/); + assert.match(systemMessageContent, /class Table2/); + assert.match(systemMessageContent, /def lookupOne/); + assert.lengthOf(systemMessageContent, 1982); + } + }); + }); + it('switches to a longer model with no retries if the model runs out of tokens while responding', async function () { fakeResponse = () => ({ "choices": [{ @@ -222,6 +255,7 @@ describe('Assistance', function () { checkModels([ OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, + OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, ]); }); @@ -245,6 +279,7 @@ describe('Assistance', function () { checkModels([ OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, + OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, ]); }); @@ -279,7 +314,7 @@ describe('Assistance', function () { OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, ]); assert.deepEqual(result.suggestedActions, [ - ["ModifyColumn", tableId, colId, {formula: "123"}] + ["ModifyColumn", table1Id, colId, {formula: "123"}] ]); }); });