mirror of
https://github.com/gristlabs/grist-core.git
synced 2024-10-27 20:44:07 +00:00
add support for conversational state to assistance endpoint (#506)
* add support for conversational state to assistance endpoint This refactors the assistance code somewhat, to allow carrying along some conversational state. It extends the OpenAI-flavored assistant to make use of that state to have a conversation. The front-end is tweaked a little bit to allow for replies that don't have any code in them (though I didn't get into formatting such replies nicely). Currently tested primarily through the runCompletion script, which has been extended a bit to allow testing simulated conversations (where an error is pasted in follow-up, or an expected-vs-actual comparison). Co-authored-by: George Gevoian <85144792+georgegevoian@users.noreply.github.com>
This commit is contained in:
parent
68fbeb4d7b
commit
51a195bd94
@ -4,6 +4,7 @@ import {CellRec, DocModel, IRowModel, recordSet,
|
|||||||
refRecord, TableRec, ViewFieldRec} from 'app/client/models/DocModel';
|
refRecord, TableRec, ViewFieldRec} from 'app/client/models/DocModel';
|
||||||
import {urlState} from 'app/client/models/gristUrlState';
|
import {urlState} from 'app/client/models/gristUrlState';
|
||||||
import {jsonObservable, ObjObservable} from 'app/client/models/modelUtil';
|
import {jsonObservable, ObjObservable} from 'app/client/models/modelUtil';
|
||||||
|
import {AssistanceState} from 'app/common/AssistancePrompts';
|
||||||
import * as gristTypes from 'app/common/gristTypes';
|
import * as gristTypes from 'app/common/gristTypes';
|
||||||
import {getReferencedTableId} from 'app/common/gristTypes';
|
import {getReferencedTableId} from 'app/common/gristTypes';
|
||||||
import {
|
import {
|
||||||
@ -83,7 +84,7 @@ export interface ColumnRec extends IRowModel<"_grist_Tables_column"> {
|
|||||||
/**
|
/**
|
||||||
* Current history of chat. This is a temporary array used only in the ui.
|
* Current history of chat. This is a temporary array used only in the ui.
|
||||||
*/
|
*/
|
||||||
chatHistory: ko.PureComputed<Observable<ChatMessage[]>>;
|
chatHistory: ko.PureComputed<Observable<ChatHistory>>;
|
||||||
|
|
||||||
// Helper which adds/removes/updates column's displayCol to match the formula.
|
// Helper which adds/removes/updates column's displayCol to match the formula.
|
||||||
saveDisplayFormula(formula: string): Promise<void>|undefined;
|
saveDisplayFormula(formula: string): Promise<void>|undefined;
|
||||||
@ -162,8 +163,9 @@ export function createColumnRec(this: ColumnRec, docModel: DocModel): void {
|
|||||||
|
|
||||||
this.chatHistory = this.autoDispose(ko.computed(() => {
|
this.chatHistory = this.autoDispose(ko.computed(() => {
|
||||||
const docId = urlState().state.get().doc ?? '';
|
const docId = urlState().state.get().doc ?? '';
|
||||||
const key = `formula-assistant-history-${docId}-${this.table().tableId()}-${this.colId()}`;
|
// Changed key name from history to history-v2 when ChatHistory changed in incompatible way.
|
||||||
return localStorageJsonObs(key, [] as ChatMessage[]);
|
const key = `formula-assistant-history-v2-${docId}-${this.table().tableId()}-${this.colId()}`;
|
||||||
|
return localStorageJsonObs(key, {messages: []} as ChatHistory);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -196,8 +198,20 @@ export interface ChatMessage {
|
|||||||
*/
|
*/
|
||||||
sender: 'user' | 'ai';
|
sender: 'user' | 'ai';
|
||||||
/**
|
/**
|
||||||
* The formula returned from the AI. It is only set when the sender is the AI. For now it is the same
|
* The formula returned from the AI. It is only set when the sender is the AI.
|
||||||
* value as the message, but it might change in the future when we use more conversational AI.
|
|
||||||
*/
|
*/
|
||||||
formula?: string;
|
formula?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The state of assistance for a particular column.
|
||||||
|
* ChatMessages are what are shown in the UI, whereas state is
|
||||||
|
* how the back-end represents the conversation. The two are
|
||||||
|
* similar but not the same because of post-processing.
|
||||||
|
* It may be possible to reconcile them when things settle down
|
||||||
|
* a bit?
|
||||||
|
*/
|
||||||
|
export interface ChatHistory {
|
||||||
|
messages: ChatMessage[];
|
||||||
|
state?: AssistanceState;
|
||||||
|
}
|
||||||
|
@ -8,7 +8,7 @@ import {basicButton, primaryButton, textButton} from 'app/client/ui2018/buttons'
|
|||||||
import {theme} from 'app/client/ui2018/cssVars';
|
import {theme} from 'app/client/ui2018/cssVars';
|
||||||
import {cssTextInput, rawTextInput} from 'app/client/ui2018/editableLabel';
|
import {cssTextInput, rawTextInput} from 'app/client/ui2018/editableLabel';
|
||||||
import {icon} from 'app/client/ui2018/icons';
|
import {icon} from 'app/client/ui2018/icons';
|
||||||
import {Suggestion} from 'app/common/AssistancePrompts';
|
import {AssistanceResponse, AssistanceState} from 'app/common/AssistancePrompts';
|
||||||
import {Disposable, dom, makeTestId, MultiHolder, obsArray, Observable, styled} from 'grainjs';
|
import {Disposable, dom, makeTestId, MultiHolder, obsArray, Observable, styled} from 'grainjs';
|
||||||
import noop from 'lodash/noop';
|
import noop from 'lodash/noop';
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ function buildControls(
|
|||||||
}
|
}
|
||||||
) {
|
) {
|
||||||
|
|
||||||
const hasHistory = props.column.chatHistory.peek().get().length > 0;
|
const hasHistory = props.column.chatHistory.peek().get().messages.length > 0;
|
||||||
|
|
||||||
// State variables, to show various parts of the UI.
|
// State variables, to show various parts of the UI.
|
||||||
const saveButtonVisible = Observable.create(owner, true);
|
const saveButtonVisible = Observable.create(owner, true);
|
||||||
@ -153,36 +153,46 @@ function buildControls(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function buildChat(owner: Disposable, context: Context & { formulaClicked: (formula: string) => void }) {
|
function buildChat(owner: Disposable, context: Context & { formulaClicked: (formula?: string) => void }) {
|
||||||
const { grist, column } = context;
|
const { grist, column } = context;
|
||||||
|
|
||||||
const history = owner.autoDispose(obsArray(column.chatHistory.peek().get()));
|
const history = owner.autoDispose(obsArray(column.chatHistory.peek().get().messages));
|
||||||
const hasHistory = history.get().length > 0;
|
const hasHistory = history.get().length > 0;
|
||||||
const enabled = Observable.create(owner, hasHistory);
|
const enabled = Observable.create(owner, hasHistory);
|
||||||
const introVisible = Observable.create(owner, !hasHistory);
|
const introVisible = Observable.create(owner, !hasHistory);
|
||||||
owner.autoDispose(history.addListener((cur) => {
|
owner.autoDispose(history.addListener((cur) => {
|
||||||
column.chatHistory.peek().set([...cur]);
|
const chatHistory = column.chatHistory.peek();
|
||||||
|
chatHistory.set({...chatHistory.get(), messages: [...cur]});
|
||||||
}));
|
}));
|
||||||
|
|
||||||
const submit = async () => {
|
const submit = async (regenerate: boolean = false) => {
|
||||||
// Ask about suggestion, and send the whole history. Currently the chat is implemented by just sending
|
// Send most recent question, and send back any conversation
|
||||||
// all previous user prompts back to the AI. This is subject to change (and probably should be done in the backend).
|
// state we have been asked to track.
|
||||||
const prompt = history.get().filter(x => x.sender === 'user')
|
const chatHistory = column.chatHistory.peek().get();
|
||||||
.map(entry => entry.message)
|
const messages = chatHistory.messages.filter(msg => msg.sender === 'user');
|
||||||
.filter(Boolean)
|
const description = messages[messages.length - 1]?.message || '';
|
||||||
.join("\n");
|
console.debug('description', {description});
|
||||||
console.debug('prompt', prompt);
|
const {reply, suggestedActions, state} = await askAI(grist, {
|
||||||
const { suggestedActions } = await askAI(grist, column, prompt);
|
column, description, state: chatHistory.state,
|
||||||
console.debug('suggestedActions', suggestedActions);
|
regenerate,
|
||||||
|
});
|
||||||
|
console.debug('suggestedActions', {suggestedActions, reply});
|
||||||
const firstAction = suggestedActions[0] as any;
|
const firstAction = suggestedActions[0] as any;
|
||||||
// Add the formula to the history.
|
// Add the formula to the history.
|
||||||
const formula = firstAction[3].formula as string;
|
const formula = firstAction ? firstAction[3].formula as string : undefined;
|
||||||
// Add to history
|
// Add to history
|
||||||
history.push({
|
history.push({
|
||||||
message: formula,
|
message: formula || reply || '(no reply)',
|
||||||
sender: 'ai',
|
sender: 'ai',
|
||||||
formula
|
formula
|
||||||
});
|
});
|
||||||
|
// If back-end is capable of conversation, keep its state.
|
||||||
|
if (state) {
|
||||||
|
const chatHistoryNew = column.chatHistory.peek();
|
||||||
|
const value = chatHistoryNew.get();
|
||||||
|
value.state = state;
|
||||||
|
chatHistoryNew.set(value);
|
||||||
|
}
|
||||||
return formula;
|
return formula;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -203,12 +213,13 @@ function buildChat(owner: Disposable, context: Context & { formulaClicked: (form
|
|||||||
// Remove the last AI response from the history.
|
// Remove the last AI response from the history.
|
||||||
history.pop();
|
history.pop();
|
||||||
// And submit again.
|
// And submit again.
|
||||||
context.formulaClicked(await submit());
|
context.formulaClicked(await submit(true));
|
||||||
};
|
};
|
||||||
|
|
||||||
const newChat = () => {
|
const newChat = () => {
|
||||||
// Clear the history.
|
// Clear the history.
|
||||||
history.set([]);
|
history.set([]);
|
||||||
|
column.chatHistory.peek().set({messages: []});
|
||||||
// Show intro.
|
// Show intro.
|
||||||
introVisible.set(true);
|
introVisible.set(true);
|
||||||
};
|
};
|
||||||
@ -371,9 +382,11 @@ function openAIAssistant(grist: GristDoc, column: ColumnRec) {
|
|||||||
const chat = buildChat(owner, {...props,
|
const chat = buildChat(owner, {...props,
|
||||||
// When a formula is clicked (or just was returned from the AI), we set it in the formula editor and hit
|
// When a formula is clicked (or just was returned from the AI), we set it in the formula editor and hit
|
||||||
// the preview button.
|
// the preview button.
|
||||||
formulaClicked: (formula: string) => {
|
formulaClicked: (formula?: string) => {
|
||||||
formulaEditor.set(formula);
|
if (formula) {
|
||||||
controls.preview().catch(reportError);
|
formulaEditor.set(formula);
|
||||||
|
controls.preview().catch(reportError);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -397,11 +410,22 @@ function openAIAssistant(grist: GristDoc, column: ColumnRec) {
|
|||||||
grist.formulaPopup.autoDispose(popup);
|
grist.formulaPopup.autoDispose(popup);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function askAI(grist: GristDoc, column: ColumnRec, description: string): Promise<Suggestion> {
|
async function askAI(grist: GristDoc, options: {
|
||||||
|
column: ColumnRec,
|
||||||
|
description: string,
|
||||||
|
regenerate?: boolean,
|
||||||
|
state?: AssistanceState
|
||||||
|
}): Promise<AssistanceResponse> {
|
||||||
|
const {column, description, state, regenerate} = options;
|
||||||
const tableId = column.table.peek().tableId.peek();
|
const tableId = column.table.peek().tableId.peek();
|
||||||
const colId = column.colId.peek();
|
const colId = column.colId.peek();
|
||||||
try {
|
try {
|
||||||
const result = await grist.docComm.getAssistance({tableId, colId, description});
|
const result = await grist.docComm.getAssistance({
|
||||||
|
context: {type: 'formula', tableId, colId},
|
||||||
|
text: description,
|
||||||
|
state,
|
||||||
|
regenerate,
|
||||||
|
});
|
||||||
return result;
|
return result;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
reportError(error);
|
reportError(error);
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import {ActionGroup} from 'app/common/ActionGroup';
|
import {ActionGroup} from 'app/common/ActionGroup';
|
||||||
import {Prompt, Suggestion} from 'app/common/AssistancePrompts';
|
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
|
||||||
import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions';
|
import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions';
|
||||||
import {FormulaProperties} from 'app/common/GranularAccessClause';
|
import {FormulaProperties} from 'app/common/GranularAccessClause';
|
||||||
import {UIRowId} from 'app/common/UIRowId';
|
import {UIRowId} from 'app/common/UIRowId';
|
||||||
@ -323,7 +323,7 @@ export interface ActiveDocAPI {
|
|||||||
/**
|
/**
|
||||||
* Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
|
* Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
|
||||||
*/
|
*/
|
||||||
getAssistance(userPrompt: Prompt): Promise<Suggestion>;
|
getAssistance(request: AssistanceRequest): Promise<AssistanceResponse>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fetch content at a url.
|
* Fetch content at a url.
|
||||||
|
@ -1,11 +1,56 @@
|
|||||||
import {DocAction} from 'app/common/DocActions';
|
import {DocAction} from 'app/common/DocActions';
|
||||||
|
|
||||||
export interface Prompt {
|
/**
|
||||||
tableId: string;
|
* State related to a request for assistance.
|
||||||
colId: string
|
*
|
||||||
description: string;
|
* If an AssistanceResponse contains state, that state can be
|
||||||
|
* echoed back in an AssistanceRequest to continue a "conversation."
|
||||||
|
*
|
||||||
|
* Ideally, the state should not be modified or relied upon
|
||||||
|
* by the client, so as not to commit too hard to a particular
|
||||||
|
* model at this time (it is a bit early for that).
|
||||||
|
*/
|
||||||
|
export interface AssistanceState {
|
||||||
|
messages?: Array<{
|
||||||
|
role: string;
|
||||||
|
content: string;
|
||||||
|
}>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Suggestion {
|
/**
|
||||||
suggestedActions: DocAction[];
|
* Currently, requests for assistance always happen in the context
|
||||||
|
* of the column of a particular table.
|
||||||
|
*/
|
||||||
|
export interface FormulaAssistanceContext {
|
||||||
|
type: 'formula';
|
||||||
|
tableId: string;
|
||||||
|
colId: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type AssistanceContext = FormulaAssistanceContext;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A request for assistance.
|
||||||
|
*/
|
||||||
|
export interface AssistanceRequest {
|
||||||
|
context: AssistanceContext;
|
||||||
|
state?: AssistanceState;
|
||||||
|
text: string;
|
||||||
|
regenerate?: boolean; // Set if there was a previous request
|
||||||
|
// and response that should be omitted
|
||||||
|
// from history, or (if available) an
|
||||||
|
// alternative response generated.
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A response to a request for assistance.
|
||||||
|
* The client should preserve the state and include it in
|
||||||
|
* any follow-up requests.
|
||||||
|
*/
|
||||||
|
export interface AssistanceResponse {
|
||||||
|
suggestedActions: DocAction[];
|
||||||
|
state?: AssistanceState;
|
||||||
|
// If the model can be trusted to issue a self-contained
|
||||||
|
// markdown-friendly string, it can be included here.
|
||||||
|
reply?: string;
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ import {
|
|||||||
} from 'app/common/ActionBundle';
|
} from 'app/common/ActionBundle';
|
||||||
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
|
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
|
||||||
import {ActionSummary} from "app/common/ActionSummary";
|
import {ActionSummary} from "app/common/ActionSummary";
|
||||||
import {Prompt, Suggestion} from "app/common/AssistancePrompts";
|
import {AssistanceRequest, AssistanceResponse} from "app/common/AssistancePrompts";
|
||||||
import {
|
import {
|
||||||
AclResources,
|
AclResources,
|
||||||
AclTableDescription,
|
AclTableDescription,
|
||||||
@ -85,7 +85,7 @@ import {Document} from 'app/gen-server/entity/Document';
|
|||||||
import {ParseOptions} from 'app/plugin/FileParserAPI';
|
import {ParseOptions} from 'app/plugin/FileParserAPI';
|
||||||
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
|
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
|
||||||
import {compileAclFormula} from 'app/server/lib/ACLFormula';
|
import {compileAclFormula} from 'app/server/lib/ACLFormula';
|
||||||
import {sendForCompletion} from 'app/server/lib/Assistance';
|
import {AssistanceDoc, AssistanceSchemaPromptV1Context, sendForCompletion} from 'app/server/lib/Assistance';
|
||||||
import {Authorizer} from 'app/server/lib/Authorizer';
|
import {Authorizer} from 'app/server/lib/Authorizer';
|
||||||
import {checksumFile} from 'app/server/lib/checksumFile';
|
import {checksumFile} from 'app/server/lib/checksumFile';
|
||||||
import {Client} from 'app/server/lib/Client';
|
import {Client} from 'app/server/lib/Client';
|
||||||
@ -180,7 +180,7 @@ interface UpdateUsageOptions {
|
|||||||
* either .loadDoc() or .createEmptyDoc() is called.
|
* either .loadDoc() or .createEmptyDoc() is called.
|
||||||
* @param {String} docName - The document's filename, without the '.grist' extension.
|
* @param {String} docName - The document's filename, without the '.grist' extension.
|
||||||
*/
|
*/
|
||||||
export class ActiveDoc extends EventEmitter {
|
export class ActiveDoc extends EventEmitter implements AssistanceDoc {
|
||||||
/**
|
/**
|
||||||
* Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e.
|
* Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e.
|
||||||
* until the returned promise is resolved.
|
* until the returned promise is resolved.
|
||||||
@ -1112,7 +1112,7 @@ export class ActiveDoc extends EventEmitter {
|
|||||||
* @param {Integer} rowId - Row number
|
* @param {Integer} rowId - Row number
|
||||||
* @returns {Promise} Promise for a error message
|
* @returns {Promise} Promise for a error message
|
||||||
*/
|
*/
|
||||||
public async getFormulaError(docSession: DocSession, tableId: string, colId: string,
|
public async getFormulaError(docSession: OptDocSession, tableId: string, colId: string,
|
||||||
rowId: number): Promise<CellValue> {
|
rowId: number): Promise<CellValue> {
|
||||||
// Throw an error if the user doesn't have access to read this cell.
|
// Throw an error if the user doesn't have access to read this cell.
|
||||||
await this._granularAccess.getCellValue(docSession, {tableId, colId, rowId});
|
await this._granularAccess.getCellValue(docSession, {tableId, colId, rowId});
|
||||||
@ -1260,22 +1260,28 @@ export class ActiveDoc extends EventEmitter {
|
|||||||
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
|
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getAssistance(docSession: DocSession, userPrompt: Prompt): Promise<Suggestion> {
|
public async getAssistance(docSession: DocSession, request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||||
// Making a prompt can leak names of tables and columns.
|
return this.getAssistanceWithOptions(docSession, request);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async getAssistanceWithOptions(docSession: DocSession,
|
||||||
|
request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||||
|
// Making a prompt leaks names of tables and columns etc.
|
||||||
if (!await this._granularAccess.canScanData(docSession)) {
|
if (!await this._granularAccess.canScanData(docSession)) {
|
||||||
throw new Error("Permission denied");
|
throw new Error("Permission denied");
|
||||||
}
|
}
|
||||||
await this.waitForInitialization();
|
await this.waitForInitialization();
|
||||||
const { tableId, colId, description } = userPrompt;
|
return sendForCompletion(this, request);
|
||||||
const prompt = await this._pyCall('get_formula_prompt', tableId, colId, description);
|
}
|
||||||
this._log.debug(docSession, 'getAssistance prompt', {prompt});
|
|
||||||
const completion = await sendForCompletion(prompt);
|
// Callback to make a data-engine formula tweak for assistance.
|
||||||
this._log.debug(docSession, 'getAssistance completion', {completion});
|
public assistanceFormulaTweak(txt: string) {
|
||||||
const formula = await this._pyCall('convert_formula_completion', completion);
|
return this._pyCall('convert_formula_completion', txt);
|
||||||
const action: DocAction = ["ModifyColumn", tableId, colId, {formula}];
|
}
|
||||||
return {
|
|
||||||
suggestedActions: [action],
|
// Callback to generate a prompt containing schema info for assistance.
|
||||||
};
|
public assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string> {
|
||||||
|
return this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
|
||||||
}
|
}
|
||||||
|
|
||||||
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
|
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
|
||||||
|
@ -2,116 +2,320 @@
|
|||||||
* Module with functions used for AI formula assistance.
|
* Module with functions used for AI formula assistance.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
|
||||||
import {delay} from 'app/common/delay';
|
import {delay} from 'app/common/delay';
|
||||||
|
import {DocAction} from 'app/common/DocActions';
|
||||||
import log from 'app/server/lib/log';
|
import log from 'app/server/lib/log';
|
||||||
import fetch from 'node-fetch';
|
import fetch from 'node-fetch';
|
||||||
|
|
||||||
export const DEPS = { fetch };
|
export const DEPS = { fetch };
|
||||||
|
|
||||||
export async function sendForCompletion(prompt: string): Promise<string> {
|
/**
|
||||||
let completion: string|null = null;
|
* An assistant can help a user do things with their document,
|
||||||
let retries: number = 0;
|
* by interfacing with an external LLM endpoint.
|
||||||
const openApiKey = process.env.OPENAI_API_KEY;
|
*/
|
||||||
const model = process.env.COMPLETION_MODEL || "text-davinci-002";
|
export interface Assistant {
|
||||||
|
apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Document-related methods for use in the implementation of assistants.
|
||||||
|
* Somewhat ad-hoc currently.
|
||||||
|
*/
|
||||||
|
export interface AssistanceDoc {
|
||||||
|
/**
|
||||||
|
* Generate a particular prompt coded in the data engine for some reason.
|
||||||
|
* It makes python code for some tables, and starts a function body with
|
||||||
|
* the given docstring.
|
||||||
|
* Marked "V1" to suggest that it is a particular prompt and it would
|
||||||
|
* be great to try variants.
|
||||||
|
*/
|
||||||
|
assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Some tweaks to a formula after it has been generated.
|
||||||
|
*/
|
||||||
|
assistanceFormulaTweak(txt: string): Promise<string>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AssistanceSchemaPromptV1Context {
|
||||||
|
tableId: string,
|
||||||
|
colId: string,
|
||||||
|
docString: string,
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A flavor of assistant for use with the OpenAI API.
|
||||||
|
* Tested primarily with text-davinci-002 and gpt-3.5-turbo.
|
||||||
|
*/
|
||||||
|
export class OpenAIAssistant implements Assistant {
|
||||||
|
private _apiKey: string;
|
||||||
|
private _model: string;
|
||||||
|
private _chatMode: boolean;
|
||||||
|
private _endpoint: string;
|
||||||
|
|
||||||
|
public constructor() {
|
||||||
|
const apiKey = process.env.OPENAI_API_KEY;
|
||||||
|
if (!apiKey) {
|
||||||
|
throw new Error('OPENAI_API_KEY not set');
|
||||||
|
}
|
||||||
|
this._apiKey = apiKey;
|
||||||
|
this._model = process.env.COMPLETION_MODEL || "text-davinci-002";
|
||||||
|
this._chatMode = this._model.includes('turbo');
|
||||||
|
this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`;
|
||||||
|
}
|
||||||
|
|
||||||
|
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||||
|
const messages = request.state?.messages || [];
|
||||||
|
const chatMode = this._chatMode;
|
||||||
|
if (chatMode) {
|
||||||
|
if (messages.length === 0) {
|
||||||
|
messages.push({
|
||||||
|
role: 'system',
|
||||||
|
content: 'The user gives you one or more Python classes, ' +
|
||||||
|
'with one last method that needs completing. Write the ' +
|
||||||
|
'method body as a single code block, ' +
|
||||||
|
'including the docstring the user gave. ' +
|
||||||
|
'Just give the Python code as a markdown block, ' +
|
||||||
|
'do not give any introduction, that will just be ' +
|
||||||
|
'awkward for the user when copying and pasting. ' +
|
||||||
|
'You are working with Grist, an environment very like ' +
|
||||||
|
'regular Python except `rec` (like record) is used ' +
|
||||||
|
'instead of `self`. ' +
|
||||||
|
'Include at least one `return` statement or the method ' +
|
||||||
|
'will fail, disappointing the user. ' +
|
||||||
|
'Your answer should be the body of a single method, ' +
|
||||||
|
'not a class, and should not include `dataclass` or ' +
|
||||||
|
'`class` since the user is counting on you to provide ' +
|
||||||
|
'a single method. Thanks!'
|
||||||
|
});
|
||||||
|
messages.push({
|
||||||
|
role: 'user', content: await makeSchemaPromptV1(doc, request),
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
if (request.regenerate) {
|
||||||
|
if (messages[messages.length - 1].role !== 'user') {
|
||||||
|
messages.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages.push({
|
||||||
|
role: 'user', content: request.text,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
messages.length = 0;
|
||||||
|
messages.push({
|
||||||
|
role: 'user', content: await makeSchemaPromptV1(doc, request),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const apiResponse = await DEPS.fetch(
|
||||||
|
this._endpoint,
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Authorization": `Bearer ${this._apiKey}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
...(!this._chatMode ? {
|
||||||
|
prompt: messages[messages.length - 1].content,
|
||||||
|
} : { messages }),
|
||||||
|
max_tokens: 1500,
|
||||||
|
temperature: 0,
|
||||||
|
model: this._model,
|
||||||
|
stop: this._chatMode ? undefined : ["\n\n"],
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (apiResponse.status !== 200) {
|
||||||
|
log.error(`OpenAI API returned ${apiResponse.status}: ${await apiResponse.text()}`);
|
||||||
|
throw new Error(`OpenAI API returned status ${apiResponse.status}`);
|
||||||
|
}
|
||||||
|
const result = await apiResponse.json();
|
||||||
|
let completion: string = String(chatMode ? result.choices[0].message.content : result.choices[0].text);
|
||||||
|
const reply = completion;
|
||||||
|
const history = { messages };
|
||||||
|
if (chatMode) {
|
||||||
|
history.messages.push(result.choices[0].message);
|
||||||
|
// This model likes returning markdown. Code will typically
|
||||||
|
// be in a code block with ``` delimiters.
|
||||||
|
let lines = completion.split('\n');
|
||||||
|
if (lines[0].startsWith('```')) {
|
||||||
|
lines.shift();
|
||||||
|
completion = lines.join('\n');
|
||||||
|
const parts = completion.split('```');
|
||||||
|
if (parts.length > 1) {
|
||||||
|
completion = parts[0];
|
||||||
|
}
|
||||||
|
lines = completion.split('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
// This model likes repeating the function signature and
|
||||||
|
// docstring, so we try to strip that out.
|
||||||
|
completion = lines.join('\n');
|
||||||
|
while (completion.includes('"""')) {
|
||||||
|
const parts = completion.split('"""');
|
||||||
|
completion = parts[parts.length - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's no code block, don't treat the answer as a formula.
|
||||||
|
if (!reply.includes('```')) {
|
||||||
|
completion = '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await completionToResponse(doc, request, completion, reply);
|
||||||
|
if (chatMode) {
|
||||||
|
response.state = history;
|
||||||
|
}
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class HuggingFaceAssistant implements Assistant {
|
||||||
|
private _apiKey: string;
|
||||||
|
private _completionUrl: string;
|
||||||
|
|
||||||
|
public constructor() {
|
||||||
|
const apiKey = process.env.HUGGINGFACE_API_KEY;
|
||||||
|
if (!apiKey) {
|
||||||
|
throw new Error('HUGGINGFACE_API_KEY not set');
|
||||||
|
}
|
||||||
|
this._apiKey = apiKey;
|
||||||
|
// COMPLETION_MODEL values I've tried:
|
||||||
|
// - codeparrot/codeparrot
|
||||||
|
// - NinedayWang/PolyCoder-2.7B
|
||||||
|
// - NovelAI/genji-python-6B
|
||||||
|
let completionUrl = process.env.COMPLETION_URL;
|
||||||
|
if (!completionUrl) {
|
||||||
|
if (process.env.COMPLETION_MODEL) {
|
||||||
|
completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`;
|
||||||
|
} else {
|
||||||
|
completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this._completionUrl = completionUrl;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||||
|
if (request.state) {
|
||||||
|
throw new Error("HuggingFaceAssistant does not support state");
|
||||||
|
}
|
||||||
|
const prompt = await makeSchemaPromptV1(doc, request);
|
||||||
|
const response = await DEPS.fetch(
|
||||||
|
this._completionUrl,
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Authorization": `Bearer ${this._apiKey}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
inputs: prompt,
|
||||||
|
parameters: {
|
||||||
|
return_full_text: false,
|
||||||
|
max_new_tokens: 50,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (response.status === 503) {
|
||||||
|
log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`);
|
||||||
|
await delay(10000);
|
||||||
|
}
|
||||||
|
if (response.status !== 200) {
|
||||||
|
const text = await response.text();
|
||||||
|
log.error(`HuggingFace API returned ${response.status}: ${text}`);
|
||||||
|
throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
|
||||||
|
}
|
||||||
|
const result = await response.json();
|
||||||
|
let completion = result[0].generated_text;
|
||||||
|
completion = completion.split('\n\n')[0];
|
||||||
|
return completionToResponse(doc, request, completion);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Instantiate an assistant, based on environment variables.
|
||||||
|
*/
|
||||||
|
function getAssistant() {
|
||||||
|
if (process.env.OPENAI_API_KEY) {
|
||||||
|
return new OpenAIAssistant();
|
||||||
|
}
|
||||||
|
if (process.env.HUGGINGFACE_API_KEY) {
|
||||||
|
return new HuggingFaceAssistant();
|
||||||
|
}
|
||||||
|
throw new Error('Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY');
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Service a request for assistance, with a little retry logic
|
||||||
|
* since these endpoints can be a bit flakey.
|
||||||
|
*/
|
||||||
|
export async function sendForCompletion(doc: AssistanceDoc,
|
||||||
|
request: AssistanceRequest): Promise<AssistanceResponse> {
|
||||||
|
const assistant = getAssistant();
|
||||||
|
|
||||||
|
let retries: number = 0;
|
||||||
|
|
||||||
|
let response: AssistanceResponse|null = null;
|
||||||
while(retries++ < 3) {
|
while(retries++ < 3) {
|
||||||
try {
|
try {
|
||||||
if (openApiKey) {
|
response = await assistant.apply(doc, request);
|
||||||
completion = await sendForCompletionOpenAI(prompt, openApiKey, model);
|
|
||||||
}
|
|
||||||
if (process.env.HUGGINGFACE_API_KEY) {
|
|
||||||
completion = await sendForCompletionHuggingFace(prompt);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
} catch(e) {
|
} catch(e) {
|
||||||
|
log.error(`Completion error: ${e}`);
|
||||||
await delay(1000);
|
await delay(1000);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (completion === null) {
|
if (!response) {
|
||||||
throw new Error("Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY (and optionally COMPLETION_MODEL)");
|
throw new Error('Failed to get response from assistant');
|
||||||
}
|
}
|
||||||
log.debug(`Received completion:`, {completion});
|
return response;
|
||||||
completion = completion.split(/\n {4}[^ ]/)[0];
|
|
||||||
return completion;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function makeSchemaPromptV1(doc: AssistanceDoc, request: AssistanceRequest) {
|
||||||
async function sendForCompletionOpenAI(prompt: string, apiKey: string, model = "text-davinci-002") {
|
if (request.context.type !== 'formula') {
|
||||||
if (!apiKey) {
|
throw new Error('makeSchemaPromptV1 only works for formulas');
|
||||||
throw new Error("OPENAI_API_KEY not set");
|
|
||||||
}
|
}
|
||||||
const response = await DEPS.fetch(
|
return doc.assistanceSchemaPromptV1({
|
||||||
"https://api.openai.com/v1/completions",
|
tableId: request.context.tableId,
|
||||||
{
|
colId: request.context.colId,
|
||||||
method: "POST",
|
docString: request.text,
|
||||||
headers: {
|
});
|
||||||
"Authorization": `Bearer ${apiKey}`,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
prompt,
|
|
||||||
max_tokens: 150,
|
|
||||||
temperature: 0,
|
|
||||||
// COMPLETION_MODEL of `code-davinci-002` may be better if you have access to it.
|
|
||||||
model,
|
|
||||||
stop: ["\n\n"],
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
if (response.status !== 200) {
|
|
||||||
log.error(`OpenAI API returned ${response.status}: ${await response.text()}`);
|
|
||||||
throw new Error(`OpenAI API returned status ${response.status}`);
|
|
||||||
}
|
|
||||||
const result = await response.json();
|
|
||||||
const completion = result.choices[0].text;
|
|
||||||
return completion;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function sendForCompletionHuggingFace(prompt: string) {
|
async function completionToResponse(doc: AssistanceDoc, request: AssistanceRequest,
|
||||||
const apiKey = process.env.HUGGINGFACE_API_KEY;
|
completion: string, reply?: string): Promise<AssistanceResponse> {
|
||||||
if (!apiKey) {
|
if (request.context.type !== 'formula') {
|
||||||
throw new Error("HUGGINGFACE_API_KEY not set");
|
throw new Error('completionToResponse only works for formulas');
|
||||||
}
|
}
|
||||||
// COMPLETION_MODEL values I've tried:
|
completion = await doc.assistanceFormulaTweak(completion);
|
||||||
// - codeparrot/codeparrot
|
// A leading newline is common.
|
||||||
// - NinedayWang/PolyCoder-2.7B
|
if (completion.charAt(0) === '\n') {
|
||||||
// - NovelAI/genji-python-6B
|
completion = completion.slice(1);
|
||||||
let completionUrl = process.env.COMPLETION_URL;
|
}
|
||||||
if (!completionUrl) {
|
// If all non-empty lines have four spaces, remove those spaces.
|
||||||
if (process.env.COMPLETION_MODEL) {
|
// They are common for GPT-3.5, which matches the prompt carefully.
|
||||||
completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`;
|
const lines = completion.split('\n');
|
||||||
} else {
|
const ok = lines.every(line => line === '\n' || line.startsWith(' '));
|
||||||
completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
|
if (ok) {
|
||||||
|
completion = lines.map(line => line === '\n' ? line : line.slice(4)).join('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Suggest an action only if the completion is non-empty (that is,
|
||||||
|
// it actually looked like code).
|
||||||
|
const suggestedActions: DocAction[] = completion ? [[
|
||||||
|
"ModifyColumn",
|
||||||
|
request.context.tableId,
|
||||||
|
request.context.colId, {
|
||||||
|
formula: completion,
|
||||||
}
|
}
|
||||||
}
|
]] : [];
|
||||||
|
return {
|
||||||
const response = await DEPS.fetch(
|
suggestedActions,
|
||||||
completionUrl,
|
reply,
|
||||||
{
|
};
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Authorization": `Bearer ${apiKey}`,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
inputs: prompt,
|
|
||||||
parameters: {
|
|
||||||
return_full_text: false,
|
|
||||||
max_new_tokens: 50,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
if (response.status === 503) {
|
|
||||||
log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`);
|
|
||||||
await delay(10000);
|
|
||||||
}
|
|
||||||
if (response.status !== 200) {
|
|
||||||
const text = await response.text();
|
|
||||||
log.error(`HuggingFace API returned ${response.status}: ${text}`);
|
|
||||||
throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
|
|
||||||
}
|
|
||||||
const result = await response.json();
|
|
||||||
const completion = result[0].generated_text;
|
|
||||||
return completion.split('\n\n')[0];
|
|
||||||
}
|
}
|
||||||
|
@ -376,7 +376,8 @@ export class GranularAccess implements GranularAccessForBundle {
|
|||||||
function fail(): never {
|
function fail(): never {
|
||||||
throw new ErrorWithCode('ACL_DENY', 'Cannot access cell');
|
throw new ErrorWithCode('ACL_DENY', 'Cannot access cell');
|
||||||
}
|
}
|
||||||
if (!await this.hasTableAccess(docSession, cell.tableId)) { fail(); }
|
const hasExceptionalAccess = this._hasExceptionalFullAccess(docSession);
|
||||||
|
if (!hasExceptionalAccess && !await this.hasTableAccess(docSession, cell.tableId)) { fail(); }
|
||||||
let rows: TableDataAction|null = null;
|
let rows: TableDataAction|null = null;
|
||||||
if (docData) {
|
if (docData) {
|
||||||
const record = docData.getTable(cell.tableId)?.getRecord(cell.rowId);
|
const record = docData.getTable(cell.tableId)?.getRecord(cell.rowId);
|
||||||
@ -393,16 +394,18 @@ export class GranularAccess implements GranularAccessForBundle {
|
|||||||
return fail();
|
return fail();
|
||||||
}
|
}
|
||||||
const rec = new RecordView(rows, 0);
|
const rec = new RecordView(rows, 0);
|
||||||
const input: AclMatchInput = {...await this.inputs(docSession), rec, newRec: rec};
|
if (!hasExceptionalAccess) {
|
||||||
const rowPermInfo = new PermissionInfo(this._ruler.ruleCollection, input);
|
const input: AclMatchInput = {...await this.inputs(docSession), rec, newRec: rec};
|
||||||
const rowAccess = rowPermInfo.getTableAccess(cell.tableId).perms.read;
|
const rowPermInfo = new PermissionInfo(this._ruler.ruleCollection, input);
|
||||||
if (rowAccess === 'deny') { fail(); }
|
const rowAccess = rowPermInfo.getTableAccess(cell.tableId).perms.read;
|
||||||
if (rowAccess !== 'allow') {
|
if (rowAccess === 'deny') { fail(); }
|
||||||
const colAccess = rowPermInfo.getColumnAccess(cell.tableId, cell.colId).perms.read;
|
if (rowAccess !== 'allow') {
|
||||||
if (colAccess === 'deny') { fail(); }
|
const colAccess = rowPermInfo.getColumnAccess(cell.tableId, cell.colId).perms.read;
|
||||||
|
if (colAccess === 'deny') { fail(); }
|
||||||
|
}
|
||||||
|
const colValues = rows[3];
|
||||||
|
if (!(cell.colId in colValues)) { fail(); }
|
||||||
}
|
}
|
||||||
const colValues = rows[3];
|
|
||||||
if (!(cell.colId in colValues)) { fail(); }
|
|
||||||
return rec.get(cell.colId);
|
return rec.get(cell.colId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,7 +37,6 @@ def column_type(engine, table_id, col_id):
|
|||||||
Attachments="Any",
|
Attachments="Any",
|
||||||
)[parts[0]]
|
)[parts[0]]
|
||||||
|
|
||||||
|
|
||||||
def choices(col_rec):
|
def choices(col_rec):
|
||||||
try:
|
try:
|
||||||
widget_options = json.loads(col_rec.widgetOptions)
|
widget_options = json.loads(col_rec.widgetOptions)
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
#!/usr/bin/env node
|
#!/usr/bin/env node
|
||||||
"use strict";
|
"use strict";
|
||||||
|
const fs = require('fs');
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
const codeRoot = path.dirname(path.dirname(path.dirname(__dirname)));
|
|
||||||
|
let codeRoot = path.dirname(path.dirname(__dirname));
|
||||||
|
if (!fs.existsSync(path.join(codeRoot, '_build'))) {
|
||||||
|
codeRoot = path.dirname(codeRoot);
|
||||||
|
}
|
||||||
|
|
||||||
process.env.DATA_PATH = path.join(__dirname, 'data');
|
process.env.DATA_PATH = path.join(__dirname, 'data');
|
||||||
|
|
||||||
|
|
||||||
require('app-module-path').addPath(path.join(codeRoot, '_build'));
|
require('app-module-path').addPath(path.join(codeRoot, '_build'));
|
||||||
require('app-module-path').addPath(path.join(codeRoot, '_build', 'core'));
|
require('app-module-path').addPath(path.join(codeRoot, '_build', 'core'));
|
||||||
require('app-module-path').addPath(path.join(codeRoot, '_build', 'ext'));
|
require('app-module-path').addPath(path.join(codeRoot, '_build', 'ext'));
|
||||||
|
require('app-module-path').addPath(path.join(codeRoot, '_build', 'stubs'));
|
||||||
require('test/formula-dataset/runCompletion_impl').runCompletion().catch(console.error);
|
require('test/formula-dataset/runCompletion_impl').runCompletion().catch(console.error);
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
import { ActiveDoc } from "app/server/lib/ActiveDoc";
|
import { ActiveDoc, Deps as ActiveDocDeps } from "app/server/lib/ActiveDoc";
|
||||||
import { DEPS } from "app/server/lib/Assistance";
|
import { DEPS } from "app/server/lib/Assistance";
|
||||||
import log from 'app/server/lib/log';
|
import log from 'app/server/lib/log';
|
||||||
import crypto from 'crypto';
|
import crypto from 'crypto';
|
||||||
@ -38,11 +38,14 @@ import * as os from 'os';
|
|||||||
import { pipeline } from 'stream';
|
import { pipeline } from 'stream';
|
||||||
import { createDocTools } from "test/server/docTools";
|
import { createDocTools } from "test/server/docTools";
|
||||||
import { promisify } from 'util';
|
import { promisify } from 'util';
|
||||||
|
import { AssistanceResponse, AssistanceState } from "app/common/AssistancePrompts";
|
||||||
|
import { CellValue } from "app/plugin/GristData";
|
||||||
|
|
||||||
const streamPipeline = promisify(pipeline);
|
const streamPipeline = promisify(pipeline);
|
||||||
|
|
||||||
const DATA_PATH = process.env.DATA_PATH || path.join(__dirname, 'data');
|
const DATA_PATH = process.env.DATA_PATH || path.join(__dirname, 'data');
|
||||||
const PATH_TO_DOC = path.join(DATA_PATH, 'templates');
|
const PATH_TO_DOC = path.join(DATA_PATH, 'templates');
|
||||||
|
const PATH_TO_RESULTS = path.join(DATA_PATH, 'results');
|
||||||
const PATH_TO_CSV = path.join(DATA_PATH, 'formula-dataset-index.csv');
|
const PATH_TO_CSV = path.join(DATA_PATH, 'formula-dataset-index.csv');
|
||||||
const PATH_TO_CACHE = path.join(DATA_PATH, 'cache');
|
const PATH_TO_CACHE = path.join(DATA_PATH, 'cache');
|
||||||
const TEMPLATE_URL = "https://grist-static.com/datasets/grist_dataset_formulai_2023_02_20.zip";
|
const TEMPLATE_URL = "https://grist-static.com/datasets/grist_dataset_formulai_2023_02_20.zip";
|
||||||
@ -60,8 +63,11 @@ const _stats = {
|
|||||||
callCount: 0,
|
callCount: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const SEEMS_CHATTY = (process.env.COMPLETION_MODEL || '').includes('turbo');
|
||||||
|
const SIMULATE_CONVERSATION = SEEMS_CHATTY;
|
||||||
|
|
||||||
export async function runCompletion() {
|
export async function runCompletion() {
|
||||||
|
ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600000;
|
||||||
|
|
||||||
// if template directory not exists, make it
|
// if template directory not exists, make it
|
||||||
if (!fs.existsSync(path.join(PATH_TO_DOC))) {
|
if (!fs.existsSync(path.join(PATH_TO_DOC))) {
|
||||||
@ -107,11 +113,12 @@ export async function runCompletion() {
|
|||||||
if (!process.env.VERBOSE) {
|
if (!process.env.VERBOSE) {
|
||||||
log.transports.file.level = 'error'; // Suppress most of log output.
|
log.transports.file.level = 'error'; // Suppress most of log output.
|
||||||
}
|
}
|
||||||
let activeDoc: ActiveDoc|undefined;
|
|
||||||
const docTools = createDocTools();
|
const docTools = createDocTools();
|
||||||
const session = docTools.createFakeSession('owners');
|
const session = docTools.createFakeSession('owners');
|
||||||
await docTools.before();
|
await docTools.before();
|
||||||
let successCount = 0;
|
let successCount = 0;
|
||||||
|
let caseCount = 0;
|
||||||
|
fs.mkdirSync(path.join(PATH_TO_RESULTS), {recursive: true});
|
||||||
|
|
||||||
console.log('Testing AI assistance: ');
|
console.log('Testing AI assistance: ');
|
||||||
|
|
||||||
@ -119,38 +126,109 @@ export async function runCompletion() {
|
|||||||
|
|
||||||
DEPS.fetch = fetchWithCache;
|
DEPS.fetch = fetchWithCache;
|
||||||
|
|
||||||
|
let activeDoc: ActiveDoc|undefined;
|
||||||
for (const rec of records) {
|
for (const rec of records) {
|
||||||
|
let success: boolean = false;
|
||||||
|
let suggestedActions: AssistanceResponse['suggestedActions'] | undefined;
|
||||||
|
let newValues: CellValue[] | undefined;
|
||||||
|
let expected: CellValue[] | undefined;
|
||||||
|
let formula: string | undefined;
|
||||||
|
let history: AssistanceState = {messages: []};
|
||||||
|
let lastFollowUp: string | undefined;
|
||||||
|
|
||||||
// load new document
|
try {
|
||||||
if (!activeDoc || activeDoc.docName !== rec.doc_id) {
|
async function sendMessage(followUp?: string) {
|
||||||
const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist');
|
// load new document
|
||||||
activeDoc = await docTools.loadLocalDoc(docPath);
|
if (!activeDoc || activeDoc.docName !== rec.doc_id) {
|
||||||
await activeDoc.waitForInitialization();
|
const docPath = path.join(PATH_TO_DOC, rec.doc_id + '.grist');
|
||||||
|
activeDoc = await docTools.loadLocalDoc(docPath);
|
||||||
|
await activeDoc.waitForInitialization();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!activeDoc) { throw new Error("No doc"); }
|
||||||
|
|
||||||
|
// get values
|
||||||
|
await activeDoc.docData!.fetchTable(rec.table_id);
|
||||||
|
expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
|
||||||
|
|
||||||
|
// send prompt
|
||||||
|
const tableId = rec.table_id;
|
||||||
|
const colId = rec.col_id;
|
||||||
|
const description = rec.Description;
|
||||||
|
const colInfo = await activeDoc.docStorage.get(`
|
||||||
|
select * from _grist_Tables_column as c
|
||||||
|
left join _grist_Tables as t on t.id = c.parentId
|
||||||
|
where c.colId = ? and t.tableId = ?
|
||||||
|
`, rec.col_id, rec.table_id);
|
||||||
|
formula = colInfo?.formula;
|
||||||
|
|
||||||
|
const result = await activeDoc.getAssistanceWithOptions(session, {
|
||||||
|
context: {type: 'formula', tableId, colId},
|
||||||
|
state: history,
|
||||||
|
text: followUp || description,
|
||||||
|
});
|
||||||
|
if (result.state) {
|
||||||
|
history = result.state;
|
||||||
|
}
|
||||||
|
suggestedActions = result.suggestedActions;
|
||||||
|
// apply modification
|
||||||
|
const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions);
|
||||||
|
|
||||||
|
// get new values
|
||||||
|
newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
|
||||||
|
|
||||||
|
// revert modification
|
||||||
|
const [bundle] = await activeDoc.getActions([actionNum]);
|
||||||
|
await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true);
|
||||||
|
|
||||||
|
// compare values
|
||||||
|
success = isEqual(expected, newValues);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
const rowIds = activeDoc.docData!.getTable(rec.table_id)!.getRowIds();
|
||||||
|
const result = await activeDoc.getFormulaError({client: null, mode: 'system'} as any, rec.table_id,
|
||||||
|
rec.col_id, rowIds[0]);
|
||||||
|
if (Array.isArray(result) && result[0] === 'E') {
|
||||||
|
result.shift();
|
||||||
|
const txt = `I got a \`${result.shift()}\` error:\n` +
|
||||||
|
'```\n' +
|
||||||
|
result.shift() + '\n' +
|
||||||
|
result.shift() + '\n' +
|
||||||
|
'```\n' +
|
||||||
|
'Please answer with the code block you (the assistant) just gave, ' +
|
||||||
|
'revised based on this error. Your answer must include a code block. ' +
|
||||||
|
'If you have to explain anything, do it after. ' +
|
||||||
|
'It is perfectly acceptable (and may be necessary) to do ' +
|
||||||
|
'imports from within a method body.\n';
|
||||||
|
return { followUp: txt };
|
||||||
|
} else {
|
||||||
|
for (let i = 0; i < expected.length; i++) {
|
||||||
|
const e = expected[i];
|
||||||
|
const v = newValues[i];
|
||||||
|
if (String(e) !== String(v)) {
|
||||||
|
const txt = `I got \`${v}\` where I expected \`${e}\`\n` +
|
||||||
|
'Please answer with the code block you (the assistant) just gave, ' +
|
||||||
|
'revised based on this information. Your answer must include a code ' +
|
||||||
|
'block. If you have to explain anything, do it after.\n';
|
||||||
|
return { followUp: txt };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const result = await sendMessage();
|
||||||
|
if (result?.followUp && SIMULATE_CONVERSATION) {
|
||||||
|
// Allow one follow up message, based on error or differences.
|
||||||
|
const result2 = await sendMessage(result.followUp);
|
||||||
|
if (result2?.followUp) {
|
||||||
|
lastFollowUp = result2.followUp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
// get values
|
|
||||||
await activeDoc.docData!.fetchTable(rec.table_id);
|
|
||||||
const expected = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
|
|
||||||
|
|
||||||
// send prompt
|
|
||||||
const tableId = rec.table_id;
|
|
||||||
const colId = rec.col_id;
|
|
||||||
const description = rec.Description;
|
|
||||||
const {suggestedActions} = await activeDoc.getAssistance(session, {tableId, colId, description});
|
|
||||||
|
|
||||||
// apply modification
|
|
||||||
const {actionNum} = await activeDoc.applyUserActions(session, suggestedActions);
|
|
||||||
|
|
||||||
// get new values
|
|
||||||
const newValues = activeDoc.docData!.getTable(rec.table_id)!.getColValues(rec.col_id)!.slice();
|
|
||||||
|
|
||||||
// revert modification
|
|
||||||
const [bundle] = await activeDoc.getActions([actionNum]);
|
|
||||||
await activeDoc.applyUserActionsById(session, [bundle!.actionNum], [bundle!.actionHash!], true);
|
|
||||||
|
|
||||||
// compare values
|
|
||||||
const success = isEqual(expected, newValues);
|
|
||||||
|
|
||||||
console.log(` ${success ? 'Successfully' : 'Failed to'} complete formula ` +
|
console.log(` ${success ? 'Successfully' : 'Failed to'} complete formula ` +
|
||||||
`for column ${rec.table_id}.${rec.col_id} (doc=${rec.doc_id})`);
|
`for column ${rec.table_id}.${rec.col_id} (doc=${rec.doc_id})`);
|
||||||
|
|
||||||
@ -162,6 +240,23 @@ export async function runCompletion() {
|
|||||||
// console.log('expected=', expected);
|
// console.log('expected=', expected);
|
||||||
// console.log('actual=', newValues);
|
// console.log('actual=', newValues);
|
||||||
}
|
}
|
||||||
|
const suggestedFormula = suggestedActions?.length === 1 &&
|
||||||
|
suggestedActions[0][0] === 'ModifyColumn' &&
|
||||||
|
suggestedActions[0][3].formula || suggestedActions;
|
||||||
|
fs.writeFileSync(
|
||||||
|
path.join(
|
||||||
|
PATH_TO_RESULTS,
|
||||||
|
`${rec.table_id}_${rec.col_id}_` +
|
||||||
|
caseCount.toLocaleString('en', {minimumIntegerDigits: 8, useGrouping: false}) + '.json'),
|
||||||
|
JSON.stringify({
|
||||||
|
formula,
|
||||||
|
suggestedFormula, success,
|
||||||
|
expectedValues: expected,
|
||||||
|
suggestedValues: newValues,
|
||||||
|
history,
|
||||||
|
lastFollowUp,
|
||||||
|
}, null, 2));
|
||||||
|
caseCount++;
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
await docTools.after();
|
await docTools.after();
|
||||||
@ -171,6 +266,13 @@ export async function runCompletion() {
|
|||||||
console.log(
|
console.log(
|
||||||
`AI Assistance completed ${successCount} successful prompt on a total of ${records.length};`
|
`AI Assistance completed ${successCount} successful prompt on a total of ${records.length};`
|
||||||
);
|
);
|
||||||
|
console.log(JSON.stringify(
|
||||||
|
{
|
||||||
|
hit: successCount,
|
||||||
|
total: records.length,
|
||||||
|
percentage: (100.0 * successCount) / Math.max(records.length, 1),
|
||||||
|
}
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user