mirror of
				https://github.com/gristlabs/grist-core.git
				synced 2025-06-13 20:53:59 +00:00 
			
		
		
		
	(core) Porting back AI formula backend
Summary: This is a backend part for the formula AI. Test Plan: New tests Reviewers: paulfitz Reviewed By: paulfitz Subscribers: cyprien Differential Revision: https://phab.getgrist.com/D3786
This commit is contained in:
		
							parent
							
								
									ef0a55ced1
								
							
						
					
					
						commit
						6e3f0f2b35
					
				@ -32,6 +32,7 @@ export class DocComm extends Disposable implements ActiveDocAPI {
 | 
			
		||||
  public addAttachments = this._wrapMethod("addAttachments");
 | 
			
		||||
  public findColFromValues = this._wrapMethod("findColFromValues");
 | 
			
		||||
  public getFormulaError = this._wrapMethod("getFormulaError");
 | 
			
		||||
  public getAssistance = this._wrapMethod("getAssistance");
 | 
			
		||||
  public fetchURL = this._wrapMethod("fetchURL");
 | 
			
		||||
  public autocomplete = this._wrapMethod("autocomplete");
 | 
			
		||||
  public removeInstanceFromDoc = this._wrapMethod("removeInstanceFromDoc");
 | 
			
		||||
 | 
			
		||||
@ -318,6 +318,11 @@ export interface ActiveDocAPI {
 | 
			
		||||
   */
 | 
			
		||||
  getFormulaError(tableId: string, colId: string, rowId: number): Promise<CellValue>;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Generates a formula code based on the AI suggestions, it also modifies the column and sets it type to a formula.
 | 
			
		||||
   */
 | 
			
		||||
  getAssistance(tableId: string, colId: string, description: string): Promise<void>;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Fetch content at a url.
 | 
			
		||||
   */
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										11
									
								
								app/common/AssistancePrompts.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								app/common/AssistancePrompts.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
			
		||||
import {DocAction} from 'app/common/DocActions';
 | 
			
		||||
 | 
			
		||||
export interface Prompt {
 | 
			
		||||
  tableId: string;
 | 
			
		||||
  colId: string
 | 
			
		||||
  description: string;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export interface Suggestion {
 | 
			
		||||
  suggestedActions: DocAction[];
 | 
			
		||||
}
 | 
			
		||||
@ -14,6 +14,7 @@ import {
 | 
			
		||||
} from 'app/common/ActionBundle';
 | 
			
		||||
import {ActionGroup, MinimalActionGroup} from 'app/common/ActionGroup';
 | 
			
		||||
import {ActionSummary} from "app/common/ActionSummary";
 | 
			
		||||
import {Prompt, Suggestion} from "app/common/AssistancePrompts";
 | 
			
		||||
import {
 | 
			
		||||
  AclResources,
 | 
			
		||||
  AclTableDescription,
 | 
			
		||||
@ -80,6 +81,7 @@ import {parseUserAction} from 'app/common/ValueParser';
 | 
			
		||||
import {ParseOptions} from 'app/plugin/FileParserAPI';
 | 
			
		||||
import {AccessTokenOptions, AccessTokenResult, GristDocAPI} from 'app/plugin/GristAPI';
 | 
			
		||||
import {compileAclFormula} from 'app/server/lib/ACLFormula';
 | 
			
		||||
import {sendForCompletion} from 'app/server/lib/Assistance';
 | 
			
		||||
import {Authorizer} from 'app/server/lib/Authorizer';
 | 
			
		||||
import {checksumFile} from 'app/server/lib/checksumFile';
 | 
			
		||||
import {Client} from 'app/server/lib/Client';
 | 
			
		||||
@ -1313,6 +1315,24 @@ export class ActiveDoc extends EventEmitter {
 | 
			
		||||
    return this._pyCall('autocomplete', txt, tableId, columnId, rowId, user.toJSON());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public async getAssistance(docSession: DocSession, userPrompt: Prompt): Promise<Suggestion> {
 | 
			
		||||
    // Making a prompt can leak names of tables and columns.
 | 
			
		||||
    if (!await this._granularAccess.canScanData(docSession)) {
 | 
			
		||||
      throw new Error("Permission denied");
 | 
			
		||||
    }
 | 
			
		||||
    await this.waitForInitialization();
 | 
			
		||||
    const { tableId, colId, description } = userPrompt;
 | 
			
		||||
    const prompt = await this._pyCall('get_formula_prompt', tableId, colId, description);
 | 
			
		||||
    this._log.debug(docSession, 'getAssistance prompt', {prompt});
 | 
			
		||||
    const completion = await sendForCompletion(prompt);
 | 
			
		||||
    this._log.debug(docSession, 'getAssistance completion', {completion});
 | 
			
		||||
    const formula = await this._pyCall('convert_formula_completion', completion);
 | 
			
		||||
    const action: DocAction = ["ModifyColumn", tableId, colId, {formula}];
 | 
			
		||||
    return {
 | 
			
		||||
      suggestedActions: [action],
 | 
			
		||||
    };
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public fetchURL(docSession: DocSession, url: string, options?: FetchUrlOptions): Promise<UploadResult> {
 | 
			
		||||
    return fetchURL(url, this.makeAccessId(docSession.authorizer.getUserId()), options);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										110
									
								
								app/server/lib/Assistance.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								app/server/lib/Assistance.ts
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,110 @@
 | 
			
		||||
/**
 | 
			
		||||
 * Module with functions used for AI formula assistance.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
import {delay} from 'app/common/delay';
 | 
			
		||||
import log from 'app/server/lib/log';
 | 
			
		||||
import fetch, { Response as FetchResponse} from 'node-fetch';
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
export async function sendForCompletion(prompt: string): Promise<string> {
 | 
			
		||||
  let completion: string|null = null;
 | 
			
		||||
  if (process.env.OPENAI_API_KEY) {
 | 
			
		||||
    completion = await sendForCompletionOpenAI(prompt);
 | 
			
		||||
  }
 | 
			
		||||
  if (process.env.HUGGINGFACE_API_KEY) {
 | 
			
		||||
    completion = await sendForCompletionHuggingFace(prompt);
 | 
			
		||||
  }
 | 
			
		||||
  if (completion === null) {
 | 
			
		||||
    throw new Error("Please set OPENAI_API_KEY or HUGGINGFACE_API_KEY (and optionally COMPLETION_MODEL)");
 | 
			
		||||
  }
 | 
			
		||||
  log.debug(`Received completion:`, {completion});
 | 
			
		||||
  completion = completion.split(/\n {4}[^ ]/)[0];
 | 
			
		||||
  return completion;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async function sendForCompletionOpenAI(prompt: string) {
 | 
			
		||||
  const apiKey = process.env.OPENAI_API_KEY;
 | 
			
		||||
  if (!apiKey) {
 | 
			
		||||
    throw new Error("OPENAI_API_KEY not set");
 | 
			
		||||
  }
 | 
			
		||||
  const response = await fetch(
 | 
			
		||||
    "https://api.openai.com/v1/completions",
 | 
			
		||||
    {
 | 
			
		||||
      method: "POST",
 | 
			
		||||
      headers: {
 | 
			
		||||
        "Authorization": `Bearer ${apiKey}`,
 | 
			
		||||
        "Content-Type": "application/json",
 | 
			
		||||
      },
 | 
			
		||||
      body: JSON.stringify({
 | 
			
		||||
        prompt,
 | 
			
		||||
        max_tokens: 150,
 | 
			
		||||
        temperature: 0,
 | 
			
		||||
        // COMPLETION_MODEL of `code-davinci-002` may be better if you have access to it.
 | 
			
		||||
        model: process.env.COMPLETION_MODEL || "text-davinci-002",
 | 
			
		||||
        stop: ["\n\n"],
 | 
			
		||||
      }),
 | 
			
		||||
    },
 | 
			
		||||
  );
 | 
			
		||||
  if (response.status !== 200) {
 | 
			
		||||
    log.error(`OpenAI API returned ${response.status}: ${await response.text()}`);
 | 
			
		||||
    throw new Error(`OpenAI API returned status ${response.status}`);
 | 
			
		||||
  }
 | 
			
		||||
  const result = await response.json();
 | 
			
		||||
  const completion = result.choices[0].text;
 | 
			
		||||
  return completion;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async function sendForCompletionHuggingFace(prompt: string) {
 | 
			
		||||
  const apiKey = process.env.HUGGINGFACE_API_KEY;
 | 
			
		||||
  if (!apiKey) {
 | 
			
		||||
    throw new Error("HUGGINGFACE_API_KEY not set");
 | 
			
		||||
  }
 | 
			
		||||
  // COMPLETION_MODEL values I've tried:
 | 
			
		||||
  //   - codeparrot/codeparrot
 | 
			
		||||
  //   - NinedayWang/PolyCoder-2.7B
 | 
			
		||||
  //   - NovelAI/genji-python-6B
 | 
			
		||||
  let completionUrl = process.env.COMPLETION_URL;
 | 
			
		||||
  if (!completionUrl) {
 | 
			
		||||
    if (process.env.COMPLETION_MODEL) {
 | 
			
		||||
      completionUrl = `https://api-inference.huggingface.co/models/${process.env.COMPLETION_MODEL}`;
 | 
			
		||||
    } else {
 | 
			
		||||
      completionUrl = 'https://api-inference.huggingface.co/models/NovelAI/genji-python-6B';
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  let retries: number = 0;
 | 
			
		||||
  let response!: FetchResponse;
 | 
			
		||||
  while (retries++ < 3) {
 | 
			
		||||
    response = await fetch(
 | 
			
		||||
      completionUrl,
 | 
			
		||||
      {
 | 
			
		||||
        method: "POST",
 | 
			
		||||
        headers: {
 | 
			
		||||
          "Authorization": `Bearer ${apiKey}`,
 | 
			
		||||
          "Content-Type": "application/json",
 | 
			
		||||
        },
 | 
			
		||||
        body: JSON.stringify({
 | 
			
		||||
          inputs: prompt,
 | 
			
		||||
          parameters: {
 | 
			
		||||
            return_full_text: false,
 | 
			
		||||
            max_new_tokens: 50,
 | 
			
		||||
          },
 | 
			
		||||
        }),
 | 
			
		||||
      },
 | 
			
		||||
    );
 | 
			
		||||
    if (response.status === 503) {
 | 
			
		||||
      log.error(`Sleeping for 10s - HuggingFace API returned ${response.status}: ${await response.text()}`);
 | 
			
		||||
      await delay(10000);
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  if (response.status !== 200) {
 | 
			
		||||
    const text = await response.text();
 | 
			
		||||
    log.error(`HuggingFace API returned ${response.status}: ${text}`);
 | 
			
		||||
    throw new Error(`HuggingFace API returned status ${response.status}: ${text}`);
 | 
			
		||||
  }
 | 
			
		||||
  const result = await response.json();
 | 
			
		||||
  const completion = result[0].generated_text;
 | 
			
		||||
  return completion.split('\n\n')[0];
 | 
			
		||||
}
 | 
			
		||||
@ -110,6 +110,7 @@ export class DocWorker {
 | 
			
		||||
      applyUserActionsById:     activeDocMethod.bind(null, 'editors', 'applyUserActionsById'),
 | 
			
		||||
      findColFromValues:        activeDocMethod.bind(null, 'viewers', 'findColFromValues'),
 | 
			
		||||
      getFormulaError:          activeDocMethod.bind(null, 'viewers', 'getFormulaError'),
 | 
			
		||||
      getAssistance:            activeDocMethod.bind(null, 'editors', 'getAssistance'),
 | 
			
		||||
      importFiles:              activeDocMethod.bind(null, 'editors', 'importFiles'),
 | 
			
		||||
      finishImportFiles:        activeDocMethod.bind(null, 'editors', 'finishImportFiles'),
 | 
			
		||||
      cancelImportFiles:        activeDocMethod.bind(null, 'editors', 'cancelImportFiles'),
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										35
									
								
								documentation/llm.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								documentation/llm.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
			
		||||
# Using Large Language Models with Grist
 | 
			
		||||
 | 
			
		||||
In this experimental Grist feature, originally developed by Alex Hall,
 | 
			
		||||
you can hook up an AI model such as OpenAI's Codex to write formulas for
 | 
			
		||||
you. Here's how.
 | 
			
		||||
 | 
			
		||||
First, you need an API key. You'll have best results currently with an
 | 
			
		||||
OpenAI model. Visit https://openai.com/api/ and prepare a key, then
 | 
			
		||||
store it in an environment variable `OPENAI_API_KEY`.
 | 
			
		||||
 | 
			
		||||
Alternatively, there are many non-proprietary models hosted on Hugging Face.
 | 
			
		||||
At the time of writing, none can compare with OpenAI for use with Grist.
 | 
			
		||||
Things can change quickly in the world of AI though. So instead of OpenAI,
 | 
			
		||||
you can visit https://huggingface.co/ and prepare a key, then
 | 
			
		||||
store it in an environment variable `HUGGINGFACE_API_KEY`.
 | 
			
		||||
 | 
			
		||||
That's all the configuration needed!
 | 
			
		||||
 | 
			
		||||
Currently it is only a backend feature, we are still working on the UI for it.
 | 
			
		||||
 | 
			
		||||
## Trying other models
 | 
			
		||||
 | 
			
		||||
The model used will default to `text-davinci-002` for OpenAI. You can
 | 
			
		||||
get better results by setting an environment variable `COMPLETION_MODEL` to
 | 
			
		||||
`code-davinci-002` if you have access to that model.
 | 
			
		||||
 | 
			
		||||
The model used will default to `NovelAI/genji-python-6B` for
 | 
			
		||||
Hugging Face. There's no particularly great model for this application,
 | 
			
		||||
but you can try other models by setting an environment variable
 | 
			
		||||
`COMPLETION_MODEL` to `codeparrot/codeparrot` or
 | 
			
		||||
`NinedayWang/PolyCoder-2.7B` or similar.
 | 
			
		||||
 | 
			
		||||
If you are hosting a model yourself, host it as Hugging Face does,
 | 
			
		||||
and use `COMPLETION_URL` rather than `COMPLETION_MODEL` to
 | 
			
		||||
point to the model on your own server rather than Hugging Face.
 | 
			
		||||
							
								
								
									
										186
									
								
								sandbox/grist/formula_prompt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								sandbox/grist/formula_prompt.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,186 @@
 | 
			
		||||
import json
 | 
			
		||||
import textwrap
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
from column import is_visible_column, BaseReferenceColumn
 | 
			
		||||
from objtypes import RaisedException
 | 
			
		||||
import records
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def column_type(engine, table_id, col_id):
 | 
			
		||||
  col_rec = engine.docmodel.get_column_rec(table_id, col_id)
 | 
			
		||||
  typ = col_rec.type
 | 
			
		||||
  parts = typ.split(":")
 | 
			
		||||
  if parts[0] == "Ref":
 | 
			
		||||
    return parts[1]
 | 
			
		||||
  elif parts[0] == "RefList":
 | 
			
		||||
    return "List[{}]".format(parts[1])
 | 
			
		||||
  elif typ == "Choice":
 | 
			
		||||
    return choices(col_rec)
 | 
			
		||||
  elif typ == "ChoiceList":
 | 
			
		||||
    return "Tuple[{}, ...]".format(choices(col_rec))
 | 
			
		||||
  elif typ == "Any":
 | 
			
		||||
    table = engine.tables[table_id]
 | 
			
		||||
    col = table.get_column(col_id)
 | 
			
		||||
    values = [col.raw_get(row_id) for row_id in table.row_ids]
 | 
			
		||||
    return values_type(values)
 | 
			
		||||
  else:
 | 
			
		||||
    return dict(
 | 
			
		||||
      Text="str",
 | 
			
		||||
      Numeric="float",
 | 
			
		||||
      Int="int",
 | 
			
		||||
      Bool="bool",
 | 
			
		||||
      Date="datetime.date",
 | 
			
		||||
      DateTime="datetime.datetime",
 | 
			
		||||
      Any="Any",
 | 
			
		||||
      Attachments="Any",
 | 
			
		||||
    )[parts[0]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def choices(col_rec):
 | 
			
		||||
  try:
 | 
			
		||||
    widget_options = json.loads(col_rec.widgetOptions)
 | 
			
		||||
    return "Literal{}".format(widget_options["choices"])
 | 
			
		||||
  except (ValueError, KeyError):
 | 
			
		||||
    return 'str'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def values_type(values):
 | 
			
		||||
  types = set(type(v) for v in values) - {RaisedException}
 | 
			
		||||
  optional = type(None) in types # pylint: disable=unidiomatic-typecheck
 | 
			
		||||
  types.discard(type(None))
 | 
			
		||||
 | 
			
		||||
  if types == {int, float}:
 | 
			
		||||
    types = {float}
 | 
			
		||||
 | 
			
		||||
  if len(types) != 1:
 | 
			
		||||
    return "Any"
 | 
			
		||||
 | 
			
		||||
  [typ] = types
 | 
			
		||||
  val = next(v for v in values if isinstance(v, typ))
 | 
			
		||||
 | 
			
		||||
  if isinstance(val, records.Record):
 | 
			
		||||
    type_name = val._table.table_id
 | 
			
		||||
  elif isinstance(val, records.RecordSet):
 | 
			
		||||
    type_name = "List[{}]".format(val._table.table_id)
 | 
			
		||||
  elif isinstance(val, list):
 | 
			
		||||
    type_name = "List[{}]".format(values_type(val))
 | 
			
		||||
  elif isinstance(val, set):
 | 
			
		||||
    type_name = "Set[{}]".format(values_type(val))
 | 
			
		||||
  elif isinstance(val, tuple):
 | 
			
		||||
    type_name = "Tuple[{}, ...]".format(values_type(val))
 | 
			
		||||
  elif isinstance(val, dict):
 | 
			
		||||
    type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values()))
 | 
			
		||||
  else:
 | 
			
		||||
    type_name = typ.__name__
 | 
			
		||||
 | 
			
		||||
  if optional:
 | 
			
		||||
    type_name = "Optional[{}]".format(type_name)
 | 
			
		||||
 | 
			
		||||
  return type_name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def referenced_tables(engine, table_id):
 | 
			
		||||
  result = set()
 | 
			
		||||
  queue = [table_id]
 | 
			
		||||
  while queue:
 | 
			
		||||
    cur_table_id = queue.pop()
 | 
			
		||||
    if cur_table_id in result:
 | 
			
		||||
      continue
 | 
			
		||||
    result.add(cur_table_id)
 | 
			
		||||
    for col_id, col in visible_columns(engine, cur_table_id):
 | 
			
		||||
      if isinstance(col, BaseReferenceColumn):
 | 
			
		||||
        target_table_id = col._target_table.table_id
 | 
			
		||||
        if not target_table_id.startswith("_"):
 | 
			
		||||
          queue.append(target_table_id)
 | 
			
		||||
  return result - {table_id}
 | 
			
		||||
 | 
			
		||||
def all_other_tables(engine, table_id):
 | 
			
		||||
  result = set(t for t in engine.tables.keys() if not t.startswith('_grist'))
 | 
			
		||||
  return result - {table_id} - {'GristDocTour'}
 | 
			
		||||
 | 
			
		||||
def visible_columns(engine, table_id):
 | 
			
		||||
  return [
 | 
			
		||||
    (col_id, col)
 | 
			
		||||
    for col_id, col in engine.tables[table_id].all_columns.items()
 | 
			
		||||
    if is_visible_column(col_id)
 | 
			
		||||
  ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def class_schema(engine, table_id, exclude_col_id=None, lookups=False):
 | 
			
		||||
  result = "@dataclass\nclass {}:\n".format(table_id)
 | 
			
		||||
 | 
			
		||||
  if lookups:
 | 
			
		||||
 | 
			
		||||
    # Build a lookupRecords and lookupOne method for each table, providing some arguments hints
 | 
			
		||||
    # for the columns that are visible.
 | 
			
		||||
    lookupRecords_args = []
 | 
			
		||||
    lookupOne_args = []
 | 
			
		||||
    for col_id, col in visible_columns(engine, table_id):
 | 
			
		||||
      if col_id != exclude_col_id:
 | 
			
		||||
        lookupOne_args.append(col_id + '=None')
 | 
			
		||||
        lookupRecords_args.append('%s=%s' % (col_id, col_id))
 | 
			
		||||
    lookupOne_args.append('sort_by=None')
 | 
			
		||||
    lookupRecords_args.append('sort_by=sort_by')
 | 
			
		||||
    lookupOne_args_line = ', '.join(lookupOne_args)
 | 
			
		||||
    lookupRecords_args_line = ', '.join(lookupRecords_args)
 | 
			
		||||
 | 
			
		||||
    result += "     def __len__(self):\n"
 | 
			
		||||
    result += "        return len(%s.lookupRecords())\n" % table_id
 | 
			
		||||
    result += "    @staticmethod\n"
 | 
			
		||||
    result += "    def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id)
 | 
			
		||||
    result += "       # ...\n"
 | 
			
		||||
    result += "    @staticmethod\n"
 | 
			
		||||
    result += "    def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id)
 | 
			
		||||
    result += "       '''\n"
 | 
			
		||||
    result += "       Filter for one result matching the keys provided.\n"
 | 
			
		||||
    result += "       To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`.\n"
 | 
			
		||||
    result += "       '''\n"
 | 
			
		||||
    result += "       return %s.lookupRecords(%s)[0]\n" % (table_id, lookupRecords_args_line)
 | 
			
		||||
    result += "\n"
 | 
			
		||||
 | 
			
		||||
  for col_id, col in visible_columns(engine, table_id):
 | 
			
		||||
    if col_id != exclude_col_id:
 | 
			
		||||
      result += "    {}: {}\n".format(col_id, column_type(engine, table_id, col_id))
 | 
			
		||||
  result += "\n"
 | 
			
		||||
  return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_formula_prompt(engine, table_id, col_id, description,
 | 
			
		||||
                       include_all_tables=True,
 | 
			
		||||
                       lookups=True):
 | 
			
		||||
  result = ""
 | 
			
		||||
  other_tables = (all_other_tables(engine, table_id)
 | 
			
		||||
    if include_all_tables else referenced_tables(engine, table_id))
 | 
			
		||||
  for other_table_id in sorted(other_tables):
 | 
			
		||||
    result += class_schema(engine, other_table_id, lookups)
 | 
			
		||||
 | 
			
		||||
  result += class_schema(engine, table_id, col_id, lookups)
 | 
			
		||||
 | 
			
		||||
  return_type = column_type(engine, table_id, col_id)
 | 
			
		||||
  result += "    @property\n"
 | 
			
		||||
  result += "    # rec is alias for self\n"
 | 
			
		||||
  result += "    def {}(rec) -> {}:\n".format(col_id, return_type)
 | 
			
		||||
  result += '        """\n'
 | 
			
		||||
  result += '{}\n'.format(indent(description, "        "))
 | 
			
		||||
  result += '        """\n'
 | 
			
		||||
  return result
 | 
			
		||||
 | 
			
		||||
def indent(text, prefix, predicate=None):
 | 
			
		||||
  """
 | 
			
		||||
  Copied from https://github.com/python/cpython/blob/main/Lib/textwrap.py for python2 compatibility.
 | 
			
		||||
  """
 | 
			
		||||
  if six.PY3:
 | 
			
		||||
    return textwrap.indent(text, prefix, predicate) # pylint: disable = no-member
 | 
			
		||||
  if predicate is None:
 | 
			
		||||
    def predicate(line):
 | 
			
		||||
      return line.strip()
 | 
			
		||||
  def prefixed_lines():
 | 
			
		||||
    for line in text.splitlines(True):
 | 
			
		||||
      yield (prefix + line if predicate(line) else line)
 | 
			
		||||
  return ''.join(prefixed_lines())
 | 
			
		||||
 | 
			
		||||
def convert_completion(completion):
 | 
			
		||||
  result = textwrap.dedent(completion)
 | 
			
		||||
  return result
 | 
			
		||||
@ -16,6 +16,7 @@ import six
 | 
			
		||||
 | 
			
		||||
import actions
 | 
			
		||||
import engine
 | 
			
		||||
import formula_prompt
 | 
			
		||||
import migrations
 | 
			
		||||
import schema
 | 
			
		||||
import useractions
 | 
			
		||||
@ -135,6 +136,14 @@ def run(sandbox):
 | 
			
		||||
  def get_formula_error(table_id, col_id, row_id):
 | 
			
		||||
    return objtypes.encode_object(eng.get_formula_error(table_id, col_id, row_id))
 | 
			
		||||
 | 
			
		||||
  @export
 | 
			
		||||
  def get_formula_prompt(table_id, col_id, description):
 | 
			
		||||
    return formula_prompt.get_formula_prompt(eng, table_id, col_id, description)
 | 
			
		||||
 | 
			
		||||
  @export
 | 
			
		||||
  def convert_formula_completion(completion):
 | 
			
		||||
    return formula_prompt.convert_completion(completion)
 | 
			
		||||
 | 
			
		||||
  export(parse_acl_formula)
 | 
			
		||||
  export(eng.load_empty)
 | 
			
		||||
  export(eng.load_done)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										217
									
								
								sandbox/grist/test_formula_prompt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								sandbox/grist/test_formula_prompt.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,217 @@
 | 
			
		||||
import unittest
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
import test_engine
 | 
			
		||||
import testutil
 | 
			
		||||
 | 
			
		||||
from formula_prompt import (
 | 
			
		||||
  values_type, column_type, referenced_tables, get_formula_prompt,
 | 
			
		||||
)
 | 
			
		||||
from objtypes import RaisedException
 | 
			
		||||
from records import Record, RecordSet
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeTable(object):
 | 
			
		||||
 | 
			
		||||
  def __init__(self):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
  table_id = "Table1"
 | 
			
		||||
  _identity_relation = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@unittest.skipUnless(six.PY3, "Python 3 only")
 | 
			
		||||
class TestFormulaPrompt(test_engine.EngineTestCase):
 | 
			
		||||
  def test_values_type(self):
 | 
			
		||||
    self.assertEqual(values_type([1, 2, 3]), "int")
 | 
			
		||||
    self.assertEqual(values_type([1.0, 2.0, 3.0]), "float")
 | 
			
		||||
    self.assertEqual(values_type([1, 2, 3.0]), "float")
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(values_type([1, 2, None]), "Optional[int]")
 | 
			
		||||
    self.assertEqual(values_type([1, 2, 3.0, None]), "Optional[float]")
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(values_type([1, RaisedException(None), 3]), "int")
 | 
			
		||||
    self.assertEqual(values_type([1, RaisedException(None), None]), "Optional[int]")
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(values_type(["1", "2", "3"]), "str")
 | 
			
		||||
    self.assertEqual(values_type([1, 2, "3"]), "Any")
 | 
			
		||||
    self.assertEqual(values_type([1, 2, "3", None]), "Any")
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(values_type([
 | 
			
		||||
      Record(FakeTable(), None),
 | 
			
		||||
      Record(FakeTable(), None),
 | 
			
		||||
    ]), "Table1")
 | 
			
		||||
    self.assertEqual(values_type([
 | 
			
		||||
      Record(FakeTable(), None),
 | 
			
		||||
      Record(FakeTable(), None),
 | 
			
		||||
      None,
 | 
			
		||||
    ]), "Optional[Table1]")
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(values_type([
 | 
			
		||||
      RecordSet(FakeTable(), None),
 | 
			
		||||
      RecordSet(FakeTable(), None),
 | 
			
		||||
    ]), "List[Table1]")
 | 
			
		||||
    self.assertEqual(values_type([
 | 
			
		||||
      RecordSet(FakeTable(), None),
 | 
			
		||||
      RecordSet(FakeTable(), None),
 | 
			
		||||
      None,
 | 
			
		||||
    ]), "Optional[List[Table1]]")
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(values_type([[1, 2, 3]]), "List[int]")
 | 
			
		||||
    self.assertEqual(values_type([[1, 2, 3], None]), "Optional[List[int]]")
 | 
			
		||||
    self.assertEqual(values_type([[1, 2, None]]), "List[Optional[int]]")
 | 
			
		||||
    self.assertEqual(values_type([[1, 2, None], None]), "Optional[List[Optional[int]]]")
 | 
			
		||||
    self.assertEqual(values_type([[1, 2, "3"]]), "List[Any]")
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(values_type([{1, 2, 3}]), "Set[int]")
 | 
			
		||||
    self.assertEqual(values_type([(1, 2, 3)]), "Tuple[int, ...]")
 | 
			
		||||
    self.assertEqual(values_type([{1: ["2"]}]), "Dict[int, List[str]]")
 | 
			
		||||
 | 
			
		||||
  def assert_column_type(self, col_id, expected_type):
 | 
			
		||||
    self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type)
 | 
			
		||||
 | 
			
		||||
  def assert_prompt(self, table_name, col_id, expected_prompt):
 | 
			
		||||
    prompt = get_formula_prompt(self.engine, table_name, col_id, "description here",
 | 
			
		||||
                                include_all_tables=False, lookups=False)
 | 
			
		||||
    # print(prompt)
 | 
			
		||||
    self.assertEqual(prompt, expected_prompt)
 | 
			
		||||
 | 
			
		||||
  def test_column_type(self):
 | 
			
		||||
    sample = testutil.parse_test_sample({
 | 
			
		||||
      "SCHEMA": [
 | 
			
		||||
        [1, "Table2", [
 | 
			
		||||
          [1, "text", "Text", False, "", "", ""],
 | 
			
		||||
          [2, "numeric", "Numeric", False, "", "", ""],
 | 
			
		||||
          [3, "int", "Int", False, "", "", ""],
 | 
			
		||||
          [4, "bool", "Bool", False, "", "", ""],
 | 
			
		||||
          [5, "date", "Date", False, "", "", ""],
 | 
			
		||||
          [6, "datetime", "DateTime", False, "", "", ""],
 | 
			
		||||
          [7, "attachments", "Attachments", False, "", "", ""],
 | 
			
		||||
          [8, "ref", "Ref:Table2", False, "", "", ""],
 | 
			
		||||
          [9, "reflist", "RefList:Table2", False, "", "", ""],
 | 
			
		||||
          [10, "choice", "Choice", False, "", "", '{"choices": ["a", "b", "c"]}'],
 | 
			
		||||
          [11, "choicelist", "ChoiceList", False, "", "", '{"choices": ["x", "y", "z"]}'],
 | 
			
		||||
          [12, "ref_formula", "Any", True, "$ref or None", "", ""],
 | 
			
		||||
          [13, "numeric_formula", "Any", True, "1 / $numeric", "", ""],
 | 
			
		||||
          [14, "new_formula", "Numeric", True, "'to be generated...'", "", ""],
 | 
			
		||||
        ]],
 | 
			
		||||
      ],
 | 
			
		||||
      "DATA": {
 | 
			
		||||
        "Table2": [
 | 
			
		||||
          ["id", "numeric", "ref"],
 | 
			
		||||
          [1, 0, 0],
 | 
			
		||||
          [2, 1, 1],
 | 
			
		||||
        ],
 | 
			
		||||
      },
 | 
			
		||||
    })
 | 
			
		||||
    self.load_sample(sample)
 | 
			
		||||
 | 
			
		||||
    self.assert_column_type("text", "str")
 | 
			
		||||
    self.assert_column_type("numeric", "float")
 | 
			
		||||
    self.assert_column_type("int", "int")
 | 
			
		||||
    self.assert_column_type("bool", "bool")
 | 
			
		||||
    self.assert_column_type("date", "datetime.date")
 | 
			
		||||
    self.assert_column_type("datetime", "datetime.datetime")
 | 
			
		||||
    self.assert_column_type("attachments", "Any")
 | 
			
		||||
    self.assert_column_type("ref", "Table2")
 | 
			
		||||
    self.assert_column_type("reflist", "List[Table2]")
 | 
			
		||||
    self.assert_column_type("choice", "Literal['a', 'b', 'c']")
 | 
			
		||||
    self.assert_column_type("choicelist", "Tuple[Literal['x', 'y', 'z'], ...]")
 | 
			
		||||
    self.assert_column_type("ref_formula", "Optional[Table2]")
 | 
			
		||||
    self.assert_column_type("numeric_formula", "float")
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(referenced_tables(self.engine, "Table2"), set())
 | 
			
		||||
 | 
			
		||||
    self.assert_prompt("Table2", "new_formula",
 | 
			
		||||
      '''\
 | 
			
		||||
@dataclass
 | 
			
		||||
class Table2:
 | 
			
		||||
    text: str
 | 
			
		||||
    numeric: float
 | 
			
		||||
    int: int
 | 
			
		||||
    bool: bool
 | 
			
		||||
    date: datetime.date
 | 
			
		||||
    datetime: datetime.datetime
 | 
			
		||||
    attachments: Any
 | 
			
		||||
    ref: Table2
 | 
			
		||||
    reflist: List[Table2]
 | 
			
		||||
    choice: Literal['a', 'b', 'c']
 | 
			
		||||
    choicelist: Tuple[Literal['x', 'y', 'z'], ...]
 | 
			
		||||
    ref_formula: Optional[Table2]
 | 
			
		||||
    numeric_formula: float
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    # rec is alias for self
 | 
			
		||||
    def new_formula(rec) -> float:
 | 
			
		||||
        """
 | 
			
		||||
        description here
 | 
			
		||||
        """
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
  def test_get_formula_prompt(self):
 | 
			
		||||
    sample = testutil.parse_test_sample({
 | 
			
		||||
      "SCHEMA": [
 | 
			
		||||
        [1, "Table1", [
 | 
			
		||||
          [1, "text", "Text", False, "", "", ""],
 | 
			
		||||
        ]],
 | 
			
		||||
        [2, "Table2", [
 | 
			
		||||
          [2, "ref", "Ref:Table1", False, "", "", ""],
 | 
			
		||||
        ]],
 | 
			
		||||
        [3, "Table3", [
 | 
			
		||||
          [3, "reflist", "RefList:Table2", False, "", "", ""],
 | 
			
		||||
        ]],
 | 
			
		||||
      ],
 | 
			
		||||
      "DATA": {},
 | 
			
		||||
    })
 | 
			
		||||
    self.load_sample(sample)
 | 
			
		||||
    self.assertEqual(referenced_tables(self.engine, "Table3"), {"Table1", "Table2"})
 | 
			
		||||
    self.assertEqual(referenced_tables(self.engine, "Table2"), {"Table1"})
 | 
			
		||||
    self.assertEqual(referenced_tables(self.engine, "Table1"), set())
 | 
			
		||||
 | 
			
		||||
    self.assert_prompt("Table1", "text", '''\
 | 
			
		||||
@dataclass
 | 
			
		||||
class Table1:
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    # rec is alias for self
 | 
			
		||||
    def text(rec) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        description here
 | 
			
		||||
        """
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    self.assert_prompt("Table2", "ref", '''\
 | 
			
		||||
@dataclass
 | 
			
		||||
class Table1:
 | 
			
		||||
    text: str
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Table2:
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    # rec is alias for self
 | 
			
		||||
    def ref(rec) -> Table1:
 | 
			
		||||
        """
 | 
			
		||||
        description here
 | 
			
		||||
        """
 | 
			
		||||
''')
 | 
			
		||||
 | 
			
		||||
    self.assert_prompt("Table3", "reflist", '''\
 | 
			
		||||
@dataclass
 | 
			
		||||
class Table1:
 | 
			
		||||
    text: str
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Table2:
 | 
			
		||||
    ref: Table1
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Table3:
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    # rec is alias for self
 | 
			
		||||
    def reflist(rec) -> List[Table2]:
 | 
			
		||||
        """
 | 
			
		||||
        description here
 | 
			
		||||
        """
 | 
			
		||||
''')
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user