(core) Porting back AI formula backend

Summary: This is a backend part for the formula AI.

Test Plan: New tests

Reviewers: paulfitz

Reviewed By: paulfitz

Subscribers: cyprien

Differential Revision: https://phab.getgrist.com/D3786
This commit is contained in:
Jarosław Sadziński 2023-02-08 16:46:34 +01:00
parent ef0a55ced1
commit 6e3f0f2b35
10 changed files with 595 additions and 0 deletions

View File

@ -32,6 +32,7 @@ export class DocComm extends Disposable implements ActiveDocAPI {
public addAttachments = this._wrapMethod("addAttachments"); public addAttachments = this._wrapMethod("addAttachments");
public findColFromValues = this._wrapMethod("findColFromValues"); public findColFromValues = this._wrapMethod("findColFromValues");
public getFormulaError = this._wrapMethod("getFormulaError"); public getFormulaError = this._wrapMethod("getFormulaError");
public getAssistance = this._wrapMethod("getAssistance");
public fetchURL = this._wrapMethod("fetchURL"); public fetchURL = this._wrapMethod("fetchURL");
public autocomplete = this._wrapMethod("autocomplete"); public autocomplete = this._wrapMethod("autocomplete");
public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc"); public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc");

View File

@ -318,6 +318,11 @@ export interface ActiveDocAPI {
*/ */
getFormulaError(tableId: string, colId: string, rowId: number): Promise<CellValue>; getFormulaError(tableId: string, colId: string, rowId: number): Promise<CellValue>;
/**
* Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
*/
getAssistance(tableId: string, colId: string, description: string): Promise<void>;
/** /**
* Fetch content at a url. * Fetch content at a url.
*/ */

View File

@ -0,0 +1,11 @@
import {DocAction} from 'app/common/DocActions';
export interface Prompt {
tableId: string;
colId: string
description: string;
}
export interface Suggestion {
suggestedActions: DocAction[];
}

View File

@ -14,6 +14,7 @@ import {
} from 'app/common/ActionBundle'; } from 'app/common/ActionBundle';
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup'; import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
import {ActionSummary} from "app/common/ActionSummary"; import {ActionSummary} from "app/common/ActionSummary";
import {Prompt, Suggestion} from "app/common/AssistancePrompts";
import { import {
AclResources, AclResources,
AclTableDescription, AclTableDescription,
@ -80,6 +81,7 @@ import {parseUserAction} from 'app/common/ValueParser';
import {ParseOptions} from 'app/plugin/FileParserAPI'; import {ParseOptions} from 'app/plugin/FileParserAPI';
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI'; import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
import {compileAclFormula} from 'app/server/lib/ACLFormula'; import {compileAclFormula} from 'app/server/lib/ACLFormula';
import {sendForCompletion} from 'app/server/lib/Assistance';
import {Authorizer} from 'app/server/lib/Authorizer'; import {Authorizer} from 'app/server/lib/Authorizer';
import {checksumFile} from 'app/server/lib/checksumFile'; import {checksumFile} from 'app/server/lib/checksumFile';
import {Client} from 'app/server/lib/Client'; import {Client} from 'app/server/lib/Client';
@ -1313,6 +1315,24 @@ export class ActiveDoc extends EventEmitter {
return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON()); return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
} }
public async getAssistance(docSession: DocSession, userPrompt: Prompt): Promise<Suggestion> {
// Making a prompt can leak names of tables and columns.
if (!await this._granularAccess.canScanData(docSession)) {
throw new Error("Permission denied");
}
await this.waitForInitialization();
const { tableId, colId, description } = userPrompt;
const prompt = await this._pyCall('get_formula_prompt', tableId, colId, description);
this._log.debug(docSession, 'getAssistance prompt', {prompt});
const completion = await sendForCompletion(prompt);
this._log.debug(docSession, 'getAssistance completion', {completion});
const formula = await this._pyCall('convert_formula_completion', completion);
const action: DocAction = ["ModifyColumn", tableId, colId, {formula}];
return {
suggestedActions: [action],
};
}
public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> { public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options); return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
} }

View File

@ -0,0 +1,110 @@
/**
* Module with functions used for AI formula assistance.
*/
import {delay} from 'app/common/delay';
import log from 'app/server/lib/log';
import fetch, { Response as FetchResponse} from 'node-fetch';
export async function sendForCompletion(prompt: string): Promise<string> {
let completion: string|null = null;
if (process.env.OPENAI_API_KEY) {
completion = await sendForCompletionOpenAI(prompt);
}
if (process.env.HUGGINGFACE_API_KEY) {
completion = await sendForCompletionHuggingFace(prompt);
}
if (completion === null) {
throw new Error("Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY (and optionally COMPLETION_MODEL)");
}
log.debug(`Received completion:`, {completion});
completion = completion.split(/\n {4}[^ ]/)[0];
return completion;
}
async function sendForCompletionOpenAI(prompt: string) {
const apiKey = process.env.OPENAI_API_KEY;
if (!apiKey) {
throw new Error("OPENAI_API_KEY not set");
}
const response = await fetch(
"https://api.openai.com/v1/completions",
{
method: "POST",
headers: {
"Authorization": `Bearer ${apiKey}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
prompt,
max_tokens: 150,
temperature: 0,
// COMPLETION_MODEL of `code-davinci-002` may be better if you have access to it.
model: process.env.COMPLETION_MODEL || "text-davinci-002",
stop: ["\n\n"],
}),
},
);
if (response.status !== 200) {
log.error(`OpenAI API returned ${response.status}: ${await response.text()}`);
throw new Error(`OpenAI API returned status ${response.status}`);
}
const result = await response.json();
const completion = result.choices[0].text;
return completion;
}
async function sendForCompletionHuggingFace(prompt: string) {
const apiKey = process.env.HUGGINGFACE_API_KEY;
if (!apiKey) {
throw new Error("HUGGINGFACE_API_KEY not set");
}
// COMPLETION_MODEL values I've tried:
// - codeparrot/codeparrot
// - NinedayWang/PolyCoder-2.7B
// - NovelAI/genji-python-6B
let completionUrl = process.env.COMPLETION_URL;
if (!completionUrl) {
if (process.env.COMPLETION_MODEL) {
completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`;
} else {
completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
}
}
let retries: number = 0;
let response!: FetchResponse;
while (retries++ < 3) {
response = await fetch(
completionUrl,
{
method: "POST",
headers: {
"Authorization": `Bearer ${apiKey}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
inputs: prompt,
parameters: {
return_full_text: false,
max_new_tokens: 50,
},
}),
},
);
if (response.status === 503) {
log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`);
await delay(10000);
continue;
}
}
if (response.status !== 200) {
const text = await response.text();
log.error(`HuggingFace API returned ${response.status}: ${text}`);
throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
}
const result = await response.json();
const completion = result[0].generated_text;
return completion.split('\n\n')[0];
}

View File

@ -110,6 +110,7 @@ export class DocWorker {
applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'), applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'),
findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'), findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'),
getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'), getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'),
getAssistance: activeDocMethod.bind(null, 'editors', 'getAssistance'),
importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'), importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'),
finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'), finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'),
cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'), cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'),

35
documentation/llm.md Normal file
View File

@ -0,0 +1,35 @@
# Using Large Language Models with Grist
In this experimental Grist feature, originally developed by Alex Hall,
you can hook up an AI model such as OpenAI's Codex to write formulas for
you. Here's how.
First, you need an API key. You'll have best results currently with an
OpenAI model. Visit https://openai.com/api/ and prepare a key, then
store it in an environment variable `OPENAI_API_KEY`.
Alternatively, there are many non-proprietary models hosted on Hugging Face.
At the time of writing, none can compare with OpenAI for use with Grist.
Things can change quickly in the world of AI though. So instead of OpenAI,
you can visit https://huggingface.co/ and prepare a key, then
store it in an environment variable `HUGGINGFACE_API_KEY`.
That's all the configuration needed!
Currently it is only a backend feature, we are still working on the UI for it.
## Trying other models
The model used will default to `text-davinci-002` for OpenAI. You can
get better results by setting an environment variable `COMPLETION_MODEL` to
`code-davinci-002` if you have access to that model.
The model used will default to `NovelAI/genji-python-6B` for
Hugging Face. There's no particularly great model for this application,
but you can try other models by setting an environment variable
`COMPLETION_MODEL` to `codeparrot/codeparrot` or
`NinedayWang/PolyCoder-2.7B` or similar.
If you are hosting a model yourself, host it as Hugging Face does,
and use `COMPLETION_URL` rather than `COMPLETION_MODEL` to
point to the model on your own server rather than Hugging Face.

View File

@ -0,0 +1,186 @@
import json
import textwrap
import six
from column import is_visible_column, BaseReferenceColumn
from objtypes import RaisedException
import records
def column_type(engine, table_id, col_id):
col_rec = engine.docmodel.get_column_rec(table_id, col_id)
typ = col_rec.type
parts = typ.split(":")
if parts[0] == "Ref":
return parts[1]
elif parts[0] == "RefList":
return "List[{}]".format(parts[1])
elif typ == "Choice":
return choices(col_rec)
elif typ == "ChoiceList":
return "Tuple[{}, ...]".format(choices(col_rec))
elif typ == "Any":
table = engine.tables[table_id]
col = table.get_column(col_id)
values = [col.raw_get(row_id) for row_id in table.row_ids]
return values_type(values)
else:
return dict(
Text="str",
Numeric="float",
Int="int",
Bool="bool",
Date="datetime.date",
DateTime="datetime.datetime",
Any="Any",
Attachments="Any",
)[parts[0]]
def choices(col_rec):
try:
widget_options = json.loads(col_rec.widgetOptions)
return "Literal{}".format(widget_options["choices"])
except (ValueError, KeyError):
return 'str'
def values_type(values):
types = set(type(v) for v in values) - {RaisedException}
optional = type(None) in types # pylint: disable=unidiomatic-typecheck
types.discard(type(None))
if types == {int, float}:
types = {float}
if len(types) != 1:
return "Any"
[typ] = types
val = next(v for v in values if isinstance(v, typ))
if isinstance(val, records.Record):
type_name = val._table.table_id
elif isinstance(val, records.RecordSet):
type_name = "List[{}]".format(val._table.table_id)
elif isinstance(val, list):
type_name = "List[{}]".format(values_type(val))
elif isinstance(val, set):
type_name = "Set[{}]".format(values_type(val))
elif isinstance(val, tuple):
type_name = "Tuple[{}, ...]".format(values_type(val))
elif isinstance(val, dict):
type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
else:
type_name = typ.__name__
if optional:
type_name = "Optional[{}]".format(type_name)
return type_name
def referenced_tables(engine, table_id):
result = set()
queue = [table_id]
while queue:
cur_table_id = queue.pop()
if cur_table_id in result:
continue
result.add(cur_table_id)
for col_id, col in visible_columns(engine, cur_table_id):
if isinstance(col, BaseReferenceColumn):
target_table_id = col._target_table.table_id
if not target_table_id.startswith("_"):
queue.append(target_table_id)
return result - {table_id}
def all_other_tables(engine, table_id):
result = set(t for t in engine.tables.keys() if not t.startswith('_grist'))
return result - {table_id} - {'GristDocTour'}
def visible_columns(engine, table_id):
return [
(col_id, col)
for col_id, col in engine.tables[table_id].all_columns.items()
if is_visible_column(col_id)
]
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
result = "@dataclass\nclass {}:\n".format(table_id)
if lookups:
# Build a lookupRecords and lookupOne method for each table, providing some arguments hints
# for the columns that are visible.
lookupRecords_args = []
lookupOne_args = []
for col_id, col in visible_columns(engine, table_id):
if col_id != exclude_col_id:
lookupOne_args.append(col_id + '=None')
lookupRecords_args.append('%s=%s' % (col_id, col_id))
lookupOne_args.append('sort_by=None')
lookupRecords_args.append('sort_by=sort_by')
lookupOne_args_line = ', '.join(lookupOne_args)
lookupRecords_args_line = ', '.join(lookupRecords_args)
result += " def __len__(self):\n"
result += " return len(%s.lookupRecords())\n" % table_id
result += " @staticmethod\n"
result += " def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id)
result += " # ...\n"
result += " @staticmethod\n"
result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
result += " '''\n"
result += " Filter for one result matching the keys provided.\n"
result += " To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n"
result += " '''\n"
result += " return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line)
result += "\n"
for col_id, col in visible_columns(engine, table_id):
if col_id != exclude_col_id:
result += " {}: {}\n".format(col_id, column_type(engine, table_id, col_id))
result += "\n"
return result
def get_formula_prompt(engine, table_id, col_id, description,
include_all_tables=True,
lookups=True):
result = ""
other_tables = (all_other_tables(engine, table_id)
if include_all_tables else referenced_tables(engine, table_id))
for other_table_id in sorted(other_tables):
result += class_schema(engine, other_table_id, lookups)
result += class_schema(engine, table_id, col_id, lookups)
return_type = column_type(engine, table_id, col_id)
result += " @property\n"
result += " # rec is alias for self\n"
result += " def {}(rec) -> {}:\n".format(col_id, return_type)
result += ' """\n'
result += '{}\n'.format(indent(description, " "))
result += ' """\n'
return result
def indent(text, prefix, predicate=None):
"""
Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility.
"""
if six.PY3:
return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member
if predicate is None:
def predicate(line):
return line.strip()
def prefixed_lines():
for line in text.splitlines(True):
yield (prefix + line if predicate(line) else line)
return ''.join(prefixed_lines())
def convert_completion(completion):
result = textwrap.dedent(completion)
return result

View File

@ -16,6 +16,7 @@ import six
import actions import actions
import engine import engine
import formula_prompt
import migrations import migrations
import schema import schema
import useractions import useractions
@ -135,6 +136,14 @@ def run(sandbox):
def get_formula_error(table_id, col_id, row_id): def get_formula_error(table_id, col_id, row_id):
return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id)) return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id))
@export
def get_formula_prompt(table_id, col_id, description):
return formula_prompt.get_formula_prompt(eng, table_id, col_id, description)
@export
def convert_formula_completion(completion):
return formula_prompt.convert_completion(completion)
export(parse_acl_formula) export(parse_acl_formula)
export(eng.load_empty) export(eng.load_empty)
export(eng.load_done) export(eng.load_done)

View File

@ -0,0 +1,217 @@
import unittest
import six
import test_engine
import testutil
from formula_prompt import (
values_type, column_type, referenced_tables, get_formula_prompt,
)
from objtypes import RaisedException
from records import Record, RecordSet
class FakeTable(object):
def __init__(self):
pass
table_id = "Table1"
_identity_relation = None
@unittest.skipUnless(six.PY3, "Python 3 only")
class TestFormulaPrompt(test_engine.EngineTestCase):
def test_values_type(self):
self.assertEqual(values_type([1, 2, 3]), "int")
self.assertEqual(values_type([1.0, 2.0, 3.0]), "float")
self.assertEqual(values_type([1, 2, 3.0]), "float")
self.assertEqual(values_type([1, 2, None]), "Optional[int]")
self.assertEqual(values_type([1, 2, 3.0, None]), "Optional[float]")
self.assertEqual(values_type([1, RaisedException(None), 3]), "int")
self.assertEqual(values_type([1, RaisedException(None), None]), "Optional[int]")
self.assertEqual(values_type(["1", "2", "3"]), "str")
self.assertEqual(values_type([1, 2, "3"]), "Any")
self.assertEqual(values_type([1, 2, "3", None]), "Any")
self.assertEqual(values_type([
Record(FakeTable(), None),
Record(FakeTable(), None),
]), "Table1")
self.assertEqual(values_type([
Record(FakeTable(), None),
Record(FakeTable(), None),
None,
]), "Optional[Table1]")
self.assertEqual(values_type([
RecordSet(FakeTable(), None),
RecordSet(FakeTable(), None),
]), "List[Table1]")
self.assertEqual(values_type([
RecordSet(FakeTable(), None),
RecordSet(FakeTable(), None),
None,
]), "Optional[List[Table1]]")
self.assertEqual(values_type([[1, 2, 3]]), "List[int]")
self.assertEqual(values_type([[1, 2, 3], None]), "Optional[List[int]]")
self.assertEqual(values_type([[1, 2, None]]), "List[Optional[int]]")
self.assertEqual(values_type([[1, 2, None], None]), "Optional[List[Optional[int]]]")
self.assertEqual(values_type([[1, 2, "3"]]), "List[Any]")
self.assertEqual(values_type([{1, 2, 3}]), "Set[int]")
self.assertEqual(values_type([(1, 2, 3)]), "Tuple[int, ...]")
self.assertEqual(values_type([{1: ["2"]}]), "Dict[int, List[str]]")
def assert_column_type(self, col_id, expected_type):
self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type)
def assert_prompt(self, table_name, col_id, expected_prompt):
prompt = get_formula_prompt(self.engine, table_name, col_id, "description here",
include_all_tables=False, lookups=False)
# print(prompt)
self.assertEqual(prompt, expected_prompt)
def test_column_type(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table2", [
[1, "text", "Text", False, "", "", ""],
[2, "numeric", "Numeric", False, "", "", ""],
[3, "int", "Int", False, "", "", ""],
[4, "bool", "Bool", False, "", "", ""],
[5, "date", "Date", False, "", "", ""],
[6, "datetime", "DateTime", False, "", "", ""],
[7, "attachments", "Attachments", False, "", "", ""],
[8, "ref", "Ref:Table2", False, "", "", ""],
[9, "reflist", "RefList:Table2", False, "", "", ""],
[10, "choice", "Choice", False, "", "", '{"choices": ["a", "b", "c"]}'],
[11, "choicelist", "ChoiceList", False, "", "", '{"choices": ["x", "y", "z"]}'],
[12, "ref_formula", "Any", True, "$ref or None", "", ""],
[13, "numeric_formula", "Any", True, "1 / $numeric", "", ""],
[14, "new_formula", "Numeric", True, "'to be generated...'", "", ""],
]],
],
"DATA": {
"Table2": [
["id", "numeric", "ref"],
[1, 0, 0],
[2, 1, 1],
],
},
})
self.load_sample(sample)
self.assert_column_type("text", "str")
self.assert_column_type("numeric", "float")
self.assert_column_type("int", "int")
self.assert_column_type("bool", "bool")
self.assert_column_type("date", "datetime.date")
self.assert_column_type("datetime", "datetime.datetime")
self.assert_column_type("attachments", "Any")
self.assert_column_type("ref", "Table2")
self.assert_column_type("reflist", "List[Table2]")
self.assert_column_type("choice", "Literal['a', 'b', 'c']")
self.assert_column_type("choicelist", "Tuple[Literal['x', 'y', 'z'], ...]")
self.assert_column_type("ref_formula", "Optional[Table2]")
self.assert_column_type("numeric_formula", "float")
self.assertEqual(referenced_tables(self.engine, "Table2"), set())
self.assert_prompt("Table2", "new_formula",
'''\
@dataclass
class Table2:
text: str
numeric: float
int: int
bool: bool
date: datetime.date
datetime: datetime.datetime
attachments: Any
ref: Table2
reflist: List[Table2]
choice: Literal['a', 'b', 'c']
choicelist: Tuple[Literal['x', 'y', 'z'], ...]
ref_formula: Optional[Table2]
numeric_formula: float
@property
# rec is alias for self
def new_formula(rec) -> float:
"""
description here
"""
''')
def test_get_formula_prompt(self):
sample = testutil.parse_test_sample({
"SCHEMA": [
[1, "Table1", [
[1, "text", "Text", False, "", "", ""],
]],
[2, "Table2", [
[2, "ref", "Ref:Table1", False, "", "", ""],
]],
[3, "Table3", [
[3, "reflist", "RefList:Table2", False, "", "", ""],
]],
],
"DATA": {},
})
self.load_sample(sample)
self.assertEqual(referenced_tables(self.engine, "Table3"), {"Table1", "Table2"})
self.assertEqual(referenced_tables(self.engine, "Table2"), {"Table1"})
self.assertEqual(referenced_tables(self.engine, "Table1"), set())
self.assert_prompt("Table1", "text", '''\
@dataclass
class Table1:
@property
# rec is alias for self
def text(rec) -> str:
"""
description here
"""
''')
self.assert_prompt("Table2", "ref", '''\
@dataclass
class Table1:
text: str
@dataclass
class Table2:
@property
# rec is alias for self
def ref(rec) -> Table1:
"""
description here
"""
''')
self.assert_prompt("Table3", "reflist", '''\
@dataclass
class Table1:
text: str
@dataclass
class Table2:
ref: Table1
@dataclass
class Table3:
@property
# rec is alias for self
def reflist(rec) -> List[Table2]:
"""
description here
"""
''')