diff --git a/app/client/models/entities/ColumnRec.ts b/app/client/models/entities/ColumnRec.ts index 5f5f9fcc..245b1193 100644 --- a/app/client/models/entities/ColumnRec.ts +++ b/app/client/models/entities/ColumnRec.ts @@ -4,6 +4,7 @@ import {CellRec, DocModel, IRowModel, recordSet, refRecord, TableRec, ViewFieldRec} from 'app/client/models/DocModel'; import {urlState} from 'app/client/models/gristUrlState'; import {jsonObservable, ObjObservable} from 'app/client/models/modelUtil'; +import {AssistanceState} from 'app/common/AssistancePrompts'; import * as gristTypes from 'app/common/gristTypes'; import {getReferencedTableId} from 'app/common/gristTypes'; import { @@ -83,7 +84,7 @@ export interface ColumnRec extends IRowModel<"_grist_Tables_column"> { /** * Current history of chat. This is a temporary array used only in the ui. */ - chatHistory: ko.PureComputed>; + chatHistory: ko.PureComputed>; // Helper which adds/removes/updates column's displayCol to match the formula. saveDisplayFormula(formula: string): Promise|undefined; @@ -162,8 +163,9 @@ export function createColumnRec(this: ColumnRec, docModel: DocModel): void { this.chatHistory = this.autoDispose(ko.computed(() => { const docId = urlState().state.get().doc ?? ''; - const key = `formula-assistant-history-${docId}-${this.table().tableId()}-${this.colId()}`; - return localStorageJsonObs(key, [] as ChatMessage[]); + // Changed key name from history to history-v2 when ChatHistory changed in incompatible way. + const key = `formula-assistant-history-v2-${docId}-${this.table().tableId()}-${this.colId()}`; + return localStorageJsonObs(key, {messages: []} as ChatHistory); })); } @@ -196,8 +198,20 @@ export interface ChatMessage { */ sender: 'user' | 'ai'; /** - * The formula returned from the AI. It is only set when the sender is the AI. For now it is the same - * value as the message, but it might change in the future when we use more conversational AI. + * The formula returned from the AI. It is only set when the sender is the AI. */ formula?: string; } + +/** + * The state of assistance for a particular column. + * ChatMessages are what are shown in the UI, whereas state is + * how the back-end represents the conversation. The two are + * similar but not the same because of post-processing. + * It may be possible to reconcile them when things settle down + * a bit? + */ +export interface ChatHistory { + messages: ChatMessage[]; + state?: AssistanceState; +} diff --git a/app/client/ui/FormulaAssistance.ts b/app/client/ui/FormulaAssistance.ts index 1d3efe32..2f5b45e3 100644 --- a/app/client/ui/FormulaAssistance.ts +++ b/app/client/ui/FormulaAssistance.ts @@ -8,7 +8,7 @@ import {basicButton, primaryButton, textButton} from 'app/client/ui2018/buttons' import {theme} from 'app/client/ui2018/cssVars'; import {cssTextInput, rawTextInput} from 'app/client/ui2018/editableLabel'; import {icon} from 'app/client/ui2018/icons'; -import {Suggestion} from 'app/common/AssistancePrompts'; +import {AssistanceResponse, AssistanceState} from 'app/common/AssistancePrompts'; import {Disposable, dom, makeTestId, MultiHolder, obsArray, Observable, styled} from 'grainjs'; import noop from 'lodash/noop'; @@ -78,7 +78,7 @@ function buildControls( } ) { - const hasHistory = props.column.chatHistory.peek().get().length > 0; + const hasHistory = props.column.chatHistory.peek().get().messages.length > 0; // State variables, to show various parts of the UI. const saveButtonVisible = Observable.create(owner, true); @@ -153,36 +153,46 @@ function buildControls( }; } -function buildChat(owner: Disposable, context: Context & { formulaClicked: (formula: string) => void }) { +function buildChat(owner: Disposable, context: Context & { formulaClicked: (formula?: string) => void }) { const { grist, column } = context; - const history = owner.autoDispose(obsArray(column.chatHistory.peek().get())); + const history = owner.autoDispose(obsArray(column.chatHistory.peek().get().messages)); const hasHistory = history.get().length > 0; const enabled = Observable.create(owner, hasHistory); const introVisible = Observable.create(owner, !hasHistory); owner.autoDispose(history.addListener((cur) => { - column.chatHistory.peek().set([...cur]); + const chatHistory = column.chatHistory.peek(); + chatHistory.set({...chatHistory.get(), messages: [...cur]}); })); - const submit = async () => { - // Ask about suggestion, and send the whole history. Currently the chat is implemented by just sending - // all previous user prompts back to the AI. This is subject to change (and probably should be done in the backend). - const prompt = history.get().filter(x => x.sender === 'user') - .map(entry => entry.message) - .filter(Boolean) - .join("\n"); - console.debug('prompt', prompt); - const { suggestedActions } = await askAI(grist, column, prompt); - console.debug('suggestedActions', suggestedActions); + const submit = async (regenerate: boolean = false) => { + // Send most recent question, and send back any conversation + // state we have been asked to track. + const chatHistory = column.chatHistory.peek().get(); + const messages = chatHistory.messages.filter(msg => msg.sender === 'user'); + const description = messages[messages.length - 1]?.message || ''; + console.debug('description', {description}); + const {reply, suggestedActions, state} = await askAI(grist, { + column, description, state: chatHistory.state, + regenerate, + }); + console.debug('suggestedActions', {suggestedActions, reply}); const firstAction = suggestedActions[0] as any; // Add the formula to the history. - const formula = firstAction[3].formula as string; + const formula = firstAction ? firstAction[3].formula as string : undefined; // Add to history history.push({ - message: formula, + message: formula || reply || '(no reply)', sender: 'ai', formula }); + // If back-end is capable of conversation, keep its state. + if (state) { + const chatHistoryNew = column.chatHistory.peek(); + const value = chatHistoryNew.get(); + value.state = state; + chatHistoryNew.set(value); + } return formula; }; @@ -203,12 +213,13 @@ function buildChat(owner: Disposable, context: Context & { formulaClicked: (form // Remove the last AI response from the history. history.pop(); // And submit again. - context.formulaClicked(await submit()); + context.formulaClicked(await submit(true)); }; const newChat = () => { // Clear the history. history.set([]); + column.chatHistory.peek().set({messages: []}); // Show intro. introVisible.set(true); }; @@ -371,9 +382,11 @@ function openAIAssistant(grist: GristDoc, column: ColumnRec) { const chat = buildChat(owner, {...props, // When a formula is clicked (or just was returned from the AI), we set it in the formula editor and hit // the preview button. - formulaClicked: (formula: string) => { - formulaEditor.set(formula); - controls.preview().catch(reportError); + formulaClicked: (formula?: string) => { + if (formula) { + formulaEditor.set(formula); + controls.preview().catch(reportError); + } }, }); @@ -397,11 +410,22 @@ function openAIAssistant(grist: GristDoc, column: ColumnRec) { grist.formulaPopup.autoDispose(popup); } -async function askAI(grist: GristDoc, column: ColumnRec, description: string): Promise { +async function askAI(grist: GristDoc, options: { + column: ColumnRec, + description: string, + regenerate?: boolean, + state?: AssistanceState +}): Promise { + const {column, description, state, regenerate} = options; const tableId = column.table.peek().tableId.peek(); const colId = column.colId.peek(); try { - const result = await grist.docComm.getAssistance({tableId, colId, description}); + const result = await grist.docComm.getAssistance({ + context: {type: 'formula', tableId, colId}, + text: description, + state, + regenerate, + }); return result; } catch (error) { reportError(error); diff --git a/app/common/ActiveDocAPI.ts b/app/common/ActiveDocAPI.ts index d2f88ee8..9b333418 100644 --- a/app/common/ActiveDocAPI.ts +++ b/app/common/ActiveDocAPI.ts @@ -1,5 +1,5 @@ import {ActionGroup} from 'app/common/ActionGroup'; -import {Prompt, Suggestion} from 'app/common/AssistancePrompts'; +import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts'; import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions'; import {FormulaProperties} from 'app/common/GranularAccessClause'; import {UIRowId} from 'app/common/UIRowId'; @@ -323,7 +323,7 @@ export interface ActiveDocAPI { /** * Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula. */ - getAssistance(userPrompt: Prompt): Promise; + getAssistance(request: AssistanceRequest): Promise; /** * Fetch content at a url. diff --git a/app/common/AssistancePrompts.ts b/app/common/AssistancePrompts.ts index c0ee269e..95215060 100644 --- a/app/common/AssistancePrompts.ts +++ b/app/common/AssistancePrompts.ts @@ -1,11 +1,56 @@ import {DocAction} from 'app/common/DocActions'; -export interface Prompt { +/** + * State related to a request for assistance. + * + * If an AssistanceResponse contains state, that state can be + * echoed back in an AssistanceRequest to continue a "conversation." + * + * Ideally, the state should not be modified or relied upon + * by the client, so as not to commit too hard to a particular + * model at this time (it is a bit early for that). + */ +export interface AssistanceState { + messages?: Array<{ + role: string; + content: string; + }>; +} + +/** + * Currently, requests for assistance always happen in the context + * of the column of a particular table. + */ +export interface FormulaAssistanceContext { + type: 'formula'; tableId: string; - colId: string - description: string; + colId: string; +} + +export type AssistanceContext = FormulaAssistanceContext; + +/** + * A request for assistance. + */ +export interface AssistanceRequest { + context: AssistanceContext; + state?: AssistanceState; + text: string; + regenerate?: boolean; // Set if there was a previous request + // and response that should be omitted + // from history, or (if available) an + // alternative response generated. } -export interface Suggestion { +/** + * A response to a request for assistance. + * The client should preserve the state and include it in + * any follow-up requests. + */ +export interface AssistanceResponse { suggestedActions: DocAction[]; + state?: AssistanceState; + // If the model can be trusted to issue a self-contained + // markdown-friendly string, it can be included here. + reply?: string; } diff --git a/app/server/lib/ActiveDoc.ts b/app/server/lib/ActiveDoc.ts index 4c69c762..4037fbb6 100644 --- a/app/server/lib/ActiveDoc.ts +++ b/app/server/lib/ActiveDoc.ts @@ -14,7 +14,7 @@ import { } from 'app/common/ActionBundle'; import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup'; import {ActionSummary} from "app/common/ActionSummary"; -import {Prompt, Suggestion} from "app/common/AssistancePrompts"; +import {AssistanceRequest, AssistanceResponse} from "app/common/AssistancePrompts"; import { AclResources, AclTableDescription, @@ -85,7 +85,7 @@ import {Document} from 'app/gen-server/entity/Document'; import {ParseOptions} from 'app/plugin/FileParserAPI'; import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI'; import {compileAclFormula} from 'app/server/lib/ACLFormula'; -import {sendForCompletion} from 'app/server/lib/Assistance'; +import {AssistanceDoc, AssistanceSchemaPromptV1Context, sendForCompletion} from 'app/server/lib/Assistance'; import {Authorizer} from 'app/server/lib/Authorizer'; import {checksumFile} from 'app/server/lib/checksumFile'; import {Client} from 'app/server/lib/Client'; @@ -180,7 +180,7 @@ interface UpdateUsageOptions { * either .loadDoc() or .createEmptyDoc() is called. * @param {String} docName - The document's filename, without the '.grist' extension. */ -export class ActiveDoc extends EventEmitter { +export class ActiveDoc extends EventEmitter implements AssistanceDoc { /** * Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e. * until the returned promise is resolved. @@ -1112,7 +1112,7 @@ export class ActiveDoc extends EventEmitter { * @param {Integer} rowId - Row number * @returns {Promise} Promise for a error message */ - public async getFormulaError(docSession: DocSession, tableId: string, colId: string, + public async getFormulaError(docSession: OptDocSession, tableId: string, colId: string, rowId: number): Promise { // Throw an error if the user doesn't have access to read this cell. await this._granularAccess.getCellValue(docSession, {tableId, colId, rowId}); @@ -1260,22 +1260,28 @@ export class ActiveDoc extends EventEmitter { return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON()); } - public async getAssistance(docSession: DocSession, userPrompt: Prompt): Promise { - // Making a prompt can leak names of tables and columns. + public async getAssistance(docSession: DocSession, request: AssistanceRequest): Promise { + return this.getAssistanceWithOptions(docSession, request); + } + + public async getAssistanceWithOptions(docSession: DocSession, + request: AssistanceRequest): Promise { + // Making a prompt leaks names of tables and columns etc. if (!await this._granularAccess.canScanData(docSession)) { throw new Error("Permission denied"); } await this.waitForInitialization(); - const { tableId, colId, description } = userPrompt; - const prompt = await this._pyCall('get_formula_prompt', tableId, colId, description); - this._log.debug(docSession, 'getAssistance prompt', {prompt}); - const completion = await sendForCompletion(prompt); - this._log.debug(docSession, 'getAssistance completion', {completion}); - const formula = await this._pyCall('convert_formula_completion', completion); - const action: DocAction = ["ModifyColumn", tableId, colId, {formula}]; - return { - suggestedActions: [action], - }; + return sendForCompletion(this, request); + } + + // Callback to make a data-engine formula tweak for assistance. + public assistanceFormulaTweak(txt: string) { + return this._pyCall('convert_formula_completion', txt); + } + + // Callback to generate a prompt containing schema info for assistance. + public assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise { + return this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString); } public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise { diff --git a/app/server/lib/Assistance.ts b/app/server/lib/Assistance.ts index 6a3e3af7..9d55dd7f 100644 --- a/app/server/lib/Assistance.ts +++ b/app/server/lib/Assistance.ts @@ -2,116 +2,320 @@ * Module with functions used for AI formula assistance. */ +import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts'; import {delay} from 'app/common/delay'; +import {DocAction} from 'app/common/DocActions'; import log from 'app/server/lib/log'; import fetch from 'node-fetch'; export const DEPS = { fetch }; -export async function sendForCompletion(prompt: string): Promise { - let completion: string|null = null; - let retries: number = 0; - const openApiKey = process.env.OPENAI_API_KEY; - const model = process.env.COMPLETION_MODEL || "text-davinci-002"; +/** + * An assistant can help a user do things with their document, + * by interfacing with an external LLM endpoint. + */ +export interface Assistant { + apply(doc: AssistanceDoc, request: AssistanceRequest): Promise; +} - while(retries++ < 3) { - try { - if (openApiKey) { - completion = await sendForCompletionOpenAI(prompt, openApiKey, model); +/** + * Document-related methods for use in the implementation of assistants. + * Somewhat ad-hoc currently. + */ +export interface AssistanceDoc { + /** + * Generate a particular prompt coded in the data engine for some reason. + * It makes python code for some tables, and starts a function body with + * the given docstring. + * Marked "V1" to suggest that it is a particular prompt and it would + * be great to try variants. + */ + assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise; + + /** + * Some tweaks to a formula after it has been generated. + */ + assistanceFormulaTweak(txt: string): Promise; +} + +export interface AssistanceSchemaPromptV1Context { + tableId: string, + colId: string, + docString: string, +} + +/** + * A flavor of assistant for use with the OpenAI API. + * Tested primarily with text-davinci-002 and gpt-3.5-turbo. + */ +export class OpenAIAssistant implements Assistant { + private _apiKey: string; + private _model: string; + private _chatMode: boolean; + private _endpoint: string; + + public constructor() { + const apiKey = process.env.OPENAI_API_KEY; + if (!apiKey) { + throw new Error('OPENAI_API_KEY not set'); + } + this._apiKey = apiKey; + this._model = process.env.COMPLETION_MODEL || "text-davinci-002"; + this._chatMode = this._model.includes('turbo'); + this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`; + } + + public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise { + const messages = request.state?.messages || []; + const chatMode = this._chatMode; + if (chatMode) { + if (messages.length === 0) { + messages.push({ + role: 'system', + content: 'The user gives you one or more Python classes, ' + + 'with one last method that needs completing. Write the ' + + 'method body as a single code block, ' + + 'including the docstring the user gave. ' + + 'Just give the Python code as a markdown block, ' + + 'do not give any introduction, that will just be ' + + 'awkward for the user when copying and pasting. ' + + 'You are working with Grist, an environment very like ' + + 'regular Python except `rec` (like record) is used ' + + 'instead of `self`. ' + + 'Include at least one `return` statement or the method ' + + 'will fail, disappointing the user. ' + + 'Your answer should be the body of a single method, ' + + 'not a class, and should not include `dataclass` or ' + + '`class` since the user is counting on you to provide ' + + 'a single method. Thanks!' + }); + messages.push({ + role: 'user', content: await makeSchemaPromptV1(doc, request), + }); + } else { + if (request.regenerate) { + if (messages[messages.length - 1].role !== 'user') { + messages.pop(); + } + } + messages.push({ + role: 'user', content: request.text, + }); } - if (process.env.HUGGINGFACE_API_KEY) { - completion = await sendForCompletionHuggingFace(prompt); + } else { + messages.length = 0; + messages.push({ + role: 'user', content: await makeSchemaPromptV1(doc, request), + }); + } + + const apiResponse = await DEPS.fetch( + this._endpoint, + { + method: "POST", + headers: { + "Authorization": `Bearer ${this._apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + ...(!this._chatMode ? { + prompt: messages[messages.length - 1].content, + } : { messages }), + max_tokens: 1500, + temperature: 0, + model: this._model, + stop: this._chatMode ? undefined : ["\n\n"], + }), + }, + ); + if (apiResponse.status !== 200) { + log.error(`OpenAI API returned ${apiResponse.status}: ${await apiResponse.text()}`); + throw new Error(`OpenAI API returned status ${apiResponse.status}`); + } + const result = await apiResponse.json(); + let completion: string = String(chatMode ? result.choices[0].message.content : result.choices[0].text); + const reply = completion; + const history = { messages }; + if (chatMode) { + history.messages.push(result.choices[0].message); + // This model likes returning markdown. Code will typically + // be in a code block with ``` delimiters. + let lines = completion.split('\n'); + if (lines[0].startsWith('```')) { + lines.shift(); + completion = lines.join('\n'); + const parts = completion.split('```'); + if (parts.length > 1) { + completion = parts[0]; + } + lines = completion.split('\n'); + } + + // This model likes repeating the function signature and + // docstring, so we try to strip that out. + completion = lines.join('\n'); + while (completion.includes('"""')) { + const parts = completion.split('"""'); + completion = parts[parts.length - 1]; + } + + // If there's no code block, don't treat the answer as a formula. + if (!reply.includes('```')) { + completion = ''; } - break; - } catch(e) { - await delay(1000); } + + const response = await completionToResponse(doc, request, completion, reply); + if (chatMode) { + response.state = history; + } + return response; } - if (completion === null) { - throw new Error("Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY (and optionally COMPLETION_MODEL)"); - } - log.debug(`Received completion:`, {completion}); - completion = completion.split(/\n {4}[^ ]/)[0]; - return completion; } +export class HuggingFaceAssistant implements Assistant { + private _apiKey: string; + private _completionUrl: string; + + public constructor() { + const apiKey = process.env.HUGGINGFACE_API_KEY; + if (!apiKey) { + throw new Error('HUGGINGFACE_API_KEY not set'); + } + this._apiKey = apiKey; + // COMPLETION_MODEL values I've tried: + // - codeparrot/codeparrot + // - NinedayWang/PolyCoder-2.7B + // - NovelAI/genji-python-6B + let completionUrl = process.env.COMPLETION_URL; + if (!completionUrl) { + if (process.env.COMPLETION_MODEL) { + completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`; + } else { + completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B'; + } + } + this._completionUrl = completionUrl; -async function sendForCompletionOpenAI(prompt: string, apiKey: string, model = "text-davinci-002") { - if (!apiKey) { - throw new Error("OPENAI_API_KEY not set"); } - const response = await DEPS.fetch( - "https://api.openai.com/v1/completions", - { - method: "POST", - headers: { - "Authorization": `Bearer ${apiKey}`, - "Content-Type": "application/json", + + public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise { + if (request.state) { + throw new Error("HuggingFaceAssistant does not support state"); + } + const prompt = await makeSchemaPromptV1(doc, request); + const response = await DEPS.fetch( + this._completionUrl, + { + method: "POST", + headers: { + "Authorization": `Bearer ${this._apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + inputs: prompt, + parameters: { + return_full_text: false, + max_new_tokens: 50, + }, + }), }, - body: JSON.stringify({ - prompt, - max_tokens: 150, - temperature: 0, - // COMPLETION_MODEL of `code-davinci-002` may be better if you have access to it. - model, - stop: ["\n\n"], - }), - }, - ); - if (response.status !== 200) { - log.error(`OpenAI API returned ${response.status}: ${await response.text()}`); - throw new Error(`OpenAI API returned status ${response.status}`); + ); + if (response.status === 503) { + log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`); + await delay(10000); + } + if (response.status !== 200) { + const text = await response.text(); + log.error(`HuggingFace API returned ${response.status}: ${text}`); + throw new Error(`HuggingFace API returned status ${response.status}: ${text}`); + } + const result = await response.json(); + let completion = result[0].generated_text; + completion = completion.split('\n\n')[0]; + return completionToResponse(doc, request, completion); } - const result = await response.json(); - const completion = result.choices[0].text; - return completion; } -async function sendForCompletionHuggingFace(prompt: string) { - const apiKey = process.env.HUGGINGFACE_API_KEY; - if (!apiKey) { - throw new Error("HUGGINGFACE_API_KEY not set"); +/** + * Instantiate an assistant, based on environment variables. + */ +function getAssistant() { + if (process.env.OPENAI_API_KEY) { + return new OpenAIAssistant(); } - // COMPLETION_MODEL values I've tried: - // - codeparrot/codeparrot - // - NinedayWang/PolyCoder-2.7B - // - NovelAI/genji-python-6B - let completionUrl = process.env.COMPLETION_URL; - if (!completionUrl) { - if (process.env.COMPLETION_MODEL) { - completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`; - } else { - completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B'; + if (process.env.HUGGINGFACE_API_KEY) { + return new HuggingFaceAssistant(); + } + throw new Error('Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY'); +} + +/** + * Service a request for assistance, with a little retry logic + * since these endpoints can be a bit flakey. + */ +export async function sendForCompletion(doc: AssistanceDoc, + request: AssistanceRequest): Promise { + const assistant = getAssistant(); + + let retries: number = 0; + + let response: AssistanceResponse|null = null; + while(retries++ < 3) { + try { + response = await assistant.apply(doc, request); + break; + } catch(e) { + log.error(`Completion error: ${e}`); + await delay(1000); } } + if (!response) { + throw new Error('Failed to get response from assistant'); + } + return response; +} - const response = await DEPS.fetch( - completionUrl, - { - method: "POST", - headers: { - "Authorization": `Bearer ${apiKey}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - inputs: prompt, - parameters: { - return_full_text: false, - max_new_tokens: 50, - }, - }), - }, - ); - if (response.status === 503) { - log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`); - await delay(10000); +async function makeSchemaPromptV1(doc: AssistanceDoc, request: AssistanceRequest) { + if (request.context.type !== 'formula') { + throw new Error('makeSchemaPromptV1 only works for formulas'); + } + return doc.assistanceSchemaPromptV1({ + tableId: request.context.tableId, + colId: request.context.colId, + docString: request.text, + }); +} + +async function completionToResponse(doc: AssistanceDoc, request: AssistanceRequest, + completion: string, reply?: string): Promise { + if (request.context.type !== 'formula') { + throw new Error('completionToResponse only works for formulas'); + } + completion = await doc.assistanceFormulaTweak(completion); + // A leading newline is common. + if (completion.charAt(0) === '\n') { + completion = completion.slice(1); } - if (response.status !== 200) { - const text = await response.text(); - log.error(`HuggingFace API returned ${response.status}: ${text}`); - throw new Error(`HuggingFace API returned status ${response.status}: ${text}`); + // If all non-empty lines have four spaces, remove those spaces. + // They are common for GPT-3.5, which matches the prompt carefully. + const lines = completion.split('\n'); + const ok = lines.every(line => line === '\n' || line.startsWith(' ')); + if (ok) { + completion = lines.map(line => line === '\n' ? line : line.slice(4)).join('\n'); } - const result = await response.json(); - const completion = result[0].generated_text; - return completion.split('\n\n')[0]; + + // Suggest an action only if the completion is non-empty (that is, + // it actually looked like code). + const suggestedActions: DocAction[] = completion ? [[ + "ModifyColumn", + request.context.tableId, + request.context.colId, { + formula: completion, + } + ]] : []; + return { + suggestedActions, + reply, + }; } diff --git a/app/server/lib/GranularAccess.ts b/app/server/lib/GranularAccess.ts index 68f14bdf..9bbafc96 100644 --- a/app/server/lib/GranularAccess.ts +++ b/app/server/lib/GranularAccess.ts @@ -376,7 +376,8 @@ export class GranularAccess implements GranularAccessForBundle { function fail(): never { throw new ErrorWithCode('ACL_DENY', 'Cannot access cell'); } - if (!await this.hasTableAccess(docSession, cell.tableId)) { fail(); } + const hasExceptionalAccess = this._hasExceptionalFullAccess(docSession); + if (!hasExceptionalAccess && !await this.hasTableAccess(docSession, cell.tableId)) { fail(); } let rows: TableDataAction|null = null; if (docData) { const record = docData.getTable(cell.tableId)?.getRecord(cell.rowId); @@ -393,16 +394,18 @@ export class GranularAccess implements GranularAccessForBundle { return fail(); } const rec = new RecordView(rows, 0); - const input: AclMatchInput = {...await this.inputs(docSession), rec, newRec: rec}; - const rowPermInfo = new PermissionInfo(this._ruler.ruleCollection, input); - const rowAccess = rowPermInfo.getTableAccess(cell.tableId).perms.read; - if (rowAccess === 'deny') { fail(); } - if (rowAccess !== 'allow') { - const colAccess = rowPermInfo.getColumnAccess(cell.tableId, cell.colId).perms.read; - if (colAccess === 'deny') { fail(); } + if (!hasExceptionalAccess) { + const input: AclMatchInput = {...await this.inputs(docSession), rec, newRec: rec}; + const rowPermInfo = new PermissionInfo(this._ruler.ruleCollection, input); + const rowAccess = rowPermInfo.getTableAccess(cell.tableId).perms.read; + if (rowAccess === 'deny') { fail(); } + if (rowAccess !== 'allow') { + const colAccess = rowPermInfo.getColumnAccess(cell.tableId, cell.colId).perms.read; + if (colAccess === 'deny') { fail(); } + } + const colValues = rows[3]; + if (!(cell.colId in colValues)) { fail(); } } - const colValues = rows[3]; - if (!(cell.colId in colValues)) { fail(); } return rec.get(cell.colId); } diff --git a/sandbox/grist/formula_prompt.py b/sandbox/grist/formula_prompt.py index 045013db..3aab410c 100644 --- a/sandbox/grist/formula_prompt.py +++ b/sandbox/grist/formula_prompt.py @@ -37,7 +37,6 @@ def column_type(engine, table_id, col_id): Attachments="Any", )[parts[0]] - def choices(col_rec): try: widget_options = json.loads(col_rec.widgetOptions) diff --git a/test/formula-dataset/runCompletion.js b/test/formula-dataset/runCompletion.js index 25409b5b..e2495048 100644 --- a/test/formula-dataset/runCompletion.js +++ b/test/formula-dataset/runCompletion.js @@ -1,12 +1,17 @@ #!/usr/bin/env node "use strict"; +const fs = require('fs'); const path = require('path'); -const codeRoot = path.dirname(path.dirname(path.dirname(__dirname))); -process.env.DATA_PATH = path.join(__dirname, 'data'); +let codeRoot = path.dirname(path.dirname(__dirname)); +if (!fs.existsSync(path.join(codeRoot, '_build'))) { + codeRoot = path.dirname(codeRoot); +} +process.env.DATA_PATH = path.join(__dirname, 'data'); require('app-module-path').addPath(path.join(codeRoot, '_build')); require('app-module-path').addPath(path.join(codeRoot, '_build', 'core')); require('app-module-path').addPath(path.join(codeRoot, '_build', 'ext')); +require('app-module-path').addPath(path.join(codeRoot, '_build', 'stubs')); require('test/formula-dataset/runCompletion_impl').runCompletion().catch(console.error); diff --git a/test/formula-dataset/runCompletion_impl.ts b/test/formula-dataset/runCompletion_impl.ts index 45c2e207..250e8a97 100644 --- a/test/formula-dataset/runCompletion_impl.ts +++ b/test/formula-dataset/runCompletion_impl.ts @@ -24,7 +24,7 @@ */ -import { ActiveDoc } from "app/server/lib/ActiveDoc"; +import { ActiveDoc, Deps as ActiveDocDeps } from "app/server/lib/ActiveDoc"; import { DEPS } from "app/server/lib/Assistance"; import log from 'app/server/lib/log'; import crypto from 'crypto'; @@ -38,11 +38,14 @@ import * as os from 'os'; import { pipeline } from 'stream'; import { createDocTools } from "test/server/docTools"; import { promisify } from 'util'; +import { AssistanceResponse, AssistanceState } from "app/common/AssistancePrompts"; +import { CellValue } from "app/plugin/GristData"; const streamPipeline = promisify(pipeline); const DATA_PATH = process.env.DATA_PATH || path.join(__dirname, 'data'); const PATH_TO_DOC = path.join(DATA_PATH, 'templates'); +const PATH_TO_RESULTS = path.join(DATA_PATH, 'results'); const PATH_TO_CSV = path.join(DATA_PATH, 'formula-dataset-index.csv'); const PATH_TO_CACHE = path.join(DATA_PATH, 'cache'); const TEMPLATE_URL = "https://grist-static.com/datasets/grist_dataset_formulai_2023_02_20.zip"; @@ -60,8 +63,11 @@ const _stats = { callCount: 0, }; +const SEEMS_CHATTY = (process.env.COMPLETION_MODEL || '').includes('turbo'); +const SIMULATE_CONVERSATION = SEEMS_CHATTY; export async function runCompletion() { + ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600000; // if template directory not exists, make it if (!fs.existsSync(path.join(PATH_TO_DOC))) { @@ -107,11 +113,12 @@ export async function runCompletion() { if (!process.env.VERBOSE) { log.transports.file.level = 'error'; // Suppress most of log output. } - let activeDoc: ActiveDoc|undefined; const docTools = createDocTools(); const session = docTools.createFakeSession('owners'); await docTools.before(); let successCount = 0; + let caseCount = 0; + fs.mkdirSync(path.join(PATH_TO_RESULTS), {recursive: true}); console.log('Testing AI assistance: '); @@ -119,38 +126,109 @@ export async function runCompletion() { DEPS.fetch = fetchWithCache; + let activeDoc: ActiveDoc|undefined; for (const rec of records) { - - // load new document - if (!activeDoc || activeDoc.docName !== rec.doc_id) { - const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist'); - activeDoc = await docTools.loadLocalDoc(docPath); - await activeDoc.waitForInitialization(); + let success: boolean = false; + let suggestedActions: AssistanceResponse['suggestedActions'] | undefined; + let newValues: CellValue[] | undefined; + let expected: CellValue[] | undefined; + let formula: string | undefined; + let history: AssistanceState = {messages: []}; + let lastFollowUp: string | undefined; + + try { + async function sendMessage(followUp?: string) { + // load new document + if (!activeDoc || activeDoc.docName !== rec.doc_id) { + const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist'); + activeDoc = await docTools.loadLocalDoc(docPath); + await activeDoc.waitForInitialization(); + } + + if (!activeDoc) { throw new Error("No doc"); } + + // get values + await activeDoc.docData!.fetchTable(rec.table_id); + expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice(); + + // send prompt + const tableId = rec.table_id; + const colId = rec.col_id; + const description = rec.Description; + const colInfo = await activeDoc.docStorage.get(` +select * from _grist_Tables_column as c +left join _grist_Tables as t on t.id = c.parentId +where c.colId = ? and t.tableId = ? +`, rec.col_id, rec.table_id); + formula = colInfo?.formula; + + const result = await activeDoc.getAssistanceWithOptions(session, { + context: {type: 'formula', tableId, colId}, + state: history, + text: followUp || description, + }); + if (result.state) { + history = result.state; + } + suggestedActions = result.suggestedActions; + // apply modification + const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions); + + // get new values + newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice(); + + // revert modification + const [bundle] = await activeDoc.getActions([actionNum]); + await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true); + + // compare values + success = isEqual(expected, newValues); + + if (!success) { + const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds(); + const result = await activeDoc.getFormulaError({client: null, mode: 'system'} as any, rec.table_id, + rec.col_id, rowIds[0]); + if (Array.isArray(result) && result[0] === 'E') { + result.shift(); + const txt = `I got a \`${result.shift()}\` error:\n` + + '```\n' + + result.shift() + '\n' + + result.shift() + '\n' + + '```\n' + + 'Please answer with the code block you (the assistant) just gave, ' + + 'revised based on this error. Your answer must include a code block. ' + + 'If you have to explain anything, do it after. ' + + 'It is perfectly acceptable (and may be necessary) to do ' + + 'imports from within a method body.\n'; + return { followUp: txt }; + } else { + for (let i = 0; i < expected.length; i++) { + const e = expected[i]; + const v = newValues[i]; + if (String(e) !== String(v)) { + const txt = `I got \`${v}\` where I expected \`${e}\`\n` + + 'Please answer with the code block you (the assistant) just gave, ' + + 'revised based on this information. Your answer must include a code ' + + 'block. If you have to explain anything, do it after.\n'; + return { followUp: txt }; + } + } + } + } + return null; + } + const result = await sendMessage(); + if (result?.followUp && SIMULATE_CONVERSATION) { + // Allow one follow up message, based on error or differences. + const result2 = await sendMessage(result.followUp); + if (result2?.followUp) { + lastFollowUp = result2.followUp; + } + } + } catch (e) { + console.error(e); } - // get values - await activeDoc.docData!.fetchTable(rec.table_id); - const expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice(); - - // send prompt - const tableId = rec.table_id; - const colId = rec.col_id; - const description = rec.Description; - const {suggestedActions} = await activeDoc.getAssistance(session, {tableId, colId, description}); - - // apply modification - const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions); - - // get new values - const newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice(); - - // revert modification - const [bundle] = await activeDoc.getActions([actionNum]); - await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true); - - // compare values - const success = isEqual(expected, newValues); - console.log(` ${success ? 'Successfully' : 'Failed to'} complete formula ` + `for column ${rec.table_id}.${rec.col_id} (doc=${rec.doc_id})`); @@ -162,6 +240,23 @@ export async function runCompletion() { // console.log('expected=', expected); // console.log('actual=', newValues); } + const suggestedFormula = suggestedActions?.length === 1 && + suggestedActions[0][0] === 'ModifyColumn' && + suggestedActions[0][3].formula || suggestedActions; + fs.writeFileSync( + path.join( + PATH_TO_RESULTS, + `${rec.table_id}_${rec.col_id}_` + + caseCount.toLocaleString('en', {minimumIntegerDigits: 8, useGrouping: false}) + '.json'), + JSON.stringify({ + formula, + suggestedFormula, success, + expectedValues: expected, + suggestedValues: newValues, + history, + lastFollowUp, + }, null, 2)); + caseCount++; } } finally { await docTools.after(); @@ -171,6 +266,13 @@ export async function runCompletion() { console.log( `AI Assistance completed ${successCount} successful prompt on a total of ${records.length};` ); + console.log(JSON.stringify( + { + hit: successCount, + total: records.length, + percentage: (100.0 * successCount) / Math.max(records.length, 1), + } + )); } }