mirror of
https://github.com/gristlabs/grist-core.git
synced 2024-10-27 20:44:07 +00:00
(core) Porting back AI formula backend
Summary: This is a backend part for the formula AI. Test Plan: New tests Reviewers: paulfitz Reviewed By: paulfitz Subscribers: cyprien Differential Revision: https://phab.getgrist.com/D3786
This commit is contained in:
parent
ef0a55ced1
commit
6e3f0f2b35
@ -32,6 +32,7 @@ export class DocComm extends Disposable implements ActiveDocAPI {
|
||||
public addAttachments = this._wrapMethod("addAttachments");
|
||||
public findColFromValues = this._wrapMethod("findColFromValues");
|
||||
public getFormulaError = this._wrapMethod("getFormulaError");
|
||||
public getAssistance = this._wrapMethod("getAssistance");
|
||||
public fetchURL = this._wrapMethod("fetchURL");
|
||||
public autocomplete = this._wrapMethod("autocomplete");
|
||||
public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc");
|
||||
|
@ -318,6 +318,11 @@ export interface ActiveDocAPI {
|
||||
*/
|
||||
getFormulaError(tableId: string, colId: string, rowId: number): Promise<CellValue>;
|
||||
|
||||
/**
|
||||
* Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
|
||||
*/
|
||||
getAssistance(tableId: string, colId: string, description: string): Promise<void>;
|
||||
|
||||
/**
|
||||
* Fetch content at a url.
|
||||
*/
|
||||
|
11
app/common/AssistancePrompts.ts
Normal file
11
app/common/AssistancePrompts.ts
Normal file
@ -0,0 +1,11 @@
|
||||
import {DocAction} from 'app/common/DocActions';
|
||||
|
||||
export interface Prompt {
|
||||
tableId: string;
|
||||
colId: string
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface Suggestion {
|
||||
suggestedActions: DocAction[];
|
||||
}
|
@ -14,6 +14,7 @@ import {
|
||||
} from 'app/common/ActionBundle';
|
||||
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
|
||||
import {ActionSummary} from "app/common/ActionSummary";
|
||||
import {Prompt, Suggestion} from "app/common/AssistancePrompts";
|
||||
import {
|
||||
AclResources,
|
||||
AclTableDescription,
|
||||
@ -80,6 +81,7 @@ import {parseUserAction} from 'app/common/ValueParser';
|
||||
import {ParseOptions} from 'app/plugin/FileParserAPI';
|
||||
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
|
||||
import {compileAclFormula} from 'app/server/lib/ACLFormula';
|
||||
import {sendForCompletion} from 'app/server/lib/Assistance';
|
||||
import {Authorizer} from 'app/server/lib/Authorizer';
|
||||
import {checksumFile} from 'app/server/lib/checksumFile';
|
||||
import {Client} from 'app/server/lib/Client';
|
||||
@ -1313,6 +1315,24 @@ export class ActiveDoc extends EventEmitter {
|
||||
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
|
||||
}
|
||||
|
||||
public async getAssistance(docSession: DocSession, userPrompt: Prompt): Promise<Suggestion> {
|
||||
// Making a prompt can leak names of tables and columns.
|
||||
if (!await this._granularAccess.canScanData(docSession)) {
|
||||
throw new Error("Permission denied");
|
||||
}
|
||||
await this.waitForInitialization();
|
||||
const { tableId, colId, description } = userPrompt;
|
||||
const prompt = await this._pyCall('get_formula_prompt', tableId, colId, description);
|
||||
this._log.debug(docSession, 'getAssistance prompt', {prompt});
|
||||
const completion = await sendForCompletion(prompt);
|
||||
this._log.debug(docSession, 'getAssistance completion', {completion});
|
||||
const formula = await this._pyCall('convert_formula_completion', completion);
|
||||
const action: DocAction = ["ModifyColumn", tableId, colId, {formula}];
|
||||
return {
|
||||
suggestedActions: [action],
|
||||
};
|
||||
}
|
||||
|
||||
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
|
||||
return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
|
||||
}
|
||||
|
110
app/server/lib/Assistance.ts
Normal file
110
app/server/lib/Assistance.ts
Normal file
@ -0,0 +1,110 @@
|
||||
/**
|
||||
* Module with functions used for AI formula assistance.
|
||||
*/
|
||||
|
||||
import {delay} from 'app/common/delay';
|
||||
import log from 'app/server/lib/log';
|
||||
import fetch, { Response as FetchResponse} from 'node-fetch';
|
||||
|
||||
|
||||
export async function sendForCompletion(prompt: string): Promise<string> {
|
||||
let completion: string|null = null;
|
||||
if (process.env.OPENAI_API_KEY) {
|
||||
completion = await sendForCompletionOpenAI(prompt);
|
||||
}
|
||||
if (process.env.HUGGINGFACE_API_KEY) {
|
||||
completion = await sendForCompletionHuggingFace(prompt);
|
||||
}
|
||||
if (completion === null) {
|
||||
throw new Error("Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY (and optionally COMPLETION_MODEL)");
|
||||
}
|
||||
log.debug(`Received completion:`, {completion});
|
||||
completion = completion.split(/\n {4}[^ ]/)[0];
|
||||
return completion;
|
||||
}
|
||||
|
||||
|
||||
async function sendForCompletionOpenAI(prompt: string) {
|
||||
const apiKey = process.env.OPENAI_API_KEY;
|
||||
if (!apiKey) {
|
||||
throw new Error("OPENAI_API_KEY not set");
|
||||
}
|
||||
const response = await fetch(
|
||||
"https://api.openai.com/v1/completions",
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
prompt,
|
||||
max_tokens: 150,
|
||||
temperature: 0,
|
||||
// COMPLETION_MODEL of `code-davinci-002` may be better if you have access to it.
|
||||
model: process.env.COMPLETION_MODEL || "text-davinci-002",
|
||||
stop: ["\n\n"],
|
||||
}),
|
||||
},
|
||||
);
|
||||
if (response.status !== 200) {
|
||||
log.error(`OpenAI API returned ${response.status}: ${await response.text()}`);
|
||||
throw new Error(`OpenAI API returned status ${response.status}`);
|
||||
}
|
||||
const result = await response.json();
|
||||
const completion = result.choices[0].text;
|
||||
return completion;
|
||||
}
|
||||
|
||||
async function sendForCompletionHuggingFace(prompt: string) {
|
||||
const apiKey = process.env.HUGGINGFACE_API_KEY;
|
||||
if (!apiKey) {
|
||||
throw new Error("HUGGINGFACE_API_KEY not set");
|
||||
}
|
||||
// COMPLETION_MODEL values I've tried:
|
||||
// - codeparrot/codeparrot
|
||||
// - NinedayWang/PolyCoder-2.7B
|
||||
// - NovelAI/genji-python-6B
|
||||
let completionUrl = process.env.COMPLETION_URL;
|
||||
if (!completionUrl) {
|
||||
if (process.env.COMPLETION_MODEL) {
|
||||
completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`;
|
||||
} else {
|
||||
completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
|
||||
}
|
||||
}
|
||||
let retries: number = 0;
|
||||
let response!: FetchResponse;
|
||||
while (retries++ < 3) {
|
||||
response = await fetch(
|
||||
completionUrl,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": `Bearer ${apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
inputs: prompt,
|
||||
parameters: {
|
||||
return_full_text: false,
|
||||
max_new_tokens: 50,
|
||||
},
|
||||
}),
|
||||
},
|
||||
);
|
||||
if (response.status === 503) {
|
||||
log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`);
|
||||
await delay(10000);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (response.status !== 200) {
|
||||
const text = await response.text();
|
||||
log.error(`HuggingFace API returned ${response.status}: ${text}`);
|
||||
throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
|
||||
}
|
||||
const result = await response.json();
|
||||
const completion = result[0].generated_text;
|
||||
return completion.split('\n\n')[0];
|
||||
}
|
@ -110,6 +110,7 @@ export class DocWorker {
|
||||
applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'),
|
||||
findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'),
|
||||
getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'),
|
||||
getAssistance: activeDocMethod.bind(null, 'editors', 'getAssistance'),
|
||||
importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'),
|
||||
finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'),
|
||||
cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'),
|
||||
|
35
documentation/llm.md
Normal file
35
documentation/llm.md
Normal file
@ -0,0 +1,35 @@
|
||||
# Using Large Language Models with Grist
|
||||
|
||||
In this experimental Grist feature, originally developed by Alex Hall,
|
||||
you can hook up an AI model such as OpenAI's Codex to write formulas for
|
||||
you. Here's how.
|
||||
|
||||
First, you need an API key. You'll have best results currently with an
|
||||
OpenAI model. Visit https://openai.com/api/ and prepare a key, then
|
||||
store it in an environment variable `OPENAI_API_KEY`.
|
||||
|
||||
Alternatively, there are many non-proprietary models hosted on Hugging Face.
|
||||
At the time of writing, none can compare with OpenAI for use with Grist.
|
||||
Things can change quickly in the world of AI though. So instead of OpenAI,
|
||||
you can visit https://huggingface.co/ and prepare a key, then
|
||||
store it in an environment variable `HUGGINGFACE_API_KEY`.
|
||||
|
||||
That's all the configuration needed!
|
||||
|
||||
Currently it is only a backend feature, we are still working on the UI for it.
|
||||
|
||||
## Trying other models
|
||||
|
||||
The model used will default to `text-davinci-002` for OpenAI. You can
|
||||
get better results by setting an environment variable `COMPLETION_MODEL` to
|
||||
`code-davinci-002` if you have access to that model.
|
||||
|
||||
The model used will default to `NovelAI/genji-python-6B` for
|
||||
Hugging Face. There's no particularly great model for this application,
|
||||
but you can try other models by setting an environment variable
|
||||
`COMPLETION_MODEL` to `codeparrot/codeparrot` or
|
||||
`NinedayWang/PolyCoder-2.7B` or similar.
|
||||
|
||||
If you are hosting a model yourself, host it as Hugging Face does,
|
||||
and use `COMPLETION_URL` rather than `COMPLETION_MODEL` to
|
||||
point to the model on your own server rather than Hugging Face.
|
186
sandbox/grist/formula_prompt.py
Normal file
186
sandbox/grist/formula_prompt.py
Normal file
@ -0,0 +1,186 @@
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
import six
|
||||
|
||||
from column import is_visible_column, BaseReferenceColumn
|
||||
from objtypes import RaisedException
|
||||
import records
|
||||
|
||||
|
||||
def column_type(engine, table_id, col_id):
|
||||
col_rec = engine.docmodel.get_column_rec(table_id, col_id)
|
||||
typ = col_rec.type
|
||||
parts = typ.split(":")
|
||||
if parts[0] == "Ref":
|
||||
return parts[1]
|
||||
elif parts[0] == "RefList":
|
||||
return "List[{}]".format(parts[1])
|
||||
elif typ == "Choice":
|
||||
return choices(col_rec)
|
||||
elif typ == "ChoiceList":
|
||||
return "Tuple[{}, ...]".format(choices(col_rec))
|
||||
elif typ == "Any":
|
||||
table = engine.tables[table_id]
|
||||
col = table.get_column(col_id)
|
||||
values = [col.raw_get(row_id) for row_id in table.row_ids]
|
||||
return values_type(values)
|
||||
else:
|
||||
return dict(
|
||||
Text="str",
|
||||
Numeric="float",
|
||||
Int="int",
|
||||
Bool="bool",
|
||||
Date="datetime.date",
|
||||
DateTime="datetime.datetime",
|
||||
Any="Any",
|
||||
Attachments="Any",
|
||||
)[parts[0]]
|
||||
|
||||
|
||||
def choices(col_rec):
|
||||
try:
|
||||
widget_options = json.loads(col_rec.widgetOptions)
|
||||
return "Literal{}".format(widget_options["choices"])
|
||||
except (ValueError, KeyError):
|
||||
return 'str'
|
||||
|
||||
|
||||
def values_type(values):
|
||||
types = set(type(v) for v in values) - {RaisedException}
|
||||
optional = type(None) in types # pylint: disable=unidiomatic-typecheck
|
||||
types.discard(type(None))
|
||||
|
||||
if types == {int, float}:
|
||||
types = {float}
|
||||
|
||||
if len(types) != 1:
|
||||
return "Any"
|
||||
|
||||
[typ] = types
|
||||
val = next(v for v in values if isinstance(v, typ))
|
||||
|
||||
if isinstance(val, records.Record):
|
||||
type_name = val._table.table_id
|
||||
elif isinstance(val, records.RecordSet):
|
||||
type_name = "List[{}]".format(val._table.table_id)
|
||||
elif isinstance(val, list):
|
||||
type_name = "List[{}]".format(values_type(val))
|
||||
elif isinstance(val, set):
|
||||
type_name = "Set[{}]".format(values_type(val))
|
||||
elif isinstance(val, tuple):
|
||||
type_name = "Tuple[{}, ...]".format(values_type(val))
|
||||
elif isinstance(val, dict):
|
||||
type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
|
||||
else:
|
||||
type_name = typ.__name__
|
||||
|
||||
if optional:
|
||||
type_name = "Optional[{}]".format(type_name)
|
||||
|
||||
return type_name
|
||||
|
||||
|
||||
def referenced_tables(engine, table_id):
|
||||
result = set()
|
||||
queue = [table_id]
|
||||
while queue:
|
||||
cur_table_id = queue.pop()
|
||||
if cur_table_id in result:
|
||||
continue
|
||||
result.add(cur_table_id)
|
||||
for col_id, col in visible_columns(engine, cur_table_id):
|
||||
if isinstance(col, BaseReferenceColumn):
|
||||
target_table_id = col._target_table.table_id
|
||||
if not target_table_id.startswith("_"):
|
||||
queue.append(target_table_id)
|
||||
return result - {table_id}
|
||||
|
||||
def all_other_tables(engine, table_id):
|
||||
result = set(t for t in engine.tables.keys() if not t.startswith('_grist'))
|
||||
return result - {table_id} - {'GristDocTour'}
|
||||
|
||||
def visible_columns(engine, table_id):
|
||||
return [
|
||||
(col_id, col)
|
||||
for col_id, col in engine.tables[table_id].all_columns.items()
|
||||
if is_visible_column(col_id)
|
||||
]
|
||||
|
||||
|
||||
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
|
||||
result = "@dataclass\nclass {}:\n".format(table_id)
|
||||
|
||||
if lookups:
|
||||
|
||||
# Build a lookupRecords and lookupOne method for each table, providing some arguments hints
|
||||
# for the columns that are visible.
|
||||
lookupRecords_args = []
|
||||
lookupOne_args = []
|
||||
for col_id, col in visible_columns(engine, table_id):
|
||||
if col_id != exclude_col_id:
|
||||
lookupOne_args.append(col_id + '=None')
|
||||
lookupRecords_args.append('%s=%s' % (col_id, col_id))
|
||||
lookupOne_args.append('sort_by=None')
|
||||
lookupRecords_args.append('sort_by=sort_by')
|
||||
lookupOne_args_line = ', '.join(lookupOne_args)
|
||||
lookupRecords_args_line = ', '.join(lookupRecords_args)
|
||||
|
||||
result += " def __len__(self):\n"
|
||||
result += " return len(%s.lookupRecords())\n" % table_id
|
||||
result += " @staticmethod\n"
|
||||
result += " def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id)
|
||||
result += " # ...\n"
|
||||
result += " @staticmethod\n"
|
||||
result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
|
||||
result += " '''\n"
|
||||
result += " Filter for one result matching the keys provided.\n"
|
||||
result += " To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n"
|
||||
result += " '''\n"
|
||||
result += " return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line)
|
||||
result += "\n"
|
||||
|
||||
for col_id, col in visible_columns(engine, table_id):
|
||||
if col_id != exclude_col_id:
|
||||
result += " {}: {}\n".format(col_id, column_type(engine, table_id, col_id))
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
|
||||
def get_formula_prompt(engine, table_id, col_id, description,
|
||||
include_all_tables=True,
|
||||
lookups=True):
|
||||
result = ""
|
||||
other_tables = (all_other_tables(engine, table_id)
|
||||
if include_all_tables else referenced_tables(engine, table_id))
|
||||
for other_table_id in sorted(other_tables):
|
||||
result += class_schema(engine, other_table_id, lookups)
|
||||
|
||||
result += class_schema(engine, table_id, col_id, lookups)
|
||||
|
||||
return_type = column_type(engine, table_id, col_id)
|
||||
result += " @property\n"
|
||||
result += " # rec is alias for self\n"
|
||||
result += " def {}(rec) -> {}:\n".format(col_id, return_type)
|
||||
result += ' """\n'
|
||||
result += '{}\n'.format(indent(description, " "))
|
||||
result += ' """\n'
|
||||
return result
|
||||
|
||||
def indent(text, prefix, predicate=None):
|
||||
"""
|
||||
Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility.
|
||||
"""
|
||||
if six.PY3:
|
||||
return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member
|
||||
if predicate is None:
|
||||
def predicate(line):
|
||||
return line.strip()
|
||||
def prefixed_lines():
|
||||
for line in text.splitlines(True):
|
||||
yield (prefix + line if predicate(line) else line)
|
||||
return ''.join(prefixed_lines())
|
||||
|
||||
def convert_completion(completion):
|
||||
result = textwrap.dedent(completion)
|
||||
return result
|
@ -16,6 +16,7 @@ import six
|
||||
|
||||
import actions
|
||||
import engine
|
||||
import formula_prompt
|
||||
import migrations
|
||||
import schema
|
||||
import useractions
|
||||
@ -135,6 +136,14 @@ def run(sandbox):
|
||||
def get_formula_error(table_id, col_id, row_id):
|
||||
return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id))
|
||||
|
||||
@export
|
||||
def get_formula_prompt(table_id, col_id, description):
|
||||
return formula_prompt.get_formula_prompt(eng, table_id, col_id, description)
|
||||
|
||||
@export
|
||||
def convert_formula_completion(completion):
|
||||
return formula_prompt.convert_completion(completion)
|
||||
|
||||
export(parse_acl_formula)
|
||||
export(eng.load_empty)
|
||||
export(eng.load_done)
|
||||
|
217
sandbox/grist/test_formula_prompt.py
Normal file
217
sandbox/grist/test_formula_prompt.py
Normal file
@ -0,0 +1,217 @@
|
||||
import unittest
|
||||
import six
|
||||
|
||||
import test_engine
|
||||
import testutil
|
||||
|
||||
from formula_prompt import (
|
||||
values_type, column_type, referenced_tables, get_formula_prompt,
|
||||
)
|
||||
from objtypes import RaisedException
|
||||
from records import Record, RecordSet
|
||||
|
||||
|
||||
class FakeTable(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
table_id = "Table1"
|
||||
_identity_relation = None
|
||||
|
||||
|
||||
@unittest.skipUnless(six.PY3, "Python 3 only")
|
||||
class TestFormulaPrompt(test_engine.EngineTestCase):
|
||||
def test_values_type(self):
|
||||
self.assertEqual(values_type([1, 2, 3]), "int")
|
||||
self.assertEqual(values_type([1.0, 2.0, 3.0]), "float")
|
||||
self.assertEqual(values_type([1, 2, 3.0]), "float")
|
||||
|
||||
self.assertEqual(values_type([1, 2, None]), "Optional[int]")
|
||||
self.assertEqual(values_type([1, 2, 3.0, None]), "Optional[float]")
|
||||
|
||||
self.assertEqual(values_type([1, RaisedException(None), 3]), "int")
|
||||
self.assertEqual(values_type([1, RaisedException(None), None]), "Optional[int]")
|
||||
|
||||
self.assertEqual(values_type(["1", "2", "3"]), "str")
|
||||
self.assertEqual(values_type([1, 2, "3"]), "Any")
|
||||
self.assertEqual(values_type([1, 2, "3", None]), "Any")
|
||||
|
||||
self.assertEqual(values_type([
|
||||
Record(FakeTable(), None),
|
||||
Record(FakeTable(), None),
|
||||
]), "Table1")
|
||||
self.assertEqual(values_type([
|
||||
Record(FakeTable(), None),
|
||||
Record(FakeTable(), None),
|
||||
None,
|
||||
]), "Optional[Table1]")
|
||||
|
||||
self.assertEqual(values_type([
|
||||
RecordSet(FakeTable(), None),
|
||||
RecordSet(FakeTable(), None),
|
||||
]), "List[Table1]")
|
||||
self.assertEqual(values_type([
|
||||
RecordSet(FakeTable(), None),
|
||||
RecordSet(FakeTable(), None),
|
||||
None,
|
||||
]), "Optional[List[Table1]]")
|
||||
|
||||
self.assertEqual(values_type([[1, 2, 3]]), "List[int]")
|
||||
self.assertEqual(values_type([[1, 2, 3], None]), "Optional[List[int]]")
|
||||
self.assertEqual(values_type([[1, 2, None]]), "List[Optional[int]]")
|
||||
self.assertEqual(values_type([[1, 2, None], None]), "Optional[List[Optional[int]]]")
|
||||
self.assertEqual(values_type([[1, 2, "3"]]), "List[Any]")
|
||||
|
||||
self.assertEqual(values_type([{1, 2, 3}]), "Set[int]")
|
||||
self.assertEqual(values_type([(1, 2, 3)]), "Tuple[int, ...]")
|
||||
self.assertEqual(values_type([{1: ["2"]}]), "Dict[int, List[str]]")
|
||||
|
||||
def assert_column_type(self, col_id, expected_type):
|
||||
self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type)
|
||||
|
||||
def assert_prompt(self, table_name, col_id, expected_prompt):
|
||||
prompt = get_formula_prompt(self.engine, table_name, col_id, "description here",
|
||||
include_all_tables=False, lookups=False)
|
||||
# print(prompt)
|
||||
self.assertEqual(prompt, expected_prompt)
|
||||
|
||||
def test_column_type(self):
|
||||
sample = testutil.parse_test_sample({
|
||||
"SCHEMA": [
|
||||
[1, "Table2", [
|
||||
[1, "text", "Text", False, "", "", ""],
|
||||
[2, "numeric", "Numeric", False, "", "", ""],
|
||||
[3, "int", "Int", False, "", "", ""],
|
||||
[4, "bool", "Bool", False, "", "", ""],
|
||||
[5, "date", "Date", False, "", "", ""],
|
||||
[6, "datetime", "DateTime", False, "", "", ""],
|
||||
[7, "attachments", "Attachments", False, "", "", ""],
|
||||
[8, "ref", "Ref:Table2", False, "", "", ""],
|
||||
[9, "reflist", "RefList:Table2", False, "", "", ""],
|
||||
[10, "choice", "Choice", False, "", "", '{"choices": ["a", "b", "c"]}'],
|
||||
[11, "choicelist", "ChoiceList", False, "", "", '{"choices": ["x", "y", "z"]}'],
|
||||
[12, "ref_formula", "Any", True, "$ref or None", "", ""],
|
||||
[13, "numeric_formula", "Any", True, "1 / $numeric", "", ""],
|
||||
[14, "new_formula", "Numeric", True, "'to be generated...'", "", ""],
|
||||
]],
|
||||
],
|
||||
"DATA": {
|
||||
"Table2": [
|
||||
["id", "numeric", "ref"],
|
||||
[1, 0, 0],
|
||||
[2, 1, 1],
|
||||
],
|
||||
},
|
||||
})
|
||||
self.load_sample(sample)
|
||||
|
||||
self.assert_column_type("text", "str")
|
||||
self.assert_column_type("numeric", "float")
|
||||
self.assert_column_type("int", "int")
|
||||
self.assert_column_type("bool", "bool")
|
||||
self.assert_column_type("date", "datetime.date")
|
||||
self.assert_column_type("datetime", "datetime.datetime")
|
||||
self.assert_column_type("attachments", "Any")
|
||||
self.assert_column_type("ref", "Table2")
|
||||
self.assert_column_type("reflist", "List[Table2]")
|
||||
self.assert_column_type("choice", "Literal['a', 'b', 'c']")
|
||||
self.assert_column_type("choicelist", "Tuple[Literal['x', 'y', 'z'], ...]")
|
||||
self.assert_column_type("ref_formula", "Optional[Table2]")
|
||||
self.assert_column_type("numeric_formula", "float")
|
||||
|
||||
self.assertEqual(referenced_tables(self.engine, "Table2"), set())
|
||||
|
||||
self.assert_prompt("Table2", "new_formula",
|
||||
'''\
|
||||
@dataclass
|
||||
class Table2:
|
||||
text: str
|
||||
numeric: float
|
||||
int: int
|
||||
bool: bool
|
||||
date: datetime.date
|
||||
datetime: datetime.datetime
|
||||
attachments: Any
|
||||
ref: Table2
|
||||
reflist: List[Table2]
|
||||
choice: Literal['a', 'b', 'c']
|
||||
choicelist: Tuple[Literal['x', 'y', 'z'], ...]
|
||||
ref_formula: Optional[Table2]
|
||||
numeric_formula: float
|
||||
|
||||
@property
|
||||
# rec is alias for self
|
||||
def new_formula(rec) -> float:
|
||||
"""
|
||||
description here
|
||||
"""
|
||||
''')
|
||||
|
||||
def test_get_formula_prompt(self):
|
||||
sample = testutil.parse_test_sample({
|
||||
"SCHEMA": [
|
||||
[1, "Table1", [
|
||||
[1, "text", "Text", False, "", "", ""],
|
||||
]],
|
||||
[2, "Table2", [
|
||||
[2, "ref", "Ref:Table1", False, "", "", ""],
|
||||
]],
|
||||
[3, "Table3", [
|
||||
[3, "reflist", "RefList:Table2", False, "", "", ""],
|
||||
]],
|
||||
],
|
||||
"DATA": {},
|
||||
})
|
||||
self.load_sample(sample)
|
||||
self.assertEqual(referenced_tables(self.engine, "Table3"), {"Table1", "Table2"})
|
||||
self.assertEqual(referenced_tables(self.engine, "Table2"), {"Table1"})
|
||||
self.assertEqual(referenced_tables(self.engine, "Table1"), set())
|
||||
|
||||
self.assert_prompt("Table1", "text", '''\
|
||||
@dataclass
|
||||
class Table1:
|
||||
|
||||
@property
|
||||
# rec is alias for self
|
||||
def text(rec) -> str:
|
||||
"""
|
||||
description here
|
||||
"""
|
||||
''')
|
||||
|
||||
self.assert_prompt("Table2", "ref", '''\
|
||||
@dataclass
|
||||
class Table1:
|
||||
text: str
|
||||
|
||||
@dataclass
|
||||
class Table2:
|
||||
|
||||
@property
|
||||
# rec is alias for self
|
||||
def ref(rec) -> Table1:
|
||||
"""
|
||||
description here
|
||||
"""
|
||||
''')
|
||||
|
||||
self.assert_prompt("Table3", "reflist", '''\
|
||||
@dataclass
|
||||
class Table1:
|
||||
text: str
|
||||
|
||||
@dataclass
|
||||
class Table2:
|
||||
ref: Table1
|
||||
|
||||
@dataclass
|
||||
class Table3:
|
||||
|
||||
@property
|
||||
# rec is alias for self
|
||||
def reflist(rec) -> List[Table2]:
|
||||
"""
|
||||
description here
|
||||
"""
|
||||
''')
|
Loading…
Reference in New Issue
Block a user