diff --git a/app/client/components/DocComm.ts b/app/client/components/DocComm.ts index fc211c24..1a0dfe37 100644 --- a/app/client/components/DocComm.ts +++ b/app/client/components/DocComm.ts @@ -32,6 +32,7 @@ export class DocComm extends Disposable implements ActiveDocAPI { public addAttachments = this._wrapMethod("addAttachments"); public findColFromValues = this._wrapMethod("findColFromValues"); public getFormulaError = this._wrapMethod("getFormulaError"); + public getAssistance = this._wrapMethod("getAssistance"); public fetchURL = this._wrapMethod("fetchURL"); public autocomplete = this._wrapMethod("autocomplete"); public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc"); diff --git a/app/common/ActiveDocAPI.ts b/app/common/ActiveDocAPI.ts index 94ad3cd6..e2ae0dcc 100644 --- a/app/common/ActiveDocAPI.ts +++ b/app/common/ActiveDocAPI.ts @@ -318,6 +318,11 @@ export interface ActiveDocAPI { */ getFormulaError(tableId: string, colId: string, rowId: number): Promise; + /** + * Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula. + */ + getAssistance(tableId: string, colId: string, description: string): Promise; + /** * Fetch content at a url. */ diff --git a/app/common/AssistancePrompts.ts b/app/common/AssistancePrompts.ts new file mode 100644 index 00000000..c0ee269e --- /dev/null +++ b/app/common/AssistancePrompts.ts @@ -0,0 +1,11 @@ +import {DocAction} from 'app/common/DocActions'; + +export interface Prompt { + tableId: string; + colId: string + description: string; +} + +export interface Suggestion { + suggestedActions: DocAction[]; +} diff --git a/app/server/lib/ActiveDoc.ts b/app/server/lib/ActiveDoc.ts index d8beb8f3..03b8a61c 100644 --- a/app/server/lib/ActiveDoc.ts +++ b/app/server/lib/ActiveDoc.ts @@ -14,6 +14,7 @@ import { } from 'app/common/ActionBundle'; import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup'; import {ActionSummary} from "app/common/ActionSummary"; +import {Prompt, Suggestion} from "app/common/AssistancePrompts"; import { AclResources, AclTableDescription, @@ -80,6 +81,7 @@ import {parseUserAction} from 'app/common/ValueParser'; import {ParseOptions} from 'app/plugin/FileParserAPI'; import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI'; import {compileAclFormula} from 'app/server/lib/ACLFormula'; +import {sendForCompletion} from 'app/server/lib/Assistance'; import {Authorizer} from 'app/server/lib/Authorizer'; import {checksumFile} from 'app/server/lib/checksumFile'; import {Client} from 'app/server/lib/Client'; @@ -1313,6 +1315,24 @@ export class ActiveDoc extends EventEmitter { return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON()); } + public async getAssistance(docSession: DocSession, userPrompt: Prompt): Promise { + // Making a prompt can leak names of tables and columns. + if (!await this._granularAccess.canScanData(docSession)) { + throw new Error("Permission denied"); + } + await this.waitForInitialization(); + const { tableId, colId, description } = userPrompt; + const prompt = await this._pyCall('get_formula_prompt', tableId, colId, description); + this._log.debug(docSession, 'getAssistance prompt', {prompt}); + const completion = await sendForCompletion(prompt); + this._log.debug(docSession, 'getAssistance completion', {completion}); + const formula = await this._pyCall('convert_formula_completion', completion); + const action: DocAction = ["ModifyColumn", tableId, colId, {formula}]; + return { + suggestedActions: [action], + }; + } + public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise { return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options); } diff --git a/app/server/lib/Assistance.ts b/app/server/lib/Assistance.ts new file mode 100644 index 00000000..ce994326 --- /dev/null +++ b/app/server/lib/Assistance.ts @@ -0,0 +1,110 @@ +/** + * Module with functions used for AI formula assistance. + */ + +import {delay} from 'app/common/delay'; +import log from 'app/server/lib/log'; +import fetch, { Response as FetchResponse} from 'node-fetch'; + + +export async function sendForCompletion(prompt: string): Promise { + let completion: string|null = null; + if (process.env.OPENAI_API_KEY) { + completion = await sendForCompletionOpenAI(prompt); + } + if (process.env.HUGGINGFACE_API_KEY) { + completion = await sendForCompletionHuggingFace(prompt); + } + if (completion === null) { + throw new Error("Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY (and optionally COMPLETION_MODEL)"); + } + log.debug(`Received completion:`, {completion}); + completion = completion.split(/\n {4}[^ ]/)[0]; + return completion; +} + + +async function sendForCompletionOpenAI(prompt: string) { + const apiKey = process.env.OPENAI_API_KEY; + if (!apiKey) { + throw new Error("OPENAI_API_KEY not set"); + } + const response = await fetch( + "https://api.openai.com/v1/completions", + { + method: "POST", + headers: { + "Authorization": `Bearer ${apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + prompt, + max_tokens: 150, + temperature: 0, + // COMPLETION_MODEL of `code-davinci-002` may be better if you have access to it. + model: process.env.COMPLETION_MODEL || "text-davinci-002", + stop: ["\n\n"], + }), + }, + ); + if (response.status !== 200) { + log.error(`OpenAI API returned ${response.status}: ${await response.text()}`); + throw new Error(`OpenAI API returned status ${response.status}`); + } + const result = await response.json(); + const completion = result.choices[0].text; + return completion; +} + +async function sendForCompletionHuggingFace(prompt: string) { + const apiKey = process.env.HUGGINGFACE_API_KEY; + if (!apiKey) { + throw new Error("HUGGINGFACE_API_KEY not set"); + } + // COMPLETION_MODEL values I've tried: + // - codeparrot/codeparrot + // - NinedayWang/PolyCoder-2.7B + // - NovelAI/genji-python-6B + let completionUrl = process.env.COMPLETION_URL; + if (!completionUrl) { + if (process.env.COMPLETION_MODEL) { + completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`; + } else { + completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B'; + } + } + let retries: number = 0; + let response!: FetchResponse; + while (retries++ < 3) { + response = await fetch( + completionUrl, + { + method: "POST", + headers: { + "Authorization": `Bearer ${apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + inputs: prompt, + parameters: { + return_full_text: false, + max_new_tokens: 50, + }, + }), + }, + ); + if (response.status === 503) { + log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`); + await delay(10000); + continue; + } + } + if (response.status !== 200) { + const text = await response.text(); + log.error(`HuggingFace API returned ${response.status}: ${text}`); + throw new Error(`HuggingFace API returned status ${response.status}: ${text}`); + } + const result = await response.json(); + const completion = result[0].generated_text; + return completion.split('\n\n')[0]; +} diff --git a/app/server/lib/DocWorker.ts b/app/server/lib/DocWorker.ts index 180fa2e5..3a870dd5 100644 --- a/app/server/lib/DocWorker.ts +++ b/app/server/lib/DocWorker.ts @@ -110,6 +110,7 @@ export class DocWorker { applyUserActionsById: activeDocMethod.bind(null, 'editors', 'applyUserActionsById'), findColFromValues: activeDocMethod.bind(null, 'viewers', 'findColFromValues'), getFormulaError: activeDocMethod.bind(null, 'viewers', 'getFormulaError'), + getAssistance: activeDocMethod.bind(null, 'editors', 'getAssistance'), importFiles: activeDocMethod.bind(null, 'editors', 'importFiles'), finishImportFiles: activeDocMethod.bind(null, 'editors', 'finishImportFiles'), cancelImportFiles: activeDocMethod.bind(null, 'editors', 'cancelImportFiles'), diff --git a/documentation/llm.md b/documentation/llm.md new file mode 100644 index 00000000..3a6b880d --- /dev/null +++ b/documentation/llm.md @@ -0,0 +1,35 @@ +# Using Large Language Models with Grist + +In this experimental Grist feature, originally developed by Alex Hall, +you can hook up an AI model such as OpenAI's Codex to write formulas for +you. Here's how. + +First, you need an API key. You'll have best results currently with an +OpenAI model. Visit https://openai.com/api/ and prepare a key, then +store it in an environment variable `OPENAI_API_KEY`. + +Alternatively, there are many non-proprietary models hosted on Hugging Face. +At the time of writing, none can compare with OpenAI for use with Grist. +Things can change quickly in the world of AI though. So instead of OpenAI, +you can visit https://huggingface.co/ and prepare a key, then +store it in an environment variable `HUGGINGFACE_API_KEY`. + +That's all the configuration needed! + +Currently it is only a backend feature, we are still working on the UI for it. + +## Trying other models + +The model used will default to `text-davinci-002` for OpenAI. You can +get better results by setting an environment variable `COMPLETION_MODEL` to +`code-davinci-002` if you have access to that model. + +The model used will default to `NovelAI/genji-python-6B` for +Hugging Face. There's no particularly great model for this application, +but you can try other models by setting an environment variable +`COMPLETION_MODEL` to `codeparrot/codeparrot` or +`NinedayWang/PolyCoder-2.7B` or similar. + +If you are hosting a model yourself, host it as Hugging Face does, +and use `COMPLETION_URL` rather than `COMPLETION_MODEL` to +point to the model on your own server rather than Hugging Face. diff --git a/sandbox/grist/formula_prompt.py b/sandbox/grist/formula_prompt.py new file mode 100644 index 00000000..045013db --- /dev/null +++ b/sandbox/grist/formula_prompt.py @@ -0,0 +1,186 @@ +import json +import textwrap + +import six + +from column import is_visible_column, BaseReferenceColumn +from objtypes import RaisedException +import records + + +def column_type(engine, table_id, col_id): + col_rec = engine.docmodel.get_column_rec(table_id, col_id) + typ = col_rec.type + parts = typ.split(":") + if parts[0] == "Ref": + return parts[1] + elif parts[0] == "RefList": + return "List[{}]".format(parts[1]) + elif typ == "Choice": + return choices(col_rec) + elif typ == "ChoiceList": + return "Tuple[{}, ...]".format(choices(col_rec)) + elif typ == "Any": + table = engine.tables[table_id] + col = table.get_column(col_id) + values = [col.raw_get(row_id) for row_id in table.row_ids] + return values_type(values) + else: + return dict( + Text="str", + Numeric="float", + Int="int", + Bool="bool", + Date="datetime.date", + DateTime="datetime.datetime", + Any="Any", + Attachments="Any", + )[parts[0]] + + +def choices(col_rec): + try: + widget_options = json.loads(col_rec.widgetOptions) + return "Literal{}".format(widget_options["choices"]) + except (ValueError, KeyError): + return 'str' + + +def values_type(values): + types = set(type(v) for v in values) - {RaisedException} + optional = type(None) in types # pylint: disable=unidiomatic-typecheck + types.discard(type(None)) + + if types == {int, float}: + types = {float} + + if len(types) != 1: + return "Any" + + [typ] = types + val = next(v for v in values if isinstance(v, typ)) + + if isinstance(val, records.Record): + type_name = val._table.table_id + elif isinstance(val, records.RecordSet): + type_name = "List[{}]".format(val._table.table_id) + elif isinstance(val, list): + type_name = "List[{}]".format(values_type(val)) + elif isinstance(val, set): + type_name = "Set[{}]".format(values_type(val)) + elif isinstance(val, tuple): + type_name = "Tuple[{}, ...]".format(values_type(val)) + elif isinstance(val, dict): + type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values())) + else: + type_name = typ.__name__ + + if optional: + type_name = "Optional[{}]".format(type_name) + + return type_name + + +def referenced_tables(engine, table_id): + result = set() + queue = [table_id] + while queue: + cur_table_id = queue.pop() + if cur_table_id in result: + continue + result.add(cur_table_id) + for col_id, col in visible_columns(engine, cur_table_id): + if isinstance(col, BaseReferenceColumn): + target_table_id = col._target_table.table_id + if not target_table_id.startswith("_"): + queue.append(target_table_id) + return result - {table_id} + +def all_other_tables(engine, table_id): + result = set(t for t in engine.tables.keys() if not t.startswith('_grist')) + return result - {table_id} - {'GristDocTour'} + +def visible_columns(engine, table_id): + return [ + (col_id, col) + for col_id, col in engine.tables[table_id].all_columns.items() + if is_visible_column(col_id) + ] + + +def class_schema(engine, table_id, exclude_col_id=None, lookups=False): + result = "@dataclass\nclass {}:\n".format(table_id) + + if lookups: + + # Build a lookupRecords and lookupOne method for each table, providing some arguments hints + # for the columns that are visible. + lookupRecords_args = [] + lookupOne_args = [] + for col_id, col in visible_columns(engine, table_id): + if col_id != exclude_col_id: + lookupOne_args.append(col_id + '=None') + lookupRecords_args.append('%s=%s' % (col_id, col_id)) + lookupOne_args.append('sort_by=None') + lookupRecords_args.append('sort_by=sort_by') + lookupOne_args_line = ', '.join(lookupOne_args) + lookupRecords_args_line = ', '.join(lookupRecords_args) + + result += " def __len__(self):\n" + result += " return len(%s.lookupRecords())\n" % table_id + result += " @staticmethod\n" + result += " def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id) + result += " # ...\n" + result += " @staticmethod\n" + result += " def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id) + result += " '''\n" + result += " Filter for one result matching the keys provided.\n" + result += " To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n" + result += " '''\n" + result += " return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line) + result += "\n" + + for col_id, col in visible_columns(engine, table_id): + if col_id != exclude_col_id: + result += " {}: {}\n".format(col_id, column_type(engine, table_id, col_id)) + result += "\n" + return result + + +def get_formula_prompt(engine, table_id, col_id, description, + include_all_tables=True, + lookups=True): + result = "" + other_tables = (all_other_tables(engine, table_id) + if include_all_tables else referenced_tables(engine, table_id)) + for other_table_id in sorted(other_tables): + result += class_schema(engine, other_table_id, lookups) + + result += class_schema(engine, table_id, col_id, lookups) + + return_type = column_type(engine, table_id, col_id) + result += " @property\n" + result += " # rec is alias for self\n" + result += " def {}(rec) -> {}:\n".format(col_id, return_type) + result += ' """\n' + result += '{}\n'.format(indent(description, " ")) + result += ' """\n' + return result + +def indent(text, prefix, predicate=None): + """ + Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility. + """ + if six.PY3: + return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member + if predicate is None: + def predicate(line): + return line.strip() + def prefixed_lines(): + for line in text.splitlines(True): + yield (prefix + line if predicate(line) else line) + return ''.join(prefixed_lines()) + +def convert_completion(completion): + result = textwrap.dedent(completion) + return result diff --git a/sandbox/grist/main.py b/sandbox/grist/main.py index 89b2f499..952efd66 100644 --- a/sandbox/grist/main.py +++ b/sandbox/grist/main.py @@ -16,6 +16,7 @@ import six import actions import engine +import formula_prompt import migrations import schema import useractions @@ -135,6 +136,14 @@ def run(sandbox): def get_formula_error(table_id, col_id, row_id): return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id)) + @export + def get_formula_prompt(table_id, col_id, description): + return formula_prompt.get_formula_prompt(eng, table_id, col_id, description) + + @export + def convert_formula_completion(completion): + return formula_prompt.convert_completion(completion) + export(parse_acl_formula) export(eng.load_empty) export(eng.load_done) diff --git a/sandbox/grist/test_formula_prompt.py b/sandbox/grist/test_formula_prompt.py new file mode 100644 index 00000000..6cceb1db --- /dev/null +++ b/sandbox/grist/test_formula_prompt.py @@ -0,0 +1,217 @@ +import unittest +import six + +import test_engine +import testutil + +from formula_prompt import ( + values_type, column_type, referenced_tables, get_formula_prompt, +) +from objtypes import RaisedException +from records import Record, RecordSet + + +class FakeTable(object): + + def __init__(self): + pass + + table_id = "Table1" + _identity_relation = None + + +@unittest.skipUnless(six.PY3, "Python 3 only") +class TestFormulaPrompt(test_engine.EngineTestCase): + def test_values_type(self): + self.assertEqual(values_type([1, 2, 3]), "int") + self.assertEqual(values_type([1.0, 2.0, 3.0]), "float") + self.assertEqual(values_type([1, 2, 3.0]), "float") + + self.assertEqual(values_type([1, 2, None]), "Optional[int]") + self.assertEqual(values_type([1, 2, 3.0, None]), "Optional[float]") + + self.assertEqual(values_type([1, RaisedException(None), 3]), "int") + self.assertEqual(values_type([1, RaisedException(None), None]), "Optional[int]") + + self.assertEqual(values_type(["1", "2", "3"]), "str") + self.assertEqual(values_type([1, 2, "3"]), "Any") + self.assertEqual(values_type([1, 2, "3", None]), "Any") + + self.assertEqual(values_type([ + Record(FakeTable(), None), + Record(FakeTable(), None), + ]), "Table1") + self.assertEqual(values_type([ + Record(FakeTable(), None), + Record(FakeTable(), None), + None, + ]), "Optional[Table1]") + + self.assertEqual(values_type([ + RecordSet(FakeTable(), None), + RecordSet(FakeTable(), None), + ]), "List[Table1]") + self.assertEqual(values_type([ + RecordSet(FakeTable(), None), + RecordSet(FakeTable(), None), + None, + ]), "Optional[List[Table1]]") + + self.assertEqual(values_type([[1, 2, 3]]), "List[int]") + self.assertEqual(values_type([[1, 2, 3], None]), "Optional[List[int]]") + self.assertEqual(values_type([[1, 2, None]]), "List[Optional[int]]") + self.assertEqual(values_type([[1, 2, None], None]), "Optional[List[Optional[int]]]") + self.assertEqual(values_type([[1, 2, "3"]]), "List[Any]") + + self.assertEqual(values_type([{1, 2, 3}]), "Set[int]") + self.assertEqual(values_type([(1, 2, 3)]), "Tuple[int, ...]") + self.assertEqual(values_type([{1: ["2"]}]), "Dict[int, List[str]]") + + def assert_column_type(self, col_id, expected_type): + self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type) + + def assert_prompt(self, table_name, col_id, expected_prompt): + prompt = get_formula_prompt(self.engine, table_name, col_id, "description here", + include_all_tables=False, lookups=False) + # print(prompt) + self.assertEqual(prompt, expected_prompt) + + def test_column_type(self): + sample = testutil.parse_test_sample({ + "SCHEMA": [ + [1, "Table2", [ + [1, "text", "Text", False, "", "", ""], + [2, "numeric", "Numeric", False, "", "", ""], + [3, "int", "Int", False, "", "", ""], + [4, "bool", "Bool", False, "", "", ""], + [5, "date", "Date", False, "", "", ""], + [6, "datetime", "DateTime", False, "", "", ""], + [7, "attachments", "Attachments", False, "", "", ""], + [8, "ref", "Ref:Table2", False, "", "", ""], + [9, "reflist", "RefList:Table2", False, "", "", ""], + [10, "choice", "Choice", False, "", "", '{"choices": ["a", "b", "c"]}'], + [11, "choicelist", "ChoiceList", False, "", "", '{"choices": ["x", "y", "z"]}'], + [12, "ref_formula", "Any", True, "$ref or None", "", ""], + [13, "numeric_formula", "Any", True, "1 / $numeric", "", ""], + [14, "new_formula", "Numeric", True, "'to be generated...'", "", ""], + ]], + ], + "DATA": { + "Table2": [ + ["id", "numeric", "ref"], + [1, 0, 0], + [2, 1, 1], + ], + }, + }) + self.load_sample(sample) + + self.assert_column_type("text", "str") + self.assert_column_type("numeric", "float") + self.assert_column_type("int", "int") + self.assert_column_type("bool", "bool") + self.assert_column_type("date", "datetime.date") + self.assert_column_type("datetime", "datetime.datetime") + self.assert_column_type("attachments", "Any") + self.assert_column_type("ref", "Table2") + self.assert_column_type("reflist", "List[Table2]") + self.assert_column_type("choice", "Literal['a', 'b', 'c']") + self.assert_column_type("choicelist", "Tuple[Literal['x', 'y', 'z'], ...]") + self.assert_column_type("ref_formula", "Optional[Table2]") + self.assert_column_type("numeric_formula", "float") + + self.assertEqual(referenced_tables(self.engine, "Table2"), set()) + + self.assert_prompt("Table2", "new_formula", + '''\ +@dataclass +class Table2: + text: str + numeric: float + int: int + bool: bool + date: datetime.date + datetime: datetime.datetime + attachments: Any + ref: Table2 + reflist: List[Table2] + choice: Literal['a', 'b', 'c'] + choicelist: Tuple[Literal['x', 'y', 'z'], ...] + ref_formula: Optional[Table2] + numeric_formula: float + + @property + # rec is alias for self + def new_formula(rec) -> float: + """ + description here + """ +''') + + def test_get_formula_prompt(self): + sample = testutil.parse_test_sample({ + "SCHEMA": [ + [1, "Table1", [ + [1, "text", "Text", False, "", "", ""], + ]], + [2, "Table2", [ + [2, "ref", "Ref:Table1", False, "", "", ""], + ]], + [3, "Table3", [ + [3, "reflist", "RefList:Table2", False, "", "", ""], + ]], + ], + "DATA": {}, + }) + self.load_sample(sample) + self.assertEqual(referenced_tables(self.engine, "Table3"), {"Table1", "Table2"}) + self.assertEqual(referenced_tables(self.engine, "Table2"), {"Table1"}) + self.assertEqual(referenced_tables(self.engine, "Table1"), set()) + + self.assert_prompt("Table1", "text", '''\ +@dataclass +class Table1: + + @property + # rec is alias for self + def text(rec) -> str: + """ + description here + """ +''') + + self.assert_prompt("Table2", "ref", '''\ +@dataclass +class Table1: + text: str + +@dataclass +class Table2: + + @property + # rec is alias for self + def ref(rec) -> Table1: + """ + description here + """ +''') + + self.assert_prompt("Table3", "reflist", '''\ +@dataclass +class Table1: + text: str + +@dataclass +class Table2: + ref: Table1 + +@dataclass +class Table3: + + @property + # rec is alias for self + def reflist(rec) -> List[Table2]: + """ + description here + """ +''')