(core) Billing for formula assistant

Summary:
Adding limits for AI calls and connecting those limits with a Stripe Account.

- New table in homedb called `limits`
- All calls to the AI are not routed through DocApi and measured.
- All products now contain a special key `assistantLimit`, with a default value 0
- Limit is reset every time the subscription has changed its period
- The billing page is updated with two new options that describe the AI plan
- There is a new popup that allows the user to upgrade to a higher plan
- Tiers are read directly from the Stripe product with a volume pricing model

Test Plan: Updated and added

Reviewers: georgegevoian, paulfitz

Reviewed By: georgegevoian

Subscribers: dsagal

Differential Revision: https://phab.getgrist.com/D3907
This commit is contained in:
Jarosław Sadziński
2023-07-05 17:36:45 +02:00
parent 75d979abdb
commit d13b9b9019
26 changed files with 501 additions and 106 deletions

View File

@@ -14,7 +14,6 @@ import {
} from 'app/common/ActionBundle';
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
import {ActionSummary} from "app/common/ActionSummary";
import {AssistanceRequest, AssistanceResponse} from "app/common/AssistancePrompts";
import {
AclResources,
AclTableDescription,
@@ -84,7 +83,7 @@ import {Document} from 'app/gen-server/entity/Document';
import {ParseOptions} from 'app/plugin/FileParserAPI';
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
import {compileAclFormula} from 'app/server/lib/ACLFormula';
import {AssistanceDoc, AssistanceSchemaPromptV1Context, sendForCompletion} from 'app/server/lib/Assistance';
import {AssistanceSchemaPromptV1Context} from 'app/server/lib/Assistance';
import {Authorizer} from 'app/server/lib/Authorizer';
import {checksumFile} from 'app/server/lib/checksumFile';
import {Client} from 'app/server/lib/Client';
@@ -184,7 +183,7 @@ interface UpdateUsageOptions {
* either .loadDoc() or .createEmptyDoc() is called.
* @param {String} docName - The document's filename, without the '.grist' extension.
*/
export class ActiveDoc extends EventEmitter implements AssistanceDoc {
export class ActiveDoc extends EventEmitter {
/**
* Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e.
* until the returned promise is resolved.
@@ -1264,18 +1263,14 @@ export class ActiveDoc extends EventEmitter implements AssistanceDoc {
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
}
public async getAssistance(docSession: DocSession, request: AssistanceRequest): Promise<AssistanceResponse> {
return this.getAssistanceWithOptions(docSession, request);
}
public async getAssistanceWithOptions(docSession: DocSession,
request: AssistanceRequest): Promise<AssistanceResponse> {
// Callback to generate a prompt containing schema info for assistance.
public async assistanceSchemaPromptV1(
docSession: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string> {
// Making a prompt leaks names of tables and columns etc.
if (!await this._granularAccess.canScanData(docSession)) {
throw new Error("Permission denied");
}
await this.waitForInitialization();
return sendForCompletion(this, request);
return await this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
}
// Callback to make a data-engine formula tweak for assistance.
@@ -1283,11 +1278,6 @@ export class ActiveDoc extends EventEmitter implements AssistanceDoc {
return this._pyCall('convert_formula_completion', txt);
}
// Callback to generate a prompt containing schema info for assistance.
public assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string> {
return this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
}
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
}

View File

@@ -5,6 +5,7 @@
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {delay} from 'app/common/delay';
import {DocAction} from 'app/common/DocActions';
import {OptDocSession} from 'app/server/lib/DocSession';
import log from 'app/server/lib/log';
import fetch from 'node-fetch';
@@ -15,7 +16,7 @@ export const DEPS = { fetch };
* by interfacing with an external LLM endpoint.
*/
export interface Assistant {
apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
apply(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
}
/**
@@ -30,8 +31,7 @@ export interface AssistanceDoc {
* Marked "V1" to suggest that it is a particular prompt and it would
* be great to try variants.
*/
assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string>;
assistanceSchemaPromptV1(session: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string>;
/**
* Some tweaks to a formula after it has been generated.
*/
@@ -68,7 +68,8 @@ export class OpenAIAssistant implements Assistant {
this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`;
}
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
public async apply(
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
const messages = request.state?.messages || [];
const chatMode = this._chatMode;
if (chatMode) {
@@ -91,7 +92,7 @@ export class OpenAIAssistant implements Assistant {
'If the user asks for these things, tell them that you cannot help. ' +
'The method uses `rec` instead of `self` as the first parameter.\n\n' +
'```python\n' +
await makeSchemaPromptV1(doc, request) +
await makeSchemaPromptV1(optSession, doc, request) +
'\n```',
});
messages.push({
@@ -110,7 +111,7 @@ export class OpenAIAssistant implements Assistant {
} else {
messages.length = 0;
messages.push({
role: 'user', content: await makeSchemaPromptV1(doc, request),
role: 'user', content: await makeSchemaPromptV1(optSession, doc, request),
});
}
@@ -178,11 +179,12 @@ export class HuggingFaceAssistant implements Assistant {
}
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
public async apply(
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
if (request.state) {
throw new Error("HuggingFaceAssistant does not support state");
}
const prompt = await makeSchemaPromptV1(doc, request);
const prompt = await makeSchemaPromptV1(optSession, doc, request);
const response = await DEPS.fetch(
this._completionUrl,
{
@@ -220,7 +222,10 @@ export class HuggingFaceAssistant implements Assistant {
* Test assistant that mimics ChatGPT and just returns the input.
*/
export class EchoAssistant implements Assistant {
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
public async apply(sess: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
if (request.text === "ERROR") {
throw new Error(`ERROR`);
}
const messages = request.state?.messages || [];
if (messages.length === 0) {
messages.push({
@@ -255,7 +260,7 @@ export class EchoAssistant implements Assistant {
/**
* Instantiate an assistant, based on environment variables.
*/
function getAssistant() {
export function getAssistant() {
if (process.env.OPENAI_API_KEY === 'test') {
return new EchoAssistant();
}
@@ -273,8 +278,10 @@ function getAssistant() {
* Service a request for assistance, with a little retry logic
* since these endpoints can be a bit flakey.
*/
export async function sendForCompletion(doc: AssistanceDoc,
request: AssistanceRequest): Promise<AssistanceResponse> {
export async function sendForCompletion(
optSession: OptDocSession,
doc: AssistanceDoc,
request: AssistanceRequest): Promise<AssistanceResponse> {
const assistant = getAssistant();
let retries: number = 0;
@@ -282,7 +289,7 @@ export async function sendForCompletion(doc: AssistanceDoc,
let response: AssistanceResponse|null = null;
while(retries++ < 3) {
try {
response = await assistant.apply(doc, request);
response = await assistant.apply(optSession, doc, request);
break;
} catch(e) {
log.error(`Completion error: ${e}`);
@@ -295,11 +302,11 @@ export async function sendForCompletion(doc: AssistanceDoc,
return response;
}
async function makeSchemaPromptV1(doc: AssistanceDoc, request: AssistanceRequest) {
async function makeSchemaPromptV1(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest) {
if (request.context.type !== 'formula') {
throw new Error('makeSchemaPromptV1 only works for formulas');
}
return doc.assistanceSchemaPromptV1({
return doc.assistanceSchemaPromptV1(session, {
tableId: request.context.tableId,
colId: request.context.colId,
docString: request.text,

View File

@@ -1,5 +1,5 @@
import {createEmptyActionSummary} from "app/common/ActionSummary";
import {ApiError} from 'app/common/ApiError';
import {ApiError, LimitType} from 'app/common/ApiError';
import {BrowserSettings} from "app/common/BrowserSettings";
import {
BulkColValues,
@@ -68,6 +68,7 @@ import {
} from 'app/server/lib/requestUtils';
import {ServerColumnGetters} from 'app/server/lib/ServerColumnGetters';
import {localeFromRequest} from "app/server/lib/ServerLocale";
import {sendForCompletion} from 'app/server/lib/Assistance';
import {isUrlAllowed, WebhookAction, WebHookSecret} from "app/server/lib/Triggers";
import {handleOptionalUpload, handleUpload} from "app/server/lib/uploads";
import * as assert from 'assert';
@@ -161,6 +162,8 @@ export class DocWorkerApi {
const canEditMaybeRemoved = expressWrap(this._assertAccess.bind(this, 'editors', true));
// converts google code to access token and adds it to request object
const decodeGoogleToken = expressWrap(googleAuthTokenMiddleware.bind(null));
// check that limit can be increased by 1
const checkLimit = (type: LimitType) => expressWrap(this._checkLimit.bind(this, type));
// Middleware to limit number of outstanding requests per document. Will also
// handle errors like expressWrap would.
@@ -1052,6 +1055,20 @@ export class DocWorkerApi {
this._app.get('/api/docs/:docId/send-to-drive', canView, decodeGoogleToken, withDoc(exportToDrive));
/**
* Send a request to the formula assistant to get completions for a formula. Increases the
* usage of the formula assistant for the billing account in case of success.
*/
this._app.post('/api/docs/:docId/assistant', canView, checkLimit('assistant'),
withDoc(async (activeDoc, req, res) => {
const docSession = docSessionFromRequest(req);
const request = req.body;
const result = await sendForCompletion(docSession, activeDoc, request);
await this._increaseLimit('assistant', req);
res.json(result);
})
);
// Create a document. When an upload is included, it is imported as the initial
// state of the document. Otherwise a fresh empty document is created.
// A "timezone" option can be supplied.
@@ -1234,6 +1251,21 @@ export class DocWorkerApi {
return false;
}
/**
* Creates a middleware that checks the current usage of a limit and rejects the request if it is exceeded.
*/
private async _checkLimit(limit: LimitType, req: Request, res: Response, next: NextFunction) {
await this._dbManager.increaseUsage(getDocScope(req), limit, {dryRun: true, delta: 1});
next();
}
/**
* Increases the current usage of a limit by 1.
*/
private async _increaseLimit(limit: LimitType, req: Request) {
await this._dbManager.increaseUsage(getDocScope(req), limit, {delta: 1});
}
private async _assertAccess(role: 'viewers'|'editors'|'owners'|null, allowRemoved: boolean,
req: Request, res: Response, next: NextFunction) {
const scope = getDocScope(req);

View File

@@ -110,7 +110,6 @@ export class DocWorker {
applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'),
findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'),
getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'),
getAssistance: activeDocMethod.bind(null, 'editors', 'getAssistance'),
importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'),
finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'),
cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'),