import {createDocTools} from "test/server/docTools"; import {ActiveDoc} from "app/server/lib/ActiveDoc"; import {DEPS, OpenAIAssistant, sendForCompletion} from "app/server/lib/Assistance"; import {assert} from 'chai'; import * as sinon from 'sinon'; import {Response} from 'node-fetch'; import {DocSession} from "app/server/lib/DocSession"; import {AssistanceState} from "app/common/AssistancePrompts"; // For some reason, assert.isRejected is not getting defined, // though test/chai-as-promised.js should be taking care of this. // So test/chai-as-promised.js is just repeated here. const chai = require('chai'); const chaiAsPromised = require('chai-as-promised'); chai.use(chaiAsPromised); describe('Assistance', function () { this.timeout(10000); const docTools = createDocTools({persistAcrossCases: true}); const tableId = "Table1"; let session: DocSession; let doc: ActiveDoc; before(async () => { process.env.OPENAI_API_KEY = "fake"; session = docTools.createFakeSession(); doc = await docTools.createDoc('test.grist'); await doc.applyUserActions(session, [ ["AddTable", tableId, [{id: "A"}, {id: "B"}, {id: "C"}]], ]); }); const colId = "C"; const userMessageContent = "Sum of A and B"; function checkSendForCompletion(state?: AssistanceState) { return sendForCompletion(session, doc, { conversationId: 'conversationId', context: {type: 'formula', tableId, colId}, state, text: userMessageContent, }); } let fakeResponse: () => any; let fakeFetch: sinon.SinonSpy; beforeEach(() => { fakeFetch = sinon.fake(() => { const body = fakeResponse(); return new Response( JSON.stringify(body), {status: body.status}, ); }); sinon.replace(DEPS, 'fetch', fakeFetch as any); sinon.replace(DEPS, 'delayTime', 1); }); afterEach(function () { sinon.restore(); }); function checkModels(expectedModels: string[]) { assert.deepEqual( fakeFetch.getCalls().map(call => JSON.parse(call.args[1].body).model), expectedModels, ); } it('can suggest a formula', async function () { const reply = "Here's a formula that adds columns A and B:\n\n" + "```python\na = int(rec.A)\nb=int(rec.B)\n\nreturn str(a + b)\n```" + "\n\nLet me know if there's anything else I can help with."; const replyMessage = {"role": "assistant", "content": reply}; fakeResponse = () => ({ "choices": [{ "index": 0, "message": replyMessage, "finish_reason": "stop" }], status: 200, }); const result = await checkSendForCompletion(); checkModels([OpenAIAssistant.DEFAULT_MODEL]); const callInfo = fakeFetch.getCall(0); const [url, request] = callInfo.args; assert.equal(url, 'https://api.openai.com/v1/chat/completions'); assert.equal(request.method, 'POST'); const {messages: requestMessages} = JSON.parse(request.body); const systemMessageContent = requestMessages[0].content; assert.match(systemMessageContent, /def C\(rec: Table1\)/); assert.deepEqual(requestMessages, [ { role: "system", content: systemMessageContent, }, { role: "user", content: userMessageContent, } ] ); const suggestedFormula = "a = int($A)\nb=int($B)\n\nstr(a + b)"; const replyWithSuggestedFormula = "Here's a formula that adds columns A and B:\n\n" + "```python\na = int($A)\nb=int($B)\n\nstr(a + b)\n```" + "\n\nLet me know if there's anything else I can help with."; assert.deepEqual(result, { suggestedActions: [ ["ModifyColumn", tableId, colId, {formula: suggestedFormula}] ], suggestedFormula, reply: replyWithSuggestedFormula, state: { messages: [...requestMessages, replyMessage] } } ); }); it('does not suggest anything if formula is invalid', async function () { const reply = "This isn't valid Python code:\n```python\nclass = 'foo'\n```"; const replyMessage = { "role": "assistant", "content": reply, }; fakeResponse = () => ({ "choices": [{ "index": 0, "message": replyMessage, "finish_reason": "stop" }], status: 200, }); const result = await checkSendForCompletion(); const callInfo = fakeFetch.getCall(0); const [, request] = callInfo.args; const {messages: requestMessages} = JSON.parse(request.body); const suggestedFormula = undefined; assert.deepEqual(result, { suggestedActions: [], suggestedFormula, reply, state: { messages: [...requestMessages, replyMessage], }, } ); }); it('tries 3 times in case of network errors', async function () { fakeResponse = () => { throw new Error("Network error"); }; await assert.isRejected( checkSendForCompletion(), "Sorry, the assistant is unavailable right now. " + "Try again in a few minutes. \n" + "(Error: Network error)", ); assert.equal(fakeFetch.callCount, 3); }); it('tries 3 times in case of bad status code', async function () { fakeResponse = () => ({status: 500}); await assert.isRejected( checkSendForCompletion(), "Sorry, the assistant is unavailable right now. " + "Try again in a few minutes. \n" + '(Error: OpenAI API returned status 500: {"status":500})', ); assert.equal(fakeFetch.callCount, 3); }); it('handles exceeded billing quota', async function () { fakeResponse = () => ({ error: { code: "insufficient_quota", }, status: 429, }); await assert.isRejected( checkSendForCompletion(), "Sorry, the assistant is facing some long term capacity issues. " + "Maybe try again tomorrow.", ); assert.equal(fakeFetch.callCount, 1); }); it('switches to a longer model with no retries if the prompt is too long', async function () { fakeResponse = () => ({ error: { code: "context_length_exceeded", }, status: 400, }); await assert.isRejected( checkSendForCompletion(), /You'll need to either shorten your message or delete some columns/ ); checkModels([ OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, ]); }); it('switches to a longer model with no retries if the model runs out of tokens while responding', async function () { fakeResponse = () => ({ "choices": [{ "index": 0, "message": {}, "finish_reason": "length" }], status: 200, }); await assert.isRejected( checkSendForCompletion(), /You'll need to either shorten your message or delete some columns/ ); checkModels([ OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, ]); }); it('suggests restarting conversation if the prompt is too long and there are past messages', async function () { fakeResponse = () => ({ error: { code: "context_length_exceeded", }, status: 400, }); await assert.isRejected( checkSendForCompletion({ messages: [ {role: "system", content: "Be good."}, {role: "user", content: "Hi."}, {role: "assistant", content: "Hi!"}, ] }), /You'll need to either shorten your message, restart the conversation, or delete some columns/ ); checkModels([ OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, ]); }); it('can switch to a longer model, retry, and succeed', async function () { fakeResponse = () => { if (fakeFetch.callCount === 1) { return { error: { code: "context_length_exceeded", }, status: 400, }; } else if (fakeFetch.callCount === 2) { return { status: 500, }; } else { return { "choices": [{ "index": 0, "message": {role: "assistant", content: "123"}, "finish_reason": "stop" }], status: 200, }; } }; const result = await checkSendForCompletion(); checkModels([ OpenAIAssistant.DEFAULT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL, ]); assert.deepEqual(result.suggestedActions, [ ["ModifyColumn", tableId, colId, {formula: "123"}] ]); }); });