add support for conversational state to assistance endpoint (#506)

* add support for conversational state to assistance endpoint

This refactors the assistance code somewhat, to allow carrying
along some conversational state. It extends the OpenAI-flavored
assistant to make use of that state to have a conversation.
The front-end is tweaked a little bit to allow for replies that
don't have any code in them (though I didn't get into formatting
such replies nicely).

Currently tested primarily through the runCompletion script,
which has been extended a bit to allow testing simulated
conversations (where an error is pasted in follow-up, or
an expected-vs-actual comparison).

Co-authored-by: George Gevoian <85144792+georgegevoian@users.noreply.github.com>
This commit is contained in:
Paul Fitzpatrick 2023-05-08 11:15:22 -07:00 committed by GitHub
parent 68fbeb4d7b
commit 51a195bd94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 587 additions and 185 deletions

View File

@ -4,6 +4,7 @@ import {CellRec, DocModel, IRowModel, recordSet,
refRecord, TableRec, ViewFieldRec} from 'app/client/models/DocModel'; refRecord, TableRec, ViewFieldRec} from 'app/client/models/DocModel';
import {urlState} from 'app/client/models/gristUrlState'; import {urlState} from 'app/client/models/gristUrlState';
import {jsonObservable, ObjObservable} from 'app/client/models/modelUtil'; import {jsonObservable, ObjObservable} from 'app/client/models/modelUtil';
import {AssistanceState} from 'app/common/AssistancePrompts';
import * as gristTypes from 'app/common/gristTypes'; import * as gristTypes from 'app/common/gristTypes';
import {getReferencedTableId} from 'app/common/gristTypes'; import {getReferencedTableId} from 'app/common/gristTypes';
import { import {
@ -83,7 +84,7 @@ export interface ColumnRec extends IRowModel<"_grist_Tables_column"> {
/** /**
* Current history of chat. This is a temporary array used only in the ui. * Current history of chat. This is a temporary array used only in the ui.
*/ */
chatHistory: ko.PureComputed<Observable<ChatMessage[]>>; chatHistory: ko.PureComputed<Observable<ChatHistory>>;
// Helper which adds/removes/updates column's displayCol to match the formula. // Helper which adds/removes/updates column's displayCol to match the formula.
saveDisplayFormula(formula: string): Promise<void>|undefined; saveDisplayFormula(formula: string): Promise<void>|undefined;
@ -162,8 +163,9 @@ export function createColumnRec(this: ColumnRec, docModel: DocModel): void {
this.chatHistory = this.autoDispose(ko.computed(() => { this.chatHistory = this.autoDispose(ko.computed(() => {
const docId = urlState().state.get().doc ?? ''; const docId = urlState().state.get().doc ?? '';
const key = `formula-assistant-history-${docId}-${this.table().tableId()}-${this.colId()}`; // Changed key name from history to history-v2 when ChatHistory changed in incompatible way.
return localStorageJsonObs(key, [] as ChatMessage[]); const key = `formula-assistant-history-v2-${docId}-${this.table().tableId()}-${this.colId()}`;
return localStorageJsonObs(key, {messages: []} as ChatHistory);
})); }));
} }
@ -196,8 +198,20 @@ export interface ChatMessage {
*/ */
sender: 'user' | 'ai'; sender: 'user' | 'ai';
/** /**
* The formula returned from the AI. It is only set when the sender is the AI. For now it is the same * The formula returned from the AI. It is only set when the sender is the AI.
* value as the message, but it might change in the future when we use more conversational AI.
*/ */
formula?: string; formula?: string;
} }
/**
* The state of assistance for a particular column.
* ChatMessages are what are shown in the UI, whereas state is
* how the back-end represents the conversation. The two are
* similar but not the same because of post-processing.
* It may be possible to reconcile them when things settle down
* a bit?
*/
export interface ChatHistory {
messages: ChatMessage[];
state?: AssistanceState;
}

View File

@ -8,7 +8,7 @@ import {basicButton, primaryButton, textButton} from 'app/client/ui2018/buttons'
import {theme} from 'app/client/ui2018/cssVars'; import {theme} from 'app/client/ui2018/cssVars';
import {cssTextInput, rawTextInput} from 'app/client/ui2018/editableLabel'; import {cssTextInput, rawTextInput} from 'app/client/ui2018/editableLabel';
import {icon} from 'app/client/ui2018/icons'; import {icon} from 'app/client/ui2018/icons';
import {Suggestion} from 'app/common/AssistancePrompts'; import {AssistanceResponse, AssistanceState} from 'app/common/AssistancePrompts';
import {Disposable, dom, makeTestId, MultiHolder, obsArray, Observable, styled} from 'grainjs'; import {Disposable, dom, makeTestId, MultiHolder, obsArray, Observable, styled} from 'grainjs';
import noop from 'lodash/noop'; import noop from 'lodash/noop';
@ -78,7 +78,7 @@ function buildControls(
} }
) { ) {
const hasHistory = props.column.chatHistory.peek().get().length > 0; const hasHistory = props.column.chatHistory.peek().get().messages.length > 0;
// State variables, to show various parts of the UI. // State variables, to show various parts of the UI.
const saveButtonVisible = Observable.create(owner, true); const saveButtonVisible = Observable.create(owner, true);
@ -153,36 +153,46 @@ function buildControls(
}; };
} }
function buildChat(owner: Disposable, context: Context & { formulaClicked: (formula: string) => void }) { function buildChat(owner: Disposable, context: Context & { formulaClicked: (formula?: string) => void }) {
const { grist, column } = context; const { grist, column } = context;
const history = owner.autoDispose(obsArray(column.chatHistory.peek().get())); const history = owner.autoDispose(obsArray(column.chatHistory.peek().get().messages));
const hasHistory = history.get().length > 0; const hasHistory = history.get().length > 0;
const enabled = Observable.create(owner, hasHistory); const enabled = Observable.create(owner, hasHistory);
const introVisible = Observable.create(owner, !hasHistory); const introVisible = Observable.create(owner, !hasHistory);
owner.autoDispose(history.addListener((cur) => { owner.autoDispose(history.addListener((cur) => {
column.chatHistory.peek().set([...cur]); const chatHistory = column.chatHistory.peek();
chatHistory.set({...chatHistory.get(), messages: [...cur]});
})); }));
const submit = async () => { const submit = async (regenerate: boolean = false) => {
// Ask about suggestion, and send the whole history. Currently the chat is implemented by just sending // Send most recent question, and send back any conversation
// all previous user prompts back to the AI. This is subject to change (and probably should be done in the backend). // state we have been asked to track.
const prompt = history.get().filter(x => x.sender === 'user') const chatHistory = column.chatHistory.peek().get();
.map(entry => entry.message) const messages = chatHistory.messages.filter(msg => msg.sender === 'user');
.filter(Boolean) const description = messages[messages.length - 1]?.message || '';
.join("\n"); console.debug('description', {description});
console.debug('prompt', prompt); const {reply, suggestedActions, state} = await askAI(grist, {
const { suggestedActions } = await askAI(grist, column, prompt); column, description, state: chatHistory.state,
console.debug('suggestedActions', suggestedActions); regenerate,
});
console.debug('suggestedActions', {suggestedActions, reply});
const firstAction = suggestedActions[0] as any; const firstAction = suggestedActions[0] as any;
// Add the formula to the history. // Add the formula to the history.
const formula = firstAction[3].formula as string; const formula = firstAction ? firstAction[3].formula as string : undefined;
// Add to history // Add to history
history.push({ history.push({
message: formula, message: formula || reply || '(no reply)',
sender: 'ai', sender: 'ai',
formula formula
}); });
// If back-end is capable of conversation, keep its state.
if (state) {
const chatHistoryNew = column.chatHistory.peek();
const value = chatHistoryNew.get();
value.state = state;
chatHistoryNew.set(value);
}
return formula; return formula;
}; };
@ -203,12 +213,13 @@ function buildChat(owner: Disposable, context: Context & { formulaClicked: (form
// Remove the last AI response from the history. // Remove the last AI response from the history.
history.pop(); history.pop();
// And submit again. // And submit again.
context.formulaClicked(await submit()); context.formulaClicked(await submit(true));
}; };
const newChat = () => { const newChat = () => {
// Clear the history. // Clear the history.
history.set([]); history.set([]);
column.chatHistory.peek().set({messages: []});
// Show intro. // Show intro.
introVisible.set(true); introVisible.set(true);
}; };
@ -371,9 +382,11 @@ function openAIAssistant(grist: GristDoc, column: ColumnRec) {
const chat = buildChat(owner, {...props, const chat = buildChat(owner, {...props,
// When a formula is clicked (or just was returned from the AI), we set it in the formula editor and hit // When a formula is clicked (or just was returned from the AI), we set it in the formula editor and hit
// the preview button. // the preview button.
formulaClicked: (formula: string) => { formulaClicked: (formula?: string) => {
if (formula) {
formulaEditor.set(formula); formulaEditor.set(formula);
controls.preview().catch(reportError); controls.preview().catch(reportError);
}
}, },
}); });
@ -397,11 +410,22 @@ function openAIAssistant(grist: GristDoc, column: ColumnRec) {
grist.formulaPopup.autoDispose(popup); grist.formulaPopup.autoDispose(popup);
} }
async function askAI(grist: GristDoc, column: ColumnRec, description: string): Promise<Suggestion> { async function askAI(grist: GristDoc, options: {
column: ColumnRec,
description: string,
regenerate?: boolean,
state?: AssistanceState
}): Promise<AssistanceResponse> {
const {column, description, state, regenerate} = options;
const tableId = column.table.peek().tableId.peek(); const tableId = column.table.peek().tableId.peek();
const colId = column.colId.peek(); const colId = column.colId.peek();
try { try {
const result = await grist.docComm.getAssistance({tableId, colId, description}); const result = await grist.docComm.getAssistance({
context: {type: 'formula', tableId, colId},
text: description,
state,
regenerate,
});
return result; return result;
} catch (error) { } catch (error) {
reportError(error); reportError(error);

View File

@ -1,5 +1,5 @@
import {ActionGroup} from 'app/common/ActionGroup'; import {ActionGroup} from 'app/common/ActionGroup';
import {Prompt, Suggestion} from 'app/common/AssistancePrompts'; import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions'; import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions';
import {FormulaProperties} from 'app/common/GranularAccessClause'; import {FormulaProperties} from 'app/common/GranularAccessClause';
import {UIRowId} from 'app/common/UIRowId'; import {UIRowId} from 'app/common/UIRowId';
@ -323,7 +323,7 @@ export interface ActiveDocAPI {
/** /**
* Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula. * Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
*/ */
getAssistance(userPrompt: Prompt): Promise<Suggestion>; getAssistance(request: AssistanceRequest): Promise<AssistanceResponse>;
/** /**
* Fetch content at a url. * Fetch content at a url.

View File

@ -1,11 +1,56 @@
import {DocAction} from 'app/common/DocActions'; import {DocAction} from 'app/common/DocActions';
export interface Prompt { /**
tableId: string; * State related to a request for assistance.
colId: string *
description: string; * If an AssistanceResponse contains state, that state can be
* echoed back in an AssistanceRequest to continue a "conversation."
*
* Ideally, the state should not be modified or relied upon
* by the client, so as not to commit too hard to a particular
* model at this time (it is a bit early for that).
*/
export interface AssistanceState {
messages?: Array<{
role: string;
content: string;
}>;
} }
export interface Suggestion { /**
suggestedActions: DocAction[]; * Currently, requests for assistance always happen in the context
* of the column of a particular table.
*/
export interface FormulaAssistanceContext {
type: 'formula';
tableId: string;
colId: string;
}
export type AssistanceContext = FormulaAssistanceContext;
/**
* A request for assistance.
*/
export interface AssistanceRequest {
context: AssistanceContext;
state?: AssistanceState;
text: string;
regenerate?: boolean; // Set if there was a previous request
// and response that should be omitted
// from history, or (if available) an
// alternative response generated.
}
/**
* A response to a request for assistance.
* The client should preserve the state and include it in
* any follow-up requests.
*/
export interface AssistanceResponse {
suggestedActions: DocAction[];
state?: AssistanceState;
// If the model can be trusted to issue a self-contained
// markdown-friendly string, it can be included here.
reply?: string;
} }

View File

@ -14,7 +14,7 @@ import {
} from 'app/common/ActionBundle'; } from 'app/common/ActionBundle';
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup'; import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
import {ActionSummary} from "app/common/ActionSummary"; import {ActionSummary} from "app/common/ActionSummary";
import {Prompt, Suggestion} from "app/common/AssistancePrompts"; import {AssistanceRequest, AssistanceResponse} from "app/common/AssistancePrompts";
import { import {
AclResources, AclResources,
AclTableDescription, AclTableDescription,
@ -85,7 +85,7 @@ import {Document} from 'app/gen-server/entity/Document';
import {ParseOptions} from 'app/plugin/FileParserAPI'; import {ParseOptions} from 'app/plugin/FileParserAPI';
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI'; import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
import {compileAclFormula} from 'app/server/lib/ACLFormula'; import {compileAclFormula} from 'app/server/lib/ACLFormula';
import {sendForCompletion} from 'app/server/lib/Assistance'; import {AssistanceDoc, AssistanceSchemaPromptV1Context, sendForCompletion} from 'app/server/lib/Assistance';
import {Authorizer} from 'app/server/lib/Authorizer'; import {Authorizer} from 'app/server/lib/Authorizer';
import {checksumFile} from 'app/server/lib/checksumFile'; import {checksumFile} from 'app/server/lib/checksumFile';
import {Client} from 'app/server/lib/Client'; import {Client} from 'app/server/lib/Client';
@ -180,7 +180,7 @@ interface UpdateUsageOptions {
* either .loadDoc() or .createEmptyDoc() is called. * either .loadDoc() or .createEmptyDoc() is called.
* @param {String} docName - The document's filename, without the '.grist' extension. * @param {String} docName - The document's filename, without the '.grist' extension.
*/ */
export class ActiveDoc extends EventEmitter { export class ActiveDoc extends EventEmitter implements AssistanceDoc {
/** /**
* Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e. * Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e.
* until the returned promise is resolved. * until the returned promise is resolved.
@ -1112,7 +1112,7 @@ export class ActiveDoc extends EventEmitter {
* @param {Integer} rowId - Row number * @param {Integer} rowId - Row number
* @returns {Promise} Promise for a error message * @returns {Promise} Promise for a error message
*/ */
public async getFormulaError(docSession: DocSession, tableId: string, colId: string, public async getFormulaError(docSession: OptDocSession, tableId: string, colId: string,
rowId: number): Promise<CellValue> { rowId: number): Promise<CellValue> {
// Throw an error if the user doesn't have access to read this cell. // Throw an error if the user doesn't have access to read this cell.
await this._granularAccess.getCellValue(docSession, {tableId, colId, rowId}); await this._granularAccess.getCellValue(docSession, {tableId, colId, rowId});
@ -1260,22 +1260,28 @@ export class ActiveDoc extends EventEmitter {
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON()); return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
} }
public async getAssistance(docSession: DocSession, userPrompt: Prompt): Promise<Suggestion> { public async getAssistance(docSession: DocSession, request: AssistanceRequest): Promise<AssistanceResponse> {
// Making a prompt can leak names of tables and columns. return this.getAssistanceWithOptions(docSession, request);
}
public async getAssistanceWithOptions(docSession: DocSession,
request: AssistanceRequest): Promise<AssistanceResponse> {
// Making a prompt leaks names of tables and columns etc.
if (!await this._granularAccess.canScanData(docSession)) { if (!await this._granularAccess.canScanData(docSession)) {
throw new Error("Permission denied"); throw new Error("Permission denied");
} }
await this.waitForInitialization(); await this.waitForInitialization();
const { tableId, colId, description } = userPrompt; return sendForCompletion(this, request);
const prompt = await this._pyCall('get_formula_prompt', tableId, colId, description); }
this._log.debug(docSession, 'getAssistance prompt', {prompt});
const completion = await sendForCompletion(prompt); // Callback to make a data-engine formula tweak for assistance.
this._log.debug(docSession, 'getAssistance completion', {completion}); public assistanceFormulaTweak(txt: string) {
const formula = await this._pyCall('convert_formula_completion', completion); return this._pyCall('convert_formula_completion', txt);
const action: DocAction = ["ModifyColumn", tableId, colId, {formula}]; }
return {
suggestedActions: [action], // Callback to generate a prompt containing schema info for assistance.
}; public assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string> {
return this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
} }
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> { public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {

View File

@ -2,76 +2,187 @@
* Module with functions used for AI formula assistance. * Module with functions used for AI formula assistance.
*/ */
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {delay} from 'app/common/delay'; import {delay} from 'app/common/delay';
import {DocAction} from 'app/common/DocActions';
import log from 'app/server/lib/log'; import log from 'app/server/lib/log';
import fetch from 'node-fetch'; import fetch from 'node-fetch';
export const DEPS = { fetch }; export const DEPS = { fetch };
export async function sendForCompletion(prompt: string): Promise<string> { /**
let completion: string|null = null; * An assistant can help a user do things with their document,
let retries: number = 0; * by interfacing with an external LLM endpoint.
const openApiKey = process.env.OPENAI_API_KEY; */
const model = process.env.COMPLETION_MODEL || "text-davinci-002"; export interface Assistant {
apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
while(retries++ < 3) {
try {
if (openApiKey) {
completion = await sendForCompletionOpenAI(prompt, openApiKey, model);
}
if (process.env.HUGGINGFACE_API_KEY) {
completion = await sendForCompletionHuggingFace(prompt);
}
break;
} catch(e) {
await delay(1000);
}
}
if (completion === null) {
throw new Error("Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY (and optionally COMPLETION_MODEL)");
}
log.debug(`Received completion:`, {completion});
completion = completion.split(/\n {4}[^ ]/)[0];
return completion;
} }
/**
* Document-related methods for use in the implementation of assistants.
* Somewhat ad-hoc currently.
*/
export interface AssistanceDoc {
/**
* Generate a particular prompt coded in the data engine for some reason.
* It makes python code for some tables, and starts a function body with
* the given docstring.
* Marked "V1" to suggest that it is a particular prompt and it would
* be great to try variants.
*/
assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string>;
async function sendForCompletionOpenAI(prompt: string, apiKey: string, model = "text-davinci-002") { /**
* Some tweaks to a formula after it has been generated.
*/
assistanceFormulaTweak(txt: string): Promise<string>;
}
export interface AssistanceSchemaPromptV1Context {
tableId: string,
colId: string,
docString: string,
}
/**
* A flavor of assistant for use with the OpenAI API.
* Tested primarily with text-davinci-002 and gpt-3.5-turbo.
*/
export class OpenAIAssistant implements Assistant {
private _apiKey: string;
private _model: string;
private _chatMode: boolean;
private _endpoint: string;
public constructor() {
const apiKey = process.env.OPENAI_API_KEY;
if (!apiKey) { if (!apiKey) {
throw new Error("OPENAI_API_KEY not set"); throw new Error('OPENAI_API_KEY not set');
} }
const response = await DEPS.fetch( this._apiKey = apiKey;
"https://api.openai.com/v1/completions", this._model = process.env.COMPLETION_MODEL || "text-davinci-002";
this._chatMode = this._model.includes('turbo');
this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`;
}
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
const messages = request.state?.messages || [];
const chatMode = this._chatMode;
if (chatMode) {
if (messages.length === 0) {
messages.push({
role: 'system',
content: 'The user gives you one or more Python classes, ' +
'with one last method that needs completing. Write the ' +
'method body as a single code block, ' +
'including the docstring the user gave. ' +
'Just give the Python code as a markdown block, ' +
'do not give any introduction, that will just be ' +
'awkward for the user when copying and pasting. ' +
'You are working with Grist, an environment very like ' +
'regular Python except `rec` (like record) is used ' +
'instead of `self`. ' +
'Include at least one `return` statement or the method ' +
'will fail, disappointing the user. ' +
'Your answer should be the body of a single method, ' +
'not a class, and should not include `dataclass` or ' +
'`class` since the user is counting on you to provide ' +
'a single method. Thanks!'
});
messages.push({
role: 'user', content: await makeSchemaPromptV1(doc, request),
});
} else {
if (request.regenerate) {
if (messages[messages.length - 1].role !== 'user') {
messages.pop();
}
}
messages.push({
role: 'user', content: request.text,
});
}
} else {
messages.length = 0;
messages.push({
role: 'user', content: await makeSchemaPromptV1(doc, request),
});
}
const apiResponse = await DEPS.fetch(
this._endpoint,
{ {
method: "POST", method: "POST",
headers: { headers: {
"Authorization": `Bearer ${apiKey}`, "Authorization": `Bearer ${this._apiKey}`,
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: JSON.stringify({ body: JSON.stringify({
prompt, ...(!this._chatMode ? {
max_tokens: 150, prompt: messages[messages.length - 1].content,
} : { messages }),
max_tokens: 1500,
temperature: 0, temperature: 0,
// COMPLETION_MODEL of `code-davinci-002` may be better if you have access to it. model: this._model,
model, stop: this._chatMode ? undefined : ["\n\n"],
stop: ["\n\n"],
}), }),
}, },
); );
if (response.status !== 200) { if (apiResponse.status !== 200) {
log.error(`OpenAI API returned ${response.status}: ${await response.text()}`); log.error(`OpenAI API returned ${apiResponse.status}: ${await apiResponse.text()}`);
throw new Error(`OpenAI API returned status ${response.status}`); throw new Error(`OpenAI API returned status ${apiResponse.status}`);
}
const result = await apiResponse.json();
let completion: string = String(chatMode ? result.choices[0].message.content : result.choices[0].text);
const reply = completion;
const history = { messages };
if (chatMode) {
history.messages.push(result.choices[0].message);
// This model likes returning markdown. Code will typically
// be in a code block with ``` delimiters.
let lines = completion.split('\n');
if (lines[0].startsWith('```')) {
lines.shift();
completion = lines.join('\n');
const parts = completion.split('```');
if (parts.length > 1) {
completion = parts[0];
}
lines = completion.split('\n');
}
// This model likes repeating the function signature and
// docstring, so we try to strip that out.
completion = lines.join('\n');
while (completion.includes('"""')) {
const parts = completion.split('"""');
completion = parts[parts.length - 1];
}
// If there's no code block, don't treat the answer as a formula.
if (!reply.includes('```')) {
completion = '';
}
}
const response = await completionToResponse(doc, request, completion, reply);
if (chatMode) {
response.state = history;
}
return response;
} }
const result = await response.json();
const completion = result.choices[0].text;
return completion;
} }
async function sendForCompletionHuggingFace(prompt: string) { export class HuggingFaceAssistant implements Assistant {
private _apiKey: string;
private _completionUrl: string;
public constructor() {
const apiKey = process.env.HUGGINGFACE_API_KEY; const apiKey = process.env.HUGGINGFACE_API_KEY;
if (!apiKey) { if (!apiKey) {
throw new Error("HUGGINGFACE_API_KEY not set"); throw new Error('HUGGINGFACE_API_KEY not set');
} }
this._apiKey = apiKey;
// COMPLETION_MODEL values I've tried: // COMPLETION_MODEL values I've tried:
// - codeparrot/codeparrot // - codeparrot/codeparrot
// - NinedayWang/PolyCoder-2.7B // - NinedayWang/PolyCoder-2.7B
@ -84,13 +195,21 @@ async function sendForCompletionHuggingFace(prompt: string) {
completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B'; completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
} }
} }
this._completionUrl = completionUrl;
}
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
if (request.state) {
throw new Error("HuggingFaceAssistant does not support state");
}
const prompt = await makeSchemaPromptV1(doc, request);
const response = await DEPS.fetch( const response = await DEPS.fetch(
completionUrl, this._completionUrl,
{ {
method: "POST", method: "POST",
headers: { headers: {
"Authorization": `Bearer ${apiKey}`, "Authorization": `Bearer ${this._apiKey}`,
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: JSON.stringify({ body: JSON.stringify({
@ -112,6 +231,91 @@ async function sendForCompletionHuggingFace(prompt: string) {
throw new Error(`HuggingFace API returned status ${response.status}: ${text}`); throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
} }
const result = await response.json(); const result = await response.json();
const completion = result[0].generated_text; let completion = result[0].generated_text;
return completion.split('\n\n')[0]; completion = completion.split('\n\n')[0];
return completionToResponse(doc, request, completion);
}
}
/**
* Instantiate an assistant, based on environment variables.
*/
function getAssistant() {
if (process.env.OPENAI_API_KEY) {
return new OpenAIAssistant();
}
if (process.env.HUGGINGFACE_API_KEY) {
return new HuggingFaceAssistant();
}
throw new Error('Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY');
}
/**
* Service a request for assistance, with a little retry logic
* since these endpoints can be a bit flakey.
*/
export async function sendForCompletion(doc: AssistanceDoc,
request: AssistanceRequest): Promise<AssistanceResponse> {
const assistant = getAssistant();
let retries: number = 0;
let response: AssistanceResponse|null = null;
while(retries++ < 3) {
try {
response = await assistant.apply(doc, request);
break;
} catch(e) {
log.error(`Completion error: ${e}`);
await delay(1000);
}
}
if (!response) {
throw new Error('Failed to get response from assistant');
}
return response;
}
async function makeSchemaPromptV1(doc: AssistanceDoc, request: AssistanceRequest) {
if (request.context.type !== 'formula') {
throw new Error('makeSchemaPromptV1 only works for formulas');
}
return doc.assistanceSchemaPromptV1({
tableId: request.context.tableId,
colId: request.context.colId,
docString: request.text,
});
}
async function completionToResponse(doc: AssistanceDoc, request: AssistanceRequest,
completion: string, reply?: string): Promise<AssistanceResponse> {
if (request.context.type !== 'formula') {
throw new Error('completionToResponse only works for formulas');
}
completion = await doc.assistanceFormulaTweak(completion);
// A leading newline is common.
if (completion.charAt(0) === '\n') {
completion = completion.slice(1);
}
// If all non-empty lines have four spaces, remove those spaces.
// They are common for GPT-3.5, which matches the prompt carefully.
const lines = completion.split('\n');
const ok = lines.every(line => line === '\n' || line.startsWith(' '));
if (ok) {
completion = lines.map(line => line === '\n' ? line : line.slice(4)).join('\n');
}
// Suggest an action only if the completion is non-empty (that is,
// it actually looked like code).
const suggestedActions: DocAction[] = completion ? [[
"ModifyColumn",
request.context.tableId,
request.context.colId, {
formula: completion,
}
]] : [];
return {
suggestedActions,
reply,
};
} }

View File

@ -376,7 +376,8 @@ export class GranularAccess implements GranularAccessForBundle {
function fail(): never { function fail(): never {
throw new ErrorWithCode('ACL_DENY', 'Cannot access cell'); throw new ErrorWithCode('ACL_DENY', 'Cannot access cell');
} }
if (!await this.hasTableAccess(docSession, cell.tableId)) { fail(); } const hasExceptionalAccess = this._hasExceptionalFullAccess(docSession);
if (!hasExceptionalAccess && !await this.hasTableAccess(docSession, cell.tableId)) { fail(); }
let rows: TableDataAction|null = null; let rows: TableDataAction|null = null;
if (docData) { if (docData) {
const record = docData.getTable(cell.tableId)?.getRecord(cell.rowId); const record = docData.getTable(cell.tableId)?.getRecord(cell.rowId);
@ -393,6 +394,7 @@ export class GranularAccess implements GranularAccessForBundle {
return fail(); return fail();
} }
const rec = new RecordView(rows, 0); const rec = new RecordView(rows, 0);
if (!hasExceptionalAccess) {
const input: AclMatchInput = {...await this.inputs(docSession), rec, newRec: rec}; const input: AclMatchInput = {...await this.inputs(docSession), rec, newRec: rec};
const rowPermInfo = new PermissionInfo(this._ruler.ruleCollection, input); const rowPermInfo = new PermissionInfo(this._ruler.ruleCollection, input);
const rowAccess = rowPermInfo.getTableAccess(cell.tableId).perms.read; const rowAccess = rowPermInfo.getTableAccess(cell.tableId).perms.read;
@ -403,6 +405,7 @@ export class GranularAccess implements GranularAccessForBundle {
} }
const colValues = rows[3]; const colValues = rows[3];
if (!(cell.colId in colValues)) { fail(); } if (!(cell.colId in colValues)) { fail(); }
}
return rec.get(cell.colId); return rec.get(cell.colId);
} }

View File

@ -37,7 +37,6 @@ def column_type(engine, table_id, col_id):
Attachments="Any", Attachments="Any",
)[parts[0]] )[parts[0]]
def choices(col_rec): def choices(col_rec):
try: try:
widget_options = json.loads(col_rec.widgetOptions) widget_options = json.loads(col_rec.widgetOptions)

View File

@ -1,12 +1,17 @@
#!/usr/bin/env node #!/usr/bin/env node
"use strict"; "use strict";
const fs = require('fs');
const path = require('path'); const path = require('path');
const codeRoot = path.dirname(path.dirname(path.dirname(__dirname)));
let codeRoot = path.dirname(path.dirname(__dirname));
if (!fs.existsSync(path.join(codeRoot, '_build'))) {
codeRoot = path.dirname(codeRoot);
}
process.env.DATA_PATH = path.join(__dirname, 'data'); process.env.DATA_PATH = path.join(__dirname, 'data');
require('app-module-path').addPath(path.join(codeRoot, '_build')); require('app-module-path').addPath(path.join(codeRoot, '_build'));
require('app-module-path').addPath(path.join(codeRoot, '_build', 'core')); require('app-module-path').addPath(path.join(codeRoot, '_build', 'core'));
require('app-module-path').addPath(path.join(codeRoot, '_build', 'ext')); require('app-module-path').addPath(path.join(codeRoot, '_build', 'ext'));
require('app-module-path').addPath(path.join(codeRoot, '_build', 'stubs'));
require('test/formula-dataset/runCompletion_impl').runCompletion().catch(console.error); require('test/formula-dataset/runCompletion_impl').runCompletion().catch(console.error);

View File

@ -24,7 +24,7 @@
*/ */
import { ActiveDoc } from "app/server/lib/ActiveDoc"; import { ActiveDoc, Deps as ActiveDocDeps } from "app/server/lib/ActiveDoc";
import { DEPS } from "app/server/lib/Assistance"; import { DEPS } from "app/server/lib/Assistance";
import log from 'app/server/lib/log'; import log from 'app/server/lib/log';
import crypto from 'crypto'; import crypto from 'crypto';
@ -38,11 +38,14 @@ import * as os from 'os';
import { pipeline } from 'stream'; import { pipeline } from 'stream';
import { createDocTools } from "test/server/docTools"; import { createDocTools } from "test/server/docTools";
import { promisify } from 'util'; import { promisify } from 'util';
import { AssistanceResponse, AssistanceState } from "app/common/AssistancePrompts";
import { CellValue } from "app/plugin/GristData";
const streamPipeline = promisify(pipeline); const streamPipeline = promisify(pipeline);
const DATA_PATH = process.env.DATA_PATH || path.join(__dirname, 'data'); const DATA_PATH = process.env.DATA_PATH || path.join(__dirname, 'data');
const PATH_TO_DOC = path.join(DATA_PATH, 'templates'); const PATH_TO_DOC = path.join(DATA_PATH, 'templates');
const PATH_TO_RESULTS = path.join(DATA_PATH, 'results');
const PATH_TO_CSV = path.join(DATA_PATH, 'formula-dataset-index.csv'); const PATH_TO_CSV = path.join(DATA_PATH, 'formula-dataset-index.csv');
const PATH_TO_CACHE = path.join(DATA_PATH, 'cache'); const PATH_TO_CACHE = path.join(DATA_PATH, 'cache');
const TEMPLATE_URL = "https://grist-static.com/datasets/grist_dataset_formulai_2023_02_20.zip"; const TEMPLATE_URL = "https://grist-static.com/datasets/grist_dataset_formulai_2023_02_20.zip";
@ -60,8 +63,11 @@ const _stats = {
callCount: 0, callCount: 0,
}; };
const SEEMS_CHATTY = (process.env.COMPLETION_MODEL || '').includes('turbo');
const SIMULATE_CONVERSATION = SEEMS_CHATTY;
export async function runCompletion() { export async function runCompletion() {
ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600000;
// if template directory not exists, make it // if template directory not exists, make it
if (!fs.existsSync(path.join(PATH_TO_DOC))) { if (!fs.existsSync(path.join(PATH_TO_DOC))) {
@ -107,11 +113,12 @@ export async function runCompletion() {
if (!process.env.VERBOSE) { if (!process.env.VERBOSE) {
log.transports.file.level = 'error'; // Suppress most of log output. log.transports.file.level = 'error'; // Suppress most of log output.
} }
let activeDoc: ActiveDoc|undefined;
const docTools = createDocTools(); const docTools = createDocTools();
const session = docTools.createFakeSession('owners'); const session = docTools.createFakeSession('owners');
await docTools.before(); await docTools.before();
let successCount = 0; let successCount = 0;
let caseCount = 0;
fs.mkdirSync(path.join(PATH_TO_RESULTS), {recursive: true});
console.log('Testing AI assistance: '); console.log('Testing AI assistance: ');
@ -119,8 +126,18 @@ export async function runCompletion() {
DEPS.fetch = fetchWithCache; DEPS.fetch = fetchWithCache;
let activeDoc: ActiveDoc|undefined;
for (const rec of records) { for (const rec of records) {
let success: boolean = false;
let suggestedActions: AssistanceResponse['suggestedActions'] | undefined;
let newValues: CellValue[] | undefined;
let expected: CellValue[] | undefined;
let formula: string | undefined;
let history: AssistanceState = {messages: []};
let lastFollowUp: string | undefined;
try {
async function sendMessage(followUp?: string) {
// load new document // load new document
if (!activeDoc || activeDoc.docName !== rec.doc_id) { if (!activeDoc || activeDoc.docName !== rec.doc_id) {
const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist'); const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist');
@ -128,28 +145,89 @@ export async function runCompletion() {
await activeDoc.waitForInitialization(); await activeDoc.waitForInitialization();
} }
if (!activeDoc) { throw new Error("No doc"); }
// get values // get values
await activeDoc.docData!.fetchTable(rec.table_id); await activeDoc.docData!.fetchTable(rec.table_id);
const expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice(); expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
// send prompt // send prompt
const tableId = rec.table_id; const tableId = rec.table_id;
const colId = rec.col_id; const colId = rec.col_id;
const description = rec.Description; const description = rec.Description;
const {suggestedActions} = await activeDoc.getAssistance(session, {tableId, colId, description}); const colInfo = await activeDoc.docStorage.get(`
select * from _grist_Tables_column as c
left join _grist_Tables as t on t.id = c.parentId
where c.colId = ? and t.tableId = ?
`, rec.col_id, rec.table_id);
formula = colInfo?.formula;
const result = await activeDoc.getAssistanceWithOptions(session, {
context: {type: 'formula', tableId, colId},
state: history,
text: followUp || description,
});
if (result.state) {
history = result.state;
}
suggestedActions = result.suggestedActions;
// apply modification // apply modification
const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions); const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions);
// get new values // get new values
const newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice(); newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
// revert modification // revert modification
const [bundle] = await activeDoc.getActions([actionNum]); const [bundle] = await activeDoc.getActions([actionNum]);
await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true); await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true);
// compare values // compare values
const success = isEqual(expected, newValues); success = isEqual(expected, newValues);
if (!success) {
const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds();
const result = await activeDoc.getFormulaError({client: null, mode: 'system'} as any, rec.table_id,
rec.col_id, rowIds[0]);
if (Array.isArray(result) && result[0] === 'E') {
result.shift();
const txt = `I got a \`${result.shift()}\` error:\n` +
'```\n' +
result.shift() + '\n' +
result.shift() + '\n' +
'```\n' +
'Please answer with the code block you (the assistant) just gave, ' +
'revised based on this error. Your answer must include a code block. ' +
'If you have to explain anything, do it after. ' +
'It is perfectly acceptable (and may be necessary) to do ' +
'imports from within a method body.\n';
return { followUp: txt };
} else {
for (let i = 0; i < expected.length; i++) {
const e = expected[i];
const v = newValues[i];
if (String(e) !== String(v)) {
const txt = `I got \`${v}\` where I expected \`${e}\`\n` +
'Please answer with the code block you (the assistant) just gave, ' +
'revised based on this information. Your answer must include a code ' +
'block. If you have to explain anything, do it after.\n';
return { followUp: txt };
}
}
}
}
return null;
}
const result = await sendMessage();
if (result?.followUp && SIMULATE_CONVERSATION) {
// Allow one follow up message, based on error or differences.
const result2 = await sendMessage(result.followUp);
if (result2?.followUp) {
lastFollowUp = result2.followUp;
}
}
} catch (e) {
console.error(e);
}
console.log(` ${success ? 'Successfully' : 'Failed to'} complete formula ` + console.log(` ${success ? 'Successfully' : 'Failed to'} complete formula ` +
`for column ${rec.table_id}.${rec.col_id} (doc=${rec.doc_id})`); `for column ${rec.table_id}.${rec.col_id} (doc=${rec.doc_id})`);
@ -162,6 +240,23 @@ export async function runCompletion() {
// console.log('expected=', expected); // console.log('expected=', expected);
// console.log('actual=', newValues); // console.log('actual=', newValues);
} }
const suggestedFormula = suggestedActions?.length === 1 &&
suggestedActions[0][0] === 'ModifyColumn' &&
suggestedActions[0][3].formula || suggestedActions;
fs.writeFileSync(
path.join(
PATH_TO_RESULTS,
`${rec.table_id}_${rec.col_id}_` +
caseCount.toLocaleString('en', {minimumIntegerDigits: 8, useGrouping: false}) + '.json'),
JSON.stringify({
formula,
suggestedFormula, success,
expectedValues: expected,
suggestedValues: newValues,
history,
lastFollowUp,
}, null, 2));
caseCount++;
} }
} finally { } finally {
await docTools.after(); await docTools.after();
@ -171,6 +266,13 @@ export async function runCompletion() {
console.log( console.log(
`AI Assistance completed ${successCount} successful prompt on a total of ${records.length};` `AI Assistance completed ${successCount} successful prompt on a total of ${records.length};`
); );
console.log(JSON.stringify(
{
hit: successCount,
total: records.length,
percentage: (100.0 * successCount) / Math.max(records.length, 1),
}
));
} }
} }