diff --git a/app/client/components/DocComm.ts b/app/client/components/DocComm.ts index 1a0dfe37..fc211c24 100644 --- a/app/client/components/DocComm.ts +++ b/app/client/components/DocComm.ts @@ -32,7 +32,6 @@ export class DocComm extends Disposable implements ActiveDocAPI { public addAttachments = this._wrapMethod("addAttachments"); public findColFromValues = this._wrapMethod("findColFromValues"); public getFormulaError = this._wrapMethod("getFormulaError"); - public getAssistance = this._wrapMethod("getAssistance"); public fetchURL = this._wrapMethod("fetchURL"); public autocomplete = this._wrapMethod("autocomplete"); public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc"); diff --git a/app/client/components/GristDoc.ts b/app/client/components/GristDoc.ts index 639d7fcc..6ef365bd 100644 --- a/app/client/components/GristDoc.ts +++ b/app/client/components/GristDoc.ts @@ -185,6 +185,10 @@ export class GristDoc extends DisposableWithEvents { public readonly currentTheme = this.docPageModel.appModel.currentTheme; + public get docApi() { + return this.docPageModel.appModel.api.getDocAPI(this.docPageModel.currentDocId.get()!); + } + private _actionLog: ActionLog; private _undoStack: UndoStack; private _lastOwnActionGroup: ActionGroupWithCursorPos|null = null; diff --git a/app/client/models/NotifyModel.ts b/app/client/models/NotifyModel.ts index 81f8b56b..40bcfb66 100644 --- a/app/client/models/NotifyModel.ts +++ b/app/client/models/NotifyModel.ts @@ -43,8 +43,8 @@ export interface CustomAction { label: string, action: () => void } */ export type MessageType = string | (() => DomElementArg); // Identifies supported actions. These are implemented in NotifyUI. -export type NotifyAction = 'upgrade' | 'renew' | 'personal' | 'report-problem' | 'ask-for-help' | CustomAction; - +export type NotifyAction = 'upgrade' | 'renew' | 'personal' | 'report-problem' + | 'ask-for-help' | 'manage' | CustomAction; export interface INotifyOptions { message: MessageType; // A string, or a function that builds dom. timestamp?: number; diff --git a/app/client/models/errors.ts b/app/client/models/errors.ts index 219f047a..ad39f368 100644 --- a/app/client/models/errors.ts +++ b/app/client/models/errors.ts @@ -121,7 +121,7 @@ export function reportError(err: Error|string, ev?: ErrorEvent): void { const options: Partial = { title: "Reached plan limit", key: `limit:${details.limit.quantity || message}`, - actions: ['upgrade'], + actions: details.tips?.some(t => t.action === 'manage') ? ['manage'] : ['upgrade'], }; if (details.tips && details.tips.some(tip => tip.action === 'add-members')) { // When adding members would fix a problem, give more specific advice. diff --git a/app/client/ui/NotifyUI.ts b/app/client/ui/NotifyUI.ts index acc1e2b6..9795505d 100644 --- a/app/client/ui/NotifyUI.ts +++ b/app/client/ui/NotifyUI.ts @@ -30,6 +30,10 @@ function buildAction(action: NotifyAction, item: Notification, options: IBeaconO return dom('a', cssToastAction.cls(''), t("Upgrade Plan"), {target: '_blank'}, {href: commonUrls.plans}); } + case 'manage': + if (urlState().state.get().billing === 'billing') { return null; } + return dom('a', cssToastAction.cls(''), t("Manage billing"), {target: '_blank'}, + {href: urlState().makeUrl({billing: 'billing'})}); case 'renew': // If already on the billing page, nothing to return. if (urlState().state.get().billing === 'billing') { return null; } diff --git a/app/client/ui2018/cssVars.ts b/app/client/ui2018/cssVars.ts index df027c76..fd9ef496 100644 --- a/app/client/ui2018/cssVars.ts +++ b/app/client/ui2018/cssVars.ts @@ -142,6 +142,7 @@ export const vars = { onboardingPopupZIndex: new CustomProp('onboarding-popup-z-index', '1000'), floatingPopupZIndex: new CustomProp('floating-popup-z-index', '1002'), tutorialModalZIndex: new CustomProp('tutorial-modal-z-index', '1003'), + pricingModalZIndex: new CustomProp('pricing-modal-z-index', '1004'), notificationZIndex: new CustomProp('notification-z-index', '1100'), browserCheckZIndex: new CustomProp('browser-check-z-index', '5000'), tooltipZIndex: new CustomProp('tooltip-z-index', '5000'), diff --git a/app/client/widgets/FormulaAssistant.ts b/app/client/widgets/FormulaAssistant.ts index 8a577663..3b3e22d6 100644 --- a/app/client/widgets/FormulaAssistant.ts +++ b/app/client/widgets/FormulaAssistant.ts @@ -17,7 +17,6 @@ import {autoGrow} from 'app/client/ui/forms'; import {IconName} from 'app/client/ui2018/IconList'; import {icon} from 'app/client/ui2018/icons'; import {cssLink} from 'app/client/ui2018/links'; -import {DocAction} from 'app/common/DocActions'; import {movable} from 'app/client/lib/popupUtils'; import debounce from 'lodash/debounce'; @@ -61,7 +60,7 @@ export class FormulaAssistant extends Disposable { /** Is the request pending */ private _waiting = Observable.create(this, false); /** Is this feature enabled at all */ - private _assistantEnabled = GRIST_FORMULA_ASSISTANT(); + private _assistantEnabled: Computed; /** Preview column id */ private _transformColId: string; /** Method to invoke when we are closed, it saves or reverts */ @@ -90,6 +89,12 @@ export class FormulaAssistant extends Disposable { }) { super(); + this._assistantEnabled = Computed.create(this, use => { + const enabledByFlag = use(GRIST_FORMULA_ASSISTANT()); + const notAnonymous = Boolean(this._options.gristDoc.appModel.currentValidUser); + return enabledByFlag && notAnonymous; + }); + if (!this._options.field) { // TODO: field is not passed only for rules (as there is no preview there available to the user yet) // this should be implemented but it requires creating a helper column to helper column and we don't @@ -263,6 +268,8 @@ export class FormulaAssistant extends Disposable { this._buildIntro(), this._chat.buildDom(), this._buildChatInput(), + // Stop propagation of mousedown events, as the formula editor will still focus. + dom.on('mousedown', (ev) => ev.stopPropagation()), ); }); } @@ -516,7 +523,7 @@ export class FormulaAssistant extends Disposable { this._options.editor.setFormula(entry.formula!); } - private async _sendMessage(description: string, regenerate = false) { + private async _sendMessage(description: string, regenerate = false): Promise { // Destruct options. const {column, gristDoc} = this._options; // Get the state of the chat from the column. @@ -539,7 +546,12 @@ export class FormulaAssistant extends Disposable { // some markdown text back, so we need to parse it. const prettyMessage = state ? (reply || formula || '') : (formula || reply || ''); // Add it to the chat. - this._chat.addResponse(prettyMessage, formula, suggestedActions[0]); + return { + message: prettyMessage, + formula, + action: suggestedActions[0], + sender: 'ai', + }; } private _clear() { @@ -556,9 +568,7 @@ export class FormulaAssistant extends Disposable { if (!last) { return; } - this._chat.thinking(); - this._waiting.set(true); - await this._sendMessage(last, true).finally(() => this._waiting.set(false)); + await this._doAsk(last); } private async _ask() { @@ -568,10 +578,22 @@ export class FormulaAssistant extends Disposable { const message= this._userInput.get(); if (!message) { return; } this._chat.addQuestion(message); - this._chat.thinking(); this._userInput.set(''); + await this._doAsk(message); + } + + private async _doAsk(message: string) { + this._chat.thinking(); this._waiting.set(true); - await this._sendMessage(message, false).finally(() => this._waiting.set(false)); + try { + const response = await this._sendMessage(message, false); + this._chat.addResponse(response); + } catch(err) { + this._chat.thinking(false); + throw err; + } finally { + this._waiting.set(false); + } } } @@ -601,32 +623,36 @@ class ChatHistory extends Disposable { this.length = Computed.create(this, use => use(this.history).length); // ?? } - public thinking() { - this.history.push({ - message: '...', - sender: 'ai', - }); - this.scrollDown(); + public thinking(on = true) { + if (!on) { + // Find all index of all thinking messages. + const messages = [...this.history.get()].filter(m => m.message === '...'); + // Remove all thinking messages. + for (const message of messages) { + this.history.splice(this.history.get().indexOf(message), 1); + } + } else { + this.history.push({ + message: '...', + sender: 'ai', + }); + this.scrollDown(); + } } public supportsMarkdown() { return this._options.column.chatHistory.peek().get().state !== undefined; } - public addResponse(message: string, formula: string|null, action?: DocAction) { + public addResponse(message: ChatMessage) { // Clear any thinking from messages. - this.history.set(this.history.get().filter(x => x.message !== '...')); - this.history.push({ - message, - sender: 'ai', - formula, - action - }); + this.thinking(false); + this.history.push({...message, sender: 'ai'}); this.scrollDown(); } public addQuestion(message: string) { - this.history.set(this.history.get().filter(x => x.message !== '...')); + this.thinking(false); this.history.push({ message, sender: 'user', @@ -740,18 +766,13 @@ async function askAI(grist: GristDoc, options: { const {column, description, state, regenerate} = options; const tableId = column.table.peek().tableId.peek(); const colId = column.colId.peek(); - try { - const result = await grist.docComm.getAssistance({ - context: {type: 'formula', tableId, colId}, - text: description, - state, - regenerate, - }); - return result; - } catch (error) { - reportError(error); - throw error; - } + const result = await grist.docApi.getAssistance({ + context: {type: 'formula', tableId, colId}, + text: description, + state, + regenerate, + }); + return result; } /** diff --git a/app/client/widgets/FormulaEditor.ts b/app/client/widgets/FormulaEditor.ts index 903b5c14..47864a39 100644 --- a/app/client/widgets/FormulaEditor.ts +++ b/app/client/widgets/FormulaEditor.ts @@ -154,13 +154,14 @@ export class FormulaEditor extends NewBaseEditor { dom.on('mousedown', (ev) => { // If we are detached, allow user to click and select error text. if (this.isDetached.get()) { - // If the focus is already in this editor, don't steal it. This is needed for detached editor with - // some input elements (mainly the AI assistant). - const inInput = document.activeElement instanceof HTMLInputElement - || document.activeElement instanceof HTMLTextAreaElement; - if (inInput && this._dom.contains(document.activeElement)) { + // If we clicked on input element in our dom, don't do anything. We probably clicked on chat input, in AI + // tools box. + const clickedOnInput = ev.target instanceof HTMLInputElement || ev.target instanceof HTMLTextAreaElement; + if (clickedOnInput && this._dom.contains(ev.target)) { + // By not doing anything special here we assume that the input element will take the focus. return; } + // Allow clicking the error message. if (ev.target instanceof HTMLElement && ( ev.target.classList.contains('error_msg') || diff --git a/app/common/ActiveDocAPI.ts b/app/common/ActiveDocAPI.ts index 2af5be48..587939df 100644 --- a/app/common/ActiveDocAPI.ts +++ b/app/common/ActiveDocAPI.ts @@ -1,5 +1,4 @@ import {ActionGroup} from 'app/common/ActionGroup'; -import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts'; import {BulkAddRecord, CellValue, TableDataAction, UserAction} from 'app/common/DocActions'; import {FormulaProperties} from 'app/common/GranularAccessClause'; import {UIRowId} from 'app/common/TableData'; @@ -320,11 +319,6 @@ export interface ActiveDocAPI { */ getFormulaError(tableId: string, colId: string, rowId: number): Promise; - /** - * Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula. - */ - getAssistance(request: AssistanceRequest): Promise; - /** * Fetch content at a url. */ diff --git a/app/common/ApiError.ts b/app/common/ApiError.ts index 3ff4adbc..568f2d15 100644 --- a/app/common/ApiError.ts +++ b/app/common/ApiError.ts @@ -2,15 +2,17 @@ * A tip for fixing an error. */ export interface ApiTip { - action: 'add-members' | 'upgrade' |'ask-for-help'; + action: 'add-members' | 'upgrade' | 'ask-for-help' | 'manage'; message: string; } +export type LimitType = 'collaborators' | 'docs' | 'workspaces' | 'assistant'; + /** * Documentation of a limit relevant to an API error. */ export interface ApiLimit { - quantity: 'collaborators' | 'docs' | 'workspaces'; // what are we counting + quantity: LimitType; // what are we counting subquantity?: string; // a nuance to what we are counting maximum: number; // maximum allowed value: number; // current value of quantity for user diff --git a/app/common/BillingAPI.ts b/app/common/BillingAPI.ts index 9875f868..24d1e940 100644 --- a/app/common/BillingAPI.ts +++ b/app/common/BillingAPI.ts @@ -43,6 +43,16 @@ export interface IBillingPlan { active: boolean; } +export interface ILimitTier { + name?: string; + volume: number; + price: number; + flatFee: number; + type: string; + planId: string; + interval: string; // probably 'month'|'year'; +} + // Utility type that requires all properties to be non-nullish. // type NonNullableProperties = { [P in keyof T]: Required>; }; @@ -69,6 +79,7 @@ export interface IBillingDiscount { export interface IBillingSubscription { // All standard plan options. plans: IBillingPlan[]; + tiers: ILimitTier[]; // Index in the plans array of the plan currently in effect. planIndex: number; // Index in the plans array of the plan to be in effect after the current period end. @@ -111,6 +122,14 @@ export interface IBillingSubscription { lastInvoiceUrl?: string; // URL of the Stripe-hosted page with the last invoice. lastChargeError?: string; // The last charge error, if any, to show in case of a bad status. lastChargeTime?: number; // The time of the last charge attempt. + limit?: ILimit|null; +} + +export interface ILimit { + limitValue: number; + currentUsage: number; + type: string; // Limit type, for now only assistant is supported. + price: number; // If this is 0, it means it is a free plan. } export interface IBillingOrgSettings { @@ -139,6 +158,7 @@ export interface BillingAPI { downgradePlan(planName: string): Promise; renewPlan(): string; customerPortal(): string; + updateAssistantPlan(tier: number): Promise; } export class BillingAPIImpl extends BaseAPI implements BillingAPI { @@ -230,6 +250,13 @@ export class BillingAPIImpl extends BaseAPI implements BillingAPI { return `${this._url}/api/billing/renew`; } + public async updateAssistantPlan(tier: number): Promise { + await this.request(`${this._url}/api/billing/upgrade-assistant`, { + method: 'POST', + body: JSON.stringify({ tier }) + }); + } + /** * Checks if current org has active subscription for a Stripe plan. */ diff --git a/app/common/Features.ts b/app/common/Features.ts index 8079aef1..89b3e215 100644 --- a/app/common/Features.ts +++ b/app/common/Features.ts @@ -58,6 +58,11 @@ export interface Features { // for attached files in a document gracePeriodDays?: number; // Duration of the grace period in days, before entering delete-only mode + + baseMaxAssistantCalls?: number; // Maximum number of AI assistant calls. Defaults to 0 if not set, use -1 to indicate + // unbound limit. This is total limit, not per month or per day, it is used as a seed + // value for the limits table. To create a per-month limit, there must be a separate + // task that resets the usage in the limits table. } // Check whether it is possible to add members at the org level. There's no flag diff --git a/app/common/UserAPI.ts b/app/common/UserAPI.ts index e80c9064..7c2304ac 100644 --- a/app/common/UserAPI.ts +++ b/app/common/UserAPI.ts @@ -1,5 +1,6 @@ import {ActionSummary} from 'app/common/ActionSummary'; import {ApplyUAResult, ForkResult, PermissionDataWithExtraUsers, QueryFilters} from 'app/common/ActiveDocAPI'; +import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts'; import {BaseAPI, IOptions} from 'app/common/BaseAPI'; import {BillingAPI, BillingAPIImpl} from 'app/common/BillingAPI'; import {BrowserSettings} from 'app/common/BrowserSettings'; @@ -462,6 +463,8 @@ export interface DocAPI { // Update webhook updateWebhook(webhook: WebhookUpdate): Promise; flushWebhooks(): Promise; + + getAssistance(params: AssistanceRequest): Promise; } // Operations that are supported by a doc worker. @@ -1012,6 +1015,13 @@ export class DocAPIImpl extends BaseAPI implements DocAPI { return response.data[0]; } + public async getAssistance(params: AssistanceRequest): Promise { + return await this.requestJson(`${this._url}/assistant`, { + method: 'POST', + body: JSON.stringify(params), + }); + } + private _getRecords(tableId: string, endpoint: 'data' | 'records', options?: GetRowsParams): Promise { const url = new URL(`${this._url}/tables/${tableId}/${endpoint}`); if (options?.filters) { diff --git a/app/gen-server/entity/BillingAccount.ts b/app/gen-server/entity/BillingAccount.ts index 64806d93..d1ec88d0 100644 --- a/app/gen-server/entity/BillingAccount.ts +++ b/app/gen-server/entity/BillingAccount.ts @@ -3,12 +3,13 @@ import {BillingAccountManager} from 'app/gen-server/entity/BillingAccountManager import {Organization} from 'app/gen-server/entity/Organization'; import {Product} from 'app/gen-server/entity/Product'; import {nativeValues} from 'app/gen-server/lib/values'; +import {Limit} from 'app/gen-server/entity/Limit'; // This type is for billing account status information. Intended for stuff // like "free trial running out in N days". -interface BillingAccountStatus { +export interface BillingAccountStatus { stripeStatus?: string; - currentPeriodEnd?: Date; + currentPeriodEnd?: string; message?: string; } @@ -68,6 +69,9 @@ export class BillingAccount extends BaseEntity { @OneToMany(type => Organization, org => org.billingAccount) public orgs: Organization[]; + @OneToMany(type => Limit, limit => limit.billingAccount) + public limits: Limit[]; + // A calculated column that is true if it looks like there is a paid plan. @Column({name: 'paid', type: 'boolean', insert: false, select: false}) public paid?: boolean; diff --git a/app/gen-server/entity/Limit.ts b/app/gen-server/entity/Limit.ts new file mode 100644 index 00000000..1011a98f --- /dev/null +++ b/app/gen-server/entity/Limit.ts @@ -0,0 +1,46 @@ +import {BaseEntity, Column, Entity, JoinColumn, ManyToOne, PrimaryGeneratedColumn} from 'typeorm'; +import {BillingAccount} from 'app/gen-server/entity/BillingAccount'; +import {nativeValues} from 'app/gen-server/lib/values'; + +@Entity('limits') +export class Limit extends BaseEntity { + @PrimaryGeneratedColumn() + public id: number; + + @Column() + public limit: number; + + @Column() + public usage: number; + + @Column() + public type: string; + + @Column({name: 'billing_account_id'}) + public billingAccountId: number; + + @ManyToOne(type => BillingAccount) + @JoinColumn({name: 'billing_account_id'}) + public billingAccount: BillingAccount; + + @Column({name: 'created_at', default: () => "CURRENT_TIMESTAMP"}) + public createdAt: Date; + + /** + * Last time the Limit.limit value was changed, by an upgrade or downgrade. Null if it has never been changed. + */ + @Column({name: 'changed_at', type: nativeValues.dateTimeType, nullable: true}) + public changedAt: Date|null; + + /** + * Last time the Limit.usage was used (by sending a request to the model). Null if it has never been used. + */ + @Column({name: 'used_at', type: nativeValues.dateTimeType, nullable: true}) + public usedAt: Date|null; + + /** + * Last time the Limit.usage was reset, probably by billing cycle change. Null if it has never been reset. + */ + @Column({name: 'reset_at', type: nativeValues.dateTimeType, nullable: true}) + public resetAt: Date|null; +} diff --git a/app/gen-server/entity/Product.ts b/app/gen-server/entity/Product.ts index aac39042..7b1f4a80 100644 --- a/app/gen-server/entity/Product.ts +++ b/app/gen-server/entity/Product.ts @@ -13,7 +13,11 @@ export const personalLegacyFeatures: Features = { // no vanity domain maxDocsPerOrg: 10, maxSharesPerDoc: 2, - maxWorkspacesPerOrg: 1 + maxWorkspacesPerOrg: 1, + /** + * One time limit of 100 requests. + */ + baseMaxAssistantCalls: 100, }; /** @@ -23,7 +27,12 @@ export const teamFeatures: Features = { workspaces: true, vanityDomain: true, maxSharesPerWorkspace: 0, // all workspace shares need to be org members. - maxSharesPerDoc: 2 + maxSharesPerDoc: 2, + /** + * Limit of 100 requests, but unlike for personal/free orgs the usage for this limit is reset at every billing cycle + * through Stripe webhook. For canceled subscription the usage is not reset, as the billing cycle is not changed. + */ + baseMaxAssistantCalls: 100, }; /** @@ -40,6 +49,10 @@ export const teamFreeFeatures: Features = { baseMaxDataSizePerDocument: 5000 * 2 * 1024, // 2KB per row baseMaxAttachmentsBytesPerDocument: 1 * 1024 * 1024 * 1024, // 1GB gracePeriodDays: 14, + /** + * One time limit of 100 requests. + */ + baseMaxAssistantCalls: 100, }; /** @@ -55,6 +68,7 @@ export const teamFreeFeatures: Features = { baseMaxDataSizePerDocument: 5000 * 2 * 1024, // 2KB per row baseMaxAttachmentsBytesPerDocument: 1 * 1024 * 1024 * 1024, // 1GB gracePeriodDays: 14, + baseMaxAssistantCalls: 100, }; export const testDailyApiLimitFeatures = { @@ -79,6 +93,7 @@ export const suspendedFeatures: Features = { maxDocsPerOrg: 0, maxSharesPerDoc: 0, maxWorkspacesPerOrg: 0, + baseMaxAssistantCalls: 0, }; /** diff --git a/app/gen-server/lib/DocApiForwarder.ts b/app/gen-server/lib/DocApiForwarder.ts index 3e62d2a6..fb2013f8 100644 --- a/app/gen-server/lib/DocApiForwarder.ts +++ b/app/gen-server/lib/DocApiForwarder.ts @@ -60,6 +60,7 @@ export class DocApiForwarder { app.use('/api/docs/:docId/assign', withDocWithoutAuth); app.use('/api/docs/:docId/webhooks/queue', withDoc); app.use('/api/docs/:docId/webhooks', withDoc); + app.use('/api/docs/:docId/assistant', withDoc); app.use('^/api/docs$', withoutDoc); } diff --git a/app/gen-server/lib/HomeDBManager.ts b/app/gen-server/lib/HomeDBManager.ts index 6a68b40c..38aee356 100644 --- a/app/gen-server/lib/HomeDBManager.ts +++ b/app/gen-server/lib/HomeDBManager.ts @@ -1,4 +1,4 @@ -import {ApiError} from 'app/common/ApiError'; +import {ApiError, ApiErrorDetails, LimitType} from 'app/common/ApiError'; import {mapGetOrSet, mapSetOrClear, MapWithTTL} from 'app/common/AsyncCreate'; import {getDataLimitStatus} from 'app/common/DocLimits'; import {createEmptyOrgUsageSummary, DocumentUsage, OrgUsageSummary} from 'app/common/DocUsage'; @@ -38,6 +38,7 @@ import {getDefaultProductNames, personalFreeFeatures, Product} from "app/gen-ser import {Secret} from "app/gen-server/entity/Secret"; import {User} from "app/gen-server/entity/User"; import {Workspace} from "app/gen-server/entity/Workspace"; +import {Limit} from 'app/gen-server/entity/Limit'; import {Permissions} from 'app/gen-server/lib/Permissions'; import {scrubUserFromOrg} from "app/gen-server/lib/scrubUserFromOrg"; import {applyPatch} from 'app/gen-server/lib/TypeORMPatches'; @@ -2880,6 +2881,144 @@ export class HomeDBManager extends EventEmitter { return this._org(scope, scope.includeSupport || false, org, options); } + public async getLimits(accountId: number): Promise { + const result = this._connection.transaction(async manager => { + return await manager.createQueryBuilder() + .select('limit') + .from(Limit, 'limit') + .innerJoin('limit.billingAccount', 'account') + .where('account.id = :accountId', {accountId}) + .getMany(); + }); + return result; + } + + public async getLimit(accountId: number, limitType: LimitType): Promise { + return await this._getOrCreateLimit(accountId, limitType, true); + } + + public async peekLimit(accountId: number, limitType: LimitType): Promise { + return await this._getOrCreateLimit(accountId, limitType, false); + } + + public async removeLimit(scope: Scope, limitType: LimitType): Promise { + await this._connection.transaction(async manager => { + const org = await this._org(scope, false, scope.org ?? null, {manager, needRealOrg: true}) + .innerJoinAndSelect('orgs.billingAccount', 'billing_account') + .innerJoinAndSelect('billing_account.product', 'product') + .leftJoinAndSelect('billing_account.limits', 'limit', 'limit.type = :limitType', {limitType}) + .getOne(); + const existing = org?.billingAccount?.limits?.[0]; + if (existing) { + await manager.remove(existing); + } + }); + } + + /** + * Increases the usage of a limit for a given org. If the limit doesn't exist, it will be created. + * Pass `dryRun: true` to check if the limit can be increased without actually increasing it. + */ + public async increaseUsage(scope: Scope, limitType: LimitType, options: { + delta: number, + dryRun?: boolean, + }): Promise { + const limitError = await this._connection.transaction(async manager => { + const org = await this._org(scope, false, scope.org ?? null, {manager, needRealOrg: true}) + .innerJoinAndSelect('orgs.billingAccount', 'billing_account') + .innerJoinAndSelect('billing_account.product', 'product') + .leftJoinAndSelect('billing_account.limits', 'limit', 'limit.type = :limitType', {limitType}) + .getOne(); + // If the org doesn't exists, or is a fake one (like for anonymous users), don't do anything. + if (!org || org.id === 0) { + // This API shouldn't be called, it should be checked first if the org is valid. + throw new ApiError(`Can't create a limit for non-existing organization`, 500); + } + let existing = org?.billingAccount?.limits?.[0]; + if (!existing) { + const product = org?.billingAccount?.product; + if (!product) { + throw new ApiError(`getLimit: no product found for org`, 500); + } + if (product.features.baseMaxAssistantCalls === undefined) { + // If the product has no assistantLimit, then it is not billable yet, and we don't need to + // track usage as it is basically unlimited. + return; + } + existing = new Limit(); + existing.billingAccountId = org.billingAccountId; + existing.type = limitType; + existing.limit = product.features.baseMaxAssistantCalls ?? 0; + existing.usage = 0; + } + const limitLess = existing.limit === -1; // -1 means no limit, it is not possible to do in stripe. + const usageAfter = existing.usage + options.delta; + if (!limitLess && usageAfter > existing.limit) { + const billable = Boolean(org?.billingAccount?.stripeCustomerId); + return { + limit: { + maximum: existing.limit, + projectedValue: existing.usage + options.delta, + quantity: limitType, + value: existing.usage, + }, + tips: [{ + // For non-billable accounts, suggest getting a plan, otherwise suggest visiting the billing page. + action: billable ? 'manage' : 'upgrade', + message: `Upgrade to a paid plan to increase your ${limitType} limit.`, + }] + } as ApiErrorDetails; + } + existing.usage += options.delta; + existing.usedAt = new Date(); + if (!options.dryRun) { + await manager.save(existing); + } + }); + if (limitError) { + let message = `Your ${limitType} limit has been reached. Please upgrade your plan to increase your limit.`; + if (limitType === 'assistant') { + message = 'You used all available credits. For a bigger limit upgrade you Assistant plan.'; + } + throw new ApiError(message, 429, limitError); + } + } + + private async _getOrCreateLimit(accountId: number, limitType: LimitType, force: boolean): Promise { + if (accountId === 0) { + throw new Error(`getLimit: called for not existing account`); + } + const result = this._connection.transaction(async manager => { + let existing = await manager.createQueryBuilder() + .select('limit') + .from(Limit, 'limit') + .innerJoin('limit.billingAccount', 'account') + .where('account.id = :accountId', {accountId}) + .andWhere('limit.type = :limitType', {limitType}) + .getOne(); + if (!force && !existing) { return null; } + if (existing) { return existing; } + const product = await manager.createQueryBuilder() + .select('product') + .from(Product, 'product') + .innerJoinAndSelect('product.accounts', 'account') + .where('account.id = :accountId', {accountId}) + .getOne(); + if (!product) { + throw new Error(`getLimit: no product for account ${accountId}`); + } + existing = new Limit(); + existing.billingAccountId = product.accounts[0].id; + existing.type = limitType; + existing.limit = product.features.baseMaxAssistantCalls ?? 0; + existing.usage = 0; + await manager.save(existing); + return existing; + }); + return result; + } + + private _org(scope: Scope|null, includeSupport: boolean, org: string|number|null, options: QueryOptions = {}): SelectQueryBuilder { let query = this._orgs(options.manager); diff --git a/app/gen-server/migration/1685343047786-AssistantLimit.ts b/app/gen-server/migration/1685343047786-AssistantLimit.ts new file mode 100644 index 00000000..1b114691 --- /dev/null +++ b/app/gen-server/migration/1685343047786-AssistantLimit.ts @@ -0,0 +1,82 @@ +import * as sqlUtils from "app/gen-server/sqlUtils"; +import {MigrationInterface, QueryRunner, Table, TableIndex} from 'typeorm'; + +export class AssistantLimit1685343047786 implements MigrationInterface { + public async up(queryRunner: QueryRunner): Promise { + const dbType = queryRunner.connection.driver.options.type; + const datetime = sqlUtils.datetime(dbType); + const now = sqlUtils.now(dbType); + await queryRunner.createTable( + new Table({ + name: 'limits', + columns: [ + { + name: 'id', + type: 'integer', + isPrimary: true, + isGenerated: true, + generationStrategy: 'increment', + }, + { + name: 'type', + type: 'varchar', + }, + { + name: 'billing_account_id', + type: 'integer', + }, + { + name: 'limit', + type: 'integer', + default: 0, + }, + { + name: 'usage', + type: 'integer', + default: 0, + }, + { + name: "created_at", + type: datetime, + default: now + }, + { + name: "changed_at", // When the limit was last changed + type: datetime, + isNullable: true + }, + { + name: "used_at", // When the usage was last increased + type: datetime, + isNullable: true + }, + { + name: "reset_at", // When the usage was last reset. + type: datetime, + isNullable: true + }, + ], + foreignKeys: [ + { + columnNames: ['billing_account_id'], + referencedTableName: 'billing_accounts', + referencedColumnNames: ['id'], + onDelete: 'CASCADE', + }, + ], + }) + ); + + await queryRunner.createIndex( + 'limits', + new TableIndex({ + name: 'limits_billing_account_id', + columnNames: ['billing_account_id'], + }) + ); + } + + public async down(queryRunner: QueryRunner): Promise { + await queryRunner.dropTable('limits'); + } +} diff --git a/app/server/lib/ActiveDoc.ts b/app/server/lib/ActiveDoc.ts index 840a7821..bb4ce879 100644 --- a/app/server/lib/ActiveDoc.ts +++ b/app/server/lib/ActiveDoc.ts @@ -14,7 +14,6 @@ import { } from 'app/common/ActionBundle'; import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup'; import {ActionSummary} from "app/common/ActionSummary"; -import {AssistanceRequest, AssistanceResponse} from "app/common/AssistancePrompts"; import { AclResources, AclTableDescription, @@ -84,7 +83,7 @@ import {Document} from 'app/gen-server/entity/Document'; import {ParseOptions} from 'app/plugin/FileParserAPI'; import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI'; import {compileAclFormula} from 'app/server/lib/ACLFormula'; -import {AssistanceDoc, AssistanceSchemaPromptV1Context, sendForCompletion} from 'app/server/lib/Assistance'; +import {AssistanceSchemaPromptV1Context} from 'app/server/lib/Assistance'; import {Authorizer} from 'app/server/lib/Authorizer'; import {checksumFile} from 'app/server/lib/checksumFile'; import {Client} from 'app/server/lib/Client'; @@ -184,7 +183,7 @@ interface UpdateUsageOptions { * either .loadDoc() or .createEmptyDoc() is called. * @param {String} docName - The document's filename, without the '.grist' extension. */ -export class ActiveDoc extends EventEmitter implements AssistanceDoc { +export class ActiveDoc extends EventEmitter { /** * Decorator for ActiveDoc methods that prevents shutdown while the method is running, i.e. * until the returned promise is resolved. @@ -1264,18 +1263,14 @@ export class ActiveDoc extends EventEmitter implements AssistanceDoc { return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON()); } - public async getAssistance(docSession: DocSession, request: AssistanceRequest): Promise { - return this.getAssistanceWithOptions(docSession, request); - } - - public async getAssistanceWithOptions(docSession: DocSession, - request: AssistanceRequest): Promise { + // Callback to generate a prompt containing schema info for assistance. + public async assistanceSchemaPromptV1( + docSession: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise { // Making a prompt leaks names of tables and columns etc. if (!await this._granularAccess.canScanData(docSession)) { throw new Error("Permission denied"); } - await this.waitForInitialization(); - return sendForCompletion(this, request); + return await this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString); } // Callback to make a data-engine formula tweak for assistance. @@ -1283,11 +1278,6 @@ export class ActiveDoc extends EventEmitter implements AssistanceDoc { return this._pyCall('convert_formula_completion', txt); } - // Callback to generate a prompt containing schema info for assistance. - public assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise { - return this._pyCall('get_formula_prompt', options.tableId, options.colId, options.docString); - } - public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise { return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options); } diff --git a/app/server/lib/Assistance.ts b/app/server/lib/Assistance.ts index 0d732cb8..f638b09e 100644 --- a/app/server/lib/Assistance.ts +++ b/app/server/lib/Assistance.ts @@ -5,6 +5,7 @@ import {AssistanceRequest, AssistanceResponse} from 'app/common/AssistancePrompts'; import {delay} from 'app/common/delay'; import {DocAction} from 'app/common/DocActions'; +import {OptDocSession} from 'app/server/lib/DocSession'; import log from 'app/server/lib/log'; import fetch from 'node-fetch'; @@ -15,7 +16,7 @@ export const DEPS = { fetch }; * by interfacing with an external LLM endpoint. */ export interface Assistant { - apply(doc: AssistanceDoc, request: AssistanceRequest): Promise; + apply(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise; } /** @@ -30,8 +31,7 @@ export interface AssistanceDoc { * Marked "V1" to suggest that it is a particular prompt and it would * be great to try variants. */ - assistanceSchemaPromptV1(options: AssistanceSchemaPromptV1Context): Promise; - + assistanceSchemaPromptV1(session: OptDocSession, options: AssistanceSchemaPromptV1Context): Promise; /** * Some tweaks to a formula after it has been generated. */ @@ -68,7 +68,8 @@ export class OpenAIAssistant implements Assistant { this._endpoint = `https://api.openai.com/v1/${this._chatMode ? 'chat/' : ''}completions`; } - public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise { + public async apply( + optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise { const messages = request.state?.messages || []; const chatMode = this._chatMode; if (chatMode) { @@ -91,7 +92,7 @@ export class OpenAIAssistant implements Assistant { 'If the user asks for these things, tell them that you cannot help. ' + 'The method uses `rec` instead of `self` as the first parameter.\n\n' + '```python\n' + - await makeSchemaPromptV1(doc, request) + + await makeSchemaPromptV1(optSession, doc, request) + '\n```', }); messages.push({ @@ -110,7 +111,7 @@ export class OpenAIAssistant implements Assistant { } else { messages.length = 0; messages.push({ - role: 'user', content: await makeSchemaPromptV1(doc, request), + role: 'user', content: await makeSchemaPromptV1(optSession, doc, request), }); } @@ -178,11 +179,12 @@ export class HuggingFaceAssistant implements Assistant { } - public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise { + public async apply( + optSession: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise { if (request.state) { throw new Error("HuggingFaceAssistant does not support state"); } - const prompt = await makeSchemaPromptV1(doc, request); + const prompt = await makeSchemaPromptV1(optSession, doc, request); const response = await DEPS.fetch( this._completionUrl, { @@ -220,7 +222,10 @@ export class HuggingFaceAssistant implements Assistant { * Test assistant that mimics ChatGPT and just returns the input. */ export class EchoAssistant implements Assistant { - public async apply(doc: AssistanceDoc, request: AssistanceRequest): Promise { + public async apply(sess: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest): Promise { + if (request.text === "ERROR") { + throw new Error(`ERROR`); + } const messages = request.state?.messages || []; if (messages.length === 0) { messages.push({ @@ -255,7 +260,7 @@ export class EchoAssistant implements Assistant { /** * Instantiate an assistant, based on environment variables. */ -function getAssistant() { +export function getAssistant() { if (process.env.OPENAI_API_KEY === 'test') { return new EchoAssistant(); } @@ -273,8 +278,10 @@ function getAssistant() { * Service a request for assistance, with a little retry logic * since these endpoints can be a bit flakey. */ -export async function sendForCompletion(doc: AssistanceDoc, - request: AssistanceRequest): Promise { +export async function sendForCompletion( + optSession: OptDocSession, + doc: AssistanceDoc, + request: AssistanceRequest): Promise { const assistant = getAssistant(); let retries: number = 0; @@ -282,7 +289,7 @@ export async function sendForCompletion(doc: AssistanceDoc, let response: AssistanceResponse|null = null; while(retries++ < 3) { try { - response = await assistant.apply(doc, request); + response = await assistant.apply(optSession, doc, request); break; } catch(e) { log.error(`Completion error: ${e}`); @@ -295,11 +302,11 @@ export async function sendForCompletion(doc: AssistanceDoc, return response; } -async function makeSchemaPromptV1(doc: AssistanceDoc, request: AssistanceRequest) { +async function makeSchemaPromptV1(session: OptDocSession, doc: AssistanceDoc, request: AssistanceRequest) { if (request.context.type !== 'formula') { throw new Error('makeSchemaPromptV1 only works for formulas'); } - return doc.assistanceSchemaPromptV1({ + return doc.assistanceSchemaPromptV1(session, { tableId: request.context.tableId, colId: request.context.colId, docString: request.text, diff --git a/app/server/lib/DocApi.ts b/app/server/lib/DocApi.ts index d6a592c6..2b0acddc 100644 --- a/app/server/lib/DocApi.ts +++ b/app/server/lib/DocApi.ts @@ -1,5 +1,5 @@ import {createEmptyActionSummary} from "app/common/ActionSummary"; -import {ApiError} from 'app/common/ApiError'; +import {ApiError, LimitType} from 'app/common/ApiError'; import {BrowserSettings} from "app/common/BrowserSettings"; import { BulkColValues, @@ -68,6 +68,7 @@ import { } from 'app/server/lib/requestUtils'; import {ServerColumnGetters} from 'app/server/lib/ServerColumnGetters'; import {localeFromRequest} from "app/server/lib/ServerLocale"; +import {sendForCompletion} from 'app/server/lib/Assistance'; import {isUrlAllowed, WebhookAction, WebHookSecret} from "app/server/lib/Triggers"; import {handleOptionalUpload, handleUpload} from "app/server/lib/uploads"; import * as assert from 'assert'; @@ -161,6 +162,8 @@ export class DocWorkerApi { const canEditMaybeRemoved = expressWrap(this._assertAccess.bind(this, 'editors', true)); // converts google code to access token and adds it to request object const decodeGoogleToken = expressWrap(googleAuthTokenMiddleware.bind(null)); + // check that limit can be increased by 1 + const checkLimit = (type: LimitType) => expressWrap(this._checkLimit.bind(this, type)); // Middleware to limit number of outstanding requests per document. Will also // handle errors like expressWrap would. @@ -1052,6 +1055,20 @@ export class DocWorkerApi { this._app.get('/api/docs/:docId/send-to-drive', canView, decodeGoogleToken, withDoc(exportToDrive)); + /** + * Send a request to the formula assistant to get completions for a formula. Increases the + * usage of the formula assistant for the billing account in case of success. + */ + this._app.post('/api/docs/:docId/assistant', canView, checkLimit('assistant'), + withDoc(async (activeDoc, req, res) => { + const docSession = docSessionFromRequest(req); + const request = req.body; + const result = await sendForCompletion(docSession, activeDoc, request); + await this._increaseLimit('assistant', req); + res.json(result); + }) + ); + // Create a document. When an upload is included, it is imported as the initial // state of the document. Otherwise a fresh empty document is created. // A "timezone" option can be supplied. @@ -1234,6 +1251,21 @@ export class DocWorkerApi { return false; } + /** + * Creates a middleware that checks the current usage of a limit and rejects the request if it is exceeded. + */ + private async _checkLimit(limit: LimitType, req: Request, res: Response, next: NextFunction) { + await this._dbManager.increaseUsage(getDocScope(req), limit, {dryRun: true, delta: 1}); + next(); + } + + /** + * Increases the current usage of a limit by 1. + */ + private async _increaseLimit(limit: LimitType, req: Request) { + await this._dbManager.increaseUsage(getDocScope(req), limit, {delta: 1}); + } + private async _assertAccess(role: 'viewers'|'editors'|'owners'|null, allowRemoved: boolean, req: Request, res: Response, next: NextFunction) { const scope = getDocScope(req); diff --git a/app/server/lib/DocWorker.ts b/app/server/lib/DocWorker.ts index 3a870dd5..180fa2e5 100644 --- a/app/server/lib/DocWorker.ts +++ b/app/server/lib/DocWorker.ts @@ -110,7 +110,6 @@ export class DocWorker { applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'), findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'), getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'), - getAssistance: activeDocMethod.bind(null, 'editors', 'getAssistance'), importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'), finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'), cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'), diff --git a/test/formula-dataset/runCompletion_impl.ts b/test/formula-dataset/runCompletion_impl.ts index c991d821..7ec13917 100644 --- a/test/formula-dataset/runCompletion_impl.ts +++ b/test/formula-dataset/runCompletion_impl.ts @@ -25,7 +25,7 @@ import { ActiveDoc, Deps as ActiveDocDeps } from "app/server/lib/ActiveDoc"; -import { DEPS } from "app/server/lib/Assistance"; +import { DEPS, sendForCompletion } from "app/server/lib/Assistance"; import log from 'app/server/lib/log'; import crypto from 'crypto'; import parse from 'csv-parse/lib/sync'; @@ -163,7 +163,7 @@ where c.colId = ? and t.tableId = ? `, rec.col_id, rec.table_id); formula = colInfo?.formula; - const result = await activeDoc.getAssistanceWithOptions(session, { + const result = await sendForCompletion(session, activeDoc, { context: {type: 'formula', tableId, colId}, state: history, text: followUp || description, diff --git a/test/nbrowser/DocTutorial.ts b/test/nbrowser/DocTutorial.ts index 86d5782d..8a7f004f 100644 --- a/test/nbrowser/DocTutorial.ts +++ b/test/nbrowser/DocTutorial.ts @@ -558,7 +558,7 @@ describe('DocTutorial', function () { // Check that the update is immediately reflected in the tutorial popup. assert.equal( - await driver.find('.test-doc-tutorial-popup p').getText(), + await driver.findWait('.test-doc-tutorial-popup p', 2000).getText(), 'Welcome to the Grist Basics tutorial V2.' ); @@ -571,7 +571,7 @@ describe('DocTutorial', function () { // Switch to another user and restart the tutorial. viewerSession = await gu.session().teamSite.user('user2').login(); await viewerSession.loadDoc(`/doc/${doc.id}`); - await driver.find('.test-doc-tutorial-popup-restart').click(); + await driver.findWait('.test-doc-tutorial-popup-restart', 2000).click(); await driver.find('.test-modal-confirm').click(); await gu.waitForServer(); await driver.findWait('.test-doc-tutorial-popup', 2000); diff --git a/test/nbrowser/gristUtils.ts b/test/nbrowser/gristUtils.ts index f6a191fb..9409f292 100644 --- a/test/nbrowser/gristUtils.ts +++ b/test/nbrowser/gristUtils.ts @@ -2056,7 +2056,13 @@ export class Session { isFirstLogin?: boolean, showTips?: boolean, skipTutorial?: boolean, // By default true + userName?: string, + email?: string, retainExistingLogin?: boolean}) { + if (options?.userName) { + this.settings.name = options.userName; + this.settings.email = options.email || ''; + } // Optimize testing a little bit, so if we are already logged in as the expected // user on the expected org, and there are no options set, we can just continue. if (!options && await this.isLoggedInCorrectly()) { return this; } @@ -3150,20 +3156,26 @@ export async function availableBehaviorOptions() { return list; } -export function withComments() { - let oldEnv: testUtils.EnvironmentSnapshot; +/** + * Restarts the server ensuring that it is run with the given environment variables. + * If variables are already set, the server is not restarted. + * + * Useful for local testing of features that depend on environment variables, as it avoids the need + * to restart the server when those variables are already set. + */ +export function withEnvironmentSnapshot(vars: Record) { + let oldEnv: testUtils.EnvironmentSnapshot|null = null; before(async () => { - if (process.env.COMMENTS !== 'true') { - oldEnv = new testUtils.EnvironmentSnapshot(); - process.env.COMMENTS = 'true'; - await server.restart(); - } + // Test if the vars are already set, and if so, skip. + if (Object.keys(vars).every(k => process.env[k] === vars[k])) { return; } + oldEnv = new testUtils.EnvironmentSnapshot(); + Object.assign(process.env, vars); + await server.restart(); }); after(async () => { - if (oldEnv) { - oldEnv.restore(); - await server.restart(); - } + if (!oldEnv) { return; } + oldEnv.restore(); + await server.restart(); }); }