/** * Module with functions used for AI formula assistance. */ import { AssistanceContext, AssistanceMessage, AssistanceRequest, AssistanceResponse } from 'app/common/AssistancePrompts'; import {delay} from 'app/common/delay'; import {DocAction} from 'app/common/DocActions'; import {ActiveDoc} from 'app/server/lib/ActiveDoc'; import {getDocSessionUser, OptDocSession} from 'app/server/lib/DocSession'; import log from 'app/server/lib/log'; import fetch from 'node-fetch'; import {createHash} from "crypto"; import {getLogMetaFromDocSession} from "./serverUtils"; // These are mocked/replaced in tests. // fetch is also replacing in the runCompletion script to add caching. export const DEPS = { fetch, delayTime: 1000 }; /** * An assistant can help a user do things with their document, * by interfacing with an external LLM endpoint. */ interface Assistant { apply(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise; } /** * Document-related methods for use in the implementation of assistants. * Somewhat ad-hoc currently. */ interface AssistanceDoc extends ActiveDoc { /** * Generate a particular prompt coded in the data engine for some reason. * It makes python code for some tables, and starts a function body with * the given docstring. * Marked "V1" to suggest that it is a particular prompt and it would * be great to try variants. */ assistanceSchemaPromptV1(session: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise; /** * Some tweaks to a formula after it has been generated. */ assistanceFormulaTweak(txt: string): Promise; /** * Compute the existing formula and return the result along with recorded values * of (possibly nested) attributes of `rec`. * Used by AI assistance to fix an incorrect formula. */ assistanceEvaluateFormula(options: AssistanceContext): Promise; } export interface AssistanceFormulaEvaluationResult { error: boolean; // true if an exception was raised result: string; // repr of the return value OR exception message // Recorded attributes of `rec` at the time of evaluation. // Keys may be e.g. "rec.foo.bar" for nested attributes. attributes: Record; formula: string; // the code that was evaluated, without special grist syntax } export interface AssistanceSchemaPromptV1Context { tableId: string, colId: string, docString: string, } class SwitchToLongerContext extends Error { } class NonRetryableError extends Error { } class TokensExceededFirstMessage extends NonRetryableError { constructor() { super( "Sorry, there's too much information for the AI to process. " + "You'll need to either shorten your message or delete some columns." ); } } class TokensExceededLaterMessage extends NonRetryableError { constructor() { super( "Sorry, there's too much information for the AI to process. " + "You'll need to either shorten your message, restart the conversation, or delete some columns." ); } } class QuotaExceededError extends NonRetryableError { constructor() { super( "Sorry, the assistant is facing some long term capacity issues. " + "Maybe try again tomorrow." ); } } class RetryableError extends Error { constructor(message: string) { super( "Sorry, the assistant is unavailable right now. " + "Try again in a few minutes. \n" + `(${message})` ); } } /** * A flavor of assistant for use with the OpenAI API. * Tested primarily with gpt-3.5-turbo. */ export class OpenAIAssistant implements Assistant { public static DEFAULT_MODEL = "gpt-3.5-turbo-0613"; public static LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613"; private _apiKey: string; private _endpoint: string; public constructor() { const apiKey = process.env.OPENAI_API_KEY; if (!apiKey) { throw new Error('OPENAI_API_KEY not set'); } this._apiKey = apiKey; this._endpoint = `https://api.openai.com/v1/chat/completions`; } public async apply( optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise { const messages = request.state?.messages || []; const newMessages = []; if (messages.length === 0) { newMessages.push({ role: 'system', content: 'You are a helpful assistant for a user of software called Grist. ' + 'Below are one or more Python classes. ' + 'The last method needs completing. ' + "The user will probably give a description of what they want the method (a 'formula') to return. " + 'If so, your response should include the method body as Python code in a markdown block. ' + 'Do not include the class or method signature, just the method body. ' + 'If your code starts with `class`, `@dataclass`, or `def` it will fail. Only give the method body. ' + 'You can import modules inside the method body if needed. ' + 'You cannot define additional functions or methods. ' + 'The method should be a pure function that performs some computation and returns a result. ' + 'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' + 'It CANNOT interact with files/databases/networks/etc. ' + 'It CANNOT display images/charts/graphs/maps/etc. ' + 'If the user asks for these things, tell them that you cannot help. ' + 'The method uses `rec` instead of `self` as the first parameter.\n\n' + '```python\n' + await makeSchemaPromptV1(optSession, doc, request) + '\n```', }); } if (request.context.evaluateCurrentFormula) { const result = await doc.assistanceEvaluateFormula(request.context); let message = "Evaluating this code:\n\n```python\n" + result.formula + "\n```\n\n"; if (Object.keys(result.attributes).length > 0) { const attributes = Object.entries(result.attributes).map(([k, v]) => `${k} = ${v}`).join('\n'); message += `where:\n\n${attributes}\n\n`; } message += `${result.error ? 'raises an exception' : 'returns'}: ${result.result}`; newMessages.push({ role: 'system', content: message, }); } newMessages.push({ role: 'user', content: request.text, }); messages.push(...newMessages); const newMessagesStartIndex = messages.length - newMessages.length; for (const [index, {role, content}] of newMessages.entries()) { doc.logTelemetryEvent(optSession, 'assistantSend', { full: { conversationId: request.conversationId, context: request.context, prompt: { index: newMessagesStartIndex + index, role, content, }, }, }); } const userIdHash = getUserHash(optSession); const completion: string = await this._getCompletion(messages, userIdHash); const response = await completionToResponse(doc, request, completion); if (response.suggestedFormula) { // Show the tweaked version of the suggested formula to the user (i.e. the one that's // copied when the Apply button is clicked). response.reply = replaceMarkdownCode(completion, response.suggestedFormula); } else { response.reply = completion; } response.state = {messages}; doc.logTelemetryEvent(optSession, 'assistantReceive', { full: { conversationId: request.conversationId, context: request.context, message: { index: messages.length - 1, content: completion, }, suggestedFormula: response.suggestedFormula, }, }); return response; } private async _fetchCompletion(messages: AssistanceMessage[], userIdHash: string, longerContext: boolean) { const apiResponse = await DEPS.fetch( this._endpoint, { method: "POST", headers: { "Authorization": `Bearer ${this._apiKey}`, "Content-Type": "application/json", }, body: JSON.stringify({ messages, temperature: 0, model: longerContext ? OpenAIAssistant.LONGER_CONTEXT_MODEL : OpenAIAssistant.DEFAULT_MODEL, user: userIdHash, }), }, ); const resultText = await apiResponse.text(); const result = JSON.parse(resultText); const errorCode = result.error?.code; if (errorCode === "context_length_exceeded" || result.choices?.[0].finish_reason === "length") { if (!longerContext) { log.info("Switching to longer context model..."); throw new SwitchToLongerContext(); } else if (messages.length <= 2) { throw new TokensExceededFirstMessage(); } else { throw new TokensExceededLaterMessage(); } } if (errorCode === "insufficient_quota") { log.error("OpenAI billing quota exceeded!!!"); throw new QuotaExceededError(); } if (apiResponse.status !== 200) { throw new Error(`OpenAI API returned status ${apiResponse.status}: ${resultText}`); } return result; } private async _fetchCompletionWithRetries( messages: AssistanceMessage[], userIdHash: string, longerContext: boolean ): Promise { const maxAttempts = 3; for (let attempt = 1; ; attempt++) { try { return await this._fetchCompletion(messages, userIdHash, longerContext); } catch (e) { if (e instanceof SwitchToLongerContext) { return await this._fetchCompletionWithRetries(messages, userIdHash, true); } else if (e instanceof NonRetryableError) { throw e; } else if (attempt === maxAttempts) { throw new RetryableError(e.toString()); } log.warn(`Waiting and then retrying after error: ${e}`); await delay(DEPS.delayTime); } } } private async _getCompletion(messages: AssistanceMessage[], userIdHash: string) { const result = await this._fetchCompletionWithRetries(messages, userIdHash, false); const {message} = result.choices[0]; messages.push(message); return message.content; } } export class HuggingFaceAssistant implements Assistant { private _apiKey: string; private _completionUrl: string; public constructor() { const apiKey = process.env.HUGGINGFACE_API_KEY; if (!apiKey) { throw new Error('HUGGINGFACE_API_KEY not set'); } this._apiKey = apiKey; // COMPLETION_MODEL values I've tried: // - codeparrot/codeparrot // - NinedayWang/PolyCoder-2.7B // - NovelAI/genji-python-6B let completionUrl = process.env.COMPLETION_URL; if (!completionUrl) { if (process.env.COMPLETION_MODEL) { completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`; } else { completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B'; } } this._completionUrl = completionUrl; } public async apply( optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise { if (request.state) { throw new Error("HuggingFaceAssistant does not support state"); } const prompt = await makeSchemaPromptV1(optSession, doc, request); const response = await DEPS.fetch( this._completionUrl, { method: "POST", headers: { "Authorization": `Bearer ${this._apiKey}`, "Content-Type": "application/json", }, body: JSON.stringify({ inputs: prompt, parameters: { return_full_text: false, max_new_tokens: 50, }, }), }, ); if (response.status === 503) { log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`); await delay(10000); } if (response.status !== 200) { const text = await response.text(); log.error(`HuggingFace API returned ${response.status}: ${text}`); throw new Error(`HuggingFace API returned status ${response.status}: ${text}`); } const result = await response.json(); let completion = result[0].generated_text; completion = completion.split('\n\n')[0]; return completionToResponse(doc, request, completion); } } /** * Test assistant that mimics ChatGPT and just returns the input. */ class EchoAssistant implements Assistant { public async apply(sess: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise { if (request.text === "ERROR") { throw new Error(`ERROR`); } const messages = request.state?.messages || []; if (messages.length === 0) { messages.push({ role: 'system', content: '' }); } messages.push({ role: 'user', content: request.text, }); const completion = request.text; const history = { messages }; history.messages.push({ role: 'assistant', content: completion, }); const response = await completionToResponse(doc, request, completion, completion); response.state = history; return response; } } /** * Instantiate an assistant, based on environment variables. */ export function getAssistant() { if (process.env.OPENAI_API_KEY === 'test') { return new EchoAssistant(); } if (process.env.OPENAI_API_KEY) { return new OpenAIAssistant(); } // Maintaining this is too much of a burden for now. // if (process.env.HUGGINGFACE_API_KEY) { // return new HuggingFaceAssistant(); // } throw new Error('Please set OPENAI_API_KEY'); } /** * Service a request for assistance. */ export async function sendForCompletion( optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest, ): Promise { const assistant = getAssistant(); return await assistant.apply(optSession, doc, request); } /** * Returns a new Markdown string with the contents of its first multi-line code block * replaced with `replaceValue`. */ export function replaceMarkdownCode(markdown: string, replaceValue: string) { return markdown.replace(/```\w*\n(.*)```/s, '```python\n' + replaceValue + '\n```'); } async function makeSchemaPromptV1(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest) { if (request.context.type !== 'formula') { throw new Error('makeSchemaPromptV1 only works for formulas'); } return doc.assistanceSchemaPromptV1(session, { tableId: request.context.tableId, colId: request.context.colId, docString: request.text, }); } async function completionToResponse( doc: AssistanceDoc, request: AssistanceRequest, completion: string, reply?: string ): Promise { if (request.context.type !== 'formula') { throw new Error('completionToResponse only works for formulas'); } const suggestedFormula = await doc.assistanceFormulaTweak(completion) || undefined; // Suggest an action only if the completion is non-empty (that is, // it actually looked like code). const suggestedActions: DocAction[] = suggestedFormula ? [[ "ModifyColumn", request.context.tableId, request.context.colId, { formula: suggestedFormula, } ]] : []; return { suggestedActions, suggestedFormula, reply, }; } function getUserHash(session: OptDocSession): string { const user = getDocSessionUser(session); // Make it a bit harder to guess the user ID. const salt = "7a8sb6987asdb678asd687sad6boas7f8b6aso7fd"; const hashSource = `${user?.id} ${user?.ref} ${salt}`; const hash = createHash('sha256').update(hashSource).digest('base64'); // So that if we get feedback about a user ID hash, we can // search for the hash in the logs to find the original user ID. log.rawInfo("getUserHash", {...getLogMetaFromDocSession(session), userRef: user?.ref, hash}); return hash; }