mirror of
https://github.com/gristlabs/grist-core.git
synced 2024-10-27 20:44:07 +00:00
0be858c19d
This adds an ASSISTANT_CHAT_COMPLETION_ENDPOINT which can be used to enable AI Assistance instead of an OpenAI API key. The assistant then works against compatible endpoints, in the mechanical sense. Quality of course will depend on the model. I found some tweaks to the prompt that work well both for Llama-2 and for OpenAI's models, but I'm not including them here because they would conflict with some prompt changes that are already in the works. Co-authored-by: Alex Hall <alex.mojaki@gmail.com>
506 lines
18 KiB
TypeScript
506 lines
18 KiB
TypeScript
/**
|
|
* Module with functions used for AI formula assistance.
|
|
*/
|
|
|
|
import {
|
|
AssistanceContext,
|
|
AssistanceMessage,
|
|
AssistanceRequest,
|
|
AssistanceResponse
|
|
} from 'app/common/AssistancePrompts';
|
|
import {delay} from 'app/common/delay';
|
|
import {DocAction} from 'app/common/DocActions';
|
|
import {ActiveDoc} from 'app/server/lib/ActiveDoc';
|
|
import {getDocSessionUser, OptDocSession} from 'app/server/lib/DocSession';
|
|
import log from 'app/server/lib/log';
|
|
import fetch from 'node-fetch';
|
|
import {createHash} from "crypto";
|
|
import {getLogMetaFromDocSession} from "./serverUtils";
|
|
|
|
// These are mocked/replaced in tests.
|
|
// fetch is also replacing in the runCompletion script to add caching.
|
|
export const DEPS = { fetch, delayTime: 1000 };
|
|
|
|
/**
|
|
* An assistant can help a user do things with their document,
|
|
* by interfacing with an external LLM endpoint.
|
|
*/
|
|
interface Assistant {
|
|
apply(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
|
|
}
|
|
|
|
/**
|
|
* Document-related methods for use in the implementation of assistants.
|
|
* Somewhat ad-hoc currently.
|
|
*/
|
|
interface AssistanceDoc extends ActiveDoc {
|
|
/**
|
|
* Generate a particular prompt coded in the data engine for some reason.
|
|
* It makes python code for some tables, and starts a function body with
|
|
* the given docstring.
|
|
* Marked "V1" to suggest that it is a particular prompt and it would
|
|
* be great to try variants.
|
|
*/
|
|
assistanceSchemaPromptV1(session: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string>;
|
|
|
|
/**
|
|
* Some tweaks to a formula after it has been generated.
|
|
*/
|
|
assistanceFormulaTweak(txt: string): Promise<string>;
|
|
|
|
/**
|
|
* Compute the existing formula and return the result along with recorded values
|
|
* of (possibly nested) attributes of `rec`.
|
|
* Used by AI assistance to fix an incorrect formula.
|
|
*/
|
|
assistanceEvaluateFormula(options: AssistanceContext): Promise<AssistanceFormulaEvaluationResult>;
|
|
}
|
|
|
|
export interface AssistanceFormulaEvaluationResult {
|
|
error: boolean; // true if an exception was raised
|
|
result: string; // repr of the return value OR exception message
|
|
|
|
// Recorded attributes of `rec` at the time of evaluation.
|
|
// Keys may be e.g. "rec.foo.bar" for nested attributes.
|
|
attributes: Record<string, string>;
|
|
|
|
formula: string; // the code that was evaluated, without special grist syntax
|
|
}
|
|
|
|
export interface AssistanceSchemaPromptV1Context {
|
|
tableId: string,
|
|
colId: string,
|
|
docString: string,
|
|
}
|
|
|
|
class SwitchToLongerContext extends Error {
|
|
}
|
|
|
|
class NonRetryableError extends Error {
|
|
}
|
|
|
|
class TokensExceededFirstMessage extends NonRetryableError {
|
|
constructor() {
|
|
super(
|
|
"Sorry, there's too much information for the AI to process. " +
|
|
"You'll need to either shorten your message or delete some columns."
|
|
);
|
|
}
|
|
}
|
|
|
|
class TokensExceededLaterMessage extends NonRetryableError {
|
|
constructor() {
|
|
super(
|
|
"Sorry, there's too much information for the AI to process. " +
|
|
"You'll need to either shorten your message, restart the conversation, or delete some columns."
|
|
);
|
|
}
|
|
}
|
|
|
|
class QuotaExceededError extends NonRetryableError {
|
|
constructor() {
|
|
super(
|
|
"Sorry, the assistant is facing some long term capacity issues. " +
|
|
"Maybe try again tomorrow."
|
|
);
|
|
}
|
|
}
|
|
|
|
class RetryableError extends Error {
|
|
constructor(message: string) {
|
|
super(
|
|
"Sorry, the assistant is unavailable right now. " +
|
|
"Try again in a few minutes. \n" +
|
|
`(${message})`
|
|
);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* A flavor of assistant for use with the OpenAI chat completion endpoint
|
|
* and tools with a compatible endpoint (e.g. llama-cpp-python).
|
|
* Tested primarily with gpt-3.5-turbo.
|
|
*
|
|
* Uses the ASSISTANT_CHAT_COMPLETION_ENDPOINT endpoint if set, else
|
|
* an OpenAI endpoint. Passes ASSISTANT_API_KEY or OPENAI_API_KEY in
|
|
* a header if set. An api key is required for the default OpenAI
|
|
* endpoint.
|
|
*
|
|
* If a model string is set in ASSISTANT_MODEL, this will be passed
|
|
* along. For the default OpenAI endpoint, a gpt-3.5-turbo variant
|
|
* will be set by default.
|
|
*
|
|
* If a request fails because of context length limitation, and the
|
|
* default OpenAI endpoint is in use, the request will be retried
|
|
* with ASSISTANT_LONGER_CONTEXT_MODEL (another gpt-3.5
|
|
* variant by default). Set this variable to "" if this behavior is
|
|
* not desired for the default OpenAI endpoint. If a custom endpoint was
|
|
* provided, this behavior will only happen if
|
|
* ASSISTANT_LONGER_CONTEXT_MODEL is explicitly set.
|
|
*
|
|
* An optional ASSISTANT_MAX_TOKENS can be specified.
|
|
*/
|
|
export class OpenAIAssistant implements Assistant {
|
|
public static DEFAULT_MODEL = "gpt-3.5-turbo-0613";
|
|
public static DEFAULT_LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613";
|
|
|
|
private _apiKey?: string;
|
|
private _model?: string;
|
|
private _longerContextModel?: string;
|
|
private _endpoint: string;
|
|
private _maxTokens = process.env.ASSISTANT_MAX_TOKENS ?
|
|
parseInt(process.env.ASSISTANT_MAX_TOKENS, 10) : undefined;
|
|
|
|
public constructor() {
|
|
const apiKey = process.env.ASSISTANT_API_KEY || process.env.OPENAI_API_KEY;
|
|
const endpoint = process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT;
|
|
if (!apiKey && !endpoint) {
|
|
throw new Error('Please set either OPENAI_API_KEY or ASSISTANT_CHAT_COMPLETION_ENDPOINT');
|
|
}
|
|
this._apiKey = apiKey;
|
|
this._model = process.env.ASSISTANT_MODEL;
|
|
this._longerContextModel = process.env.ASSISTANT_LONGER_CONTEXT_MODEL;
|
|
if (!endpoint) {
|
|
this._model = this._model ?? OpenAIAssistant.DEFAULT_MODEL;
|
|
this._longerContextModel = this._longerContextModel ?? OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL;
|
|
}
|
|
this._endpoint = endpoint || `https://api.openai.com/v1/chat/completions`;
|
|
}
|
|
|
|
public async apply(
|
|
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
|
|
const messages = request.state?.messages || [];
|
|
const newMessages = [];
|
|
if (messages.length === 0) {
|
|
newMessages.push({
|
|
role: 'system',
|
|
content: 'You are a helpful assistant for a user of software called Grist. ' +
|
|
'Below are one or more Python classes. ' +
|
|
'The last method needs completing. ' +
|
|
"The user will probably give a description of what they want the method (a 'formula') to return. " +
|
|
'If so, your response should include the method body as Python code in a markdown block. ' +
|
|
'Do not include the class or method signature, just the method body. ' +
|
|
'If your code starts with `class`, `@dataclass`, or `def` it will fail. Only give the method body. ' +
|
|
'You can import modules inside the method body if needed. ' +
|
|
'You cannot define additional functions or methods. ' +
|
|
'The method should be a pure function that performs some computation and returns a result. ' +
|
|
'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' +
|
|
'It CANNOT interact with files/databases/networks/etc. ' +
|
|
'It CANNOT display images/charts/graphs/maps/etc. ' +
|
|
'If the user asks for these things, tell them that you cannot help. ' +
|
|
'The method uses `rec` instead of `self` as the first parameter.\n\n' +
|
|
'```python\n' +
|
|
await makeSchemaPromptV1(optSession, doc, request) +
|
|
'\n```',
|
|
});
|
|
}
|
|
if (request.context.evaluateCurrentFormula) {
|
|
const result = await doc.assistanceEvaluateFormula(request.context);
|
|
let message = "Evaluating this code:\n\n```python\n" + result.formula + "\n```\n\n";
|
|
if (Object.keys(result.attributes).length > 0) {
|
|
const attributes = Object.entries(result.attributes).map(([k, v]) => `${k} = ${v}`).join('\n');
|
|
message += `where:\n\n${attributes}\n\n`;
|
|
}
|
|
message += `${result.error ? 'raises an exception' : 'returns'}: ${result.result}`;
|
|
newMessages.push({
|
|
role: 'system',
|
|
content: message,
|
|
});
|
|
}
|
|
newMessages.push({
|
|
role: 'user', content: request.text,
|
|
});
|
|
messages.push(...newMessages);
|
|
|
|
const newMessagesStartIndex = messages.length - newMessages.length;
|
|
for (const [index, {role, content}] of newMessages.entries()) {
|
|
doc.logTelemetryEvent(optSession, 'assistantSend', {
|
|
full: {
|
|
conversationId: request.conversationId,
|
|
context: request.context,
|
|
prompt: {
|
|
index: newMessagesStartIndex + index,
|
|
role,
|
|
content,
|
|
},
|
|
},
|
|
});
|
|
}
|
|
|
|
const userIdHash = getUserHash(optSession);
|
|
const completion: string = await this._getCompletion(messages, userIdHash);
|
|
const response = await completionToResponse(doc, request, completion);
|
|
if (response.suggestedFormula) {
|
|
// Show the tweaked version of the suggested formula to the user (i.e. the one that's
|
|
// copied when the Apply button is clicked).
|
|
response.reply = replaceMarkdownCode(completion, response.suggestedFormula);
|
|
} else {
|
|
response.reply = completion;
|
|
}
|
|
response.state = {messages};
|
|
doc.logTelemetryEvent(optSession, 'assistantReceive', {
|
|
full: {
|
|
conversationId: request.conversationId,
|
|
context: request.context,
|
|
message: {
|
|
index: messages.length - 1,
|
|
content: completion,
|
|
},
|
|
suggestedFormula: response.suggestedFormula,
|
|
},
|
|
});
|
|
return response;
|
|
}
|
|
|
|
private async _fetchCompletion(messages: AssistanceMessage[], userIdHash: string, longerContext: boolean) {
|
|
const model = longerContext ? this._longerContextModel : this._model;
|
|
const apiResponse = await DEPS.fetch(
|
|
this._endpoint,
|
|
{
|
|
method: "POST",
|
|
headers: {
|
|
...(this._apiKey ? {
|
|
"Authorization": `Bearer ${this._apiKey}`,
|
|
} : undefined),
|
|
"Content-Type": "application/json",
|
|
},
|
|
body: JSON.stringify({
|
|
messages,
|
|
temperature: 0,
|
|
...(model ? { model } : undefined),
|
|
user: userIdHash,
|
|
...(this._maxTokens ? {
|
|
max_tokens: this._maxTokens,
|
|
} : undefined),
|
|
}),
|
|
},
|
|
);
|
|
const resultText = await apiResponse.text();
|
|
const result = JSON.parse(resultText);
|
|
const errorCode = result.error?.code;
|
|
if (errorCode === "context_length_exceeded" || result.choices?.[0].finish_reason === "length") {
|
|
if (!longerContext && this._longerContextModel) {
|
|
log.info("Switching to longer context model...");
|
|
throw new SwitchToLongerContext();
|
|
} else if (messages.length <= 2) {
|
|
throw new TokensExceededFirstMessage();
|
|
} else {
|
|
throw new TokensExceededLaterMessage();
|
|
}
|
|
}
|
|
if (errorCode === "insufficient_quota") {
|
|
log.error("OpenAI billing quota exceeded!!!");
|
|
throw new QuotaExceededError();
|
|
}
|
|
if (apiResponse.status !== 200) {
|
|
throw new Error(`OpenAI API returned status ${apiResponse.status}: ${resultText}`);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
private async _fetchCompletionWithRetries(
|
|
messages: AssistanceMessage[], userIdHash: string, longerContext: boolean
|
|
): Promise<any> {
|
|
const maxAttempts = 3;
|
|
for (let attempt = 1; ; attempt++) {
|
|
try {
|
|
return await this._fetchCompletion(messages, userIdHash, longerContext);
|
|
} catch (e) {
|
|
if (e instanceof SwitchToLongerContext) {
|
|
return await this._fetchCompletionWithRetries(messages, userIdHash, true);
|
|
} else if (e instanceof NonRetryableError) {
|
|
throw e;
|
|
} else if (attempt === maxAttempts) {
|
|
throw new RetryableError(e.toString());
|
|
}
|
|
log.warn(`Waiting and then retrying after error: ${e}`);
|
|
await delay(DEPS.delayTime);
|
|
}
|
|
}
|
|
}
|
|
|
|
private async _getCompletion(messages: AssistanceMessage[], userIdHash: string) {
|
|
const result = await this._fetchCompletionWithRetries(messages, userIdHash, false);
|
|
const {message} = result.choices[0];
|
|
messages.push(message);
|
|
return message.content;
|
|
}
|
|
}
|
|
|
|
export class HuggingFaceAssistant implements Assistant {
|
|
private _apiKey: string;
|
|
private _completionUrl: string;
|
|
|
|
public constructor() {
|
|
const apiKey = process.env.HUGGINGFACE_API_KEY;
|
|
if (!apiKey) {
|
|
throw new Error('HUGGINGFACE_API_KEY not set');
|
|
}
|
|
this._apiKey = apiKey;
|
|
// COMPLETION_MODEL values I've tried:
|
|
// - codeparrot/codeparrot
|
|
// - NinedayWang/PolyCoder-2.7B
|
|
// - NovelAI/genji-python-6B
|
|
let completionUrl = process.env.COMPLETION_URL;
|
|
if (!completionUrl) {
|
|
if (process.env.COMPLETION_MODEL) {
|
|
completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`;
|
|
} else {
|
|
completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
|
|
}
|
|
}
|
|
this._completionUrl = completionUrl;
|
|
|
|
}
|
|
|
|
public async apply(
|
|
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
|
|
if (request.state) {
|
|
throw new Error("HuggingFaceAssistant does not support state");
|
|
}
|
|
const prompt = await makeSchemaPromptV1(optSession, doc, request);
|
|
const response = await DEPS.fetch(
|
|
this._completionUrl,
|
|
{
|
|
method: "POST",
|
|
headers: {
|
|
"Authorization": `Bearer ${this._apiKey}`,
|
|
"Content-Type": "application/json",
|
|
},
|
|
body: JSON.stringify({
|
|
inputs: prompt,
|
|
parameters: {
|
|
return_full_text: false,
|
|
max_new_tokens: 50,
|
|
},
|
|
}),
|
|
},
|
|
);
|
|
if (response.status === 503) {
|
|
log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`);
|
|
await delay(10000);
|
|
}
|
|
if (response.status !== 200) {
|
|
const text = await response.text();
|
|
log.error(`HuggingFace API returned ${response.status}: ${text}`);
|
|
throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
|
|
}
|
|
const result = await response.json();
|
|
let completion = result[0].generated_text;
|
|
completion = completion.split('\n\n')[0];
|
|
return completionToResponse(doc, request, completion);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Test assistant that mimics ChatGPT and just returns the input.
|
|
*/
|
|
class EchoAssistant implements Assistant {
|
|
public async apply(sess: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
|
|
if (request.text === "ERROR") {
|
|
throw new Error(`ERROR`);
|
|
}
|
|
const messages = request.state?.messages || [];
|
|
if (messages.length === 0) {
|
|
messages.push({
|
|
role: 'system',
|
|
content: ''
|
|
});
|
|
}
|
|
messages.push({
|
|
role: 'user', content: request.text,
|
|
});
|
|
const completion = request.text;
|
|
const history = { messages };
|
|
history.messages.push({
|
|
role: 'assistant',
|
|
content: completion,
|
|
});
|
|
const response = await completionToResponse(doc, request, completion, completion);
|
|
response.state = history;
|
|
return response;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Instantiate an assistant, based on environment variables.
|
|
*/
|
|
export function getAssistant() {
|
|
if (process.env.OPENAI_API_KEY === 'test') {
|
|
return new EchoAssistant();
|
|
}
|
|
if (process.env.OPENAI_API_KEY || process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT) {
|
|
return new OpenAIAssistant();
|
|
}
|
|
throw new Error('Please set OPENAI_API_KEY or ASSISTANT_CHAT_COMPLETION_ENDPOINT');
|
|
}
|
|
|
|
/**
|
|
* Service a request for assistance.
|
|
*/
|
|
export async function sendForCompletion(
|
|
optSession: OptDocSession,
|
|
doc: AssistanceDoc,
|
|
request: AssistanceRequest,
|
|
): Promise<AssistanceResponse> {
|
|
const assistant = getAssistant();
|
|
return await assistant.apply(optSession, doc, request);
|
|
}
|
|
|
|
/**
|
|
* Returns a new Markdown string with the contents of its first multi-line code block
|
|
* replaced with `replaceValue`.
|
|
*/
|
|
export function replaceMarkdownCode(markdown: string, replaceValue: string) {
|
|
return markdown.replace(/```\w*\n(.*)```/s, '```python\n' + replaceValue + '\n```');
|
|
}
|
|
|
|
async function makeSchemaPromptV1(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest) {
|
|
if (request.context.type !== 'formula') {
|
|
throw new Error('makeSchemaPromptV1 only works for formulas');
|
|
}
|
|
return doc.assistanceSchemaPromptV1(session, {
|
|
tableId: request.context.tableId,
|
|
colId: request.context.colId,
|
|
docString: request.text,
|
|
});
|
|
}
|
|
|
|
async function completionToResponse(
|
|
doc: AssistanceDoc,
|
|
request: AssistanceRequest,
|
|
completion: string,
|
|
reply?: string
|
|
): Promise<AssistanceResponse> {
|
|
if (request.context.type !== 'formula') {
|
|
throw new Error('completionToResponse only works for formulas');
|
|
}
|
|
const suggestedFormula = await doc.assistanceFormulaTweak(completion) || undefined;
|
|
// Suggest an action only if the completion is non-empty (that is,
|
|
// it actually looked like code).
|
|
const suggestedActions: DocAction[] = suggestedFormula ? [[
|
|
"ModifyColumn",
|
|
request.context.tableId,
|
|
request.context.colId, {
|
|
formula: suggestedFormula,
|
|
}
|
|
]] : [];
|
|
return {
|
|
suggestedActions,
|
|
suggestedFormula,
|
|
reply,
|
|
};
|
|
}
|
|
|
|
function getUserHash(session: OptDocSession): string {
|
|
const user = getDocSessionUser(session);
|
|
// Make it a bit harder to guess the user ID.
|
|
const salt = "7a8sb6987asdb678asd687sad6boas7f8b6aso7fd";
|
|
const hashSource = `${user?.id} ${user?.ref} ${salt}`;
|
|
const hash = createHash('sha256').update(hashSource).digest('base64');
|
|
// So that if we get feedback about a user ID hash, we can
|
|
// search for the hash in the logs to find the original user ID.
|
|
log.rawInfo("getUserHash", {...getLogMetaFromDocSession(session), userRef: user?.ref, hash});
|
|
return hash;
|
|
}
|