mirror of
https://github.com/gristlabs/grist-core.git
synced 2024-10-27 20:44:07 +00:00
add support for conversational state to assistance endpoint (#506)
* add support for conversational state to assistance endpoint This refactors the assistance code somewhat, to allow carrying along some conversational state. It extends the OpenAI-flavored assistant to make use of that state to have a conversation. The front-end is tweaked a little bit to allow for replies that don't have any code in them (though I didn't get into formatting such replies nicely). Currently tested primarily through the runCompletion script, which has been extended a bit to allow testing simulated conversations (where an error is pasted in follow-up, or an expected-vs-actual comparison). Co-authored-by: George Gevoian <85144792+georgegevoian@users.noreply.github.com>
This commit is contained in:
parent
68fbeb4d7b
commit
51a195bd94
@ -4,6 +4,7 @@ import {CellRec, DocModel, IRowModel, recordSet,
|
||||
refRecord, TableRec, ViewFieldRec} from 'app/client/models/DocModel';
|
||||
import {urlState} from 'app/client/models/gristUrlState';
|
||||
import {jsonObservable, ObjObservable} from 'app/client/models/modelUtil';
|
||||
import {AssistanceState} from 'app/common/AssistancePrompts';
|
||||
import * as gristTypes from 'app/common/gristTypes';
|
||||
import {getReferencedTableId} from 'app/common/gristTypes';
|
||||
import {
|
||||
@ -83,7 +84,7 @@ export interface ColumnRec extends IRowModel<"_grist_Tables_column"> {
|
||||
/**
|
||||
* Current history of chat. This is a temporary array used only in the ui.
|
||||
*/
|
||||
chatHistory: ko.PureComputed<Observable<ChatMessage[]>>;
|
||||
chatHistory: ko.PureComputed<Observable<ChatHistory>>;
|
||||
|
||||
// Helper which adds/removes/updates column's displayCol to match the formula.
|
||||
saveDisplayFormula(formula: string): Promise<void>|undefined;
|
||||
@ -162,8 +163,9 @@ export function createColumnRec(this: ColumnRec, docModel: DocModel): void {
|
||||
|
||||
this.chatHistory = this.autoDispose(ko.computed(() => {
|
||||
const docId = urlState().state.get().doc ?? '';
|
||||
const key = `formula-assistant-history-${docId}-${this.table().tableId()}-${this.colId()}`;
|
||||
return localStorageJsonObs(key, [] as ChatMessage[]);
|
||||
// Changed key name from history to history-v2 when ChatHistory changed in incompatible way.
|
||||
const key = `formula-assistant-history-v2-${docId}-${this.table().tableId()}-${this.colId()}`;
|
||||
return localStorageJsonObs(key, {messages: []} as ChatHistory);
|
||||
}));
|
||||
}
|
||||
|
||||
@ -196,8 +198,20 @@ export interface ChatMessage {
|
||||
*/
|
||||
sender: 'user' | 'ai';
|
||||
/**
|
||||
* The formula returned from the AI. It is only set when the sender is the AI. For now it is the same
|
||||
* value as the message, but it might change in the future when we use more conversational AI.
|
||||
* The formula returned from the AI. It is only set when the sender is the AI.
|
||||
*/
|
||||
formula?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* The state of assistance for a particular column.
|
||||
* ChatMessages are what are shown in the UI, whereas state is
|
||||
* how the back-end represents the conversation. The two are
|
||||
* similar but not the same because of post-processing.
|
||||
* It may be possible to reconcile them when things settle down
|
||||
* a bit?
|
||||
*/
|
||||
export interface ChatHistory {
|
||||
messages: ChatMessage[];
|
||||
state?: AssistanceState;
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import {basicButton, primaryButton, textButton} from 'app/client/ui2018/buttons'
|
||||
import {theme} from 'app/client/ui2018/cssVars';
|
||||
import {cssTextInput, rawTextInput} from 'app/client/ui2018/editableLabel';
|
||||
import {icon} from 'app/client/ui2018/icons';
|
||||
import {Suggestion} from 'app/common/AssistancePrompts';
|
||||
import {AssistanceResponse, AssistanceState} from 'app/common/AssistancePrompts';
|
||||
import {Disposable, dom, makeTestId, MultiHolder, obsArray, Observable, styled} from 'grainjs';
|
||||
import noop from 'lodash/noop';
|
||||
|
||||
@ -78,7 +78,7 @@ function buildControls(
|
||||
}
|
||||
) {
|
||||
|
||||
const hasHistory = props.column.chatHistory.peek().get().length > 0;
|
||||
const hasHistory = props.column.chatHistory.peek().get().messages.length > 0;
|
||||
|
||||
// State variables, to show various parts of the UI.
|
||||
const saveButtonVisible = Observable.create(owner, true);
|
||||
@ -153,36 +153,46 @@ function buildControls(
|
||||
};
|
||||
}
|
||||
|
||||
function buildChat(owner: Disposable, context: Context & { formulaClicked: (formula: string) => void }) {
|
||||
function buildChat(owner: Disposable, context: Context & { formulaClicked: (formula?: string) => void }) {
|
||||
const { grist, column } = context;
|
||||
|
||||
const history = owner.autoDispose(obsArray(column.chatHistory.peek().get()));
|
||||
const history = owner.autoDispose(obsArray(column.chatHistory.peek().get().messages));
|
||||
const hasHistory = history.get().length > 0;
|
||||
const enabled = Observable.create(owner, hasHistory);
|
||||
const introVisible = Observable.create(owner, !hasHistory);
|
||||
owner.autoDispose(history.addListener((cur) => {
|
||||
column.chatHistory.peek().set([...cur]);
|
||||
const chatHistory = column.chatHistory.peek();
|
||||
chatHistory.set({...chatHistory.get(), messages: [...cur]});
|
||||
}));
|
||||
|
||||
const submit = async () => {
|
||||
// Ask about suggestion, and send the whole history. Currently the chat is implemented by just sending
|
||||
// all previous user prompts back to the AI. This is subject to change (and probably should be done in the backend).
|
||||
const prompt = history.get().filter(x => x.sender === 'user')
|
||||
.map(entry => entry.message)
|
||||
.filter(Boolean)
|
||||
.join("\n");
|
||||
console.debug('prompt', prompt);
|
||||
const { suggestedActions } = await askAI(grist, column, prompt);
|
||||
console.debug('suggestedActions', suggestedActions);
|
||||
const submit = async (regenerate: boolean = false) => {
|
||||
// Send most recent question, and send back any conversation
|
||||
// state we have been asked to track.
|
||||
const chatHistory = column.chatHistory.peek().get();
|
||||
const messages = chatHistory.messages.filter(msg => msg.sender === 'user');
|
||||
const description = messages[messages.length - 1]?.message || '';
|
||||
console.debug('description', {description});
|
||||
const {reply, suggestedActions, state} = await askAI(grist, {
|
||||
column, description, state: chatHistory.state,
|
||||
regenerate,
|
||||
});
|
||||
console.debug('suggestedActions', {suggestedActions, reply});
|
||||
const firstAction = suggestedActions[0] as any;
|
||||
// Add the formula to the history.
|
||||
const formula = firstAction[3].formula as string;
|
||||
const formula = firstAction ? firstAction[3].formula as string : undefined;
|
||||
// Add to history
|
||||
history.push({
|
||||
message: formula,
|
||||
message: formula || reply || '(no reply)',
|
||||
sender: 'ai',
|
||||
formula
|
||||
});
|
||||
// If back-end is capable of conversation, keep its state.
|
||||
if (state) {
|
||||
const chatHistoryNew = column.chatHistory.peek();
|
||||
const value = chatHistoryNew.get();
|
||||
value.state = state;
|
||||
chatHistoryNew.set(value);
|
||||
}
|
||||
return formula;
|
||||
};
|
||||
|
||||
@ -203,12 +213,13 @@ function buildChat(owner: Disposable, context: Context & { formulaClicked: (form
|
||||
// Remove the last AI response from the history.
|
||||
history.pop();
|
||||
// And submit again.
|
||||
context.formulaClicked(await submit());
|
||||
context.formulaClicked(await submit(true));
|
||||
};
|
||||
|
||||
const newChat = () => {
|
||||
// Clear the history.
|
||||
history.set([]);
|
||||
column.chatHistory.peek().set({messages: []});
|
||||
// Show intro.
|
||||
introVisible.set(true);
|
||||
};
|
||||
@ -371,9 +382,11 @@ function openAIAssistant(grist: GristDoc, column: ColumnRec) {
|
||||
const chat = buildChat(owner, {...props,
|
||||
// When a formula is clicked (or just was returned from the AI), we set it in the formula editor and hit
|
||||
// the preview button.
|
||||
formulaClicked: (formula: string) => {
|
||||
formulaClicked: (formula?: string) => {
|
||||
if (formula) {
|
||||
formulaEditor.set(formula);
|
||||
controls.preview().catch(reportError);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@ -397,11 +410,22 @@ function openAIAssistant(grist: GristDoc, column: ColumnRec) {
|
||||
grist.formulaPopup.autoDispose(popup);
|
||||
}
|
||||
|
||||
async function askAI(grist: GristDoc, column: ColumnRec, description: string): Promise<Suggestion> {
|
||||
async function askAI(grist: GristDoc, options: {
|
||||
column: ColumnRec,
|
||||
description: string,
|
||||
regenerate?: boolean,
|
||||
state?: AssistanceState
|
||||
}): Promise<AssistanceResponse> {
|
||||
const {column, description, state, regenerate} = options;
|
||||
const tableId = column.table.peek().tableId.peek();
|
||||
const colId = column.colId.peek();
|
||||
try {
|
||||
const result = await grist.docComm.getAssistance({tableId, colId, description});
|
||||
const result = await grist.docComm.getAssistance({
|
||||
context: {type: 'formula', tableId, colId},
|
||||
text: description,
|
||||
state,
|
||||
regenerate,
|
||||
});
|
||||
return result;
|
||||
} catch (error) {
|
||||
reportError(error);
|
||||
|
@ -1,5 +1,5 @@
|
||||
import {ActionGroup} from 'app/common/ActionGroup';
|
||||
import {Prompt, Suggestion} from 'app/common/AssistancePrompts';
|
||||
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
|
||||
import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions';
|
||||
import {FormulaProperties} from 'app/common/GranularAccessClause';
|
||||
import {UIRowId} from 'app/common/UIRowId';
|
||||
@ -323,7 +323,7 @@ export interface ActiveDocAPI {
|
||||
/**
|
||||
* Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
|
||||
*/
|
||||
getAssistance(userPrompt: Prompt): Promise<Suggestion>;
|
||||
getAssistance(request: AssistanceRequest): Promise<AssistanceResponse>;
|
||||
|
||||
/**
|
||||
* Fetch content at a url.
|
||||
|
@ -1,11 +1,56 @@
|
||||
import {DocAction} from 'app/common/DocActions';
|
||||
|
||||
export interface Prompt {
|
||||
tableId: string;
|
||||
colId: string
|
||||
description: string;
|
||||
/**
|
||||
* State related to a request for assistance.
|
||||
*
|
||||
* If an AssistanceResponse contains state, that state can be
|
||||
* echoed back in an AssistanceRequest to continue a "conversation."
|
||||
*
|
||||
* Ideally, the state should not be modified or relied upon
|
||||
* by the client, so as not to commit too hard to a particular
|
||||
* model at this time (it is a bit early for that).
|
||||
*/
|
||||
export interface AssistanceState {
|
||||
messages?: Array<{
|
||||
role: string;
|
||||
content: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
export interface Suggestion {
|
||||
suggestedActions: DocAction[];
|
||||
/**
|
||||
* Currently, requests for assistance always happen in the context
|
||||
* of the column of a particular table.
|
||||
*/
|
||||
export interface FormulaAssistanceContext {
|
||||
type: 'formula';
|
||||
tableId: string;
|
||||
colId: string;
|
||||
}
|
||||
|
||||
export type AssistanceContext = FormulaAssistanceContext;
|
||||
|
||||
/**
|
||||
* A request for assistance.
|
||||
*/
|
||||
export interface AssistanceRequest {
|
||||
context: AssistanceContext;
|
||||
state?: AssistanceState;
|
||||
text: string;
|
||||
regenerate?: boolean; // Set if there was a previous request
|
||||
// and response that should be omitted
|
||||
// from history, or (if available) an
|
||||
// alternative response generated.
|
||||
}
|
||||
|
||||
/**
|
||||
* A response to a request for assistance.
|
||||
* The client should preserve the state and include it in
|
||||
* any follow-up requests.
|
||||
*/
|
||||
export interface AssistanceResponse {
|
||||
suggestedActions: DocAction[];
|
||||
state?: AssistanceState;
|
||||
// If the model can be trusted to issue a self-contained
|
||||
// markdown-friendly string, it can be included here.
|
||||
reply?: string;
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ import {
|
||||
} from 'app/common/ActionBundle';
|
||||
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
|
||||
import {ActionSummary} from "app/common/ActionSummary";
|
||||
import {Prompt, Suggestion} from "app/common/AssistancePrompts";
|
||||
import {AssistanceRequest, AssistanceResponse} from "app/common/AssistancePrompts";
|
||||
import {
|
||||
AclResources,
|
||||
AclTableDescription,
|
||||
@ -85,7 +85,7 @@ import {Document} from 'app/gen-server/entity/Document';
|
||||
import {ParseOptions} from 'app/plugin/FileParserAPI';
|
||||
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
|
||||
import {compileAclFormula} from 'app/server/lib/ACLFormula';
|
||||
import {sendForCompletion} from 'app/server/lib/Assistance';
|
||||
import {AssistanceDoc, AssistanceSchemaPromptV1Context, sendForCompletion} from 'app/server/lib/Assistance';
|
||||
import {Authorizer} from 'app/server/lib/Authorizer';
|
||||
import {checksumFile} from 'app/server/lib/checksumFile';
|
||||
import {Client} from 'app/server/lib/Client';
|
||||
@ -180,7 +180,7 @@ interface UpdateUsageOptions {
|
||||
* either .loadDoc() or .createEmptyDoc() is called.
|
||||
* @param {String} docName - The document's filename, without the '.grist' extension.
|
||||
*/
|
||||
export class ActiveDoc extends EventEmitter {
|
||||
export class ActiveDoc extends EventEmitter implements AssistanceDoc {
|
||||
/**
|
||||
* Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e.
|
||||
* until the returned promise is resolved.
|
||||
@ -1112,7 +1112,7 @@ export class ActiveDoc extends EventEmitter {
|
||||
* @param {Integer} rowId - Row number
|
||||
* @returns {Promise} Promise for a error message
|
||||
*/
|
||||
public async getFormulaError(docSession: DocSession, tableId: string, colId: string,
|
||||
public async getFormulaError(docSession: OptDocSession, tableId: string, colId: string,
|
||||
rowId: number): Promise<CellValue> {
|
||||
// Throw an error if the user doesn't have access to read this cell.
|
||||
await this._granularAccess.getCellValue(docSession, {tableId, colId, rowId});
|
||||
@ -1260,22 +1260,28 @@ export class ActiveDoc extends EventEmitter {
|
||||
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
|
||||
}
|
||||
|
||||
public async getAssistance(docSession: DocSession, userPrompt: Prompt): Promise<Suggestion> {
|
||||
// Making a prompt can leak names of tables and columns.
|
||||
public async getAssistance(docSession: DocSession, request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||
return this.getAssistanceWithOptions(docSession, request);
|
||||
}
|
||||
|
||||
public async getAssistanceWithOptions(docSession: DocSession,
|
||||
request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||
// Making a prompt leaks names of tables and columns etc.
|
||||
if (!await this._granularAccess.canScanData(docSession)) {
|
||||
throw new Error("Permission denied");
|
||||
}
|
||||
await this.waitForInitialization();
|
||||
const { tableId, colId, description } = userPrompt;
|
||||
const prompt = await this._pyCall('get_formula_prompt', tableId, colId, description);
|
||||
this._log.debug(docSession, 'getAssistance prompt', {prompt});
|
||||
const completion = await sendForCompletion(prompt);
|
||||
this._log.debug(docSession, 'getAssistance completion', {completion});
|
||||
const formula = await this._pyCall('convert_formula_completion', completion);
|
||||
const action: DocAction = ["ModifyColumn", tableId, colId, {formula}];
|
||||
return {
|
||||
suggestedActions: [action],
|
||||
};
|
||||
return sendForCompletion(this, request);
|
||||
}
|
||||
|
||||
// Callback to make a data-engine formula tweak for assistance.
|
||||
public assistanceFormulaTweak(txt: string) {
|
||||
return this._pyCall('convert_formula_completion', txt);
|
||||
}
|
||||
|
||||
// Callback to generate a prompt containing schema info for assistance.
|
||||
public assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string> {
|
||||
return this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
|
||||
}
|
||||
|
||||
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
|
||||
|
@ -2,76 +2,187 @@
|
||||
* Module with functions used for AI formula assistance.
|
||||
*/
|
||||
|
||||
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
|
||||
import {delay} from 'app/common/delay';
|
||||
import {DocAction} from 'app/common/DocActions';
|
||||
import log from 'app/server/lib/log';
|
||||
import fetch from 'node-fetch';
|
||||
|
||||
export const DEPS = { fetch };
|
||||
|
||||
export async function sendForCompletion(prompt: string): Promise<string> {
|
||||
let completion: string|null = null;
|
||||
let retries: number = 0;
|
||||
const openApiKey = process.env.OPENAI_API_KEY;
|
||||
const model = process.env.COMPLETION_MODEL || "text-davinci-002";
|
||||
|
||||
while(retries++ < 3) {
|
||||
try {
|
||||
if (openApiKey) {
|
||||
completion = await sendForCompletionOpenAI(prompt, openApiKey, model);
|
||||
}
|
||||
if (process.env.HUGGINGFACE_API_KEY) {
|
||||
completion = await sendForCompletionHuggingFace(prompt);
|
||||
}
|
||||
break;
|
||||
} catch(e) {
|
||||
await delay(1000);
|
||||
}
|
||||
}
|
||||
if (completion === null) {
|
||||
throw new Error("Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY (and optionally COMPLETION_MODEL)");
|
||||
}
|
||||
log.debug(`Received completion:`, {completion});
|
||||
completion = completion.split(/\n {4}[^ ]/)[0];
|
||||
return completion;
|
||||
/**
|
||||
* An assistant can help a user do things with their document,
|
||||
* by interfacing with an external LLM endpoint.
|
||||
*/
|
||||
export interface Assistant {
|
||||
apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Document-related methods for use in the implementation of assistants.
|
||||
* Somewhat ad-hoc currently.
|
||||
*/
|
||||
export interface AssistanceDoc {
|
||||
/**
|
||||
* Generate a particular prompt coded in the data engine for some reason.
|
||||
* It makes python code for some tables, and starts a function body with
|
||||
* the given docstring.
|
||||
* Marked "V1" to suggest that it is a particular prompt and it would
|
||||
* be great to try variants.
|
||||
*/
|
||||
assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string>;
|
||||
|
||||
async function sendForCompletionOpenAI(prompt: string, apiKey: string, model = "text-davinci-002") {
|
||||
/**
|
||||
* Some tweaks to a formula after it has been generated.
|
||||
*/
|
||||
assistanceFormulaTweak(txt: string): Promise<string>;
|
||||
}
|
||||
|
||||
export interface AssistanceSchemaPromptV1Context {
|
||||
tableId: string,
|
||||
colId: string,
|
||||
docString: string,
|
||||
}
|
||||
|
||||
/**
|
||||
* A flavor of assistant for use with the OpenAI API.
|
||||
* Tested primarily with text-davinci-002 and gpt-3.5-turbo.
|
||||
*/
|
||||
export class OpenAIAssistant implements Assistant {
|
||||
private _apiKey: string;
|
||||
private _model: string;
|
||||
private _chatMode: boolean;
|
||||
private _endpoint: string;
|
||||
|
||||
public constructor() {
|
||||
const apiKey = process.env.OPENAI_API_KEY;
|
||||
if (!apiKey) {
|
||||
throw new Error("OPENAI_API_KEY not set");
|
||||
throw new Error('OPENAI_API_KEY not set');
|
||||
}
|
||||
const response = await DEPS.fetch(
|
||||
"https://api.openai.com/v1/completions",
|
||||
this._apiKey = apiKey;
|
||||
this._model = process.env.COMPLETION_MODEL || "text-davinci-002";
|
||||
this._chatMode = this._model.includes('turbo');
|
||||
this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`;
|
||||
}
|
||||
|
||||
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||
const messages = request.state?.messages || [];
|
||||
const chatMode = this._chatMode;
|
||||
if (chatMode) {
|
||||
if (messages.length === 0) {
|
||||
messages.push({
|
||||
role: 'system',
|
||||
content: 'The user gives you one or more Python classes, ' +
|
||||
'with one last method that needs completing. Write the ' +
|
||||
'method body as a single code block, ' +
|
||||
'including the docstring the user gave. ' +
|
||||
'Just give the Python code as a markdown block, ' +
|
||||
'do not give any introduction, that will just be ' +
|
||||
'awkward for the user when copying and pasting. ' +
|
||||
'You are working with Grist, an environment very like ' +
|
||||
'regular Python except `rec` (like record) is used ' +
|
||||
'instead of `self`. ' +
|
||||
'Include at least one `return` statement or the method ' +
|
||||
'will fail, disappointing the user. ' +
|
||||
'Your answer should be the body of a single method, ' +
|
||||
'not a class, and should not include `dataclass` or ' +
|
||||
'`class` since the user is counting on you to provide ' +
|
||||
'a single method. Thanks!'
|
||||
});
|
||||
messages.push({
|
||||
role: 'user', content: await makeSchemaPromptV1(doc, request),
|
||||
});
|
||||
} else {
|
||||
if (request.regenerate) {
|
||||
if (messages[messages.length - 1].role !== 'user') {
|
||||
messages.pop();
|
||||
}
|
||||
}
|
||||
messages.push({
|
||||
role: 'user', content: request.text,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
messages.length = 0;
|
||||
messages.push({
|
||||
role: 'user', content: await makeSchemaPromptV1(doc, request),
|
||||
});
|
||||
}
|
||||
|
||||
const apiResponse = await DEPS.fetch(
|
||||
this._endpoint,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${apiKey}`,
|
||||
"Authorization": `Bearer ${this._apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
prompt,
|
||||
max_tokens: 150,
|
||||
...(!this._chatMode ? {
|
||||
prompt: messages[messages.length - 1].content,
|
||||
} : { messages }),
|
||||
max_tokens: 1500,
|
||||
temperature: 0,
|
||||
// COMPLETION_MODEL of `code-davinci-002` may be better if you have access to it.
|
||||
model,
|
||||
stop: ["\n\n"],
|
||||
model: this._model,
|
||||
stop: this._chatMode ? undefined : ["\n\n"],
|
||||
}),
|
||||
},
|
||||
);
|
||||
if (response.status !== 200) {
|
||||
log.error(`OpenAI API returned ${response.status}: ${await response.text()}`);
|
||||
throw new Error(`OpenAI API returned status ${response.status}`);
|
||||
if (apiResponse.status !== 200) {
|
||||
log.error(`OpenAI API returned ${apiResponse.status}: ${await apiResponse.text()}`);
|
||||
throw new Error(`OpenAI API returned status ${apiResponse.status}`);
|
||||
}
|
||||
const result = await response.json();
|
||||
const completion = result.choices[0].text;
|
||||
return completion;
|
||||
const result = await apiResponse.json();
|
||||
let completion: string = String(chatMode ? result.choices[0].message.content : result.choices[0].text);
|
||||
const reply = completion;
|
||||
const history = { messages };
|
||||
if (chatMode) {
|
||||
history.messages.push(result.choices[0].message);
|
||||
// This model likes returning markdown. Code will typically
|
||||
// be in a code block with ``` delimiters.
|
||||
let lines = completion.split('\n');
|
||||
if (lines[0].startsWith('```')) {
|
||||
lines.shift();
|
||||
completion = lines.join('\n');
|
||||
const parts = completion.split('```');
|
||||
if (parts.length > 1) {
|
||||
completion = parts[0];
|
||||
}
|
||||
lines = completion.split('\n');
|
||||
}
|
||||
|
||||
async function sendForCompletionHuggingFace(prompt: string) {
|
||||
// This model likes repeating the function signature and
|
||||
// docstring, so we try to strip that out.
|
||||
completion = lines.join('\n');
|
||||
while (completion.includes('"""')) {
|
||||
const parts = completion.split('"""');
|
||||
completion = parts[parts.length - 1];
|
||||
}
|
||||
|
||||
// If there's no code block, don't treat the answer as a formula.
|
||||
if (!reply.includes('```')) {
|
||||
completion = '';
|
||||
}
|
||||
}
|
||||
|
||||
const response = await completionToResponse(doc, request, completion, reply);
|
||||
if (chatMode) {
|
||||
response.state = history;
|
||||
}
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
export class HuggingFaceAssistant implements Assistant {
|
||||
private _apiKey: string;
|
||||
private _completionUrl: string;
|
||||
|
||||
public constructor() {
|
||||
const apiKey = process.env.HUGGINGFACE_API_KEY;
|
||||
if (!apiKey) {
|
||||
throw new Error("HUGGINGFACE_API_KEY not set");
|
||||
throw new Error('HUGGINGFACE_API_KEY not set');
|
||||
}
|
||||
this._apiKey = apiKey;
|
||||
// COMPLETION_MODEL values I've tried:
|
||||
// - codeparrot/codeparrot
|
||||
// - NinedayWang/PolyCoder-2.7B
|
||||
@ -84,13 +195,21 @@ async function sendForCompletionHuggingFace(prompt: string) {
|
||||
completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
|
||||
}
|
||||
}
|
||||
this._completionUrl = completionUrl;
|
||||
|
||||
}
|
||||
|
||||
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||
if (request.state) {
|
||||
throw new Error("HuggingFaceAssistant does not support state");
|
||||
}
|
||||
const prompt = await makeSchemaPromptV1(doc, request);
|
||||
const response = await DEPS.fetch(
|
||||
completionUrl,
|
||||
this._completionUrl,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${apiKey}`,
|
||||
"Authorization": `Bearer ${this._apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
@ -112,6 +231,91 @@ async function sendForCompletionHuggingFace(prompt: string) {
|
||||
throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
|
||||
}
|
||||
const result = await response.json();
|
||||
const completion = result[0].generated_text;
|
||||
return completion.split('\n\n')[0];
|
||||
let completion = result[0].generated_text;
|
||||
completion = completion.split('\n\n')[0];
|
||||
return completionToResponse(doc, request, completion);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Instantiate an assistant, based on environment variables.
|
||||
*/
|
||||
function getAssistant() {
|
||||
if (process.env.OPENAI_API_KEY) {
|
||||
return new OpenAIAssistant();
|
||||
}
|
||||
if (process.env.HUGGINGFACE_API_KEY) {
|
||||
return new HuggingFaceAssistant();
|
||||
}
|
||||
throw new Error('Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY');
|
||||
}
|
||||
|
||||
/**
|
||||
* Service a request for assistance, with a little retry logic
|
||||
* since these endpoints can be a bit flakey.
|
||||
*/
|
||||
export async function sendForCompletion(doc: AssistanceDoc,
|
||||
request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||
const assistant = getAssistant();
|
||||
|
||||
let retries: number = 0;
|
||||
|
||||
let response: AssistanceResponse|null = null;
|
||||
while(retries++ < 3) {
|
||||
try {
|
||||
response = await assistant.apply(doc, request);
|
||||
break;
|
||||
} catch(e) {
|
||||
log.error(`Completion error: ${e}`);
|
||||
await delay(1000);
|
||||
}
|
||||
}
|
||||
if (!response) {
|
||||
throw new Error('Failed to get response from assistant');
|
||||
}
|
||||
return response;
|
||||
}
|
||||
|
||||
async function makeSchemaPromptV1(doc: AssistanceDoc, request: AssistanceRequest) {
|
||||
if (request.context.type !== 'formula') {
|
||||
throw new Error('makeSchemaPromptV1 only works for formulas');
|
||||
}
|
||||
return doc.assistanceSchemaPromptV1({
|
||||
tableId: request.context.tableId,
|
||||
colId: request.context.colId,
|
||||
docString: request.text,
|
||||
});
|
||||
}
|
||||
|
||||
async function completionToResponse(doc: AssistanceDoc, request: AssistanceRequest,
|
||||
completion: string, reply?: string): Promise<AssistanceResponse> {
|
||||
if (request.context.type !== 'formula') {
|
||||
throw new Error('completionToResponse only works for formulas');
|
||||
}
|
||||
completion = await doc.assistanceFormulaTweak(completion);
|
||||
// A leading newline is common.
|
||||
if (completion.charAt(0) === '\n') {
|
||||
completion = completion.slice(1);
|
||||
}
|
||||
// If all non-empty lines have four spaces, remove those spaces.
|
||||
// They are common for GPT-3.5, which matches the prompt carefully.
|
||||
const lines = completion.split('\n');
|
||||
const ok = lines.every(line => line === '\n' || line.startsWith(' '));
|
||||
if (ok) {
|
||||
completion = lines.map(line => line === '\n' ? line : line.slice(4)).join('\n');
|
||||
}
|
||||
|
||||
// Suggest an action only if the completion is non-empty (that is,
|
||||
// it actually looked like code).
|
||||
const suggestedActions: DocAction[] = completion ? [[
|
||||
"ModifyColumn",
|
||||
request.context.tableId,
|
||||
request.context.colId, {
|
||||
formula: completion,
|
||||
}
|
||||
]] : [];
|
||||
return {
|
||||
suggestedActions,
|
||||
reply,
|
||||
};
|
||||
}
|
||||
|
@ -376,7 +376,8 @@ export class GranularAccess implements GranularAccessForBundle {
|
||||
function fail(): never {
|
||||
throw new ErrorWithCode('ACL_DENY', 'Cannot access cell');
|
||||
}
|
||||
if (!await this.hasTableAccess(docSession, cell.tableId)) { fail(); }
|
||||
const hasExceptionalAccess = this._hasExceptionalFullAccess(docSession);
|
||||
if (!hasExceptionalAccess && !await this.hasTableAccess(docSession, cell.tableId)) { fail(); }
|
||||
let rows: TableDataAction|null = null;
|
||||
if (docData) {
|
||||
const record = docData.getTable(cell.tableId)?.getRecord(cell.rowId);
|
||||
@ -393,6 +394,7 @@ export class GranularAccess implements GranularAccessForBundle {
|
||||
return fail();
|
||||
}
|
||||
const rec = new RecordView(rows, 0);
|
||||
if (!hasExceptionalAccess) {
|
||||
const input: AclMatchInput = {...await this.inputs(docSession), rec, newRec: rec};
|
||||
const rowPermInfo = new PermissionInfo(this._ruler.ruleCollection, input);
|
||||
const rowAccess = rowPermInfo.getTableAccess(cell.tableId).perms.read;
|
||||
@ -403,6 +405,7 @@ export class GranularAccess implements GranularAccessForBundle {
|
||||
}
|
||||
const colValues = rows[3];
|
||||
if (!(cell.colId in colValues)) { fail(); }
|
||||
}
|
||||
return rec.get(cell.colId);
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,6 @@ def column_type(engine, table_id, col_id):
|
||||
Attachments="Any",
|
||||
)[parts[0]]
|
||||
|
||||
|
||||
def choices(col_rec):
|
||||
try:
|
||||
widget_options = json.loads(col_rec.widgetOptions)
|
||||
|
@ -1,12 +1,17 @@
|
||||
#!/usr/bin/env node
|
||||
"use strict";
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const codeRoot = path.dirname(path.dirname(path.dirname(__dirname)));
|
||||
|
||||
let codeRoot = path.dirname(path.dirname(__dirname));
|
||||
if (!fs.existsSync(path.join(codeRoot, '_build'))) {
|
||||
codeRoot = path.dirname(codeRoot);
|
||||
}
|
||||
|
||||
process.env.DATA_PATH = path.join(__dirname, 'data');
|
||||
|
||||
|
||||
require('app-module-path').addPath(path.join(codeRoot, '_build'));
|
||||
require('app-module-path').addPath(path.join(codeRoot, '_build', 'core'));
|
||||
require('app-module-path').addPath(path.join(codeRoot, '_build', 'ext'));
|
||||
require('app-module-path').addPath(path.join(codeRoot, '_build', 'stubs'));
|
||||
require('test/formula-dataset/runCompletion_impl').runCompletion().catch(console.error);
|
||||
|
@ -24,7 +24,7 @@
|
||||
*/
|
||||
|
||||
|
||||
import { ActiveDoc } from "app/server/lib/ActiveDoc";
|
||||
import { ActiveDoc, Deps as ActiveDocDeps } from "app/server/lib/ActiveDoc";
|
||||
import { DEPS } from "app/server/lib/Assistance";
|
||||
import log from 'app/server/lib/log';
|
||||
import crypto from 'crypto';
|
||||
@ -38,11 +38,14 @@ import * as os from 'os';
|
||||
import { pipeline } from 'stream';
|
||||
import { createDocTools } from "test/server/docTools";
|
||||
import { promisify } from 'util';
|
||||
import { AssistanceResponse, AssistanceState } from "app/common/AssistancePrompts";
|
||||
import { CellValue } from "app/plugin/GristData";
|
||||
|
||||
const streamPipeline = promisify(pipeline);
|
||||
|
||||
const DATA_PATH = process.env.DATA_PATH || path.join(__dirname, 'data');
|
||||
const PATH_TO_DOC = path.join(DATA_PATH, 'templates');
|
||||
const PATH_TO_RESULTS = path.join(DATA_PATH, 'results');
|
||||
const PATH_TO_CSV = path.join(DATA_PATH, 'formula-dataset-index.csv');
|
||||
const PATH_TO_CACHE = path.join(DATA_PATH, 'cache');
|
||||
const TEMPLATE_URL = "https://grist-static.com/datasets/grist_dataset_formulai_2023_02_20.zip";
|
||||
@ -60,8 +63,11 @@ const _stats = {
|
||||
callCount: 0,
|
||||
};
|
||||
|
||||
const SEEMS_CHATTY = (process.env.COMPLETION_MODEL || '').includes('turbo');
|
||||
const SIMULATE_CONVERSATION = SEEMS_CHATTY;
|
||||
|
||||
export async function runCompletion() {
|
||||
ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600000;
|
||||
|
||||
// if template directory not exists, make it
|
||||
if (!fs.existsSync(path.join(PATH_TO_DOC))) {
|
||||
@ -107,11 +113,12 @@ export async function runCompletion() {
|
||||
if (!process.env.VERBOSE) {
|
||||
log.transports.file.level = 'error'; // Suppress most of log output.
|
||||
}
|
||||
let activeDoc: ActiveDoc|undefined;
|
||||
const docTools = createDocTools();
|
||||
const session = docTools.createFakeSession('owners');
|
||||
await docTools.before();
|
||||
let successCount = 0;
|
||||
let caseCount = 0;
|
||||
fs.mkdirSync(path.join(PATH_TO_RESULTS), {recursive: true});
|
||||
|
||||
console.log('Testing AI assistance: ');
|
||||
|
||||
@ -119,8 +126,18 @@ export async function runCompletion() {
|
||||
|
||||
DEPS.fetch = fetchWithCache;
|
||||
|
||||
let activeDoc: ActiveDoc|undefined;
|
||||
for (const rec of records) {
|
||||
let success: boolean = false;
|
||||
let suggestedActions: AssistanceResponse['suggestedActions'] | undefined;
|
||||
let newValues: CellValue[] | undefined;
|
||||
let expected: CellValue[] | undefined;
|
||||
let formula: string | undefined;
|
||||
let history: AssistanceState = {messages: []};
|
||||
let lastFollowUp: string | undefined;
|
||||
|
||||
try {
|
||||
async function sendMessage(followUp?: string) {
|
||||
// load new document
|
||||
if (!activeDoc || activeDoc.docName !== rec.doc_id) {
|
||||
const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist');
|
||||
@ -128,28 +145,89 @@ export async function runCompletion() {
|
||||
await activeDoc.waitForInitialization();
|
||||
}
|
||||
|
||||
if (!activeDoc) { throw new Error("No doc"); }
|
||||
|
||||
// get values
|
||||
await activeDoc.docData!.fetchTable(rec.table_id);
|
||||
const expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
|
||||
expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
|
||||
|
||||
// send prompt
|
||||
const tableId = rec.table_id;
|
||||
const colId = rec.col_id;
|
||||
const description = rec.Description;
|
||||
const {suggestedActions} = await activeDoc.getAssistance(session, {tableId, colId, description});
|
||||
const colInfo = await activeDoc.docStorage.get(`
|
||||
select * from _grist_Tables_column as c
|
||||
left join _grist_Tables as t on t.id = c.parentId
|
||||
where c.colId = ? and t.tableId = ?
|
||||
`, rec.col_id, rec.table_id);
|
||||
formula = colInfo?.formula;
|
||||
|
||||
const result = await activeDoc.getAssistanceWithOptions(session, {
|
||||
context: {type: 'formula', tableId, colId},
|
||||
state: history,
|
||||
text: followUp || description,
|
||||
});
|
||||
if (result.state) {
|
||||
history = result.state;
|
||||
}
|
||||
suggestedActions = result.suggestedActions;
|
||||
// apply modification
|
||||
const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions);
|
||||
|
||||
// get new values
|
||||
const newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
|
||||
newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
|
||||
|
||||
// revert modification
|
||||
const [bundle] = await activeDoc.getActions([actionNum]);
|
||||
await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true);
|
||||
|
||||
// compare values
|
||||
const success = isEqual(expected, newValues);
|
||||
success = isEqual(expected, newValues);
|
||||
|
||||
if (!success) {
|
||||
const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds();
|
||||
const result = await activeDoc.getFormulaError({client: null, mode: 'system'} as any, rec.table_id,
|
||||
rec.col_id, rowIds[0]);
|
||||
if (Array.isArray(result) && result[0] === 'E') {
|
||||
result.shift();
|
||||
const txt = `I got a \`${result.shift()}\` error:\n` +
|
||||
'```\n' +
|
||||
result.shift() + '\n' +
|
||||
result.shift() + '\n' +
|
||||
'```\n' +
|
||||
'Please answer with the code block you (the assistant) just gave, ' +
|
||||
'revised based on this error. Your answer must include a code block. ' +
|
||||
'If you have to explain anything, do it after. ' +
|
||||
'It is perfectly acceptable (and may be necessary) to do ' +
|
||||
'imports from within a method body.\n';
|
||||
return { followUp: txt };
|
||||
} else {
|
||||
for (let i = 0; i < expected.length; i++) {
|
||||
const e = expected[i];
|
||||
const v = newValues[i];
|
||||
if (String(e) !== String(v)) {
|
||||
const txt = `I got \`${v}\` where I expected \`${e}\`\n` +
|
||||
'Please answer with the code block you (the assistant) just gave, ' +
|
||||
'revised based on this information. Your answer must include a code ' +
|
||||
'block. If you have to explain anything, do it after.\n';
|
||||
return { followUp: txt };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
const result = await sendMessage();
|
||||
if (result?.followUp && SIMULATE_CONVERSATION) {
|
||||
// Allow one follow up message, based on error or differences.
|
||||
const result2 = await sendMessage(result.followUp);
|
||||
if (result2?.followUp) {
|
||||
lastFollowUp = result2.followUp;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
|
||||
console.log(` ${success ? 'Successfully' : 'Failed to'} complete formula ` +
|
||||
`for column ${rec.table_id}.${rec.col_id} (doc=${rec.doc_id})`);
|
||||
@ -162,6 +240,23 @@ export async function runCompletion() {
|
||||
// console.log('expected=', expected);
|
||||
// console.log('actual=', newValues);
|
||||
}
|
||||
const suggestedFormula = suggestedActions?.length === 1 &&
|
||||
suggestedActions[0][0] === 'ModifyColumn' &&
|
||||
suggestedActions[0][3].formula || suggestedActions;
|
||||
fs.writeFileSync(
|
||||
path.join(
|
||||
PATH_TO_RESULTS,
|
||||
`${rec.table_id}_${rec.col_id}_` +
|
||||
caseCount.toLocaleString('en', {minimumIntegerDigits: 8, useGrouping: false}) + '.json'),
|
||||
JSON.stringify({
|
||||
formula,
|
||||
suggestedFormula, success,
|
||||
expectedValues: expected,
|
||||
suggestedValues: newValues,
|
||||
history,
|
||||
lastFollowUp,
|
||||
}, null, 2));
|
||||
caseCount++;
|
||||
}
|
||||
} finally {
|
||||
await docTools.after();
|
||||
@ -171,6 +266,13 @@ export async function runCompletion() {
|
||||
console.log(
|
||||
`AI Assistance completed ${successCount} successful prompt on a total of ${records.length};`
|
||||
);
|
||||
console.log(JSON.stringify(
|
||||
{
|
||||
hit: successCount,
|
||||
total: records.length,
|
||||
percentage: (100.0 * successCount) / Math.max(records.length, 1),
|
||||
}
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user