allow AI Assistance to run against any chat-completion-style endpoint (#630)

This adds an ASSISTANT_CHAT_COMPLETION_ENDPOINT which can be used
to enable AI Assistance instead of an OpenAI API key. The assistant
then works against compatible endpoints, in the mechanical sense.
Quality of course will depend on the model. I found some tweaks
to the prompt that work well both for Llama-2 and for OpenAI's models,
but I'm not including them here because they would conflict with some
prompt changes that are already in the works.

Co-authored-by: Alex Hall <alex.mojaki@gmail.com>
This commit is contained in:
Paul Fitzpatrick 2023-08-18 16:14:42 -04:00 committed by GitHub
parent 5705c37d02
commit 0be858c19d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 81 additions and 23 deletions

View File

@ -81,7 +81,7 @@ Here are some specific feature highlights of Grist:
- On OSX, you can use native sandboxing. - On OSX, you can use native sandboxing.
- On any OS, including Windows, you can use a wasm-based sandbox. - On any OS, including Windows, you can use a wasm-based sandbox.
* Translated to many languages. * Translated to many languages.
* Support for an AI Formula Assistant (using OpenAI gpt-3.5-turbo). * Support for an AI Formula Assistant (using OpenAI gpt-3.5-turbo or comparable models).
* `F1` key brings up some quick help. This used to go without saying. In general Grist has good keyboard support. * `F1` key brings up some quick help. This used to go without saying. In general Grist has good keyboard support.
* We post progress on [𝕏 or Twitter or whatever](https://twitter.com/getgrist). * We post progress on [𝕏 or Twitter or whatever](https://twitter.com/getgrist).
@ -302,7 +302,19 @@ PORT | port number to listen on for Grist server
REDIS_URL | optional redis server for browser sessions and db query caching REDIS_URL | optional redis server for browser sessions and db query caching
GRIST_SNAPSHOT_TIME_CAP | optional. Define the caps for tracking buckets. Usage: {"hour": 25, "day": 32, "isoWeek": 12, "month": 96, "year": 1000} GRIST_SNAPSHOT_TIME_CAP | optional. Define the caps for tracking buckets. Usage: {"hour": 25, "day": 32, "isoWeek": 12, "month": 96, "year": 1000}
GRIST_SNAPSHOT_KEEP | optional. Number of recent snapshots to retain unconditionally for a document, regardless of when they were made GRIST_SNAPSHOT_KEEP | optional. Number of recent snapshots to retain unconditionally for a document, regardless of when they were made
OPENAI_API_KEY | optional. Used for the AI formula assistant. Sign up for an account on OpenAI and then generate a secret key [here](https://platform.openai.com/account/api-keys).
AI Formula Assistant related variables (all optional):
Variable | Purpose
-------- | -------
ASSISTANT_API_KEY | optional. An API key to pass when making requests to an external AI conversational endpoint.
ASSISTANT_CHAT_COMPLETION_ENDPOINT | optional. A chat-completion style endpoint to call. Not needed if OpenAI is being used.
ASSISTANT_MODEL | optional. If set, this string is passed along in calls to the AI conversational endpoint.
ASSISTANT_LONGER_CONTEXT_MODEL | optional. If set, requests that fail because of a context length limitation will be retried with this model set.
OPENAI_API_KEY | optional. Synonym for ASSISTANT_API_KEY that assumes an OpenAI endpoint is being used. Sign up for an account on OpenAI and then generate a secret key [here](https://platform.openai.com/account/api-keys).
At the time of writing, the AI Assistant is known to function against OpenAI chat completion endpoints for gpt-3.5-turbo and gpt-4.
It can also function against the chat completion endpoint provided by <a href="https://github.com/abetlen/llama-cpp-python">llama-cpp-python</a>.
Sandbox related variables: Sandbox related variables:

View File

@ -20,3 +20,7 @@ export function COMMENTS(): Observable<boolean> {
export function HAS_FORMULA_ASSISTANT() { export function HAS_FORMULA_ASSISTANT() {
return Boolean(getGristConfig().featureFormulaAssistant); return Boolean(getGristConfig().featureFormulaAssistant);
} }
export function WHICH_FORMULA_ASSISTANT() {
return getGristConfig().assistantService;
}

View File

@ -6,7 +6,7 @@ import {movable} from 'app/client/lib/popupUtils';
import {logTelemetryEvent} from 'app/client/lib/telemetry'; import {logTelemetryEvent} from 'app/client/lib/telemetry';
import {ColumnRec, ViewFieldRec} from 'app/client/models/DocModel'; import {ColumnRec, ViewFieldRec} from 'app/client/models/DocModel';
import {ChatMessage} from 'app/client/models/entities/ColumnRec'; import {ChatMessage} from 'app/client/models/entities/ColumnRec';
import {HAS_FORMULA_ASSISTANT} from 'app/client/models/features'; import {HAS_FORMULA_ASSISTANT, WHICH_FORMULA_ASSISTANT} from 'app/client/models/features';
import {getLoginOrSignupUrl, urlState} from 'app/client/models/gristUrlState'; import {getLoginOrSignupUrl, urlState} from 'app/client/models/gristUrlState';
import {buildHighlightedCode} from 'app/client/ui/CodeHighlight'; import {buildHighlightedCode} from 'app/client/ui/CodeHighlight';
import {autoGrow} from 'app/client/ui/forms'; import {autoGrow} from 'app/client/ui/forms';
@ -879,7 +879,7 @@ class ChatHistory extends Disposable {
'"Please calculate the total invoice amount."' '"Please calculate the total invoice amount."'
), ),
), ),
cssAiMessageBullet( (WHICH_FORMULA_ASSISTANT() === 'OpenAI') ? cssAiMessageBullet(
cssTickIcon('Tick'), cssTickIcon('Tick'),
dom('div', dom('div',
t( t(
@ -891,7 +891,7 @@ class ChatHistory extends Disposable {
} }
), ),
), ),
), ) : null,
), ),
cssAiMessageParagraph( cssAiMessageParagraph(
t( t(

View File

@ -666,6 +666,10 @@ export interface GristLoadConfig {
// TODO: remove once released. // TODO: remove once released.
featureFormulaAssistant?: boolean; featureFormulaAssistant?: boolean;
// Used to determine which disclosure links should be provided to user of
// formula assistance.
assistantService?: 'OpenAI' | undefined;
// Email address of the support user. // Email address of the support user.
supportEmail?: string; supportEmail?: string;

View File

@ -117,23 +117,54 @@ class RetryableError extends Error {
} }
/** /**
* A flavor of assistant for use with the OpenAI API. * A flavor of assistant for use with the OpenAI chat completion endpoint
* and tools with a compatible endpoint (e.g. llama-cpp-python).
* Tested primarily with gpt-3.5-turbo. * Tested primarily with gpt-3.5-turbo.
*
* Uses the ASSISTANT_CHAT_COMPLETION_ENDPOINT endpoint if set, else
* an OpenAI endpoint. Passes ASSISTANT_API_KEY or OPENAI_API_KEY in
* a header if set. An api key is required for the default OpenAI
* endpoint.
*
* If a model string is set in ASSISTANT_MODEL, this will be passed
* along. For the default OpenAI endpoint, a gpt-3.5-turbo variant
* will be set by default.
*
* If a request fails because of context length limitation, and the
* default OpenAI endpoint is in use, the request will be retried
* with ASSISTANT_LONGER_CONTEXT_MODEL (another gpt-3.5
* variant by default). Set this variable to "" if this behavior is
* not desired for the default OpenAI endpoint. If a custom endpoint was
* provided, this behavior will only happen if
* ASSISTANT_LONGER_CONTEXT_MODEL is explicitly set.
*
* An optional ASSISTANT_MAX_TOKENS can be specified.
*/ */
export class OpenAIAssistant implements Assistant { export class OpenAIAssistant implements Assistant {
public static DEFAULT_MODEL = "gpt-3.5-turbo-0613"; public static DEFAULT_MODEL = "gpt-3.5-turbo-0613";
public static LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613"; public static DEFAULT_LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613";
private _apiKey: string; private _apiKey?: string;
private _model?: string;
private _longerContextModel?: string;
private _endpoint: string; private _endpoint: string;
private _maxTokens = process.env.ASSISTANT_MAX_TOKENS ?
parseInt(process.env.ASSISTANT_MAX_TOKENS, 10) : undefined;
public constructor() { public constructor() {
const apiKey = process.env.OPENAI_API_KEY; const apiKey = process.env.ASSISTANT_API_KEY || process.env.OPENAI_API_KEY;
if (!apiKey) { const endpoint = process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT;
throw new Error('OPENAI_API_KEY not set'); if (!apiKey && !endpoint) {
throw new Error('Please set either OPENAI_API_KEY or ASSISTANT_CHAT_COMPLETION_ENDPOINT');
} }
this._apiKey = apiKey; this._apiKey = apiKey;
this._endpoint = `https://api.openai.com/v1/chat/completions`; this._model = process.env.ASSISTANT_MODEL;
this._longerContextModel = process.env.ASSISTANT_LONGER_CONTEXT_MODEL;
if (!endpoint) {
this._model = this._model ?? OpenAIAssistant.DEFAULT_MODEL;
this._longerContextModel = this._longerContextModel ?? OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL;
}
this._endpoint = endpoint || `https://api.openai.com/v1/chat/completions`;
} }
public async apply( public async apply(
@ -222,19 +253,25 @@ export class OpenAIAssistant implements Assistant {
} }
private async _fetchCompletion(messages: AssistanceMessage[], userIdHash: string, longerContext: boolean) { private async _fetchCompletion(messages: AssistanceMessage[], userIdHash: string, longerContext: boolean) {
const model = longerContext ? this._longerContextModel : this._model;
const apiResponse = await DEPS.fetch( const apiResponse = await DEPS.fetch(
this._endpoint, this._endpoint,
{ {
method: "POST", method: "POST",
headers: { headers: {
"Authorization": `Bearer ${this._apiKey}`, ...(this._apiKey ? {
"Authorization": `Bearer ${this._apiKey}`,
} : undefined),
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: JSON.stringify({ body: JSON.stringify({
messages, messages,
temperature: 0, temperature: 0,
model: longerContext ? OpenAIAssistant.LONGER_CONTEXT_MODEL : OpenAIAssistant.DEFAULT_MODEL, ...(model ? { model } : undefined),
user: userIdHash, user: userIdHash,
...(this._maxTokens ? {
max_tokens: this._maxTokens,
} : undefined),
}), }),
}, },
); );
@ -242,7 +279,7 @@ export class OpenAIAssistant implements Assistant {
const result = JSON.parse(resultText); const result = JSON.parse(resultText);
const errorCode = result.error?.code; const errorCode = result.error?.code;
if (errorCode === "context_length_exceeded" || result.choices?.[0].finish_reason === "length") { if (errorCode === "context_length_exceeded" || result.choices?.[0].finish_reason === "length") {
if (!longerContext) { if (!longerContext && this._longerContextModel) {
log.info("Switching to longer context model..."); log.info("Switching to longer context model...");
throw new SwitchToLongerContext(); throw new SwitchToLongerContext();
} else if (messages.length <= 2) { } else if (messages.length <= 2) {
@ -392,14 +429,10 @@ export function getAssistant() {
if (process.env.OPENAI_API_KEY === 'test') { if (process.env.OPENAI_API_KEY === 'test') {
return new EchoAssistant(); return new EchoAssistant();
} }
if (process.env.OPENAI_API_KEY) { if (process.env.OPENAI_API_KEY || process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT) {
return new OpenAIAssistant(); return new OpenAIAssistant();
} }
// Maintaining this is too much of a burden for now. throw new Error('Please set OPENAI_API_KEY or ASSISTANT_CHAT_COMPLETION_ENDPOINT');
// if (process.env.HUGGINGFACE_API_KEY) {
// return new HuggingFaceAssistant();
// }
throw new Error('Please set OPENAI_API_KEY');
} }
/** /**

View File

@ -74,7 +74,8 @@ export function makeGristConfig(options: MakeGristConfigOptons): GristLoadConfig
supportedLngs: readLoadedLngs(req?.i18n), supportedLngs: readLoadedLngs(req?.i18n),
namespaces: readLoadedNamespaces(req?.i18n), namespaces: readLoadedNamespaces(req?.i18n),
featureComments: isAffirmative(process.env.COMMENTS), featureComments: isAffirmative(process.env.COMMENTS),
featureFormulaAssistant: Boolean(process.env.OPENAI_API_KEY), featureFormulaAssistant: Boolean(process.env.OPENAI_API_KEY || process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT),
assistantService: process.env.OPENAI_API_KEY ? 'OpenAI' : undefined,
supportEmail: SUPPORT_EMAIL, supportEmail: SUPPORT_EMAIL,
userLocale: (req as RequestWithLogin | undefined)?.user?.options?.locale, userLocale: (req as RequestWithLogin | undefined)?.user?.options?.locale,
telemetry: server?.getTelemetry().getTelemetryConfig(), telemetry: server?.getTelemetry().getTelemetryConfig(),

View File

@ -15,6 +15,9 @@
* *
* USAGE: * USAGE:
* OPENAI_API_KEY=<my_openai_api_key> node core/test/formula-dataset/runCompletion.js * OPENAI_API_KEY=<my_openai_api_key> node core/test/formula-dataset/runCompletion.js
* or
* ASSISTANT_CHAT_COMPLETION_ENDPOINT=http.... node core/test/formula-dataset/runCompletion.js
* (see Assistance.ts for more options).
* *
* # WITH VERBOSE: * # WITH VERBOSE:
* VERBOSE=1 OPENAI_API_KEY=<my_openai_api_key> node core/test/formula-dataset/runCompletion.js * VERBOSE=1 OPENAI_API_KEY=<my_openai_api_key> node core/test/formula-dataset/runCompletion.js
@ -68,7 +71,8 @@ const SIMULATE_CONVERSATION = true;
const FOLLOWUP_EVALUATE = false; const FOLLOWUP_EVALUATE = false;
export async function runCompletion() { export async function runCompletion() {
ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600; // This could take a long time for LLMs running on underpowered hardware >:)
ActiveDocDeps.ACTIVEDOC_TIMEOUT = 500000;
// if template directory not exists, make it // if template directory not exists, make it
if (!fs.existsSync(path.join(PATH_TO_DOC))) { if (!fs.existsSync(path.join(PATH_TO_DOC))) {