diff --git a/test/server/lib/Assistance.ts b/test/server/lib/Assistance.ts new file mode 100644 index 00000000..9a60acb3 --- /dev/null +++ b/test/server/lib/Assistance.ts @@ -0,0 +1,285 @@ +import {createDocTools} from "test/server/docTools"; +import {ActiveDoc} from "app/server/lib/ActiveDoc"; +import {DEPS, OpenAIAssistant, sendForCompletion} from "app/server/lib/Assistance"; +import {assert} from 'chai'; +import * as sinon from 'sinon'; +import {Response} from 'node-fetch'; +import {DocSession} from "app/server/lib/DocSession"; +import {AssistanceState} from "app/common/AssistancePrompts"; + +// For some reason, assert.isRejected is not getting defined, +// though test/chai-as-promised.js should be taking care of this. +// So test/chai-as-promised.js is just repeated here. +const chai = require('chai'); +const chaiAsPromised = require('chai-as-promised'); +chai.use(chaiAsPromised); + +describe('Assistance', function () { + this.timeout(10000); + + const docTools = createDocTools({persistAcrossCases: true}); + const tableId = "Table1"; + let session: DocSession; + let doc: ActiveDoc; + before(async () => { + process.env.OPENAI_API_KEY = "fake"; + session = docTools.createFakeSession(); + doc = await docTools.createDoc('test.grist'); + await doc.applyUserActions(session, [ + ["AddTable", tableId, [{id: "A"}, {id: "B"}, {id: "C"}]], + ]); + }); + + const colId = "C"; + const userMessageContent = "Sum of A and B"; + + function checkSendForCompletion(state?: AssistanceState) { + return sendForCompletion(session, doc, { + conversationId: 'conversationId', + context: {type: 'formula', tableId, colId}, + state, + text: userMessageContent, + }); + } + + let fakeResponse: () => any; + let fakeFetch: sinon.SinonSpy; + + beforeEach(() => { + fakeFetch = sinon.fake(() => { + const body = fakeResponse(); + return new Response( + JSON.stringify(body), + {status: body.status}, + ); + }); + sinon.replace(DEPS, 'fetch', fakeFetch as any); + sinon.replace(DEPS, 'delayTime', 1); + }); + + afterEach(function () { + sinon.restore(); + }); + + function checkModels(expectedModels: string[]) { + assert.deepEqual( + fakeFetch.getCalls().map(call => JSON.parse(call.args[1].body).model), + expectedModels, + ); + } + + it('can suggest a formula', async function () { + const reply = "Here's a formula that adds columns A and B:\n\n" + + "```python\na = int(rec.A)\nb=int(rec.B)\n\nreturn str(a + b)\n```" + + "\n\nLet me know if there's anything else I can help with."; + const replyMessage = {"role": "assistant", "content": reply}; + + fakeResponse = () => ({ + "choices": [{ + "index": 0, + "message": replyMessage, + "finish_reason": "stop" + }], + status: 200, + }); + const result = await checkSendForCompletion(); + checkModels([OpenAIAssistant.DEFAULT_MODEL]); + const callInfo = fakeFetch.getCall(0); + const [url, request] = callInfo.args; + assert.equal(url, 'https://api.openai.com/v1/chat/completions'); + assert.equal(request.method, 'POST'); + const {messages: requestMessages} = JSON.parse(request.body); + const systemMessageContent = requestMessages[0].content; + assert.match(systemMessageContent, /def C\(rec: Table1\)/); + assert.deepEqual(requestMessages, [ + { + role: "system", + content: systemMessageContent, + }, + { + role: "user", + content: userMessageContent, + } + ] + ); + const suggestedFormula = "a = int($A)\nb=int($B)\n\nstr(a + b)"; + const replyWithSuggestedFormula = "Here's a formula that adds columns A and B:\n\n" + + "```python\na = int($A)\nb=int($B)\n\nstr(a + b)\n```" + + "\n\nLet me know if there's anything else I can help with."; + assert.deepEqual(result, { + suggestedActions: [ + ["ModifyColumn", tableId, colId, {formula: suggestedFormula}] + ], + suggestedFormula, + reply: replyWithSuggestedFormula, + state: { + messages: [...requestMessages, replyMessage] + } + } + ); + }); + + it('does not suggest anything if formula is invalid', async function () { + const reply = "This isn't valid Python code:\n```python\nclass = 'foo'\n```"; + const replyMessage = { + "role": "assistant", + "content": reply, + }; + + fakeResponse = () => ({ + "choices": [{ + "index": 0, + "message": replyMessage, + "finish_reason": "stop" + }], + status: 200, + }); + const result = await checkSendForCompletion(); + const callInfo = fakeFetch.getCall(0); + const [, request] = callInfo.args; + const {messages: requestMessages} = JSON.parse(request.body); + const suggestedFormula = undefined; + assert.deepEqual(result, { + suggestedActions: [], + suggestedFormula, + reply, + state: { + messages: [...requestMessages, replyMessage], + }, + } + ); + }); + + it('tries 3 times in case of network errors', async function () { + fakeResponse = () => { + throw new Error("Network error"); + }; + await assert.isRejected( + checkSendForCompletion(), + "Sorry, the assistant is unavailable right now. " + + "Try again in a few minutes. \n" + + "(Error: Network error)", + ); + assert.equal(fakeFetch.callCount, 3); + }); + + it('tries 3 times in case of bad status code', async function () { + fakeResponse = () => ({status: 500}); + await assert.isRejected( + checkSendForCompletion(), + "Sorry, the assistant is unavailable right now. " + + "Try again in a few minutes. \n" + + '(Error: OpenAI API returned status 500: {"status":500})', + ); + assert.equal(fakeFetch.callCount, 3); + }); + + it('handles exceeded billing quota', async function () { + fakeResponse = () => ({ + error: { + code: "insufficient_quota", + }, + status: 429, + }); + await assert.isRejected( + checkSendForCompletion(), + "Sorry, the assistant is facing some long term capacity issues. " + + "Maybe try again tomorrow.", + ); + assert.equal(fakeFetch.callCount, 1); + }); + + it('switches to a longer model with no retries if the prompt is too long', async function () { + fakeResponse = () => ({ + error: { + code: "context_length_exceeded", + }, + status: 400, + }); + await assert.isRejected( + checkSendForCompletion(), + /You'll need to either shorten your message or delete some columns/ + ); + checkModels([ + OpenAIAssistant.DEFAULT_MODEL, + OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, + ]); + }); + + it('switches to a longer model with no retries if the model runs out of tokens while responding', async function () { + fakeResponse = () => ({ + "choices": [{ + "index": 0, + "message": {}, + "finish_reason": "length" + }], + status: 200, + }); + await assert.isRejected( + checkSendForCompletion(), + /You'll need to either shorten your message or delete some columns/ + ); + checkModels([ + OpenAIAssistant.DEFAULT_MODEL, + OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, + ]); + }); + + it('suggests restarting conversation if the prompt is too long and there are past messages', async function () { + fakeResponse = () => ({ + error: { + code: "context_length_exceeded", + }, + status: 400, + }); + await assert.isRejected( + checkSendForCompletion({ + messages: [ + {role: "system", content: "Be good."}, + {role: "user", content: "Hi."}, + {role: "assistant", content: "Hi!"}, + ] + }), + /You'll need to either shorten your message, restart the conversation, or delete some columns/ + ); + checkModels([ + OpenAIAssistant.DEFAULT_MODEL, + OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, + ]); + }); + + it('can switch to a longer model, retry, and succeed', async function () { + fakeResponse = () => { + if (fakeFetch.callCount === 1) { + return { + error: { + code: "context_length_exceeded", + }, + status: 400, + }; + } else if (fakeFetch.callCount === 2) { + return { + status: 500, + }; + } else { + return { + "choices": [{ + "index": 0, + "message": {role: "assistant", content: "123"}, + "finish_reason": "stop" + }], + status: 200, + }; + } + }; + const result = await checkSendForCompletion(); + checkModels([ + OpenAIAssistant.DEFAULT_MODEL, + OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, + OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, + ]); + assert.deepEqual(result.suggestedActions, [ + ["ModifyColumn", tableId, colId, {formula: "123"}] + ]); + }); +});