mirror of
https://github.com/gristlabs/grist-core.git
synced 2026-03-02 04:09:24 +00:00
allow AI Assistance to run against any chat-completion-style endpoint (#630)
This adds an ASSISTANT_CHAT_COMPLETION_ENDPOINT which can be used to enable AI Assistance instead of an OpenAI API key. The assistant then works against compatible endpoints, in the mechanical sense. Quality of course will depend on the model. I found some tweaks to the prompt that work well both for Llama-2 and for OpenAI's models, but I'm not including them here because they would conflict with some prompt changes that are already in the works. Co-authored-by: Alex Hall <alex.mojaki@gmail.com>
This commit is contained in:
@@ -20,3 +20,7 @@ export function COMMENTS(): Observable<boolean> {
|
||||
export function HAS_FORMULA_ASSISTANT() {
|
||||
return Boolean(getGristConfig().featureFormulaAssistant);
|
||||
}
|
||||
|
||||
export function WHICH_FORMULA_ASSISTANT() {
|
||||
return getGristConfig().assistantService;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import {movable} from 'app/client/lib/popupUtils';
|
||||
import {logTelemetryEvent} from 'app/client/lib/telemetry';
|
||||
import {ColumnRec, ViewFieldRec} from 'app/client/models/DocModel';
|
||||
import {ChatMessage} from 'app/client/models/entities/ColumnRec';
|
||||
import {HAS_FORMULA_ASSISTANT} from 'app/client/models/features';
|
||||
import {HAS_FORMULA_ASSISTANT, WHICH_FORMULA_ASSISTANT} from 'app/client/models/features';
|
||||
import {getLoginOrSignupUrl, urlState} from 'app/client/models/gristUrlState';
|
||||
import {buildHighlightedCode} from 'app/client/ui/CodeHighlight';
|
||||
import {autoGrow} from 'app/client/ui/forms';
|
||||
@@ -879,7 +879,7 @@ class ChatHistory extends Disposable {
|
||||
'"Please calculate the total invoice amount."'
|
||||
),
|
||||
),
|
||||
cssAiMessageBullet(
|
||||
(WHICH_FORMULA_ASSISTANT() === 'OpenAI') ? cssAiMessageBullet(
|
||||
cssTickIcon('Tick'),
|
||||
dom('div',
|
||||
t(
|
||||
@@ -891,7 +891,7 @@ class ChatHistory extends Disposable {
|
||||
}
|
||||
),
|
||||
),
|
||||
),
|
||||
) : null,
|
||||
),
|
||||
cssAiMessageParagraph(
|
||||
t(
|
||||
|
||||
@@ -666,6 +666,10 @@ export interface GristLoadConfig {
|
||||
// TODO: remove once released.
|
||||
featureFormulaAssistant?: boolean;
|
||||
|
||||
// Used to determine which disclosure links should be provided to user of
|
||||
// formula assistance.
|
||||
assistantService?: 'OpenAI' | undefined;
|
||||
|
||||
// Email address of the support user.
|
||||
supportEmail?: string;
|
||||
|
||||
|
||||
@@ -117,23 +117,54 @@ class RetryableError extends Error {
|
||||
}
|
||||
|
||||
/**
|
||||
* A flavor of assistant for use with the OpenAI API.
|
||||
* A flavor of assistant for use with the OpenAI chat completion endpoint
|
||||
* and tools with a compatible endpoint (e.g. llama-cpp-python).
|
||||
* Tested primarily with gpt-3.5-turbo.
|
||||
*
|
||||
* Uses the ASSISTANT_CHAT_COMPLETION_ENDPOINT endpoint if set, else
|
||||
* an OpenAI endpoint. Passes ASSISTANT_API_KEY or OPENAI_API_KEY in
|
||||
* a header if set. An api key is required for the default OpenAI
|
||||
* endpoint.
|
||||
*
|
||||
* If a model string is set in ASSISTANT_MODEL, this will be passed
|
||||
* along. For the default OpenAI endpoint, a gpt-3.5-turbo variant
|
||||
* will be set by default.
|
||||
*
|
||||
* If a request fails because of context length limitation, and the
|
||||
* default OpenAI endpoint is in use, the request will be retried
|
||||
* with ASSISTANT_LONGER_CONTEXT_MODEL (another gpt-3.5
|
||||
* variant by default). Set this variable to "" if this behavior is
|
||||
* not desired for the default OpenAI endpoint. If a custom endpoint was
|
||||
* provided, this behavior will only happen if
|
||||
* ASSISTANT_LONGER_CONTEXT_MODEL is explicitly set.
|
||||
*
|
||||
* An optional ASSISTANT_MAX_TOKENS can be specified.
|
||||
*/
|
||||
export class OpenAIAssistant implements Assistant {
|
||||
public static DEFAULT_MODEL = "gpt-3.5-turbo-0613";
|
||||
public static LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613";
|
||||
public static DEFAULT_LONGER_CONTEXT_MODEL = "gpt-3.5-turbo-16k-0613";
|
||||
|
||||
private _apiKey: string;
|
||||
private _apiKey?: string;
|
||||
private _model?: string;
|
||||
private _longerContextModel?: string;
|
||||
private _endpoint: string;
|
||||
private _maxTokens = process.env.ASSISTANT_MAX_TOKENS ?
|
||||
parseInt(process.env.ASSISTANT_MAX_TOKENS, 10) : undefined;
|
||||
|
||||
public constructor() {
|
||||
const apiKey = process.env.OPENAI_API_KEY;
|
||||
if (!apiKey) {
|
||||
throw new Error('OPENAI_API_KEY not set');
|
||||
const apiKey = process.env.ASSISTANT_API_KEY || process.env.OPENAI_API_KEY;
|
||||
const endpoint = process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT;
|
||||
if (!apiKey && !endpoint) {
|
||||
throw new Error('Please set either OPENAI_API_KEY or ASSISTANT_CHAT_COMPLETION_ENDPOINT');
|
||||
}
|
||||
this._apiKey = apiKey;
|
||||
this._endpoint = `https://api.openai.com/v1/chat/completions`;
|
||||
this._model = process.env.ASSISTANT_MODEL;
|
||||
this._longerContextModel = process.env.ASSISTANT_LONGER_CONTEXT_MODEL;
|
||||
if (!endpoint) {
|
||||
this._model = this._model ?? OpenAIAssistant.DEFAULT_MODEL;
|
||||
this._longerContextModel = this._longerContextModel ?? OpenAIAssistant.DEFAULT_LONGER_CONTEXT_MODEL;
|
||||
}
|
||||
this._endpoint = endpoint || `https://api.openai.com/v1/chat/completions`;
|
||||
}
|
||||
|
||||
public async apply(
|
||||
@@ -222,19 +253,25 @@ export class OpenAIAssistant implements Assistant {
|
||||
}
|
||||
|
||||
private async _fetchCompletion(messages: AssistanceMessage[], userIdHash: string, longerContext: boolean) {
|
||||
const model = longerContext ? this._longerContextModel : this._model;
|
||||
const apiResponse = await DEPS.fetch(
|
||||
this._endpoint,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${this._apiKey}`,
|
||||
...(this._apiKey ? {
|
||||
"Authorization": `Bearer ${this._apiKey}`,
|
||||
} : undefined),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
messages,
|
||||
temperature: 0,
|
||||
model: longerContext ? OpenAIAssistant.LONGER_CONTEXT_MODEL : OpenAIAssistant.DEFAULT_MODEL,
|
||||
...(model ? { model } : undefined),
|
||||
user: userIdHash,
|
||||
...(this._maxTokens ? {
|
||||
max_tokens: this._maxTokens,
|
||||
} : undefined),
|
||||
}),
|
||||
},
|
||||
);
|
||||
@@ -242,7 +279,7 @@ export class OpenAIAssistant implements Assistant {
|
||||
const result = JSON.parse(resultText);
|
||||
const errorCode = result.error?.code;
|
||||
if (errorCode === "context_length_exceeded" || result.choices?.[0].finish_reason === "length") {
|
||||
if (!longerContext) {
|
||||
if (!longerContext && this._longerContextModel) {
|
||||
log.info("Switching to longer context model...");
|
||||
throw new SwitchToLongerContext();
|
||||
} else if (messages.length <= 2) {
|
||||
@@ -392,14 +429,10 @@ export function getAssistant() {
|
||||
if (process.env.OPENAI_API_KEY === 'test') {
|
||||
return new EchoAssistant();
|
||||
}
|
||||
if (process.env.OPENAI_API_KEY) {
|
||||
if (process.env.OPENAI_API_KEY || process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT) {
|
||||
return new OpenAIAssistant();
|
||||
}
|
||||
// Maintaining this is too much of a burden for now.
|
||||
// if (process.env.HUGGINGFACE_API_KEY) {
|
||||
// return new HuggingFaceAssistant();
|
||||
// }
|
||||
throw new Error('Please set OPENAI_API_KEY');
|
||||
throw new Error('Please set OPENAI_API_KEY or ASSISTANT_CHAT_COMPLETION_ENDPOINT');
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -74,7 +74,8 @@ export function makeGristConfig(options: MakeGristConfigOptons): GristLoadConfig
|
||||
supportedLngs: readLoadedLngs(req?.i18n),
|
||||
namespaces: readLoadedNamespaces(req?.i18n),
|
||||
featureComments: isAffirmative(process.env.COMMENTS),
|
||||
featureFormulaAssistant: Boolean(process.env.OPENAI_API_KEY),
|
||||
featureFormulaAssistant: Boolean(process.env.OPENAI_API_KEY || process.env.ASSISTANT_CHAT_COMPLETION_ENDPOINT),
|
||||
assistantService: process.env.OPENAI_API_KEY ? 'OpenAI' : undefined,
|
||||
supportEmail: SUPPORT_EMAIL,
|
||||
userLocale: (req as RequestWithLogin | undefined)?.user?.options?.locale,
|
||||
telemetry: server?.getTelemetry().getTelemetryConfig(),
|
||||
|
||||
Reference in New Issue
Block a user