mirror of
https://github.com/gristlabs/grist-core.git
synced 2024-10-27 20:44:07 +00:00
(core) Porting back AI formula backend
Summary: This is a backend part for the formula AI. Test Plan: New tests Reviewers: paulfitz Reviewed By: paulfitz Subscribers: cyprien Differential Revision: https://phab.getgrist.com/D3786
This commit is contained in:
parent
ef0a55ced1
commit
6e3f0f2b35
@ -32,6 +32,7 @@ export class DocComm extends Disposable implements ActiveDocAPI {
|
|||||||
public addAttachments = this._wrapMethod("addAttachments");
|
public addAttachments = this._wrapMethod("addAttachments");
|
||||||
public findColFromValues = this._wrapMethod("findColFromValues");
|
public findColFromValues = this._wrapMethod("findColFromValues");
|
||||||
public getFormulaError = this._wrapMethod("getFormulaError");
|
public getFormulaError = this._wrapMethod("getFormulaError");
|
||||||
|
public getAssistance = this._wrapMethod("getAssistance");
|
||||||
public fetchURL = this._wrapMethod("fetchURL");
|
public fetchURL = this._wrapMethod("fetchURL");
|
||||||
public autocomplete = this._wrapMethod("autocomplete");
|
public autocomplete = this._wrapMethod("autocomplete");
|
||||||
public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc");
|
public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc");
|
||||||
|
@ -318,6 +318,11 @@ export interface ActiveDocAPI {
|
|||||||
*/
|
*/
|
||||||
getFormulaError(tableId: string, colId: string, rowId: number): Promise<CellValue>;
|
getFormulaError(tableId: string, colId: string, rowId: number): Promise<CellValue>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
|
||||||
|
*/
|
||||||
|
getAssistance(tableId: string, colId: string, description: string): Promise<void>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fetch content at a url.
|
* Fetch content at a url.
|
||||||
*/
|
*/
|
||||||
|
11
app/common/AssistancePrompts.ts
Normal file
11
app/common/AssistancePrompts.ts
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import {DocAction} from 'app/common/DocActions';
|
||||||
|
|
||||||
|
export interface Prompt {
|
||||||
|
tableId: string;
|
||||||
|
colId: string
|
||||||
|
description: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface Suggestion {
|
||||||
|
suggestedActions: DocAction[];
|
||||||
|
}
|
@ -14,6 +14,7 @@ import {
|
|||||||
} from 'app/common/ActionBundle';
|
} from 'app/common/ActionBundle';
|
||||||
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
|
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
|
||||||
import {ActionSummary} from "app/common/ActionSummary";
|
import {ActionSummary} from "app/common/ActionSummary";
|
||||||
|
import {Prompt, Suggestion} from "app/common/AssistancePrompts";
|
||||||
import {
|
import {
|
||||||
AclResources,
|
AclResources,
|
||||||
AclTableDescription,
|
AclTableDescription,
|
||||||
@ -80,6 +81,7 @@ import {parseUserAction} from 'app/common/ValueParser';
|
|||||||
import {ParseOptions} from 'app/plugin/FileParserAPI';
|
import {ParseOptions} from 'app/plugin/FileParserAPI';
|
||||||
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
|
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
|
||||||
import {compileAclFormula} from 'app/server/lib/ACLFormula';
|
import {compileAclFormula} from 'app/server/lib/ACLFormula';
|
||||||
|
import {sendForCompletion} from 'app/server/lib/Assistance';
|
||||||
import {Authorizer} from 'app/server/lib/Authorizer';
|
import {Authorizer} from 'app/server/lib/Authorizer';
|
||||||
import {checksumFile} from 'app/server/lib/checksumFile';
|
import {checksumFile} from 'app/server/lib/checksumFile';
|
||||||
import {Client} from 'app/server/lib/Client';
|
import {Client} from 'app/server/lib/Client';
|
||||||
@ -1313,6 +1315,24 @@ export class ActiveDoc extends EventEmitter {
|
|||||||
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
|
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async getAssistance(docSession: DocSession, userPrompt: Prompt): Promise<Suggestion> {
|
||||||
|
// Making a prompt can leak names of tables and columns.
|
||||||
|
if (!await this._granularAccess.canScanData(docSession)) {
|
||||||
|
throw new Error("Permission denied");
|
||||||
|
}
|
||||||
|
await this.waitForInitialization();
|
||||||
|
const { tableId, colId, description } = userPrompt;
|
||||||
|
const prompt = await this._pyCall('get_formula_prompt', tableId, colId, description);
|
||||||
|
this._log.debug(docSession, 'getAssistance prompt', {prompt});
|
||||||
|
const completion = await sendForCompletion(prompt);
|
||||||
|
this._log.debug(docSession, 'getAssistance completion', {completion});
|
||||||
|
const formula = await this._pyCall('convert_formula_completion', completion);
|
||||||
|
const action: DocAction = ["ModifyColumn", tableId, colId, {formula}];
|
||||||
|
return {
|
||||||
|
suggestedActions: [action],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
|
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
|
||||||
return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
|
return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
|
||||||
}
|
}
|
||||||
|
110
app/server/lib/Assistance.ts
Normal file
110
app/server/lib/Assistance.ts
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
/**
|
||||||
|
* Module with functions used for AI formula assistance.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {delay} from 'app/common/delay';
|
||||||
|
import log from 'app/server/lib/log';
|
||||||
|
import fetch, { Response as FetchResponse} from 'node-fetch';
|
||||||
|
|
||||||
|
|
||||||
|
export async function sendForCompletion(prompt: string): Promise<string> {
|
||||||
|
let completion: string|null = null;
|
||||||
|
if (process.env.OPENAI_API_KEY) {
|
||||||
|
completion = await sendForCompletionOpenAI(prompt);
|
||||||
|
}
|
||||||
|
if (process.env.HUGGINGFACE_API_KEY) {
|
||||||
|
completion = await sendForCompletionHuggingFace(prompt);
|
||||||
|
}
|
||||||
|
if (completion === null) {
|
||||||
|
throw new Error("Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY (and optionally COMPLETION_MODEL)");
|
||||||
|
}
|
||||||
|
log.debug(`Received completion:`, {completion});
|
||||||
|
completion = completion.split(/\n {4}[^ ]/)[0];
|
||||||
|
return completion;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async function sendForCompletionOpenAI(prompt: string) {
|
||||||
|
const apiKey = process.env.OPENAI_API_KEY;
|
||||||
|
if (!apiKey) {
|
||||||
|
throw new Error("OPENAI_API_KEY not set");
|
||||||
|
}
|
||||||
|
const response = await fetch(
|
||||||
|
"https://api.openai.com/v1/completions",
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Authorization": `Bearer ${apiKey}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
prompt,
|
||||||
|
max_tokens: 150,
|
||||||
|
temperature: 0,
|
||||||
|
// COMPLETION_MODEL of `code-davinci-002` may be better if you have access to it.
|
||||||
|
model: process.env.COMPLETION_MODEL || "text-davinci-002",
|
||||||
|
stop: ["\n\n"],
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (response.status !== 200) {
|
||||||
|
log.error(`OpenAI API returned ${response.status}: ${await response.text()}`);
|
||||||
|
throw new Error(`OpenAI API returned status ${response.status}`);
|
||||||
|
}
|
||||||
|
const result = await response.json();
|
||||||
|
const completion = result.choices[0].text;
|
||||||
|
return completion;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function sendForCompletionHuggingFace(prompt: string) {
|
||||||
|
const apiKey = process.env.HUGGINGFACE_API_KEY;
|
||||||
|
if (!apiKey) {
|
||||||
|
throw new Error("HUGGINGFACE_API_KEY not set");
|
||||||
|
}
|
||||||
|
// COMPLETION_MODEL values I've tried:
|
||||||
|
// - codeparrot/codeparrot
|
||||||
|
// - NinedayWang/PolyCoder-2.7B
|
||||||
|
// - NovelAI/genji-python-6B
|
||||||
|
let completionUrl = process.env.COMPLETION_URL;
|
||||||
|
if (!completionUrl) {
|
||||||
|
if (process.env.COMPLETION_MODEL) {
|
||||||
|
completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`;
|
||||||
|
} else {
|
||||||
|
completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let retries: number = 0;
|
||||||
|
let response!: FetchResponse;
|
||||||
|
while (retries++ < 3) {
|
||||||
|
response = await fetch(
|
||||||
|
completionUrl,
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Authorization": `Bearer ${apiKey}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
inputs: prompt,
|
||||||
|
parameters: {
|
||||||
|
return_full_text: false,
|
||||||
|
max_new_tokens: 50,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (response.status === 503) {
|
||||||
|
log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`);
|
||||||
|
await delay(10000);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (response.status !== 200) {
|
||||||
|
const text = await response.text();
|
||||||
|
log.error(`HuggingFace API returned ${response.status}: ${text}`);
|
||||||
|
throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
|
||||||
|
}
|
||||||
|
const result = await response.json();
|
||||||
|
const completion = result[0].generated_text;
|
||||||
|
return completion.split('\n\n')[0];
|
||||||
|
}
|
@ -110,6 +110,7 @@ export class DocWorker {
|
|||||||
applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'),
|
applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'),
|
||||||
findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'),
|
findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'),
|
||||||
getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'),
|
getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'),
|
||||||
|
getAssistance: activeDocMethod.bind(null, 'editors', 'getAssistance'),
|
||||||
importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'),
|
importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'),
|
||||||
finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'),
|
finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'),
|
||||||
cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'),
|
cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'),
|
||||||
|
35
documentation/llm.md
Normal file
35
documentation/llm.md
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# Using Large Language Models with Grist
|
||||||
|
|
||||||
|
In this experimental Grist feature, originally developed by Alex Hall,
|
||||||
|
you can hook up an AI model such as OpenAI's Codex to write formulas for
|
||||||
|
you. Here's how.
|
||||||
|
|
||||||
|
First, you need an API key. You'll have best results currently with an
|
||||||
|
OpenAI model. Visit https://openai.com/api/ and prepare a key, then
|
||||||
|
store it in an environment variable `OPENAI_API_KEY`.
|
||||||
|
|
||||||
|
Alternatively, there are many non-proprietary models hosted on Hugging Face.
|
||||||
|
At the time of writing, none can compare with OpenAI for use with Grist.
|
||||||
|
Things can change quickly in the world of AI though. So instead of OpenAI,
|
||||||
|
you can visit https://huggingface.co/ and prepare a key, then
|
||||||
|
store it in an environment variable `HUGGINGFACE_API_KEY`.
|
||||||
|
|
||||||
|
That's all the configuration needed!
|
||||||
|
|
||||||
|
Currently it is only a backend feature, we are still working on the UI for it.
|
||||||
|
|
||||||
|
## Trying other models
|
||||||
|
|
||||||
|
The model used will default to `text-davinci-002` for OpenAI. You can
|
||||||
|
get better results by setting an environment variable `COMPLETION_MODEL` to
|
||||||
|
`code-davinci-002` if you have access to that model.
|
||||||
|
|
||||||
|
The model used will default to `NovelAI/genji-python-6B` for
|
||||||
|
Hugging Face. There's no particularly great model for this application,
|
||||||
|
but you can try other models by setting an environment variable
|
||||||
|
`COMPLETION_MODEL` to `codeparrot/codeparrot` or
|
||||||
|
`NinedayWang/PolyCoder-2.7B` or similar.
|
||||||
|
|
||||||
|
If you are hosting a model yourself, host it as Hugging Face does,
|
||||||
|
and use `COMPLETION_URL` rather than `COMPLETION_MODEL` to
|
||||||
|
point to the model on your own server rather than Hugging Face.
|
186
sandbox/grist/formula_prompt.py
Normal file
186
sandbox/grist/formula_prompt.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import json
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from column import is_visible_column, BaseReferenceColumn
|
||||||
|
from objtypes import RaisedException
|
||||||
|
import records
|
||||||
|
|
||||||
|
|
||||||
|
def column_type(engine, table_id, col_id):
|
||||||
|
col_rec = engine.docmodel.get_column_rec(table_id, col_id)
|
||||||
|
typ = col_rec.type
|
||||||
|
parts = typ.split(":")
|
||||||
|
if parts[0] == "Ref":
|
||||||
|
return parts[1]
|
||||||
|
elif parts[0] == "RefList":
|
||||||
|
return "List[{}]".format(parts[1])
|
||||||
|
elif typ == "Choice":
|
||||||
|
return choices(col_rec)
|
||||||
|
elif typ == "ChoiceList":
|
||||||
|
return "Tuple[{}, ...]".format(choices(col_rec))
|
||||||
|
elif typ == "Any":
|
||||||
|
table = engine.tables[table_id]
|
||||||
|
col = table.get_column(col_id)
|
||||||
|
values = [col.raw_get(row_id) for row_id in table.row_ids]
|
||||||
|
return values_type(values)
|
||||||
|
else:
|
||||||
|
return dict(
|
||||||
|
Text="str",
|
||||||
|
Numeric="float",
|
||||||
|
Int="int",
|
||||||
|
Bool="bool",
|
||||||
|
Date="datetime.date",
|
||||||
|
DateTime="datetime.datetime",
|
||||||
|
Any="Any",
|
||||||
|
Attachments="Any",
|
||||||
|
)[parts[0]]
|
||||||
|
|
||||||
|
|
||||||
|
def choices(col_rec):
|
||||||
|
try:
|
||||||
|
widget_options = json.loads(col_rec.widgetOptions)
|
||||||
|
return "Literal{}".format(widget_options["choices"])
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
return 'str'
|
||||||
|
|
||||||
|
|
||||||
|
def values_type(values):
|
||||||
|
types = set(type(v) for v in values) - {RaisedException}
|
||||||
|
optional = type(None) in types # pylint: disable=unidiomatic-typecheck
|
||||||
|
types.discard(type(None))
|
||||||
|
|
||||||
|
if types == {int, float}:
|
||||||
|
types = {float}
|
||||||
|
|
||||||
|
if len(types) != 1:
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
[typ] = types
|
||||||
|
val = next(v for v in values if isinstance(v, typ))
|
||||||
|
|
||||||
|
if isinstance(val, records.Record):
|
||||||
|
type_name = val._table.table_id
|
||||||
|
elif isinstance(val, records.RecordSet):
|
||||||
|
type_name = "List[{}]".format(val._table.table_id)
|
||||||
|
elif isinstance(val, list):
|
||||||
|
type_name = "List[{}]".format(values_type(val))
|
||||||
|
elif isinstance(val, set):
|
||||||
|
type_name = "Set[{}]".format(values_type(val))
|
||||||
|
elif isinstance(val, tuple):
|
||||||
|
type_name = "Tuple[{}, ...]".format(values_type(val))
|
||||||
|
elif isinstance(val, dict):
|
||||||
|
type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
|
||||||
|
else:
|
||||||
|
type_name = typ.__name__
|
||||||
|
|
||||||
|
if optional:
|
||||||
|
type_name = "Optional[{}]".format(type_name)
|
||||||
|
|
||||||
|
return type_name
|
||||||
|
|
||||||
|
|
||||||
|
def referenced_tables(engine, table_id):
|
||||||
|
result = set()
|
||||||
|
queue = [table_id]
|
||||||
|
while queue:
|
||||||
|
cur_table_id = queue.pop()
|
||||||
|
if cur_table_id in result:
|
||||||
|
continue
|
||||||
|
result.add(cur_table_id)
|
||||||
|
for col_id, col in visible_columns(engine, cur_table_id):
|
||||||
|
if isinstance(col, BaseReferenceColumn):
|
||||||
|
target_table_id = col._target_table.table_id
|
||||||
|
if not target_table_id.startswith("_"):
|
||||||
|
queue.append(target_table_id)
|
||||||
|
return result - {table_id}
|
||||||
|
|
||||||
|
def all_other_tables(engine, table_id):
|
||||||
|
result = set(t for t in engine.tables.keys() if not t.startswith('_grist'))
|
||||||
|
return result - {table_id} - {'GristDocTour'}
|
||||||
|
|
||||||
|
def visible_columns(engine, table_id):
|
||||||
|
return [
|
||||||
|
(col_id, col)
|
||||||
|
for col_id, col in engine.tables[table_id].all_columns.items()
|
||||||
|
if is_visible_column(col_id)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
|
||||||
|
result = "@dataclass\nclass {}:\n".format(table_id)
|
||||||
|
|
||||||
|
if lookups:
|
||||||
|
|
||||||
|
# Build a lookupRecords and lookupOne method for each table, providing some arguments hints
|
||||||
|
# for the columns that are visible.
|
||||||
|
lookupRecords_args = []
|
||||||
|
lookupOne_args = []
|
||||||
|
for col_id, col in visible_columns(engine, table_id):
|
||||||
|
if col_id != exclude_col_id:
|
||||||
|
lookupOne_args.append(col_id + '=None')
|
||||||
|
lookupRecords_args.append('%s=%s' % (col_id, col_id))
|
||||||
|
lookupOne_args.append('sort_by=None')
|
||||||
|
lookupRecords_args.append('sort_by=sort_by')
|
||||||
|
lookupOne_args_line = ', '.join(lookupOne_args)
|
||||||
|
lookupRecords_args_line = ', '.join(lookupRecords_args)
|
||||||
|
|
||||||
|
result += " def __len__(self):\n"
|
||||||
|
result += " return len(%s.lookupRecords())\n" % table_id
|
||||||
|
result += " @staticmethod\n"
|
||||||
|
result += " def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id)
|
||||||
|
result += " # ...\n"
|
||||||
|
result += " @staticmethod\n"
|
||||||
|
result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
|
||||||
|
result += " '''\n"
|
||||||
|
result += " Filter for one result matching the keys provided.\n"
|
||||||
|
result += " To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n"
|
||||||
|
result += " '''\n"
|
||||||
|
result += " return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line)
|
||||||
|
result += "\n"
|
||||||
|
|
||||||
|
for col_id, col in visible_columns(engine, table_id):
|
||||||
|
if col_id != exclude_col_id:
|
||||||
|
result += " {}: {}\n".format(col_id, column_type(engine, table_id, col_id))
|
||||||
|
result += "\n"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_formula_prompt(engine, table_id, col_id, description,
|
||||||
|
include_all_tables=True,
|
||||||
|
lookups=True):
|
||||||
|
result = ""
|
||||||
|
other_tables = (all_other_tables(engine, table_id)
|
||||||
|
if include_all_tables else referenced_tables(engine, table_id))
|
||||||
|
for other_table_id in sorted(other_tables):
|
||||||
|
result += class_schema(engine, other_table_id, lookups)
|
||||||
|
|
||||||
|
result += class_schema(engine, table_id, col_id, lookups)
|
||||||
|
|
||||||
|
return_type = column_type(engine, table_id, col_id)
|
||||||
|
result += " @property\n"
|
||||||
|
result += " # rec is alias for self\n"
|
||||||
|
result += " def {}(rec) -> {}:\n".format(col_id, return_type)
|
||||||
|
result += ' """\n'
|
||||||
|
result += '{}\n'.format(indent(description, " "))
|
||||||
|
result += ' """\n'
|
||||||
|
return result
|
||||||
|
|
||||||
|
def indent(text, prefix, predicate=None):
|
||||||
|
"""
|
||||||
|
Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility.
|
||||||
|
"""
|
||||||
|
if six.PY3:
|
||||||
|
return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member
|
||||||
|
if predicate is None:
|
||||||
|
def predicate(line):
|
||||||
|
return line.strip()
|
||||||
|
def prefixed_lines():
|
||||||
|
for line in text.splitlines(True):
|
||||||
|
yield (prefix + line if predicate(line) else line)
|
||||||
|
return ''.join(prefixed_lines())
|
||||||
|
|
||||||
|
def convert_completion(completion):
|
||||||
|
result = textwrap.dedent(completion)
|
||||||
|
return result
|
@ -16,6 +16,7 @@ import six
|
|||||||
|
|
||||||
import actions
|
import actions
|
||||||
import engine
|
import engine
|
||||||
|
import formula_prompt
|
||||||
import migrations
|
import migrations
|
||||||
import schema
|
import schema
|
||||||
import useractions
|
import useractions
|
||||||
@ -135,6 +136,14 @@ def run(sandbox):
|
|||||||
def get_formula_error(table_id, col_id, row_id):
|
def get_formula_error(table_id, col_id, row_id):
|
||||||
return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id))
|
return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id))
|
||||||
|
|
||||||
|
@export
|
||||||
|
def get_formula_prompt(table_id, col_id, description):
|
||||||
|
return formula_prompt.get_formula_prompt(eng, table_id, col_id, description)
|
||||||
|
|
||||||
|
@export
|
||||||
|
def convert_formula_completion(completion):
|
||||||
|
return formula_prompt.convert_completion(completion)
|
||||||
|
|
||||||
export(parse_acl_formula)
|
export(parse_acl_formula)
|
||||||
export(eng.load_empty)
|
export(eng.load_empty)
|
||||||
export(eng.load_done)
|
export(eng.load_done)
|
||||||
|
217
sandbox/grist/test_formula_prompt.py
Normal file
217
sandbox/grist/test_formula_prompt.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
import unittest
|
||||||
|
import six
|
||||||
|
|
||||||
|
import test_engine
|
||||||
|
import testutil
|
||||||
|
|
||||||
|
from formula_prompt import (
|
||||||
|
values_type, column_type, referenced_tables, get_formula_prompt,
|
||||||
|
)
|
||||||
|
from objtypes import RaisedException
|
||||||
|
from records import Record, RecordSet
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTable(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
table_id = "Table1"
|
||||||
|
_identity_relation = None
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(six.PY3, "Python 3 only")
|
||||||
|
class TestFormulaPrompt(test_engine.EngineTestCase):
|
||||||
|
def test_values_type(self):
|
||||||
|
self.assertEqual(values_type([1, 2, 3]), "int")
|
||||||
|
self.assertEqual(values_type([1.0, 2.0, 3.0]), "float")
|
||||||
|
self.assertEqual(values_type([1, 2, 3.0]), "float")
|
||||||
|
|
||||||
|
self.assertEqual(values_type([1, 2, None]), "Optional[int]")
|
||||||
|
self.assertEqual(values_type([1, 2, 3.0, None]), "Optional[float]")
|
||||||
|
|
||||||
|
self.assertEqual(values_type([1, RaisedException(None), 3]), "int")
|
||||||
|
self.assertEqual(values_type([1, RaisedException(None), None]), "Optional[int]")
|
||||||
|
|
||||||
|
self.assertEqual(values_type(["1", "2", "3"]), "str")
|
||||||
|
self.assertEqual(values_type([1, 2, "3"]), "Any")
|
||||||
|
self.assertEqual(values_type([1, 2, "3", None]), "Any")
|
||||||
|
|
||||||
|
self.assertEqual(values_type([
|
||||||
|
Record(FakeTable(), None),
|
||||||
|
Record(FakeTable(), None),
|
||||||
|
]), "Table1")
|
||||||
|
self.assertEqual(values_type([
|
||||||
|
Record(FakeTable(), None),
|
||||||
|
Record(FakeTable(), None),
|
||||||
|
None,
|
||||||
|
]), "Optional[Table1]")
|
||||||
|
|
||||||
|
self.assertEqual(values_type([
|
||||||
|
RecordSet(FakeTable(), None),
|
||||||
|
RecordSet(FakeTable(), None),
|
||||||
|
]), "List[Table1]")
|
||||||
|
self.assertEqual(values_type([
|
||||||
|
RecordSet(FakeTable(), None),
|
||||||
|
RecordSet(FakeTable(), None),
|
||||||
|
None,
|
||||||
|
]), "Optional[List[Table1]]")
|
||||||
|
|
||||||
|
self.assertEqual(values_type([[1, 2, 3]]), "List[int]")
|
||||||
|
self.assertEqual(values_type([[1, 2, 3], None]), "Optional[List[int]]")
|
||||||
|
self.assertEqual(values_type([[1, 2, None]]), "List[Optional[int]]")
|
||||||
|
self.assertEqual(values_type([[1, 2, None], None]), "Optional[List[Optional[int]]]")
|
||||||
|
self.assertEqual(values_type([[1, 2, "3"]]), "List[Any]")
|
||||||
|
|
||||||
|
self.assertEqual(values_type([{1, 2, 3}]), "Set[int]")
|
||||||
|
self.assertEqual(values_type([(1, 2, 3)]), "Tuple[int, ...]")
|
||||||
|
self.assertEqual(values_type([{1: ["2"]}]), "Dict[int, List[str]]")
|
||||||
|
|
||||||
|
def assert_column_type(self, col_id, expected_type):
|
||||||
|
self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type)
|
||||||
|
|
||||||
|
def assert_prompt(self, table_name, col_id, expected_prompt):
|
||||||
|
prompt = get_formula_prompt(self.engine, table_name, col_id, "description here",
|
||||||
|
include_all_tables=False, lookups=False)
|
||||||
|
# print(prompt)
|
||||||
|
self.assertEqual(prompt, expected_prompt)
|
||||||
|
|
||||||
|
def test_column_type(self):
|
||||||
|
sample = testutil.parse_test_sample({
|
||||||
|
"SCHEMA": [
|
||||||
|
[1, "Table2", [
|
||||||
|
[1, "text", "Text", False, "", "", ""],
|
||||||
|
[2, "numeric", "Numeric", False, "", "", ""],
|
||||||
|
[3, "int", "Int", False, "", "", ""],
|
||||||
|
[4, "bool", "Bool", False, "", "", ""],
|
||||||
|
[5, "date", "Date", False, "", "", ""],
|
||||||
|
[6, "datetime", "DateTime", False, "", "", ""],
|
||||||
|
[7, "attachments", "Attachments", False, "", "", ""],
|
||||||
|
[8, "ref", "Ref:Table2", False, "", "", ""],
|
||||||
|
[9, "reflist", "RefList:Table2", False, "", "", ""],
|
||||||
|
[10, "choice", "Choice", False, "", "", '{"choices": ["a", "b", "c"]}'],
|
||||||
|
[11, "choicelist", "ChoiceList", False, "", "", '{"choices": ["x", "y", "z"]}'],
|
||||||
|
[12, "ref_formula", "Any", True, "$ref or None", "", ""],
|
||||||
|
[13, "numeric_formula", "Any", True, "1 / $numeric", "", ""],
|
||||||
|
[14, "new_formula", "Numeric", True, "'to be generated...'", "", ""],
|
||||||
|
]],
|
||||||
|
],
|
||||||
|
"DATA": {
|
||||||
|
"Table2": [
|
||||||
|
["id", "numeric", "ref"],
|
||||||
|
[1, 0, 0],
|
||||||
|
[2, 1, 1],
|
||||||
|
],
|
||||||
|
},
|
||||||
|
})
|
||||||
|
self.load_sample(sample)
|
||||||
|
|
||||||
|
self.assert_column_type("text", "str")
|
||||||
|
self.assert_column_type("numeric", "float")
|
||||||
|
self.assert_column_type("int", "int")
|
||||||
|
self.assert_column_type("bool", "bool")
|
||||||
|
self.assert_column_type("date", "datetime.date")
|
||||||
|
self.assert_column_type("datetime", "datetime.datetime")
|
||||||
|
self.assert_column_type("attachments", "Any")
|
||||||
|
self.assert_column_type("ref", "Table2")
|
||||||
|
self.assert_column_type("reflist", "List[Table2]")
|
||||||
|
self.assert_column_type("choice", "Literal['a', 'b', 'c']")
|
||||||
|
self.assert_column_type("choicelist", "Tuple[Literal['x', 'y', 'z'], ...]")
|
||||||
|
self.assert_column_type("ref_formula", "Optional[Table2]")
|
||||||
|
self.assert_column_type("numeric_formula", "float")
|
||||||
|
|
||||||
|
self.assertEqual(referenced_tables(self.engine, "Table2"), set())
|
||||||
|
|
||||||
|
self.assert_prompt("Table2", "new_formula",
|
||||||
|
'''\
|
||||||
|
@dataclass
|
||||||
|
class Table2:
|
||||||
|
text: str
|
||||||
|
numeric: float
|
||||||
|
int: int
|
||||||
|
bool: bool
|
||||||
|
date: datetime.date
|
||||||
|
datetime: datetime.datetime
|
||||||
|
attachments: Any
|
||||||
|
ref: Table2
|
||||||
|
reflist: List[Table2]
|
||||||
|
choice: Literal['a', 'b', 'c']
|
||||||
|
choicelist: Tuple[Literal['x', 'y', 'z'], ...]
|
||||||
|
ref_formula: Optional[Table2]
|
||||||
|
numeric_formula: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
# rec is alias for self
|
||||||
|
def new_formula(rec) -> float:
|
||||||
|
"""
|
||||||
|
description here
|
||||||
|
"""
|
||||||
|
''')
|
||||||
|
|
||||||
|
def test_get_formula_prompt(self):
|
||||||
|
sample = testutil.parse_test_sample({
|
||||||
|
"SCHEMA": [
|
||||||
|
[1, "Table1", [
|
||||||
|
[1, "text", "Text", False, "", "", ""],
|
||||||
|
]],
|
||||||
|
[2, "Table2", [
|
||||||
|
[2, "ref", "Ref:Table1", False, "", "", ""],
|
||||||
|
]],
|
||||||
|
[3, "Table3", [
|
||||||
|
[3, "reflist", "RefList:Table2", False, "", "", ""],
|
||||||
|
]],
|
||||||
|
],
|
||||||
|
"DATA": {},
|
||||||
|
})
|
||||||
|
self.load_sample(sample)
|
||||||
|
self.assertEqual(referenced_tables(self.engine, "Table3"), {"Table1", "Table2"})
|
||||||
|
self.assertEqual(referenced_tables(self.engine, "Table2"), {"Table1"})
|
||||||
|
self.assertEqual(referenced_tables(self.engine, "Table1"), set())
|
||||||
|
|
||||||
|
self.assert_prompt("Table1", "text", '''\
|
||||||
|
@dataclass
|
||||||
|
class Table1:
|
||||||
|
|
||||||
|
@property
|
||||||
|
# rec is alias for self
|
||||||
|
def text(rec) -> str:
|
||||||
|
"""
|
||||||
|
description here
|
||||||
|
"""
|
||||||
|
''')
|
||||||
|
|
||||||
|
self.assert_prompt("Table2", "ref", '''\
|
||||||
|
@dataclass
|
||||||
|
class Table1:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Table2:
|
||||||
|
|
||||||
|
@property
|
||||||
|
# rec is alias for self
|
||||||
|
def ref(rec) -> Table1:
|
||||||
|
"""
|
||||||
|
description here
|
||||||
|
"""
|
||||||
|
''')
|
||||||
|
|
||||||
|
self.assert_prompt("Table3", "reflist", '''\
|
||||||
|
@dataclass
|
||||||
|
class Table1:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Table2:
|
||||||
|
ref: Table1
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Table3:
|
||||||
|
|
||||||
|
@property
|
||||||
|
# rec is alias for self
|
||||||
|
def reflist(rec) -> List[Table2]:
|
||||||
|
"""
|
||||||
|
description here
|
||||||
|
"""
|
||||||
|
''')
|
Loading…
Reference in New Issue
Block a user