From 0be858c19d79f8b41a16c8700bd69911d9ff5cbb Mon Sep 17 00:00:00 2001 From: Paul Fitzpatrick Date: Fri, 18 Aug 2023 16:14:42 -0400 Subject: [PATCH] allow AI Assistance to run against any chat-completion-style endpoint (#630) This adds an ASSISTANT_CHAT_COMPLETION_ENDPOINT which can be used to enable AI Assistance instead of an OpenAI API key. The assistant then works against compatible endpoints, in the mechanical sense. Quality of course will depend on the model. I found some tweaks to the prompt that work well both for Llama-2 and for OpenAI's models, but I'm not including them here because they would conflict with some prompt changes that are already in the works. Co-authored-by: Alex Hall --- README.md | 16 +++++- app/client/models/features.ts | 4 ++ app/client/widgets/FormulaAssistant.ts | 6 +- app/common/gristUrls.ts | 4 ++ app/server/lib/Assistance.ts | 65 ++++++++++++++++------ app/server/lib/sendAppPage.ts | 3 +- test/formula-dataset/runCompletion_impl.ts | 6 +- 7 files changed, 81 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 7bbe7ff0..25441eee 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ Here are some specific feature highlights of Grist: - On OSX, you can use native sandboxing. - On any OS, including Windows, you can use a wasm-based sandbox. * Translated to many languages. - * Support for an AI Formula Assistant (using OpenAI gpt-3.5-turbo). + * Support for an AI Formula Assistant (using OpenAI gpt-3.5-turbo or comparable models). * `F1` key brings up some quick help. This used to go without saying. In general Grist has good keyboard support. * We post progress on [𝕏 or Twitter or whatever](https://twitter.com/getgrist). @@ -302,7 +302,19 @@ PORT | port number to listen on for Grist server REDIS_URL | optional redis server for browser sessions and db query caching GRIST_SNAPSHOT_TIME_CAP | optional. Define the caps for tracking buckets. Usage: {"hour": 25, "day": 32, "isoWeek": 12, "month": 96, "year": 1000} GRIST_SNAPSHOT_KEEP | optional. Number of recent snapshots to retain unconditionally for a document, regardless of when they were made -OPENAI_API_KEY | optional. Used for the AI formula assistant. Sign up for an account on OpenAI and then generate a secret key [here](https://platform.openai.com/account/api-keys). + +AI Formula Assistant related variables (all optional): + +Variable | Purpose +-------- | ------- +ASSISTANT_API_KEY | optional. An API key to pass when making requests to an external AI conversational endpoint. +ASSISTANT_CHAT_COMPLETION_ENDPOINT | optional. A chat-completion style endpoint to call. Not needed if OpenAI is being used. +ASSISTANT_MODEL | optional. If set, this string is passed along in calls to the AI conversational endpoint. +ASSISTANT_LONGER_CONTEXT_MODEL | optional. If set, requests that fail because of a context length limitation will be retried with this model set. +OPENAI_API_KEY | optional. Synonym for ASSISTANT_API_KEY that assumes an OpenAI endpoint is being used. Sign up for an account on OpenAI and then generate a secret key [here](https://platform.openai.com/account/api-keys). + +At the time of writing, the AI Assistant is known to function against OpenAI chat completion endpoints for gpt-3.5-turbo and gpt-4. +It can also function against the chat completion endpoint provided by llama-cpp-python. Sandbox related variables: diff --git a/app/client/models/features.ts b/app/client/models/features.ts index aa883e52..d19a8d28 100644 --- a/app/client/models/features.ts +++ b/app/client/models/features.ts @@ -20,3 +20,7 @@ export function COMMENTS(): Observable { export function HAS_FORMULA_ASSISTANT() { return Boolean(getGristConfig().featureFormulaAssistant); } + +export function WHICH_FORMULA_ASSISTANT() { + return getGristConfig().assistantService; +} diff --git a/app/client/widgets/FormulaAssistant.ts b/app/client/widgets/FormulaAssistant.ts index aaeae3be..8fc9c827 100644 --- a/app/client/widgets/FormulaAssistant.ts +++ b/app/client/widgets/FormulaAssistant.ts @@ -6,7 +6,7 @@ import {movable} from 'app/client/lib/popupUtils'; import {logTelemetryEvent} from 'app/client/lib/telemetry'; import {ColumnRec, ViewFieldRec} from 'app/client/models/DocModel'; import {ChatMessage} from 'app/client/models/entities/ColumnRec'; -import {HAS_FORMULA_ASSISTANT} from 'app/client/models/features'; +import {HAS_FORMULA_ASSISTANT, WHICH_FORMULA_ASSISTANT} from 'app/client/models/features'; import {getLoginOrSignupUrl, urlState} from 'app/client/models/gristUrlState'; import {buildHighlightedCode} from 'app/client/ui/CodeHighlight'; import {autoGrow} from 'app/client/ui/forms'; @@ -879,7 +879,7 @@ class ChatHistory extends Disposable { '"Please calculate the total invoice amount."' ), ), - cssAiMessageBullet( + (WHICH_FORMULA_ASSISTANT() === 'OpenAI') ? cssAiMessageBullet( cssTickIcon('Tick'), dom('div', t( @@ -891,7 +891,7 @@ class ChatHistory extends Disposable { } ), ), - ), + ) : null, ), cssAiMessageParagraph( t( diff --git a/app/common/gristUrls.ts b/app/common/gristUrls.ts index e7541322..c0202e89 100644 --- a/app/common/gristUrls.ts +++ b/app/common/gristUrls.ts @@ -666,6 +666,10 @@ export interface GristLoadConfig { // TODO: remove once released. featureFormulaAssistant?: boolean; + // Used to determine which disclosure links should be provided to user of + // formula assistance. + assistantService?: 'OpenAI' | undefined; + // Email address of the support user. supportEmail?: string; diff --git a/app/server/lib/Assistance.ts b/app/server/lib/Assistance.ts index 0ece1d65..943eb1ea 100644 --- a/app/server/lib/Assistance.ts +++ b/app/server/lib/Assistance.ts @@ -117,23 +117,54 @@ class RetryableError extends Error { } /** - * A flavor of assistant for use with the OpenAI API. + * A flavor of assistant for use with the OpenAI chat completion endpoint + * and tools with a compatible endpoint (e.g. llama-cpp-python). * Tested primarily with gpt-3.5-turbo. + * + * Uses the ASSISTANT_CHAT_COMPLETION_ENDPOINT endpoint if set, else + * an OpenAI endpoint. Passes ASSISTANT_API_KEY or OPENAI_API_KEY in + * a header if set. An api key is required for the default OpenAI + * endpoint. + * + * If a model string is set in ASSISTANT_MODEL, this will be passed + * along. For the default OpenAI endpoint, a gpt-3.5-turbo variant + * will be set by default. + * + * If a request fails because of context length limitation, and the + * default OpenAI endpoint is in use, the request will be retried + * with ASSISTANT_LONGER_CONTEXT_MODEL (another gpt-3.5 + * variant by default). Set this variable to "" if this behavior is + * not desired for the default OpenAI endpoint. If a custom endpoint was + * provided, this behavior will only happen if + * ASSISTANT_LONGER_CONTEXT_MODEL is explicitly set. + * + * An optional ASSISTANT_MAX_TOKENS can be specified. */ export class OpenAIAssistant implements Assistant { public static DEFAULT_MODEL = "gpt-3.5-turbo-0613"; - public static LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613"; + public static DEFAULT_LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613"; - private _apiKey: string; + private _apiKey?: string; + private _model?: string; + private _longerContextModel?: string; private _endpoint: string; + private _maxTokens = process.env.ASSISTANT_MAX_TOKENS ? + parseInt(process.env.ASSISTANT_MAX_TOKENS, 10) : undefined; public constructor() { - const apiKey = process.env.OPENAI_API_KEY; - if (!apiKey) { - throw new Error('OPENAI_API_KEY not set'); + const apiKey = process.env.ASSISTANT_API_KEY || process.env.OPENAI_API_KEY; + const endpoint = process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT; + if (!apiKey && !endpoint) { + throw new Error('Please set either OPENAI_API_KEY or ASSISTANT_CHAT_COMPLETION_ENDPOINT'); } this._apiKey = apiKey; - this._endpoint = `https://api.openai.com/v1/chat/completions`; + this._model = process.env.ASSISTANT_MODEL; + this._longerContextModel = process.env.ASSISTANT_LONGER_CONTEXT_MODEL; + if (!endpoint) { + this._model = this._model ?? OpenAIAssistant.DEFAULT_MODEL; + this._longerContextModel = this._longerContextModel ?? OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL; + } + this._endpoint = endpoint || `https://api.openai.com/v1/chat/completions`; } public async apply( @@ -222,19 +253,25 @@ export class OpenAIAssistant implements Assistant { } private async _fetchCompletion(messages: AssistanceMessage[], userIdHash: string, longerContext: boolean) { + const model = longerContext ? this._longerContextModel : this._model; const apiResponse = await DEPS.fetch( this._endpoint, { method: "POST", headers: { - "Authorization": `Bearer ${this._apiKey}`, + ...(this._apiKey ? { + "Authorization": `Bearer ${this._apiKey}`, + } : undefined), "Content-Type": "application/json", }, body: JSON.stringify({ messages, temperature: 0, - model: longerContext ? OpenAIAssistant.LONGER_CONTEXT_MODEL : OpenAIAssistant.DEFAULT_MODEL, + ...(model ? { model } : undefined), user: userIdHash, + ...(this._maxTokens ? { + max_tokens: this._maxTokens, + } : undefined), }), }, ); @@ -242,7 +279,7 @@ export class OpenAIAssistant implements Assistant { const result = JSON.parse(resultText); const errorCode = result.error?.code; if (errorCode === "context_length_exceeded" || result.choices?.[0].finish_reason === "length") { - if (!longerContext) { + if (!longerContext && this._longerContextModel) { log.info("Switching to longer context model..."); throw new SwitchToLongerContext(); } else if (messages.length <= 2) { @@ -392,14 +429,10 @@ export function getAssistant() { if (process.env.OPENAI_API_KEY === 'test') { return new EchoAssistant(); } - if (process.env.OPENAI_API_KEY) { + if (process.env.OPENAI_API_KEY || process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT) { return new OpenAIAssistant(); } - // Maintaining this is too much of a burden for now. - // if (process.env.HUGGINGFACE_API_KEY) { - // return new HuggingFaceAssistant(); - // } - throw new Error('Please set OPENAI_API_KEY'); + throw new Error('Please set OPENAI_API_KEY or ASSISTANT_CHAT_COMPLETION_ENDPOINT'); } /** diff --git a/app/server/lib/sendAppPage.ts b/app/server/lib/sendAppPage.ts index 03cbf7b0..ca5b31ff 100644 --- a/app/server/lib/sendAppPage.ts +++ b/app/server/lib/sendAppPage.ts @@ -74,7 +74,8 @@ export function makeGristConfig(options: MakeGristConfigOptons): GristLoadConfig supportedLngs: readLoadedLngs(req?.i18n), namespaces: readLoadedNamespaces(req?.i18n), featureComments: isAffirmative(process.env.COMMENTS), - featureFormulaAssistant: Boolean(process.env.OPENAI_API_KEY), + featureFormulaAssistant: Boolean(process.env.OPENAI_API_KEY || process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT), + assistantService: process.env.OPENAI_API_KEY ? 'OpenAI' : undefined, supportEmail: SUPPORT_EMAIL, userLocale: (req as RequestWithLogin | undefined)?.user?.options?.locale, telemetry: server?.getTelemetry().getTelemetryConfig(), diff --git a/test/formula-dataset/runCompletion_impl.ts b/test/formula-dataset/runCompletion_impl.ts index 2eb92594..5582faf2 100644 --- a/test/formula-dataset/runCompletion_impl.ts +++ b/test/formula-dataset/runCompletion_impl.ts @@ -15,6 +15,9 @@ * * USAGE: * OPENAI_API_KEY= node core/test/formula-dataset/runCompletion.js + * or + * ASSISTANT_CHAT_COMPLETION_ENDPOINT=http.... node core/test/formula-dataset/runCompletion.js + * (see Assistance.ts for more options). * * # WITH VERBOSE: * VERBOSE=1 OPENAI_API_KEY= node core/test/formula-dataset/runCompletion.js @@ -68,7 +71,8 @@ const SIMULATE_CONVERSATION = true; const FOLLOWUP_EVALUATE = false; export async function runCompletion() { - ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600; + // This could take a long time for LLMs running on underpowered hardware >:) + ActiveDocDeps.ACTIVEDOC_TIMEOUT = 500000; // if template directory not exists, make it if (!fs.existsSync(path.join(PATH_TO_DOC))) {