gristlabs_grist-core/app/server/lib/Assistance.ts
Alex Hall 52469c5a7e (core) Improve parsing formula from completion
Summary:
The previous code for extracting a Python formula from the LLM completion involved some shaky string manipulation which this improves on.
Overall the 'test results' from `runCompletion` went from 37/47 to 45/47 for `gpt-3.5-turbo-0613`.

The biggest problem that motivated these changes was that it assumed that code was always inside a markdown code block
(i.e. triple backticks) and so if there was no block there was no code. But the completion often consists of *only* code
with no accompanying explanation or markdown. By parsing the completion in Python instead of JS,
we can easily check if the entire completion is valid Python syntax and accept it if it is.

I also noticed one failure resulting from the completion containing the full function (instead of just the body)
and necessary imports before that function instead of inside. The new parsing moves import inside.

Test Plan: Added a Python unit test

Reviewers: paulfitz

Reviewed By: paulfitz

Subscribers: paulfitz

Differential Revision: https://phab.getgrist.com/D3922
2023-06-16 13:38:20 +02:00

323 lines
10 KiB
TypeScript

/**
* Module with functions used for AI formula assistance.
*/
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {delay} from 'app/common/delay';
import {DocAction} from 'app/common/DocActions';
import log from 'app/server/lib/log';
import fetch from 'node-fetch';
export const DEPS = { fetch };
/**
* An assistant can help a user do things with their document,
* by interfacing with an external LLM endpoint.
*/
export interface Assistant {
apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
}
/**
* Document-related methods for use in the implementation of assistants.
* Somewhat ad-hoc currently.
*/
export interface AssistanceDoc {
/**
* Generate a particular prompt coded in the data engine for some reason.
* It makes python code for some tables, and starts a function body with
* the given docstring.
* Marked "V1" to suggest that it is a particular prompt and it would
* be great to try variants.
*/
assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string>;
/**
* Some tweaks to a formula after it has been generated.
*/
assistanceFormulaTweak(txt: string): Promise<string>;
}
export interface AssistanceSchemaPromptV1Context {
tableId: string,
colId: string,
docString: string,
}
/**
* A flavor of assistant for use with the OpenAI API.
* Tested primarily with text-davinci-002 and gpt-3.5-turbo.
*/
export class OpenAIAssistant implements Assistant {
private _apiKey: string;
private _model: string;
private _chatMode: boolean;
private _endpoint: string;
public constructor() {
const apiKey = process.env.OPENAI_API_KEY;
if (!apiKey) {
throw new Error('OPENAI_API_KEY not set');
}
this._apiKey = apiKey;
this._model = process.env.COMPLETION_MODEL || "text-davinci-002";
this._chatMode = this._model.includes('turbo');
this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`;
}
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
const messages = request.state?.messages || [];
const chatMode = this._chatMode;
if (chatMode) {
if (messages.length === 0) {
messages.push({
role: 'system',
content: 'The user gives you one or more Python classes, ' +
'with one last method that needs completing. Write the ' +
'method body as a single code block, ' +
'including the docstring the user gave. ' +
'Just give the Python code as a markdown block, ' +
'do not give any introduction, that will just be ' +
'awkward for the user when copying and pasting. ' +
'You are working with Grist, an environment very like ' +
'regular Python except `rec` (like record) is used ' +
'instead of `self`. ' +
'Include at least one `return` statement or the method ' +
'will fail, disappointing the user. ' +
'Your answer should be the body of a single method, ' +
'not a class, and should not include `dataclass` or ' +
'`class` since the user is counting on you to provide ' +
'a single method. Thanks!'
});
messages.push({
role: 'user', content: await makeSchemaPromptV1(doc, request),
});
} else {
if (request.regenerate) {
if (messages[messages.length - 1].role !== 'user') {
messages.pop();
}
}
messages.push({
role: 'user', content: request.text,
});
}
} else {
messages.length = 0;
messages.push({
role: 'user', content: await makeSchemaPromptV1(doc, request),
});
}
const apiResponse = await DEPS.fetch(
this._endpoint,
{
method: "POST",
headers: {
"Authorization": `Bearer ${this._apiKey}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
...(!this._chatMode ? {
prompt: messages[messages.length - 1].content,
} : { messages }),
max_tokens: 1500,
temperature: 0,
model: this._model,
stop: this._chatMode ? undefined : ["\n\n"],
}),
},
);
if (apiResponse.status !== 200) {
log.error(`OpenAI API returned ${apiResponse.status}: ${await apiResponse.text()}`);
throw new Error(`OpenAI API returned status ${apiResponse.status}`);
}
const result = await apiResponse.json();
const completion: string = String(chatMode ? result.choices[0].message.content : result.choices[0].text);
const history = { messages };
if (chatMode) {
history.messages.push(result.choices[0].message);
}
const response = await completionToResponse(doc, request, completion, completion);
if (chatMode) {
response.state = history;
}
return response;
}
}
export class HuggingFaceAssistant implements Assistant {
private _apiKey: string;
private _completionUrl: string;
public constructor() {
const apiKey = process.env.HUGGINGFACE_API_KEY;
if (!apiKey) {
throw new Error('HUGGINGFACE_API_KEY not set');
}
this._apiKey = apiKey;
// COMPLETION_MODEL values I've tried:
// - codeparrot/codeparrot
// - NinedayWang/PolyCoder-2.7B
// - NovelAI/genji-python-6B
let completionUrl = process.env.COMPLETION_URL;
if (!completionUrl) {
if (process.env.COMPLETION_MODEL) {
completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`;
} else {
completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
}
}
this._completionUrl = completionUrl;
}
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
if (request.state) {
throw new Error("HuggingFaceAssistant does not support state");
}
const prompt = await makeSchemaPromptV1(doc, request);
const response = await DEPS.fetch(
this._completionUrl,
{
method: "POST",
headers: {
"Authorization": `Bearer ${this._apiKey}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
inputs: prompt,
parameters: {
return_full_text: false,
max_new_tokens: 50,
},
}),
},
);
if (response.status === 503) {
log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`);
await delay(10000);
}
if (response.status !== 200) {
const text = await response.text();
log.error(`HuggingFace API returned ${response.status}: ${text}`);
throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
}
const result = await response.json();
let completion = result[0].generated_text;
completion = completion.split('\n\n')[0];
return completionToResponse(doc, request, completion);
}
}
/**
* Test assistant that mimics ChatGPT and just returns the input.
*/
export class EchoAssistant implements Assistant {
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
const messages = request.state?.messages || [];
if (messages.length === 0) {
messages.push({
role: 'system',
content: ''
});
messages.push({
role: 'user', content: request.text,
});
} else {
if (request.regenerate) {
if (messages[messages.length - 1].role !== 'user') {
messages.pop();
}
}
messages.push({
role: 'user', content: request.text,
});
}
const completion = request.text;
const history = { messages };
history.messages.push({
role: 'assistant',
content: completion,
});
const response = await completionToResponse(doc, request, completion, completion);
response.state = history;
return response;
}
}
/**
* Instantiate an assistant, based on environment variables.
*/
function getAssistant() {
if (process.env.OPENAI_API_KEY === 'test') {
return new EchoAssistant();
}
if (process.env.OPENAI_API_KEY) {
return new OpenAIAssistant();
}
if (process.env.HUGGINGFACE_API_KEY) {
return new HuggingFaceAssistant();
}
throw new Error('Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY');
}
/**
* Service a request for assistance, with a little retry logic
* since these endpoints can be a bit flakey.
*/
export async function sendForCompletion(doc: AssistanceDoc,
request: AssistanceRequest): Promise<AssistanceResponse> {
const assistant = getAssistant();
let retries: number = 0;
let response: AssistanceResponse|null = null;
while(retries++ < 3) {
try {
response = await assistant.apply(doc, request);
break;
} catch(e) {
log.error(`Completion error: ${e}`);
await delay(1000);
}
}
if (!response) {
throw new Error('Failed to get response from assistant');
}
return response;
}
async function makeSchemaPromptV1(doc: AssistanceDoc, request: AssistanceRequest) {
if (request.context.type !== 'formula') {
throw new Error('makeSchemaPromptV1 only works for formulas');
}
return doc.assistanceSchemaPromptV1({
tableId: request.context.tableId,
colId: request.context.colId,
docString: request.text,
});
}
async function completionToResponse(doc: AssistanceDoc, request: AssistanceRequest,
completion: string, reply?: string): Promise<AssistanceResponse> {
if (request.context.type !== 'formula') {
throw new Error('completionToResponse only works for formulas');
}
completion = await doc.assistanceFormulaTweak(completion);
// Suggest an action only if the completion is non-empty (that is,
// it actually looked like code).
const suggestedActions: DocAction[] = completion ? [[
"ModifyColumn",
request.context.tableId,
request.context.colId, {
formula: completion,
}
]] : [];
return {
suggestedActions,
reply,
};
}