(core) Billing for formula assistant

Summary:
Adding limits for AI calls and connecting those limits with a Stripe Account.

- New table in homedb called `limits`
- All calls to the AI are not routed through DocApi and measured.
- All products now contain a special key `assistantLimit`, with a default value 0
- Limit is reset every time the subscription has changed its period
- The billing page is updated with two new options that describe the AI plan
- There is a new popup that allows the user to upgrade to a higher plan
- Tiers are read directly from the Stripe product with a volume pricing model

Test Plan: Updated and added

Reviewers: georgegevoian, paulfitz

Reviewed By: georgegevoian

Subscribers: dsagal

Differential Revision: https://phab.getgrist.com/D3907
pull/563/head
Jarosław Sadziński 11 months ago
parent 75d979abdb
commit d13b9b9019

@ -32,7 +32,6 @@ export class DocComm extends Disposable implements ActiveDocAPI {
public addAttachments = this._wrapMethod("addAttachments");
public findColFromValues = this._wrapMethod("findColFromValues");
public getFormulaError = this._wrapMethod("getFormulaError");
public getAssistance = this._wrapMethod("getAssistance");
public fetchURL = this._wrapMethod("fetchURL");
public autocomplete = this._wrapMethod("autocomplete");
public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc");

@ -185,6 +185,10 @@ export class GristDoc extends DisposableWithEvents {
public readonly currentTheme = this.docPageModel.appModel.currentTheme;
public get docApi() {
return this.docPageModel.appModel.api.getDocAPI(this.docPageModel.currentDocId.get()!);
}
private _actionLog: ActionLog;
private _undoStack: UndoStack;
private _lastOwnActionGroup: ActionGroupWithCursorPos|null = null;

@ -43,8 +43,8 @@ export interface CustomAction { label: string, action: () => void }
*/
export type MessageType = string | (() => DomElementArg);
// Identifies supported actions. These are implemented in NotifyUI.
export type NotifyAction = 'upgrade' | 'renew' | 'personal' | 'report-problem' | 'ask-for-help' | CustomAction;
export type NotifyAction = 'upgrade' | 'renew' | 'personal' | 'report-problem'
| 'ask-for-help' | 'manage' | CustomAction;
export interface INotifyOptions {
message: MessageType; // A string, or a function that builds dom.
timestamp?: number;

@ -121,7 +121,7 @@ export function reportError(err: Error|string, ev?: ErrorEvent): void {
const options: Partial<INotifyOptions> = {
title: "Reached plan limit",
key: `limit:${details.limit.quantity || message}`,
actions: ['upgrade'],
actions: details.tips?.some(t => t.action === 'manage') ? ['manage'] : ['upgrade'],
};
if (details.tips && details.tips.some(tip => tip.action === 'add-members')) {
// When adding members would fix a problem, give more specific advice.

@ -30,6 +30,10 @@ function buildAction(action: NotifyAction, item: Notification, options: IBeaconO
return dom('a', cssToastAction.cls(''), t("Upgrade Plan"), {target: '_blank'},
{href: commonUrls.plans});
}
case 'manage':
if (urlState().state.get().billing === 'billing') { return null; }
return dom('a', cssToastAction.cls(''), t("Manage billing"), {target: '_blank'},
{href: urlState().makeUrl({billing: 'billing'})});
case 'renew':
// If already on the billing page, nothing to return.
if (urlState().state.get().billing === 'billing') { return null; }

@ -142,6 +142,7 @@ export const vars = {
onboardingPopupZIndex: new CustomProp('onboarding-popup-z-index', '1000'),
floatingPopupZIndex: new CustomProp('floating-popup-z-index', '1002'),
tutorialModalZIndex: new CustomProp('tutorial-modal-z-index', '1003'),
pricingModalZIndex: new CustomProp('pricing-modal-z-index', '1004'),
notificationZIndex: new CustomProp('notification-z-index', '1100'),
browserCheckZIndex: new CustomProp('browser-check-z-index', '5000'),
tooltipZIndex: new CustomProp('tooltip-z-index', '5000'),

@ -17,7 +17,6 @@ import {autoGrow} from 'app/client/ui/forms';
import {IconName} from 'app/client/ui2018/IconList';
import {icon} from 'app/client/ui2018/icons';
import {cssLink} from 'app/client/ui2018/links';
import {DocAction} from 'app/common/DocActions';
import {movable} from 'app/client/lib/popupUtils';
import debounce from 'lodash/debounce';
@ -61,7 +60,7 @@ export class FormulaAssistant extends Disposable {
/** Is the request pending */
private _waiting = Observable.create(this, false);
/** Is this feature enabled at all */
private _assistantEnabled = GRIST_FORMULA_ASSISTANT();
private _assistantEnabled: Computed<boolean>;
/** Preview column id */
private _transformColId: string;
/** Method to invoke when we are closed, it saves or reverts */
@ -90,6 +89,12 @@ export class FormulaAssistant extends Disposable {
}) {
super();
this._assistantEnabled = Computed.create(this, use => {
const enabledByFlag = use(GRIST_FORMULA_ASSISTANT());
const notAnonymous = Boolean(this._options.gristDoc.appModel.currentValidUser);
return enabledByFlag && notAnonymous;
});
if (!this._options.field) {
// TODO: field is not passed only for rules (as there is no preview there available to the user yet)
// this should be implemented but it requires creating a helper column to helper column and we don't
@ -263,6 +268,8 @@ export class FormulaAssistant extends Disposable {
this._buildIntro(),
this._chat.buildDom(),
this._buildChatInput(),
// Stop propagation of mousedown events, as the formula editor will still focus.
dom.on('mousedown', (ev) => ev.stopPropagation()),
);
});
}
@ -516,7 +523,7 @@ export class FormulaAssistant extends Disposable {
this._options.editor.setFormula(entry.formula!);
}
private async _sendMessage(description: string, regenerate = false) {
private async _sendMessage(description: string, regenerate = false): Promise<ChatMessage> {
// Destruct options.
const {column, gristDoc} = this._options;
// Get the state of the chat from the column.
@ -539,7 +546,12 @@ export class FormulaAssistant extends Disposable {
// some markdown text back, so we need to parse it.
const prettyMessage = state ? (reply || formula || '') : (formula || reply || '');
// Add it to the chat.
this._chat.addResponse(prettyMessage, formula, suggestedActions[0]);
return {
message: prettyMessage,
formula,
action: suggestedActions[0],
sender: 'ai',
};
}
private _clear() {
@ -556,9 +568,7 @@ export class FormulaAssistant extends Disposable {
if (!last) {
return;
}
this._chat.thinking();
this._waiting.set(true);
await this._sendMessage(last, true).finally(() => this._waiting.set(false));
await this._doAsk(last);
}
private async _ask() {
@ -568,10 +578,22 @@ export class FormulaAssistant extends Disposable {
const message= this._userInput.get();
if (!message) { return; }
this._chat.addQuestion(message);
this._chat.thinking();
this._userInput.set('');
await this._doAsk(message);
}
private async _doAsk(message: string) {
this._chat.thinking();
this._waiting.set(true);
await this._sendMessage(message, false).finally(() => this._waiting.set(false));
try {
const response = await this._sendMessage(message, false);
this._chat.addResponse(response);
} catch(err) {
this._chat.thinking(false);
throw err;
} finally {
this._waiting.set(false);
}
}
}
@ -601,32 +623,36 @@ class ChatHistory extends Disposable {
this.length = Computed.create(this, use => use(this.history).length); // ??
}
public thinking() {
this.history.push({
message: '...',
sender: 'ai',
});
this.scrollDown();
public thinking(on = true) {
if (!on) {
// Find all index of all thinking messages.
const messages = [...this.history.get()].filter(m => m.message === '...');
// Remove all thinking messages.
for (const message of messages) {
this.history.splice(this.history.get().indexOf(message), 1);
}
} else {
this.history.push({
message: '...',
sender: 'ai',
});
this.scrollDown();
}
}
public supportsMarkdown() {
return this._options.column.chatHistory.peek().get().state !== undefined;
}
public addResponse(message: string, formula: string|null, action?: DocAction) {
public addResponse(message: ChatMessage) {
// Clear any thinking from messages.
this.history.set(this.history.get().filter(x => x.message !== '...'));
this.history.push({
message,
sender: 'ai',
formula,
action
});
this.thinking(false);
this.history.push({...message, sender: 'ai'});
this.scrollDown();
}
public addQuestion(message: string) {
this.history.set(this.history.get().filter(x => x.message !== '...'));
this.thinking(false);
this.history.push({
message,
sender: 'user',
@ -740,18 +766,13 @@ async function askAI(grist: GristDoc, options: {
const {column, description, state, regenerate} = options;
const tableId = column.table.peek().tableId.peek();
const colId = column.colId.peek();
try {
const result = await grist.docComm.getAssistance({
context: {type: 'formula', tableId, colId},
text: description,
state,
regenerate,
});
return result;
} catch (error) {
reportError(error);
throw error;
}
const result = await grist.docApi.getAssistance({
context: {type: 'formula', tableId, colId},
text: description,
state,
regenerate,
});
return result;
}
/**

@ -154,13 +154,14 @@ export class FormulaEditor extends NewBaseEditor {
dom.on('mousedown', (ev) => {
// If we are detached, allow user to click and select error text.
if (this.isDetached.get()) {
// If the focus is already in this editor, don't steal it. This is needed for detached editor with
// some input elements (mainly the AI assistant).
const inInput = document.activeElement instanceof HTMLInputElement
|| document.activeElement instanceof HTMLTextAreaElement;
if (inInput && this._dom.contains(document.activeElement)) {
// If we clicked on input element in our dom, don't do anything. We probably clicked on chat input, in AI
// tools box.
const clickedOnInput = ev.target instanceof HTMLInputElement || ev.target instanceof HTMLTextAreaElement;
if (clickedOnInput && this._dom.contains(ev.target)) {
// By not doing anything special here we assume that the input element will take the focus.
return;
}
// Allow clicking the error message.
if (ev.target instanceof HTMLElement && (
ev.target.classList.contains('error_msg') ||

@ -1,5 +1,4 @@
import {ActionGroup} from 'app/common/ActionGroup';
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions';
import {FormulaProperties} from 'app/common/GranularAccessClause';
import {UIRowId} from 'app/common/TableData';
@ -320,11 +319,6 @@ export interface ActiveDocAPI {
*/
getFormulaError(tableId: string, colId: string, rowId: number): Promise<CellValue>;
/**
* Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
*/
getAssistance(request: AssistanceRequest): Promise<AssistanceResponse>;
/**
* Fetch content at a url.
*/

@ -2,15 +2,17 @@
* A tip for fixing an error.
*/
export interface ApiTip {
action: 'add-members' | 'upgrade' |'ask-for-help';
action: 'add-members' | 'upgrade' | 'ask-for-help' | 'manage';
message: string;
}
export type LimitType = 'collaborators' | 'docs' | 'workspaces' | 'assistant';
/**
* Documentation of a limit relevant to an API error.
*/
export interface ApiLimit {
quantity: 'collaborators' | 'docs' | 'workspaces'; // what are we counting
quantity: LimitType; // what are we counting
subquantity?: string; // a nuance to what we are counting
maximum: number; // maximum allowed
value: number; // current value of quantity for user

@ -43,6 +43,16 @@ export interface IBillingPlan {
active: boolean;
}
export interface ILimitTier {
name?: string;
volume: number;
price: number;
flatFee: number;
type: string;
planId: string;
interval: string; // probably 'month'|'year';
}
// Utility type that requires all properties to be non-nullish.
// type NonNullableProperties<T> = { [P in keyof T]: Required<NonNullable<T[P]>>; };
@ -69,6 +79,7 @@ export interface IBillingDiscount {
export interface IBillingSubscription {
// All standard plan options.
plans: IBillingPlan[];
tiers: ILimitTier[];
// Index in the plans array of the plan currently in effect.
planIndex: number;
// Index in the plans array of the plan to be in effect after the current period end.
@ -111,6 +122,14 @@ export interface IBillingSubscription {
lastInvoiceUrl?: string; // URL of the Stripe-hosted page with the last invoice.
lastChargeError?: string; // The last charge error, if any, to show in case of a bad status.
lastChargeTime?: number; // The time of the last charge attempt.
limit?: ILimit|null;
}
export interface ILimit {
limitValue: number;
currentUsage: number;
type: string; // Limit type, for now only assistant is supported.
price: number; // If this is 0, it means it is a free plan.
}
export interface IBillingOrgSettings {
@ -139,6 +158,7 @@ export interface BillingAPI {
downgradePlan(planName: string): Promise<void>;
renewPlan(): string;
customerPortal(): string;
updateAssistantPlan(tier: number): Promise<void>;
}
export class BillingAPIImpl extends BaseAPI implements BillingAPI {
@ -230,6 +250,13 @@ export class BillingAPIImpl extends BaseAPI implements BillingAPI {
return `${this._url}/api/billing/renew`;
}
public async updateAssistantPlan(tier: number): Promise<void> {
await this.request(`${this._url}/api/billing/upgrade-assistant`, {
method: 'POST',
body: JSON.stringify({ tier })
});
}
/**
* Checks if current org has active subscription for a Stripe plan.
*/

@ -58,6 +58,11 @@ export interface Features {
// for attached files in a document
gracePeriodDays?: number; // Duration of the grace period in days, before entering delete-only mode
baseMaxAssistantCalls?: number; // Maximum number of AI assistant calls. Defaults to 0 if not set, use -1 to indicate
// unbound limit. This is total limit, not per month or per day, it is used as a seed
// value for the limits table. To create a per-month limit, there must be a separate
// task that resets the usage in the limits table.
}
// Check whether it is possible to add members at the org level. There's no flag

@ -1,5 +1,6 @@
import {ActionSummary} from 'app/common/ActionSummary';
import {ApplyUAResult, ForkResult, PermissionDataWithExtraUsers, QueryFilters} from 'app/common/ActiveDocAPI';
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {BaseAPI, IOptions} from 'app/common/BaseAPI';
import {BillingAPI, BillingAPIImpl} from 'app/common/BillingAPI';
import {BrowserSettings} from 'app/common/BrowserSettings';
@ -462,6 +463,8 @@ export interface DocAPI {
// Update webhook
updateWebhook(webhook: WebhookUpdate): Promise<void>;
flushWebhooks(): Promise<void>;
getAssistance(params: AssistanceRequest): Promise<AssistanceResponse>;
}
// Operations that are supported by a doc worker.
@ -1012,6 +1015,13 @@ export class DocAPIImpl extends BaseAPI implements DocAPI {
return response.data[0];
}
public async getAssistance(params: AssistanceRequest): Promise<AssistanceResponse> {
return await this.requestJson(`${this._url}/assistant`, {
method: 'POST',
body: JSON.stringify(params),
});
}
private _getRecords(tableId: string, endpoint: 'data' | 'records', options?: GetRowsParams): Promise<any> {
const url = new URL(`${this._url}/tables/${tableId}/${endpoint}`);
if (options?.filters) {

@ -3,12 +3,13 @@ import {BillingAccountManager} from 'app/gen-server/entity/BillingAccountManager
import {Organization} from 'app/gen-server/entity/Organization';
import {Product} from 'app/gen-server/entity/Product';
import {nativeValues} from 'app/gen-server/lib/values';
import {Limit} from 'app/gen-server/entity/Limit';
// This type is for billing account status information. Intended for stuff
// like "free trial running out in N days".
interface BillingAccountStatus {
export interface BillingAccountStatus {
stripeStatus?: string;
currentPeriodEnd?: Date;
currentPeriodEnd?: string;
message?: string;
}
@ -68,6 +69,9 @@ export class BillingAccount extends BaseEntity {
@OneToMany(type => Organization, org => org.billingAccount)
public orgs: Organization[];
@OneToMany(type => Limit, limit => limit.billingAccount)
public limits: Limit[];
// A calculated column that is true if it looks like there is a paid plan.
@Column({name: 'paid', type: 'boolean', insert: false, select: false})
public paid?: boolean;

@ -0,0 +1,46 @@
import {BaseEntity, Column, Entity, JoinColumn, ManyToOne, PrimaryGeneratedColumn} from 'typeorm';
import {BillingAccount} from 'app/gen-server/entity/BillingAccount';
import {nativeValues} from 'app/gen-server/lib/values';
@Entity('limits')
export class Limit extends BaseEntity {
@PrimaryGeneratedColumn()
public id: number;
@Column()
public limit: number;
@Column()
public usage: number;
@Column()
public type: string;
@Column({name: 'billing_account_id'})
public billingAccountId: number;
@ManyToOne(type => BillingAccount)
@JoinColumn({name: 'billing_account_id'})
public billingAccount: BillingAccount;
@Column({name: 'created_at', default: () => "CURRENT_TIMESTAMP"})
public createdAt: Date;
/**
* Last time the Limit.limit value was changed, by an upgrade or downgrade. Null if it has never been changed.
*/
@Column({name: 'changed_at', type: nativeValues.dateTimeType, nullable: true})
public changedAt: Date|null;
/**
* Last time the Limit.usage was used (by sending a request to the model). Null if it has never been used.
*/
@Column({name: 'used_at', type: nativeValues.dateTimeType, nullable: true})
public usedAt: Date|null;
/**
* Last time the Limit.usage was reset, probably by billing cycle change. Null if it has never been reset.
*/
@Column({name: 'reset_at', type: nativeValues.dateTimeType, nullable: true})
public resetAt: Date|null;
}

@ -13,7 +13,11 @@ export const personalLegacyFeatures: Features = {
// no vanity domain
maxDocsPerOrg: 10,
maxSharesPerDoc: 2,
maxWorkspacesPerOrg: 1
maxWorkspacesPerOrg: 1,
/**
* One time limit of 100 requests.
*/
baseMaxAssistantCalls: 100,
};
/**
@ -23,7 +27,12 @@ export const teamFeatures: Features = {
workspaces: true,
vanityDomain: true,
maxSharesPerWorkspace: 0, // all workspace shares need to be org members.
maxSharesPerDoc: 2
maxSharesPerDoc: 2,
/**
* Limit of 100 requests, but unlike for personal/free orgs the usage for this limit is reset at every billing cycle
* through Stripe webhook. For canceled subscription the usage is not reset, as the billing cycle is not changed.
*/
baseMaxAssistantCalls: 100,
};
/**
@ -40,6 +49,10 @@ export const teamFreeFeatures: Features = {
baseMaxDataSizePerDocument: 5000 * 2 * 1024, // 2KB per row
baseMaxAttachmentsBytesPerDocument: 1 * 1024 * 1024 * 1024, // 1GB
gracePeriodDays: 14,
/**
* One time limit of 100 requests.
*/
baseMaxAssistantCalls: 100,
};
/**
@ -55,6 +68,7 @@ export const teamFreeFeatures: Features = {
baseMaxDataSizePerDocument: 5000 * 2 * 1024, // 2KB per row
baseMaxAttachmentsBytesPerDocument: 1 * 1024 * 1024 * 1024, // 1GB
gracePeriodDays: 14,
baseMaxAssistantCalls: 100,
};
export const testDailyApiLimitFeatures = {
@ -79,6 +93,7 @@ export const suspendedFeatures: Features = {
maxDocsPerOrg: 0,
maxSharesPerDoc: 0,
maxWorkspacesPerOrg: 0,
baseMaxAssistantCalls: 0,
};
/**

@ -60,6 +60,7 @@ export class DocApiForwarder {
app.use('/api/docs/:docId/assign', withDocWithoutAuth);
app.use('/api/docs/:docId/webhooks/queue', withDoc);
app.use('/api/docs/:docId/webhooks', withDoc);
app.use('/api/docs/:docId/assistant', withDoc);
app.use('^/api/docs$', withoutDoc);
}

@ -1,4 +1,4 @@
import {ApiError} from 'app/common/ApiError';
import {ApiError, ApiErrorDetails, LimitType} from 'app/common/ApiError';
import {mapGetOrSet, mapSetOrClear, MapWithTTL} from 'app/common/AsyncCreate';
import {getDataLimitStatus} from 'app/common/DocLimits';
import {createEmptyOrgUsageSummary, DocumentUsage, OrgUsageSummary} from 'app/common/DocUsage';
@ -38,6 +38,7 @@ import {getDefaultProductNames, personalFreeFeatures, Product} from "app/gen-ser
import {Secret} from "app/gen-server/entity/Secret";
import {User} from "app/gen-server/entity/User";
import {Workspace} from "app/gen-server/entity/Workspace";
import {Limit} from 'app/gen-server/entity/Limit';
import {Permissions} from 'app/gen-server/lib/Permissions';
import {scrubUserFromOrg} from "app/gen-server/lib/scrubUserFromOrg";
import {applyPatch} from 'app/gen-server/lib/TypeORMPatches';
@ -2880,6 +2881,144 @@ export class HomeDBManager extends EventEmitter {
return this._org(scope, scope.includeSupport || false, org, options);
}
public async getLimits(accountId: number): Promise<Limit[]> {
const result = this._connection.transaction(async manager => {
return await manager.createQueryBuilder()
.select('limit')
.from(Limit, 'limit')
.innerJoin('limit.billingAccount', 'account')
.where('account.id = :accountId', {accountId})
.getMany();
});
return result;
}
public async getLimit(accountId: number, limitType: LimitType): Promise<Limit|null> {
return await this._getOrCreateLimit(accountId, limitType, true);
}
public async peekLimit(accountId: number, limitType: LimitType): Promise<Limit|null> {
return await this._getOrCreateLimit(accountId, limitType, false);
}
public async removeLimit(scope: Scope, limitType: LimitType): Promise<void> {
await this._connection.transaction(async manager => {
const org = await this._org(scope, false, scope.org ?? null, {manager, needRealOrg: true})
.innerJoinAndSelect('orgs.billingAccount', 'billing_account')
.innerJoinAndSelect('billing_account.product', 'product')
.leftJoinAndSelect('billing_account.limits', 'limit', 'limit.type = :limitType', {limitType})
.getOne();
const existing = org?.billingAccount?.limits?.[0];
if (existing) {
await manager.remove(existing);
}
});
}
/**
* Increases the usage of a limit for a given org. If the limit doesn't exist, it will be created.
* Pass `dryRun: true` to check if the limit can be increased without actually increasing it.
*/
public async increaseUsage(scope: Scope, limitType: LimitType, options: {
delta: number,
dryRun?: boolean,
}): Promise<void> {
const limitError = await this._connection.transaction(async manager => {
const org = await this._org(scope, false, scope.org ?? null, {manager, needRealOrg: true})
.innerJoinAndSelect('orgs.billingAccount', 'billing_account')
.innerJoinAndSelect('billing_account.product', 'product')
.leftJoinAndSelect('billing_account.limits', 'limit', 'limit.type = :limitType', {limitType})
.getOne();
// If the org doesn't exists, or is a fake one (like for anonymous users), don't do anything.
if (!org || org.id === 0) {
// This API shouldn't be called, it should be checked first if the org is valid.
throw new ApiError(`Can't create a limit for non-existing organization`, 500);
}
let existing = org?.billingAccount?.limits?.[0];
if (!existing) {
const product = org?.billingAccount?.product;
if (!product) {
throw new ApiError(`getLimit: no product found for org`, 500);
}
if (product.features.baseMaxAssistantCalls === undefined) {
// If the product has no assistantLimit, then it is not billable yet, and we don't need to
// track usage as it is basically unlimited.
return;
}
existing = new Limit();
existing.billingAccountId = org.billingAccountId;
existing.type = limitType;
existing.limit = product.features.baseMaxAssistantCalls ?? 0;
existing.usage = 0;
}
const limitLess = existing.limit === -1; // -1 means no limit, it is not possible to do in stripe.
const usageAfter = existing.usage + options.delta;
if (!limitLess && usageAfter > existing.limit) {
const billable = Boolean(org?.billingAccount?.stripeCustomerId);
return {
limit: {
maximum: existing.limit,
projectedValue: existing.usage + options.delta,
quantity: limitType,
value: existing.usage,
},
tips: [{
// For non-billable accounts, suggest getting a plan, otherwise suggest visiting the billing page.
action: billable ? 'manage' : 'upgrade',
message: `Upgrade to a paid plan to increase your ${limitType} limit.`,
}]
} as ApiErrorDetails;
}
existing.usage += options.delta;
existing.usedAt = new Date();
if (!options.dryRun) {
await manager.save(existing);
}
});
if (limitError) {
let message = `Your ${limitType} limit has been reached. Please upgrade your plan to increase your limit.`;
if (limitType === 'assistant') {
message = 'You used all available credits. For a bigger limit upgrade you Assistant plan.';
}
throw new ApiError(message, 429, limitError);
}
}
private async _getOrCreateLimit(accountId: number, limitType: LimitType, force: boolean): Promise<Limit|null> {
if (accountId === 0) {
throw new Error(`getLimit: called for not existing account`);
}
const result = this._connection.transaction(async manager => {
let existing = await manager.createQueryBuilder()
.select('limit')
.from(Limit, 'limit')
.innerJoin('limit.billingAccount', 'account')
.where('account.id = :accountId', {accountId})
.andWhere('limit.type = :limitType', {limitType})
.getOne();
if (!force && !existing) { return null; }
if (existing) { return existing; }
const product = await manager.createQueryBuilder()
.select('product')
.from(Product, 'product')
.innerJoinAndSelect('product.accounts', 'account')
.where('account.id = :accountId', {accountId})
.getOne();
if (!product) {
throw new Error(`getLimit: no product for account ${accountId}`);
}
existing = new Limit();
existing.billingAccountId = product.accounts[0].id;
existing.type = limitType;
existing.limit = product.features.baseMaxAssistantCalls ?? 0;
existing.usage = 0;
await manager.save(existing);
return existing;
});
return result;
}
private _org(scope: Scope|null, includeSupport: boolean, org: string|number|null,
options: QueryOptions = {}): SelectQueryBuilder<Organization> {
let query = this._orgs(options.manager);

@ -0,0 +1,82 @@
import * as sqlUtils from "app/gen-server/sqlUtils";
import {MigrationInterface, QueryRunner, Table, TableIndex} from 'typeorm';
export class AssistantLimit1685343047786 implements MigrationInterface {
public async up(queryRunner: QueryRunner): Promise<void> {
const dbType = queryRunner.connection.driver.options.type;
const datetime = sqlUtils.datetime(dbType);
const now = sqlUtils.now(dbType);
await queryRunner.createTable(
new Table({
name: 'limits',
columns: [
{
name: 'id',
type: 'integer',
isPrimary: true,
isGenerated: true,
generationStrategy: 'increment',
},
{
name: 'type',
type: 'varchar',
},
{
name: 'billing_account_id',
type: 'integer',
},
{
name: 'limit',
type: 'integer',
default: 0,
},
{
name: 'usage',
type: 'integer',
default: 0,
},
{
name: "created_at",
type: datetime,
default: now
},
{
name: "changed_at", // When the limit was last changed
type: datetime,
isNullable: true
},
{
name: "used_at", // When the usage was last increased
type: datetime,
isNullable: true
},
{
name: "reset_at", // When the usage was last reset.
type: datetime,
isNullable: true
},
],
foreignKeys: [
{
columnNames: ['billing_account_id'],
referencedTableName: 'billing_accounts',
referencedColumnNames: ['id'],
onDelete: 'CASCADE',
},
],
})
);
await queryRunner.createIndex(
'limits',
new TableIndex({
name: 'limits_billing_account_id',
columnNames: ['billing_account_id'],
})
);
}
public async down(queryRunner: QueryRunner): Promise<void> {
await queryRunner.dropTable('limits');
}
}

@ -14,7 +14,6 @@ import {
} from 'app/common/ActionBundle';
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
import {ActionSummary} from "app/common/ActionSummary";
import {AssistanceRequest, AssistanceResponse} from "app/common/AssistancePrompts";
import {
AclResources,
AclTableDescription,
@ -84,7 +83,7 @@ import {Document} from 'app/gen-server/entity/Document';
import {ParseOptions} from 'app/plugin/FileParserAPI';
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
import {compileAclFormula} from 'app/server/lib/ACLFormula';
import {AssistanceDoc, AssistanceSchemaPromptV1Context, sendForCompletion} from 'app/server/lib/Assistance';
import {AssistanceSchemaPromptV1Context} from 'app/server/lib/Assistance';
import {Authorizer} from 'app/server/lib/Authorizer';
import {checksumFile} from 'app/server/lib/checksumFile';
import {Client} from 'app/server/lib/Client';
@ -184,7 +183,7 @@ interface UpdateUsageOptions {
* either .loadDoc() or .createEmptyDoc() is called.
* @param {String} docName - The document's filename, without the '.grist' extension.
*/
export class ActiveDoc extends EventEmitter implements AssistanceDoc {
export class ActiveDoc extends EventEmitter {
/**
* Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e.
* until the returned promise is resolved.
@ -1264,18 +1263,14 @@ export class ActiveDoc extends EventEmitter implements AssistanceDoc {
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
}
public async getAssistance(docSession: DocSession, request: AssistanceRequest): Promise<AssistanceResponse> {
return this.getAssistanceWithOptions(docSession, request);
}
public async getAssistanceWithOptions(docSession: DocSession,
request: AssistanceRequest): Promise<AssistanceResponse> {
// Callback to generate a prompt containing schema info for assistance.
public async assistanceSchemaPromptV1(
docSession: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string> {
// Making a prompt leaks names of tables and columns etc.
if (!await this._granularAccess.canScanData(docSession)) {
throw new Error("Permission denied");
}
await this.waitForInitialization();
return sendForCompletion(this, request);
return await this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
}
// Callback to make a data-engine formula tweak for assistance.
@ -1283,11 +1278,6 @@ export class ActiveDoc extends EventEmitter implements AssistanceDoc {
return this._pyCall('convert_formula_completion', txt);
}
// Callback to generate a prompt containing schema info for assistance.
public assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string> {
return this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
}
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
}

@ -5,6 +5,7 @@
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {delay} from 'app/common/delay';
import {DocAction} from 'app/common/DocActions';
import {OptDocSession} from 'app/server/lib/DocSession';
import log from 'app/server/lib/log';
import fetch from 'node-fetch';
@ -15,7 +16,7 @@ export const DEPS = { fetch };
* by interfacing with an external LLM endpoint.
*/
export interface Assistant {
apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
apply(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
}
/**
@ -30,8 +31,7 @@ export interface AssistanceDoc {
* Marked "V1" to suggest that it is a particular prompt and it would
* be great to try variants.
*/
assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string>;
assistanceSchemaPromptV1(session: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string>;
/**
* Some tweaks to a formula after it has been generated.
*/
@ -68,7 +68,8 @@ export class OpenAIAssistant implements Assistant {
this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`;
}
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
public async apply(
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
const messages = request.state?.messages || [];
const chatMode = this._chatMode;
if (chatMode) {
@ -91,7 +92,7 @@ export class OpenAIAssistant implements Assistant {
'If the user asks for these things, tell them that you cannot help. ' +
'The method uses `rec` instead of `self` as the first parameter.\n\n' +
'```python\n' +
await makeSchemaPromptV1(doc, request) +
await makeSchemaPromptV1(optSession, doc, request) +
'\n```',
});
messages.push({
@ -110,7 +111,7 @@ export class OpenAIAssistant implements Assistant {
} else {
messages.length = 0;
messages.push({
role: 'user', content: await makeSchemaPromptV1(doc, request),
role: 'user', content: await makeSchemaPromptV1(optSession, doc, request),
});
}
@ -178,11 +179,12 @@ export class HuggingFaceAssistant implements Assistant {
}
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
public async apply(
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
if (request.state) {
throw new Error("HuggingFaceAssistant does not support state");
}
const prompt = await makeSchemaPromptV1(doc, request);
const prompt = await makeSchemaPromptV1(optSession, doc, request);
const response = await DEPS.fetch(
this._completionUrl,
{
@ -220,7 +222,10 @@ export class HuggingFaceAssistant implements Assistant {
* Test assistant that mimics ChatGPT and just returns the input.
*/
export class EchoAssistant implements Assistant {
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
public async apply(sess: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
if (request.text === "ERROR") {
throw new Error(`ERROR`);
}
const messages = request.state?.messages || [];
if (messages.length === 0) {
messages.push({
@ -255,7 +260,7 @@ export class EchoAssistant implements Assistant {
/**
* Instantiate an assistant, based on environment variables.
*/
function getAssistant() {
export function getAssistant() {
if (process.env.OPENAI_API_KEY === 'test') {
return new EchoAssistant();
}
@ -273,8 +278,10 @@ function getAssistant() {
* Service a request for assistance, with a little retry logic
* since these endpoints can be a bit flakey.
*/
export async function sendForCompletion(doc: AssistanceDoc,
request: AssistanceRequest): Promise<AssistanceResponse> {
export async function sendForCompletion(
optSession: OptDocSession,
doc: AssistanceDoc,
request: AssistanceRequest): Promise<AssistanceResponse> {
const assistant = getAssistant();
let retries: number = 0;
@ -282,7 +289,7 @@ export async function sendForCompletion(doc: AssistanceDoc,
let response: AssistanceResponse|null = null;
while(retries++ < 3) {
try {
response = await assistant.apply(doc, request);
response = await assistant.apply(optSession, doc, request);
break;
} catch(e) {
log.error(`Completion error: ${e}`);
@ -295,11 +302,11 @@ export async function sendForCompletion(doc: AssistanceDoc,
return response;
}
async function makeSchemaPromptV1(doc: AssistanceDoc, request: AssistanceRequest) {
async function makeSchemaPromptV1(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest) {
if (request.context.type !== 'formula') {
throw new Error('makeSchemaPromptV1 only works for formulas');
}
return doc.assistanceSchemaPromptV1({
return doc.assistanceSchemaPromptV1(session, {
tableId: request.context.tableId,
colId: request.context.colId,
docString: request.text,

@ -1,5 +1,5 @@
import {createEmptyActionSummary} from "app/common/ActionSummary";
import {ApiError} from 'app/common/ApiError';
import {ApiError, LimitType} from 'app/common/ApiError';
import {BrowserSettings} from "app/common/BrowserSettings";
import {
BulkColValues,
@ -68,6 +68,7 @@ import {
} from 'app/server/lib/requestUtils';
import {ServerColumnGetters} from 'app/server/lib/ServerColumnGetters';
import {localeFromRequest} from "app/server/lib/ServerLocale";
import {sendForCompletion} from 'app/server/lib/Assistance';
import {isUrlAllowed, WebhookAction, WebHookSecret} from "app/server/lib/Triggers";
import {handleOptionalUpload, handleUpload} from "app/server/lib/uploads";
import * as assert from 'assert';
@ -161,6 +162,8 @@ export class DocWorkerApi {
const canEditMaybeRemoved = expressWrap(this._assertAccess.bind(this, 'editors', true));
// converts google code to access token and adds it to request object
const decodeGoogleToken = expressWrap(googleAuthTokenMiddleware.bind(null));
// check that limit can be increased by 1
const checkLimit = (type: LimitType) => expressWrap(this._checkLimit.bind(this, type));
// Middleware to limit number of outstanding requests per document. Will also
// handle errors like expressWrap would.
@ -1052,6 +1055,20 @@ export class DocWorkerApi {
this._app.get('/api/docs/:docId/send-to-drive', canView, decodeGoogleToken, withDoc(exportToDrive));
/**
* Send a request to the formula assistant to get completions for a formula. Increases the
* usage of the formula assistant for the billing account in case of success.
*/
this._app.post('/api/docs/:docId/assistant', canView, checkLimit('assistant'),
withDoc(async (activeDoc, req, res) => {
const docSession = docSessionFromRequest(req);
const request = req.body;
const result = await sendForCompletion(docSession, activeDoc, request);
await this._increaseLimit('assistant', req);
res.json(result);
})
);
// Create a document. When an upload is included, it is imported as the initial
// state of the document. Otherwise a fresh empty document is created.
// A "timezone" option can be supplied.
@ -1234,6 +1251,21 @@ export class DocWorkerApi {
return false;
}
/**
* Creates a middleware that checks the current usage of a limit and rejects the request if it is exceeded.
*/
private async _checkLimit(limit: LimitType, req: Request, res: Response, next: NextFunction) {
await this._dbManager.increaseUsage(getDocScope(req), limit, {dryRun: true, delta: 1});
next();
}
/**
* Increases the current usage of a limit by 1.
*/
private async _increaseLimit(limit: LimitType, req: Request) {
await this._dbManager.increaseUsage(getDocScope(req), limit, {delta: 1});
}
private async _assertAccess(role: 'viewers'|'editors'|'owners'|null, allowRemoved: boolean,
req: Request, res: Response, next: NextFunction) {
const scope = getDocScope(req);

@ -110,7 +110,6 @@ export class DocWorker {
applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'),
findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'),
getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'),
getAssistance: activeDocMethod.bind(null, 'editors', 'getAssistance'),
importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'),
finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'),
cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'),

@ -25,7 +25,7 @@
import { ActiveDoc, Deps as ActiveDocDeps } from "app/server/lib/ActiveDoc";
import { DEPS } from "app/server/lib/Assistance";
import { DEPS, sendForCompletion } from "app/server/lib/Assistance";
import log from 'app/server/lib/log';
import crypto from 'crypto';
import parse from 'csv-parse/lib/sync';
@ -163,7 +163,7 @@ where c.colId = ? and t.tableId = ?
`, rec.col_id, rec.table_id);
formula = colInfo?.formula;
const result = await activeDoc.getAssistanceWithOptions(session, {
const result = await sendForCompletion(session, activeDoc, {
context: {type: 'formula', tableId, colId},
state: history,
text: followUp || description,

@ -558,7 +558,7 @@ describe('DocTutorial', function () {
// Check that the update is immediately reflected in the tutorial popup.
assert.equal(
await driver.find('.test-doc-tutorial-popup p').getText(),
await driver.findWait('.test-doc-tutorial-popup p', 2000).getText(),
'Welcome to the Grist Basics tutorial V2.'
);
@ -571,7 +571,7 @@ describe('DocTutorial', function () {
// Switch to another user and restart the tutorial.
viewerSession = await gu.session().teamSite.user('user2').login();
await viewerSession.loadDoc(`/doc/${doc.id}`);
await driver.find('.test-doc-tutorial-popup-restart').click();
await driver.findWait('.test-doc-tutorial-popup-restart', 2000).click();
await driver.find('.test-modal-confirm').click();
await gu.waitForServer();
await driver.findWait('.test-doc-tutorial-popup', 2000);

@ -2056,7 +2056,13 @@ export class Session {
isFirstLogin?: boolean,
showTips?: boolean,
skipTutorial?: boolean, // By default true
userName?: string,
email?: string,
retainExistingLogin?: boolean}) {
if (options?.userName) {
this.settings.name = options.userName;
this.settings.email = options.email || '';
}
// Optimize testing a little bit, so if we are already logged in as the expected
// user on the expected org, and there are no options set, we can just continue.
if (!options && await this.isLoggedInCorrectly()) { return this; }
@ -3150,20 +3156,26 @@ export async function availableBehaviorOptions() {
return list;
}
export function withComments() {
let oldEnv: testUtils.EnvironmentSnapshot;
/**
* Restarts the server ensuring that it is run with the given environment variables.
* If variables are already set, the server is not restarted.
*
* Useful for local testing of features that depend on environment variables, as it avoids the need
* to restart the server when those variables are already set.
*/
export function withEnvironmentSnapshot(vars: Record<string, any>) {
let oldEnv: testUtils.EnvironmentSnapshot|null = null;
before(async () => {
if (process.env.COMMENTS !== 'true') {
oldEnv = new testUtils.EnvironmentSnapshot();
process.env.COMMENTS = 'true';
await server.restart();
}
// Test if the vars are already set, and if so, skip.
if (Object.keys(vars).every(k => process.env[k] === vars[k])) { return; }
oldEnv = new testUtils.EnvironmentSnapshot();
Object.assign(process.env, vars);
await server.restart();
});
after(async () => {
if (oldEnv) {
oldEnv.restore();
await server.restart();
}
if (!oldEnv) { return; }
oldEnv.restore();
await server.restart();
});
}

Loading…
Cancel
Save