(core) Billing for formula assistant

Summary:
Adding limits for AI calls and connecting those limits with a Stripe Account.

- New table in homedb called `limits`
- All calls to the AI are not routed through DocApi and measured.
- All products now contain a special key `assistantLimit`, with a default value 0
- Limit is reset every time the subscription has changed its period
- The billing page is updated with two new options that describe the AI plan
- There is a new popup that allows the user to upgrade to a higher plan
- Tiers are read directly from the Stripe product with a volume pricing model

Test Plan: Updated and added

Reviewers: georgegevoian, paulfitz

Reviewed By: georgegevoian

Subscribers: dsagal

Differential Revision: https://phab.getgrist.com/D3907
This commit is contained in:
Jarosław Sadziński 2023-07-05 17:36:45 +02:00
parent 75d979abdb
commit d13b9b9019
26 changed files with 501 additions and 106 deletions

View File

@ -32,7 +32,6 @@ export class DocComm extends Disposable implements ActiveDocAPI {
public addAttachments = this._wrapMethod("addAttachments"); public addAttachments = this._wrapMethod("addAttachments");
public findColFromValues = this._wrapMethod("findColFromValues"); public findColFromValues = this._wrapMethod("findColFromValues");
public getFormulaError = this._wrapMethod("getFormulaError"); public getFormulaError = this._wrapMethod("getFormulaError");
public getAssistance = this._wrapMethod("getAssistance");
public fetchURL = this._wrapMethod("fetchURL"); public fetchURL = this._wrapMethod("fetchURL");
public autocomplete = this._wrapMethod("autocomplete"); public autocomplete = this._wrapMethod("autocomplete");
public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc"); public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc");

View File

@ -185,6 +185,10 @@ export class GristDoc extends DisposableWithEvents {
public readonly currentTheme = this.docPageModel.appModel.currentTheme; public readonly currentTheme = this.docPageModel.appModel.currentTheme;
public get docApi() {
return this.docPageModel.appModel.api.getDocAPI(this.docPageModel.currentDocId.get()!);
}
private _actionLog: ActionLog; private _actionLog: ActionLog;
private _undoStack: UndoStack; private _undoStack: UndoStack;
private _lastOwnActionGroup: ActionGroupWithCursorPos|null = null; private _lastOwnActionGroup: ActionGroupWithCursorPos|null = null;

View File

@ -43,8 +43,8 @@ export interface CustomAction { label: string, action: () => void }
*/ */
export type MessageType = string | (() => DomElementArg); export type MessageType = string | (() => DomElementArg);
// Identifies supported actions. These are implemented in NotifyUI. // Identifies supported actions. These are implemented in NotifyUI.
export type NotifyAction = 'upgrade' | 'renew' | 'personal' | 'report-problem' | 'ask-for-help' | CustomAction; export type NotifyAction = 'upgrade' | 'renew' | 'personal' | 'report-problem'
| 'ask-for-help' | 'manage' | CustomAction;
export interface INotifyOptions { export interface INotifyOptions {
message: MessageType; // A string, or a function that builds dom. message: MessageType; // A string, or a function that builds dom.
timestamp?: number; timestamp?: number;

View File

@ -121,7 +121,7 @@ export function reportError(err: Error|string, ev?: ErrorEvent): void {
const options: Partial<INotifyOptions> = { const options: Partial<INotifyOptions> = {
title: "Reached plan limit", title: "Reached plan limit",
key: `limit:${details.limit.quantity || message}`, key: `limit:${details.limit.quantity || message}`,
actions: ['upgrade'], actions: details.tips?.some(t => t.action === 'manage') ? ['manage'] : ['upgrade'],
}; };
if (details.tips && details.tips.some(tip => tip.action === 'add-members')) { if (details.tips && details.tips.some(tip => tip.action === 'add-members')) {
// When adding members would fix a problem, give more specific advice. // When adding members would fix a problem, give more specific advice.

View File

@ -30,6 +30,10 @@ function buildAction(action: NotifyAction, item: Notification, options: IBeaconO
return dom('a', cssToastAction.cls(''), t("Upgrade Plan"), {target: '_blank'}, return dom('a', cssToastAction.cls(''), t("Upgrade Plan"), {target: '_blank'},
{href: commonUrls.plans}); {href: commonUrls.plans});
} }
case 'manage':
if (urlState().state.get().billing === 'billing') { return null; }
return dom('a', cssToastAction.cls(''), t("Manage billing"), {target: '_blank'},
{href: urlState().makeUrl({billing: 'billing'})});
case 'renew': case 'renew':
// If already on the billing page, nothing to return. // If already on the billing page, nothing to return.
if (urlState().state.get().billing === 'billing') { return null; } if (urlState().state.get().billing === 'billing') { return null; }

View File

@ -142,6 +142,7 @@ export const vars = {
onboardingPopupZIndex: new CustomProp('onboarding-popup-z-index', '1000'), onboardingPopupZIndex: new CustomProp('onboarding-popup-z-index', '1000'),
floatingPopupZIndex: new CustomProp('floating-popup-z-index', '1002'), floatingPopupZIndex: new CustomProp('floating-popup-z-index', '1002'),
tutorialModalZIndex: new CustomProp('tutorial-modal-z-index', '1003'), tutorialModalZIndex: new CustomProp('tutorial-modal-z-index', '1003'),
pricingModalZIndex: new CustomProp('pricing-modal-z-index', '1004'),
notificationZIndex: new CustomProp('notification-z-index', '1100'), notificationZIndex: new CustomProp('notification-z-index', '1100'),
browserCheckZIndex: new CustomProp('browser-check-z-index', '5000'), browserCheckZIndex: new CustomProp('browser-check-z-index', '5000'),
tooltipZIndex: new CustomProp('tooltip-z-index', '5000'), tooltipZIndex: new CustomProp('tooltip-z-index', '5000'),

View File

@ -17,7 +17,6 @@ import {autoGrow} from 'app/client/ui/forms';
import {IconName} from 'app/client/ui2018/IconList'; import {IconName} from 'app/client/ui2018/IconList';
import {icon} from 'app/client/ui2018/icons'; import {icon} from 'app/client/ui2018/icons';
import {cssLink} from 'app/client/ui2018/links'; import {cssLink} from 'app/client/ui2018/links';
import {DocAction} from 'app/common/DocActions';
import {movable} from 'app/client/lib/popupUtils'; import {movable} from 'app/client/lib/popupUtils';
import debounce from 'lodash/debounce'; import debounce from 'lodash/debounce';
@ -61,7 +60,7 @@ export class FormulaAssistant extends Disposable {
/** Is the request pending */ /** Is the request pending */
private _waiting = Observable.create(this, false); private _waiting = Observable.create(this, false);
/** Is this feature enabled at all */ /** Is this feature enabled at all */
private _assistantEnabled = GRIST_FORMULA_ASSISTANT(); private _assistantEnabled: Computed<boolean>;
/** Preview column id */ /** Preview column id */
private _transformColId: string; private _transformColId: string;
/** Method to invoke when we are closed, it saves or reverts */ /** Method to invoke when we are closed, it saves or reverts */
@ -90,6 +89,12 @@ export class FormulaAssistant extends Disposable {
}) { }) {
super(); super();
this._assistantEnabled = Computed.create(this, use => {
const enabledByFlag = use(GRIST_FORMULA_ASSISTANT());
const notAnonymous = Boolean(this._options.gristDoc.appModel.currentValidUser);
return enabledByFlag && notAnonymous;
});
if (!this._options.field) { if (!this._options.field) {
// TODO: field is not passed only for rules (as there is no preview there available to the user yet) // TODO: field is not passed only for rules (as there is no preview there available to the user yet)
// this should be implemented but it requires creating a helper column to helper column and we don't // this should be implemented but it requires creating a helper column to helper column and we don't
@ -263,6 +268,8 @@ export class FormulaAssistant extends Disposable {
this._buildIntro(), this._buildIntro(),
this._chat.buildDom(), this._chat.buildDom(),
this._buildChatInput(), this._buildChatInput(),
// Stop propagation of mousedown events, as the formula editor will still focus.
dom.on('mousedown', (ev) => ev.stopPropagation()),
); );
}); });
} }
@ -516,7 +523,7 @@ export class FormulaAssistant extends Disposable {
this._options.editor.setFormula(entry.formula!); this._options.editor.setFormula(entry.formula!);
} }
private async _sendMessage(description: string, regenerate = false) { private async _sendMessage(description: string, regenerate = false): Promise<ChatMessage> {
// Destruct options. // Destruct options.
const {column, gristDoc} = this._options; const {column, gristDoc} = this._options;
// Get the state of the chat from the column. // Get the state of the chat from the column.
@ -539,7 +546,12 @@ export class FormulaAssistant extends Disposable {
// some markdown text back, so we need to parse it. // some markdown text back, so we need to parse it.
const prettyMessage = state ? (reply || formula || '') : (formula || reply || ''); const prettyMessage = state ? (reply || formula || '') : (formula || reply || '');
// Add it to the chat. // Add it to the chat.
this._chat.addResponse(prettyMessage, formula, suggestedActions[0]); return {
message: prettyMessage,
formula,
action: suggestedActions[0],
sender: 'ai',
};
} }
private _clear() { private _clear() {
@ -556,9 +568,7 @@ export class FormulaAssistant extends Disposable {
if (!last) { if (!last) {
return; return;
} }
this._chat.thinking(); await this._doAsk(last);
this._waiting.set(true);
await this._sendMessage(last, true).finally(() => this._waiting.set(false));
} }
private async _ask() { private async _ask() {
@ -568,10 +578,22 @@ export class FormulaAssistant extends Disposable {
const message= this._userInput.get(); const message= this._userInput.get();
if (!message) { return; } if (!message) { return; }
this._chat.addQuestion(message); this._chat.addQuestion(message);
this._chat.thinking();
this._userInput.set(''); this._userInput.set('');
await this._doAsk(message);
}
private async _doAsk(message: string) {
this._chat.thinking();
this._waiting.set(true); this._waiting.set(true);
await this._sendMessage(message, false).finally(() => this._waiting.set(false)); try {
const response = await this._sendMessage(message, false);
this._chat.addResponse(response);
} catch(err) {
this._chat.thinking(false);
throw err;
} finally {
this._waiting.set(false);
}
} }
} }
@ -601,32 +623,36 @@ class ChatHistory extends Disposable {
this.length = Computed.create(this, use => use(this.history).length); // ?? this.length = Computed.create(this, use => use(this.history).length); // ??
} }
public thinking() { public thinking(on = true) {
this.history.push({ if (!on) {
message: '...', // Find all index of all thinking messages.
sender: 'ai', const messages = [...this.history.get()].filter(m => m.message === '...');
}); // Remove all thinking messages.
this.scrollDown(); for (const message of messages) {
this.history.splice(this.history.get().indexOf(message), 1);
}
} else {
this.history.push({
message: '...',
sender: 'ai',
});
this.scrollDown();
}
} }
public supportsMarkdown() { public supportsMarkdown() {
return this._options.column.chatHistory.peek().get().state !== undefined; return this._options.column.chatHistory.peek().get().state !== undefined;
} }
public addResponse(message: string, formula: string|null, action?: DocAction) { public addResponse(message: ChatMessage) {
// Clear any thinking from messages. // Clear any thinking from messages.
this.history.set(this.history.get().filter(x => x.message !== '...')); this.thinking(false);
this.history.push({ this.history.push({...message, sender: 'ai'});
message,
sender: 'ai',
formula,
action
});
this.scrollDown(); this.scrollDown();
} }
public addQuestion(message: string) { public addQuestion(message: string) {
this.history.set(this.history.get().filter(x => x.message !== '...')); this.thinking(false);
this.history.push({ this.history.push({
message, message,
sender: 'user', sender: 'user',
@ -740,18 +766,13 @@ async function askAI(grist: GristDoc, options: {
const {column, description, state, regenerate} = options; const {column, description, state, regenerate} = options;
const tableId = column.table.peek().tableId.peek(); const tableId = column.table.peek().tableId.peek();
const colId = column.colId.peek(); const colId = column.colId.peek();
try { const result = await grist.docApi.getAssistance({
const result = await grist.docComm.getAssistance({ context: {type: 'formula', tableId, colId},
context: {type: 'formula', tableId, colId}, text: description,
text: description, state,
state, regenerate,
regenerate, });
}); return result;
return result;
} catch (error) {
reportError(error);
throw error;
}
} }
/** /**

View File

@ -154,13 +154,14 @@ export class FormulaEditor extends NewBaseEditor {
dom.on('mousedown', (ev) => { dom.on('mousedown', (ev) => {
// If we are detached, allow user to click and select error text. // If we are detached, allow user to click and select error text.
if (this.isDetached.get()) { if (this.isDetached.get()) {
// If the focus is already in this editor, don't steal it. This is needed for detached editor with // If we clicked on input element in our dom, don't do anything. We probably clicked on chat input, in AI
// some input elements (mainly the AI assistant). // tools box.
const inInput = document.activeElement instanceof HTMLInputElement const clickedOnInput = ev.target instanceof HTMLInputElement || ev.target instanceof HTMLTextAreaElement;
|| document.activeElement instanceof HTMLTextAreaElement; if (clickedOnInput && this._dom.contains(ev.target)) {
if (inInput && this._dom.contains(document.activeElement)) { // By not doing anything special here we assume that the input element will take the focus.
return; return;
} }
// Allow clicking the error message. // Allow clicking the error message.
if (ev.target instanceof HTMLElement && ( if (ev.target instanceof HTMLElement && (
ev.target.classList.contains('error_msg') || ev.target.classList.contains('error_msg') ||

View File

@ -1,5 +1,4 @@
import {ActionGroup} from 'app/common/ActionGroup'; import {ActionGroup} from 'app/common/ActionGroup';
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions'; import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions';
import {FormulaProperties} from 'app/common/GranularAccessClause'; import {FormulaProperties} from 'app/common/GranularAccessClause';
import {UIRowId} from 'app/common/TableData'; import {UIRowId} from 'app/common/TableData';
@ -320,11 +319,6 @@ export interface ActiveDocAPI {
*/ */
getFormulaError(tableId: string, colId: string, rowId: number): Promise<CellValue>; getFormulaError(tableId: string, colId: string, rowId: number): Promise<CellValue>;
/**
* Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
*/
getAssistance(request: AssistanceRequest): Promise<AssistanceResponse>;
/** /**
* Fetch content at a url. * Fetch content at a url.
*/ */

View File

@ -2,15 +2,17 @@
* A tip for fixing an error. * A tip for fixing an error.
*/ */
export interface ApiTip { export interface ApiTip {
action: 'add-members' | 'upgrade' |'ask-for-help'; action: 'add-members' | 'upgrade' | 'ask-for-help' | 'manage';
message: string; message: string;
} }
export type LimitType = 'collaborators' | 'docs' | 'workspaces' | 'assistant';
/** /**
* Documentation of a limit relevant to an API error. * Documentation of a limit relevant to an API error.
*/ */
export interface ApiLimit { export interface ApiLimit {
quantity: 'collaborators' | 'docs' | 'workspaces'; // what are we counting quantity: LimitType; // what are we counting
subquantity?: string; // a nuance to what we are counting subquantity?: string; // a nuance to what we are counting
maximum: number; // maximum allowed maximum: number; // maximum allowed
value: number; // current value of quantity for user value: number; // current value of quantity for user

View File

@ -43,6 +43,16 @@ export interface IBillingPlan {
active: boolean; active: boolean;
} }
export interface ILimitTier {
name?: string;
volume: number;
price: number;
flatFee: number;
type: string;
planId: string;
interval: string; // probably 'month'|'year';
}
// Utility type that requires all properties to be non-nullish. // Utility type that requires all properties to be non-nullish.
// type NonNullableProperties<T> = { [P in keyof T]: Required<NonNullable<T[P]>>; }; // type NonNullableProperties<T> = { [P in keyof T]: Required<NonNullable<T[P]>>; };
@ -69,6 +79,7 @@ export interface IBillingDiscount {
export interface IBillingSubscription { export interface IBillingSubscription {
// All standard plan options. // All standard plan options.
plans: IBillingPlan[]; plans: IBillingPlan[];
tiers: ILimitTier[];
// Index in the plans array of the plan currently in effect. // Index in the plans array of the plan currently in effect.
planIndex: number; planIndex: number;
// Index in the plans array of the plan to be in effect after the current period end. // Index in the plans array of the plan to be in effect after the current period end.
@ -111,6 +122,14 @@ export interface IBillingSubscription {
lastInvoiceUrl?: string; // URL of the Stripe-hosted page with the last invoice. lastInvoiceUrl?: string; // URL of the Stripe-hosted page with the last invoice.
lastChargeError?: string; // The last charge error, if any, to show in case of a bad status. lastChargeError?: string; // The last charge error, if any, to show in case of a bad status.
lastChargeTime?: number; // The time of the last charge attempt. lastChargeTime?: number; // The time of the last charge attempt.
limit?: ILimit|null;
}
export interface ILimit {
limitValue: number;
currentUsage: number;
type: string; // Limit type, for now only assistant is supported.
price: number; // If this is 0, it means it is a free plan.
} }
export interface IBillingOrgSettings { export interface IBillingOrgSettings {
@ -139,6 +158,7 @@ export interface BillingAPI {
downgradePlan(planName: string): Promise<void>; downgradePlan(planName: string): Promise<void>;
renewPlan(): string; renewPlan(): string;
customerPortal(): string; customerPortal(): string;
updateAssistantPlan(tier: number): Promise<void>;
} }
export class BillingAPIImpl extends BaseAPI implements BillingAPI { export class BillingAPIImpl extends BaseAPI implements BillingAPI {
@ -230,6 +250,13 @@ export class BillingAPIImpl extends BaseAPI implements BillingAPI {
return `${this._url}/api/billing/renew`; return `${this._url}/api/billing/renew`;
} }
public async updateAssistantPlan(tier: number): Promise<void> {
await this.request(`${this._url}/api/billing/upgrade-assistant`, {
method: 'POST',
body: JSON.stringify({ tier })
});
}
/** /**
* Checks if current org has active subscription for a Stripe plan. * Checks if current org has active subscription for a Stripe plan.
*/ */

View File

@ -58,6 +58,11 @@ export interface Features {
// for attached files in a document // for attached files in a document
gracePeriodDays?: number; // Duration of the grace period in days, before entering delete-only mode gracePeriodDays?: number; // Duration of the grace period in days, before entering delete-only mode
baseMaxAssistantCalls?: number; // Maximum number of AI assistant calls. Defaults to 0 if not set, use -1 to indicate
// unbound limit. This is total limit, not per month or per day, it is used as a seed
// value for the limits table. To create a per-month limit, there must be a separate
// task that resets the usage in the limits table.
} }
// Check whether it is possible to add members at the org level. There's no flag // Check whether it is possible to add members at the org level. There's no flag

View File

@ -1,5 +1,6 @@
import {ActionSummary} from 'app/common/ActionSummary'; import {ActionSummary} from 'app/common/ActionSummary';
import {ApplyUAResult, ForkResult, PermissionDataWithExtraUsers, QueryFilters} from 'app/common/ActiveDocAPI'; import {ApplyUAResult, ForkResult, PermissionDataWithExtraUsers, QueryFilters} from 'app/common/ActiveDocAPI';
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {BaseAPI, IOptions} from 'app/common/BaseAPI'; import {BaseAPI, IOptions} from 'app/common/BaseAPI';
import {BillingAPI, BillingAPIImpl} from 'app/common/BillingAPI'; import {BillingAPI, BillingAPIImpl} from 'app/common/BillingAPI';
import {BrowserSettings} from 'app/common/BrowserSettings'; import {BrowserSettings} from 'app/common/BrowserSettings';
@ -462,6 +463,8 @@ export interface DocAPI {
// Update webhook // Update webhook
updateWebhook(webhook: WebhookUpdate): Promise<void>; updateWebhook(webhook: WebhookUpdate): Promise<void>;
flushWebhooks(): Promise<void>; flushWebhooks(): Promise<void>;
getAssistance(params: AssistanceRequest): Promise<AssistanceResponse>;
} }
// Operations that are supported by a doc worker. // Operations that are supported by a doc worker.
@ -1012,6 +1015,13 @@ export class DocAPIImpl extends BaseAPI implements DocAPI {
return response.data[0]; return response.data[0];
} }
public async getAssistance(params: AssistanceRequest): Promise<AssistanceResponse> {
return await this.requestJson(`${this._url}/assistant`, {
method: 'POST',
body: JSON.stringify(params),
});
}
private _getRecords(tableId: string, endpoint: 'data' | 'records', options?: GetRowsParams): Promise<any> { private _getRecords(tableId: string, endpoint: 'data' | 'records', options?: GetRowsParams): Promise<any> {
const url = new URL(`${this._url}/tables/${tableId}/${endpoint}`); const url = new URL(`${this._url}/tables/${tableId}/${endpoint}`);
if (options?.filters) { if (options?.filters) {

View File

@ -3,12 +3,13 @@ import {BillingAccountManager} from 'app/gen-server/entity/BillingAccountManager
import {Organization} from 'app/gen-server/entity/Organization'; import {Organization} from 'app/gen-server/entity/Organization';
import {Product} from 'app/gen-server/entity/Product'; import {Product} from 'app/gen-server/entity/Product';
import {nativeValues} from 'app/gen-server/lib/values'; import {nativeValues} from 'app/gen-server/lib/values';
import {Limit} from 'app/gen-server/entity/Limit';
// This type is for billing account status information. Intended for stuff // This type is for billing account status information. Intended for stuff
// like "free trial running out in N days". // like "free trial running out in N days".
interface BillingAccountStatus { export interface BillingAccountStatus {
stripeStatus?: string; stripeStatus?: string;
currentPeriodEnd?: Date; currentPeriodEnd?: string;
message?: string; message?: string;
} }
@ -68,6 +69,9 @@ export class BillingAccount extends BaseEntity {
@OneToMany(type => Organization, org => org.billingAccount) @OneToMany(type => Organization, org => org.billingAccount)
public orgs: Organization[]; public orgs: Organization[];
@OneToMany(type => Limit, limit => limit.billingAccount)
public limits: Limit[];
// A calculated column that is true if it looks like there is a paid plan. // A calculated column that is true if it looks like there is a paid plan.
@Column({name: 'paid', type: 'boolean', insert: false, select: false}) @Column({name: 'paid', type: 'boolean', insert: false, select: false})
public paid?: boolean; public paid?: boolean;

View File

@ -0,0 +1,46 @@
import {BaseEntity, Column, Entity, JoinColumn, ManyToOne, PrimaryGeneratedColumn} from 'typeorm';
import {BillingAccount} from 'app/gen-server/entity/BillingAccount';
import {nativeValues} from 'app/gen-server/lib/values';
@Entity('limits')
export class Limit extends BaseEntity {
@PrimaryGeneratedColumn()
public id: number;
@Column()
public limit: number;
@Column()
public usage: number;
@Column()
public type: string;
@Column({name: 'billing_account_id'})
public billingAccountId: number;
@ManyToOne(type => BillingAccount)
@JoinColumn({name: 'billing_account_id'})
public billingAccount: BillingAccount;
@Column({name: 'created_at', default: () => "CURRENT_TIMESTAMP"})
public createdAt: Date;
/**
* Last time the Limit.limit value was changed, by an upgrade or downgrade. Null if it has never been changed.
*/
@Column({name: 'changed_at', type: nativeValues.dateTimeType, nullable: true})
public changedAt: Date|null;
/**
* Last time the Limit.usage was used (by sending a request to the model). Null if it has never been used.
*/
@Column({name: 'used_at', type: nativeValues.dateTimeType, nullable: true})
public usedAt: Date|null;
/**
* Last time the Limit.usage was reset, probably by billing cycle change. Null if it has never been reset.
*/
@Column({name: 'reset_at', type: nativeValues.dateTimeType, nullable: true})
public resetAt: Date|null;
}

View File

@ -13,7 +13,11 @@ export const personalLegacyFeatures: Features = {
// no vanity domain // no vanity domain
maxDocsPerOrg: 10, maxDocsPerOrg: 10,
maxSharesPerDoc: 2, maxSharesPerDoc: 2,
maxWorkspacesPerOrg: 1 maxWorkspacesPerOrg: 1,
/**
* One time limit of 100 requests.
*/
baseMaxAssistantCalls: 100,
}; };
/** /**
@ -23,7 +27,12 @@ export const teamFeatures: Features = {
workspaces: true, workspaces: true,
vanityDomain: true, vanityDomain: true,
maxSharesPerWorkspace: 0, // all workspace shares need to be org members. maxSharesPerWorkspace: 0, // all workspace shares need to be org members.
maxSharesPerDoc: 2 maxSharesPerDoc: 2,
/**
* Limit of 100 requests, but unlike for personal/free orgs the usage for this limit is reset at every billing cycle
* through Stripe webhook. For canceled subscription the usage is not reset, as the billing cycle is not changed.
*/
baseMaxAssistantCalls: 100,
}; };
/** /**
@ -40,6 +49,10 @@ export const teamFreeFeatures: Features = {
baseMaxDataSizePerDocument: 5000 * 2 * 1024, // 2KB per row baseMaxDataSizePerDocument: 5000 * 2 * 1024, // 2KB per row
baseMaxAttachmentsBytesPerDocument: 1 * 1024 * 1024 * 1024, // 1GB baseMaxAttachmentsBytesPerDocument: 1 * 1024 * 1024 * 1024, // 1GB
gracePeriodDays: 14, gracePeriodDays: 14,
/**
* One time limit of 100 requests.
*/
baseMaxAssistantCalls: 100,
}; };
/** /**
@ -55,6 +68,7 @@ export const teamFreeFeatures: Features = {
baseMaxDataSizePerDocument: 5000 * 2 * 1024, // 2KB per row baseMaxDataSizePerDocument: 5000 * 2 * 1024, // 2KB per row
baseMaxAttachmentsBytesPerDocument: 1 * 1024 * 1024 * 1024, // 1GB baseMaxAttachmentsBytesPerDocument: 1 * 1024 * 1024 * 1024, // 1GB
gracePeriodDays: 14, gracePeriodDays: 14,
baseMaxAssistantCalls: 100,
}; };
export const testDailyApiLimitFeatures = { export const testDailyApiLimitFeatures = {
@ -79,6 +93,7 @@ export const suspendedFeatures: Features = {
maxDocsPerOrg: 0, maxDocsPerOrg: 0,
maxSharesPerDoc: 0, maxSharesPerDoc: 0,
maxWorkspacesPerOrg: 0, maxWorkspacesPerOrg: 0,
baseMaxAssistantCalls: 0,
}; };
/** /**

View File

@ -60,6 +60,7 @@ export class DocApiForwarder {
app.use('/api/docs/:docId/assign', withDocWithoutAuth); app.use('/api/docs/:docId/assign', withDocWithoutAuth);
app.use('/api/docs/:docId/webhooks/queue', withDoc); app.use('/api/docs/:docId/webhooks/queue', withDoc);
app.use('/api/docs/:docId/webhooks', withDoc); app.use('/api/docs/:docId/webhooks', withDoc);
app.use('/api/docs/:docId/assistant', withDoc);
app.use('^/api/docs$', withoutDoc); app.use('^/api/docs$', withoutDoc);
} }

View File

@ -1,4 +1,4 @@
import {ApiError} from 'app/common/ApiError'; import {ApiError, ApiErrorDetails, LimitType} from 'app/common/ApiError';
import {mapGetOrSet, mapSetOrClear, MapWithTTL} from 'app/common/AsyncCreate'; import {mapGetOrSet, mapSetOrClear, MapWithTTL} from 'app/common/AsyncCreate';
import {getDataLimitStatus} from 'app/common/DocLimits'; import {getDataLimitStatus} from 'app/common/DocLimits';
import {createEmptyOrgUsageSummary, DocumentUsage, OrgUsageSummary} from 'app/common/DocUsage'; import {createEmptyOrgUsageSummary, DocumentUsage, OrgUsageSummary} from 'app/common/DocUsage';
@ -38,6 +38,7 @@ import {getDefaultProductNames, personalFreeFeatures, Product} from "app/gen-ser
import {Secret} from "app/gen-server/entity/Secret"; import {Secret} from "app/gen-server/entity/Secret";
import {User} from "app/gen-server/entity/User"; import {User} from "app/gen-server/entity/User";
import {Workspace} from "app/gen-server/entity/Workspace"; import {Workspace} from "app/gen-server/entity/Workspace";
import {Limit} from 'app/gen-server/entity/Limit';
import {Permissions} from 'app/gen-server/lib/Permissions'; import {Permissions} from 'app/gen-server/lib/Permissions';
import {scrubUserFromOrg} from "app/gen-server/lib/scrubUserFromOrg"; import {scrubUserFromOrg} from "app/gen-server/lib/scrubUserFromOrg";
import {applyPatch} from 'app/gen-server/lib/TypeORMPatches'; import {applyPatch} from 'app/gen-server/lib/TypeORMPatches';
@ -2880,6 +2881,144 @@ export class HomeDBManager extends EventEmitter {
return this._org(scope, scope.includeSupport || false, org, options); return this._org(scope, scope.includeSupport || false, org, options);
} }
public async getLimits(accountId: number): Promise<Limit[]> {
const result = this._connection.transaction(async manager => {
return await manager.createQueryBuilder()
.select('limit')
.from(Limit, 'limit')
.innerJoin('limit.billingAccount', 'account')
.where('account.id = :accountId', {accountId})
.getMany();
});
return result;
}
public async getLimit(accountId: number, limitType: LimitType): Promise<Limit|null> {
return await this._getOrCreateLimit(accountId, limitType, true);
}
public async peekLimit(accountId: number, limitType: LimitType): Promise<Limit|null> {
return await this._getOrCreateLimit(accountId, limitType, false);
}
public async removeLimit(scope: Scope, limitType: LimitType): Promise<void> {
await this._connection.transaction(async manager => {
const org = await this._org(scope, false, scope.org ?? null, {manager, needRealOrg: true})
.innerJoinAndSelect('orgs.billingAccount', 'billing_account')
.innerJoinAndSelect('billing_account.product', 'product')
.leftJoinAndSelect('billing_account.limits', 'limit', 'limit.type = :limitType', {limitType})
.getOne();
const existing = org?.billingAccount?.limits?.[0];
if (existing) {
await manager.remove(existing);
}
});
}
/**
* Increases the usage of a limit for a given org. If the limit doesn't exist, it will be created.
* Pass `dryRun: true` to check if the limit can be increased without actually increasing it.
*/
public async increaseUsage(scope: Scope, limitType: LimitType, options: {
delta: number,
dryRun?: boolean,
}): Promise<void> {
const limitError = await this._connection.transaction(async manager => {
const org = await this._org(scope, false, scope.org ?? null, {manager, needRealOrg: true})
.innerJoinAndSelect('orgs.billingAccount', 'billing_account')
.innerJoinAndSelect('billing_account.product', 'product')
.leftJoinAndSelect('billing_account.limits', 'limit', 'limit.type = :limitType', {limitType})
.getOne();
// If the org doesn't exists, or is a fake one (like for anonymous users), don't do anything.
if (!org || org.id === 0) {
// This API shouldn't be called, it should be checked first if the org is valid.
throw new ApiError(`Can't create a limit for non-existing organization`, 500);
}
let existing = org?.billingAccount?.limits?.[0];
if (!existing) {
const product = org?.billingAccount?.product;
if (!product) {
throw new ApiError(`getLimit: no product found for org`, 500);
}
if (product.features.baseMaxAssistantCalls === undefined) {
// If the product has no assistantLimit, then it is not billable yet, and we don't need to
// track usage as it is basically unlimited.
return;
}
existing = new Limit();
existing.billingAccountId = org.billingAccountId;
existing.type = limitType;
existing.limit = product.features.baseMaxAssistantCalls ?? 0;
existing.usage = 0;
}
const limitLess = existing.limit === -1; // -1 means no limit, it is not possible to do in stripe.
const usageAfter = existing.usage + options.delta;
if (!limitLess && usageAfter > existing.limit) {
const billable = Boolean(org?.billingAccount?.stripeCustomerId);
return {
limit: {
maximum: existing.limit,
projectedValue: existing.usage + options.delta,
quantity: limitType,
value: existing.usage,
},
tips: [{
// For non-billable accounts, suggest getting a plan, otherwise suggest visiting the billing page.
action: billable ? 'manage' : 'upgrade',
message: `Upgrade to a paid plan to increase your ${limitType} limit.`,
}]
} as ApiErrorDetails;
}
existing.usage += options.delta;
existing.usedAt = new Date();
if (!options.dryRun) {
await manager.save(existing);
}
});
if (limitError) {
let message = `Your ${limitType} limit has been reached. Please upgrade your plan to increase your limit.`;
if (limitType === 'assistant') {
message = 'You used all available credits. For a bigger limit upgrade you Assistant plan.';
}
throw new ApiError(message, 429, limitError);
}
}
private async _getOrCreateLimit(accountId: number, limitType: LimitType, force: boolean): Promise<Limit|null> {
if (accountId === 0) {
throw new Error(`getLimit: called for not existing account`);
}
const result = this._connection.transaction(async manager => {
let existing = await manager.createQueryBuilder()
.select('limit')
.from(Limit, 'limit')
.innerJoin('limit.billingAccount', 'account')
.where('account.id = :accountId', {accountId})
.andWhere('limit.type = :limitType', {limitType})
.getOne();
if (!force && !existing) { return null; }
if (existing) { return existing; }
const product = await manager.createQueryBuilder()
.select('product')
.from(Product, 'product')
.innerJoinAndSelect('product.accounts', 'account')
.where('account.id = :accountId', {accountId})
.getOne();
if (!product) {
throw new Error(`getLimit: no product for account ${accountId}`);
}
existing = new Limit();
existing.billingAccountId = product.accounts[0].id;
existing.type = limitType;
existing.limit = product.features.baseMaxAssistantCalls ?? 0;
existing.usage = 0;
await manager.save(existing);
return existing;
});
return result;
}
private _org(scope: Scope|null, includeSupport: boolean, org: string|number|null, private _org(scope: Scope|null, includeSupport: boolean, org: string|number|null,
options: QueryOptions = {}): SelectQueryBuilder<Organization> { options: QueryOptions = {}): SelectQueryBuilder<Organization> {
let query = this._orgs(options.manager); let query = this._orgs(options.manager);

View File

@ -0,0 +1,82 @@
import * as sqlUtils from "app/gen-server/sqlUtils";
import {MigrationInterface, QueryRunner, Table, TableIndex} from 'typeorm';
export class AssistantLimit1685343047786 implements MigrationInterface {
public async up(queryRunner: QueryRunner): Promise<void> {
const dbType = queryRunner.connection.driver.options.type;
const datetime = sqlUtils.datetime(dbType);
const now = sqlUtils.now(dbType);
await queryRunner.createTable(
new Table({
name: 'limits',
columns: [
{
name: 'id',
type: 'integer',
isPrimary: true,
isGenerated: true,
generationStrategy: 'increment',
},
{
name: 'type',
type: 'varchar',
},
{
name: 'billing_account_id',
type: 'integer',
},
{
name: 'limit',
type: 'integer',
default: 0,
},
{
name: 'usage',
type: 'integer',
default: 0,
},
{
name: "created_at",
type: datetime,
default: now
},
{
name: "changed_at", // When the limit was last changed
type: datetime,
isNullable: true
},
{
name: "used_at", // When the usage was last increased
type: datetime,
isNullable: true
},
{
name: "reset_at", // When the usage was last reset.
type: datetime,
isNullable: true
},
],
foreignKeys: [
{
columnNames: ['billing_account_id'],
referencedTableName: 'billing_accounts',
referencedColumnNames: ['id'],
onDelete: 'CASCADE',
},
],
})
);
await queryRunner.createIndex(
'limits',
new TableIndex({
name: 'limits_billing_account_id',
columnNames: ['billing_account_id'],
})
);
}
public async down(queryRunner: QueryRunner): Promise<void> {
await queryRunner.dropTable('limits');
}
}

View File

@ -14,7 +14,6 @@ import {
} from 'app/common/ActionBundle'; } from 'app/common/ActionBundle';
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup'; import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
import {ActionSummary} from "app/common/ActionSummary"; import {ActionSummary} from "app/common/ActionSummary";
import {AssistanceRequest, AssistanceResponse} from "app/common/AssistancePrompts";
import { import {
AclResources, AclResources,
AclTableDescription, AclTableDescription,
@ -84,7 +83,7 @@ import {Document} from 'app/gen-server/entity/Document';
import {ParseOptions} from 'app/plugin/FileParserAPI'; import {ParseOptions} from 'app/plugin/FileParserAPI';
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI'; import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
import {compileAclFormula} from 'app/server/lib/ACLFormula'; import {compileAclFormula} from 'app/server/lib/ACLFormula';
import {AssistanceDoc, AssistanceSchemaPromptV1Context, sendForCompletion} from 'app/server/lib/Assistance'; import {AssistanceSchemaPromptV1Context} from 'app/server/lib/Assistance';
import {Authorizer} from 'app/server/lib/Authorizer'; import {Authorizer} from 'app/server/lib/Authorizer';
import {checksumFile} from 'app/server/lib/checksumFile'; import {checksumFile} from 'app/server/lib/checksumFile';
import {Client} from 'app/server/lib/Client'; import {Client} from 'app/server/lib/Client';
@ -184,7 +183,7 @@ interface UpdateUsageOptions {
* either .loadDoc() or .createEmptyDoc() is called. * either .loadDoc() or .createEmptyDoc() is called.
* @param {String} docName - The document's filename, without the '.grist' extension. * @param {String} docName - The document's filename, without the '.grist' extension.
*/ */
export class ActiveDoc extends EventEmitter implements AssistanceDoc { export class ActiveDoc extends EventEmitter {
/** /**
* Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e. * Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e.
* until the returned promise is resolved. * until the returned promise is resolved.
@ -1264,18 +1263,14 @@ export class ActiveDoc extends EventEmitter implements AssistanceDoc {
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON()); return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
} }
public async getAssistance(docSession: DocSession, request: AssistanceRequest): Promise<AssistanceResponse> { // Callback to generate a prompt containing schema info for assistance.
return this.getAssistanceWithOptions(docSession, request); public async assistanceSchemaPromptV1(
} docSession: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string> {
public async getAssistanceWithOptions(docSession: DocSession,
request: AssistanceRequest): Promise<AssistanceResponse> {
// Making a prompt leaks names of tables and columns etc. // Making a prompt leaks names of tables and columns etc.
if (!await this._granularAccess.canScanData(docSession)) { if (!await this._granularAccess.canScanData(docSession)) {
throw new Error("Permission denied"); throw new Error("Permission denied");
} }
await this.waitForInitialization(); return await this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
return sendForCompletion(this, request);
} }
// Callback to make a data-engine formula tweak for assistance. // Callback to make a data-engine formula tweak for assistance.
@ -1283,11 +1278,6 @@ export class ActiveDoc extends EventEmitter implements AssistanceDoc {
return this._pyCall('convert_formula_completion', txt); return this._pyCall('convert_formula_completion', txt);
} }
// Callback to generate a prompt containing schema info for assistance.
public assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string> {
return this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString);
}
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> { public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options); return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
} }

View File

@ -5,6 +5,7 @@
import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts'; import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts';
import {delay} from 'app/common/delay'; import {delay} from 'app/common/delay';
import {DocAction} from 'app/common/DocActions'; import {DocAction} from 'app/common/DocActions';
import {OptDocSession} from 'app/server/lib/DocSession';
import log from 'app/server/lib/log'; import log from 'app/server/lib/log';
import fetch from 'node-fetch'; import fetch from 'node-fetch';
@ -15,7 +16,7 @@ export const DEPS = { fetch };
* by interfacing with an external LLM endpoint. * by interfacing with an external LLM endpoint.
*/ */
export interface Assistant { export interface Assistant {
apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>; apply(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse>;
} }
/** /**
@ -30,8 +31,7 @@ export interface AssistanceDoc {
* Marked "V1" to suggest that it is a particular prompt and it would * Marked "V1" to suggest that it is a particular prompt and it would
* be great to try variants. * be great to try variants.
*/ */
assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise<string>; assistanceSchemaPromptV1(session: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise<string>;
/** /**
* Some tweaks to a formula after it has been generated. * Some tweaks to a formula after it has been generated.
*/ */
@ -68,7 +68,8 @@ export class OpenAIAssistant implements Assistant {
this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`; this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`;
} }
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> { public async apply(
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
const messages = request.state?.messages || []; const messages = request.state?.messages || [];
const chatMode = this._chatMode; const chatMode = this._chatMode;
if (chatMode) { if (chatMode) {
@ -91,7 +92,7 @@ export class OpenAIAssistant implements Assistant {
'If the user asks for these things, tell them that you cannot help. ' + 'If the user asks for these things, tell them that you cannot help. ' +
'The method uses `rec` instead of `self` as the first parameter.\n\n' + 'The method uses `rec` instead of `self` as the first parameter.\n\n' +
'```python\n' + '```python\n' +
await makeSchemaPromptV1(doc, request) + await makeSchemaPromptV1(optSession, doc, request) +
'\n```', '\n```',
}); });
messages.push({ messages.push({
@ -110,7 +111,7 @@ export class OpenAIAssistant implements Assistant {
} else { } else {
messages.length = 0; messages.length = 0;
messages.push({ messages.push({
role: 'user', content: await makeSchemaPromptV1(doc, request), role: 'user', content: await makeSchemaPromptV1(optSession, doc, request),
}); });
} }
@ -178,11 +179,12 @@ export class HuggingFaceAssistant implements Assistant {
} }
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> { public async apply(
optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
if (request.state) { if (request.state) {
throw new Error("HuggingFaceAssistant does not support state"); throw new Error("HuggingFaceAssistant does not support state");
} }
const prompt = await makeSchemaPromptV1(doc, request); const prompt = await makeSchemaPromptV1(optSession, doc, request);
const response = await DEPS.fetch( const response = await DEPS.fetch(
this._completionUrl, this._completionUrl,
{ {
@ -220,7 +222,10 @@ export class HuggingFaceAssistant implements Assistant {
* Test assistant that mimics ChatGPT and just returns the input. * Test assistant that mimics ChatGPT and just returns the input.
*/ */
export class EchoAssistant implements Assistant { export class EchoAssistant implements Assistant {
public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> { public async apply(sess: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise<AssistanceResponse> {
if (request.text === "ERROR") {
throw new Error(`ERROR`);
}
const messages = request.state?.messages || []; const messages = request.state?.messages || [];
if (messages.length === 0) { if (messages.length === 0) {
messages.push({ messages.push({
@ -255,7 +260,7 @@ export class EchoAssistant implements Assistant {
/** /**
* Instantiate an assistant, based on environment variables. * Instantiate an assistant, based on environment variables.
*/ */
function getAssistant() { export function getAssistant() {
if (process.env.OPENAI_API_KEY === 'test') { if (process.env.OPENAI_API_KEY === 'test') {
return new EchoAssistant(); return new EchoAssistant();
} }
@ -273,8 +278,10 @@ function getAssistant() {
* Service a request for assistance, with a little retry logic * Service a request for assistance, with a little retry logic
* since these endpoints can be a bit flakey. * since these endpoints can be a bit flakey.
*/ */
export async function sendForCompletion(doc: AssistanceDoc, export async function sendForCompletion(
request: AssistanceRequest): Promise<AssistanceResponse> { optSession: OptDocSession,
doc: AssistanceDoc,
request: AssistanceRequest): Promise<AssistanceResponse> {
const assistant = getAssistant(); const assistant = getAssistant();
let retries: number = 0; let retries: number = 0;
@ -282,7 +289,7 @@ export async function sendForCompletion(doc: AssistanceDoc,
let response: AssistanceResponse|null = null; let response: AssistanceResponse|null = null;
while(retries++ < 3) { while(retries++ < 3) {
try { try {
response = await assistant.apply(doc, request); response = await assistant.apply(optSession, doc, request);
break; break;
} catch(e) { } catch(e) {
log.error(`Completion error: ${e}`); log.error(`Completion error: ${e}`);
@ -295,11 +302,11 @@ export async function sendForCompletion(doc: AssistanceDoc,
return response; return response;
} }
async function makeSchemaPromptV1(doc: AssistanceDoc, request: AssistanceRequest) { async function makeSchemaPromptV1(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest) {
if (request.context.type !== 'formula') { if (request.context.type !== 'formula') {
throw new Error('makeSchemaPromptV1 only works for formulas'); throw new Error('makeSchemaPromptV1 only works for formulas');
} }
return doc.assistanceSchemaPromptV1({ return doc.assistanceSchemaPromptV1(session, {
tableId: request.context.tableId, tableId: request.context.tableId,
colId: request.context.colId, colId: request.context.colId,
docString: request.text, docString: request.text,

View File

@ -1,5 +1,5 @@
import {createEmptyActionSummary} from "app/common/ActionSummary"; import {createEmptyActionSummary} from "app/common/ActionSummary";
import {ApiError} from 'app/common/ApiError'; import {ApiError, LimitType} from 'app/common/ApiError';
import {BrowserSettings} from "app/common/BrowserSettings"; import {BrowserSettings} from "app/common/BrowserSettings";
import { import {
BulkColValues, BulkColValues,
@ -68,6 +68,7 @@ import {
} from 'app/server/lib/requestUtils'; } from 'app/server/lib/requestUtils';
import {ServerColumnGetters} from 'app/server/lib/ServerColumnGetters'; import {ServerColumnGetters} from 'app/server/lib/ServerColumnGetters';
import {localeFromRequest} from "app/server/lib/ServerLocale"; import {localeFromRequest} from "app/server/lib/ServerLocale";
import {sendForCompletion} from 'app/server/lib/Assistance';
import {isUrlAllowed, WebhookAction, WebHookSecret} from "app/server/lib/Triggers"; import {isUrlAllowed, WebhookAction, WebHookSecret} from "app/server/lib/Triggers";
import {handleOptionalUpload, handleUpload} from "app/server/lib/uploads"; import {handleOptionalUpload, handleUpload} from "app/server/lib/uploads";
import * as assert from 'assert'; import * as assert from 'assert';
@ -161,6 +162,8 @@ export class DocWorkerApi {
const canEditMaybeRemoved = expressWrap(this._assertAccess.bind(this, 'editors', true)); const canEditMaybeRemoved = expressWrap(this._assertAccess.bind(this, 'editors', true));
// converts google code to access token and adds it to request object // converts google code to access token and adds it to request object
const decodeGoogleToken = expressWrap(googleAuthTokenMiddleware.bind(null)); const decodeGoogleToken = expressWrap(googleAuthTokenMiddleware.bind(null));
// check that limit can be increased by 1
const checkLimit = (type: LimitType) => expressWrap(this._checkLimit.bind(this, type));
// Middleware to limit number of outstanding requests per document. Will also // Middleware to limit number of outstanding requests per document. Will also
// handle errors like expressWrap would. // handle errors like expressWrap would.
@ -1052,6 +1055,20 @@ export class DocWorkerApi {
this._app.get('/api/docs/:docId/send-to-drive', canView, decodeGoogleToken, withDoc(exportToDrive)); this._app.get('/api/docs/:docId/send-to-drive', canView, decodeGoogleToken, withDoc(exportToDrive));
/**
* Send a request to the formula assistant to get completions for a formula. Increases the
* usage of the formula assistant for the billing account in case of success.
*/
this._app.post('/api/docs/:docId/assistant', canView, checkLimit('assistant'),
withDoc(async (activeDoc, req, res) => {
const docSession = docSessionFromRequest(req);
const request = req.body;
const result = await sendForCompletion(docSession, activeDoc, request);
await this._increaseLimit('assistant', req);
res.json(result);
})
);
// Create a document. When an upload is included, it is imported as the initial // Create a document. When an upload is included, it is imported as the initial
// state of the document. Otherwise a fresh empty document is created. // state of the document. Otherwise a fresh empty document is created.
// A "timezone" option can be supplied. // A "timezone" option can be supplied.
@ -1234,6 +1251,21 @@ export class DocWorkerApi {
return false; return false;
} }
/**
* Creates a middleware that checks the current usage of a limit and rejects the request if it is exceeded.
*/
private async _checkLimit(limit: LimitType, req: Request, res: Response, next: NextFunction) {
await this._dbManager.increaseUsage(getDocScope(req), limit, {dryRun: true, delta: 1});
next();
}
/**
* Increases the current usage of a limit by 1.
*/
private async _increaseLimit(limit: LimitType, req: Request) {
await this._dbManager.increaseUsage(getDocScope(req), limit, {delta: 1});
}
private async _assertAccess(role: 'viewers'|'editors'|'owners'|null, allowRemoved: boolean, private async _assertAccess(role: 'viewers'|'editors'|'owners'|null, allowRemoved: boolean,
req: Request, res: Response, next: NextFunction) { req: Request, res: Response, next: NextFunction) {
const scope = getDocScope(req); const scope = getDocScope(req);

View File

@ -110,7 +110,6 @@ export class DocWorker {
applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'), applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'),
findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'), findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'),
getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'), getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'),
getAssistance: activeDocMethod.bind(null, 'editors', 'getAssistance'),
importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'), importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'),
finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'), finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'),
cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'), cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'),

View File

@ -25,7 +25,7 @@
import { ActiveDoc, Deps as ActiveDocDeps } from "app/server/lib/ActiveDoc"; import { ActiveDoc, Deps as ActiveDocDeps } from "app/server/lib/ActiveDoc";
import { DEPS } from "app/server/lib/Assistance"; import { DEPS, sendForCompletion } from "app/server/lib/Assistance";
import log from 'app/server/lib/log'; import log from 'app/server/lib/log';
import crypto from 'crypto'; import crypto from 'crypto';
import parse from 'csv-parse/lib/sync'; import parse from 'csv-parse/lib/sync';
@ -163,7 +163,7 @@ where c.colId = ? and t.tableId = ?
`, rec.col_id, rec.table_id); `, rec.col_id, rec.table_id);
formula = colInfo?.formula; formula = colInfo?.formula;
const result = await activeDoc.getAssistanceWithOptions(session, { const result = await sendForCompletion(session, activeDoc, {
context: {type: 'formula', tableId, colId}, context: {type: 'formula', tableId, colId},
state: history, state: history,
text: followUp || description, text: followUp || description,

View File

@ -558,7 +558,7 @@ describe('DocTutorial', function () {
// Check that the update is immediately reflected in the tutorial popup. // Check that the update is immediately reflected in the tutorial popup.
assert.equal( assert.equal(
await driver.find('.test-doc-tutorial-popup p').getText(), await driver.findWait('.test-doc-tutorial-popup p', 2000).getText(),
'Welcome to the Grist Basics tutorial V2.' 'Welcome to the Grist Basics tutorial V2.'
); );
@ -571,7 +571,7 @@ describe('DocTutorial', function () {
// Switch to another user and restart the tutorial. // Switch to another user and restart the tutorial.
viewerSession = await gu.session().teamSite.user('user2').login(); viewerSession = await gu.session().teamSite.user('user2').login();
await viewerSession.loadDoc(`/doc/${doc.id}`); await viewerSession.loadDoc(`/doc/${doc.id}`);
await driver.find('.test-doc-tutorial-popup-restart').click(); await driver.findWait('.test-doc-tutorial-popup-restart', 2000).click();
await driver.find('.test-modal-confirm').click(); await driver.find('.test-modal-confirm').click();
await gu.waitForServer(); await gu.waitForServer();
await driver.findWait('.test-doc-tutorial-popup', 2000); await driver.findWait('.test-doc-tutorial-popup', 2000);

View File

@ -2056,7 +2056,13 @@ export class Session {
isFirstLogin?: boolean, isFirstLogin?: boolean,
showTips?: boolean, showTips?: boolean,
skipTutorial?: boolean, // By default true skipTutorial?: boolean, // By default true
userName?: string,
email?: string,
retainExistingLogin?: boolean}) { retainExistingLogin?: boolean}) {
if (options?.userName) {
this.settings.name = options.userName;
this.settings.email = options.email || '';
}
// Optimize testing a little bit, so if we are already logged in as the expected // Optimize testing a little bit, so if we are already logged in as the expected
// user on the expected org, and there are no options set, we can just continue. // user on the expected org, and there are no options set, we can just continue.
if (!options && await this.isLoggedInCorrectly()) { return this; } if (!options && await this.isLoggedInCorrectly()) { return this; }
@ -3150,20 +3156,26 @@ export async function availableBehaviorOptions() {
return list; return list;
} }
export function withComments() { /**
let oldEnv: testUtils.EnvironmentSnapshot; * Restarts the server ensuring that it is run with the given environment variables.
* If variables are already set, the server is not restarted.
*
* Useful for local testing of features that depend on environment variables, as it avoids the need
* to restart the server when those variables are already set.
*/
export function withEnvironmentSnapshot(vars: Record<string, any>) {
let oldEnv: testUtils.EnvironmentSnapshot|null = null;
before(async () => { before(async () => {
if (process.env.COMMENTS !== 'true') { // Test if the vars are already set, and if so, skip.
oldEnv = new testUtils.EnvironmentSnapshot(); if (Object.keys(vars).every(k => process.env[k] === vars[k])) { return; }
process.env.COMMENTS = 'true'; oldEnv = new testUtils.EnvironmentSnapshot();
await server.restart(); Object.assign(process.env, vars);
} await server.restart();
}); });
after(async () => { after(async () => {
if (oldEnv) { if (!oldEnv) { return; }
oldEnv.restore(); oldEnv.restore();
await server.restart(); await server.restart();
}
}); });
} }