mirror of
https://github.com/gristlabs/grist-core.git
synced 2024-10-27 20:44:07 +00:00
(core) Improve parsing formula from completion
Summary: The previous code for extracting a Python formula from the LLM completion involved some shaky string manipulation which this improves on. Overall the 'test results' from `runCompletion` went from 37/47 to 45/47 for `gpt-3.5-turbo-0613`. The biggest problem that motivated these changes was that it assumed that code was always inside a markdown code block (i.e. triple backticks) and so if there was no block there was no code. But the completion often consists of *only* code with no accompanying explanation or markdown. By parsing the completion in Python instead of JS, we can easily check if the entire completion is valid Python syntax and accept it if it is. I also noticed one failure resulting from the completion containing the full function (instead of just the body) and necessary imports before that function instead of inside. The new parsing moves import inside. Test Plan: Added a Python unit test Reviewers: paulfitz Reviewed By: paulfitz Subscribers: paulfitz Differential Revision: https://phab.getgrist.com/D3922
This commit is contained in:
parent
0b64e408b0
commit
52469c5a7e
@ -133,39 +133,13 @@ export class OpenAIAssistant implements Assistant {
|
|||||||
throw new Error(`OpenAI API returned status ${apiResponse.status}`);
|
throw new Error(`OpenAI API returned status ${apiResponse.status}`);
|
||||||
}
|
}
|
||||||
const result = await apiResponse.json();
|
const result = await apiResponse.json();
|
||||||
let completion: string = String(chatMode ? result.choices[0].message.content : result.choices[0].text);
|
const completion: string = String(chatMode ? result.choices[0].message.content : result.choices[0].text);
|
||||||
const reply = completion;
|
|
||||||
const history = { messages };
|
const history = { messages };
|
||||||
if (chatMode) {
|
if (chatMode) {
|
||||||
history.messages.push(result.choices[0].message);
|
history.messages.push(result.choices[0].message);
|
||||||
// This model likes returning markdown. Code will typically
|
|
||||||
// be in a code block with ``` delimiters.
|
|
||||||
let lines = completion.split('\n');
|
|
||||||
if (lines[0].startsWith('```')) {
|
|
||||||
lines.shift();
|
|
||||||
completion = lines.join('\n');
|
|
||||||
const parts = completion.split('```');
|
|
||||||
if (parts.length > 1) {
|
|
||||||
completion = parts[0];
|
|
||||||
}
|
|
||||||
lines = completion.split('\n');
|
|
||||||
}
|
|
||||||
|
|
||||||
// This model likes repeating the function signature and
|
|
||||||
// docstring, so we try to strip that out.
|
|
||||||
completion = lines.join('\n');
|
|
||||||
while (completion.includes('"""')) {
|
|
||||||
const parts = completion.split('"""');
|
|
||||||
completion = parts[parts.length - 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there's no code block, don't treat the answer as a formula.
|
|
||||||
if (!reply.includes('```')) {
|
|
||||||
completion = '';
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await completionToResponse(doc, request, completion, reply);
|
const response = await completionToResponse(doc, request, completion, completion);
|
||||||
if (chatMode) {
|
if (chatMode) {
|
||||||
response.state = history;
|
response.state = history;
|
||||||
}
|
}
|
||||||
@ -261,38 +235,13 @@ export class EchoAssistant implements Assistant {
|
|||||||
role: 'user', content: request.text,
|
role: 'user', content: request.text,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
let completion = request.text;
|
const completion = request.text;
|
||||||
const reply = completion;
|
|
||||||
const history = { messages };
|
const history = { messages };
|
||||||
history.messages.push({
|
history.messages.push({
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: completion,
|
content: completion,
|
||||||
});
|
});
|
||||||
// This model likes returning markdown. Code will typically
|
const response = await completionToResponse(doc, request, completion, completion);
|
||||||
// be in a code block with ``` delimiters.
|
|
||||||
let lines = completion.split('\n');
|
|
||||||
if (lines[0].startsWith('```')) {
|
|
||||||
lines.shift();
|
|
||||||
completion = lines.join('\n');
|
|
||||||
const parts = completion.split('```');
|
|
||||||
if (parts.length > 1) {
|
|
||||||
completion = parts[0];
|
|
||||||
}
|
|
||||||
lines = completion.split('\n');
|
|
||||||
}
|
|
||||||
// This model likes repeating the function signature and
|
|
||||||
// docstring, so we try to strip that out.
|
|
||||||
completion = lines.join('\n');
|
|
||||||
while (completion.includes('"""')) {
|
|
||||||
const parts = completion.split('"""');
|
|
||||||
completion = parts[parts.length - 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there's no code block, don't treat the answer as a formula.
|
|
||||||
if (!reply.includes('```')) {
|
|
||||||
completion = '';
|
|
||||||
}
|
|
||||||
const response = await completionToResponse(doc, request, completion, reply);
|
|
||||||
response.state = history;
|
response.state = history;
|
||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
@ -357,18 +306,6 @@ async function completionToResponse(doc: AssistanceDoc, request: AssistanceReque
|
|||||||
throw new Error('completionToResponse only works for formulas');
|
throw new Error('completionToResponse only works for formulas');
|
||||||
}
|
}
|
||||||
completion = await doc.assistanceFormulaTweak(completion);
|
completion = await doc.assistanceFormulaTweak(completion);
|
||||||
// A leading newline is common.
|
|
||||||
if (completion.charAt(0) === '\n') {
|
|
||||||
completion = completion.slice(1);
|
|
||||||
}
|
|
||||||
// If all non-empty lines have four spaces, remove those spaces.
|
|
||||||
// They are common for GPT-3.5, which matches the prompt carefully.
|
|
||||||
const lines = completion.split('\n');
|
|
||||||
const ok = lines.every(line => line === '\n' || line.startsWith(' '));
|
|
||||||
if (ok) {
|
|
||||||
completion = lines.map(line => line === '\n' ? line : line.slice(4)).join('\n');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Suggest an action only if the completion is non-empty (that is,
|
// Suggest an action only if the completion is non-empty (that is,
|
||||||
// it actually looked like code).
|
// it actually looked like code).
|
||||||
const suggestedActions: DocAction[] = completion ? [[
|
const suggestedActions: DocAction[] = completion ? [[
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
|
import ast
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
import asttokens
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from column import is_visible_column, BaseReferenceColumn
|
from column import is_visible_column, BaseReferenceColumn
|
||||||
@ -180,6 +183,57 @@ def indent(text, prefix, predicate=None):
|
|||||||
yield (prefix + line if predicate(line) else line)
|
yield (prefix + line if predicate(line) else line)
|
||||||
return ''.join(prefixed_lines())
|
return ''.join(prefixed_lines())
|
||||||
|
|
||||||
|
|
||||||
def convert_completion(completion):
|
def convert_completion(completion):
|
||||||
|
# Extract code from a markdown code block if needed.
|
||||||
|
match = re.search(r"```\w*\n(.*)```", completion, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
completion = match.group(1)
|
||||||
|
|
||||||
result = textwrap.dedent(completion)
|
result = textwrap.dedent(completion)
|
||||||
return result
|
|
||||||
|
try:
|
||||||
|
atok = asttokens.ASTTokens(result, parse=True)
|
||||||
|
except SyntaxError:
|
||||||
|
# If we don't have valid Python code, don't suggest a formula at all
|
||||||
|
return ""
|
||||||
|
|
||||||
|
stmts = atok.tree.body
|
||||||
|
|
||||||
|
# If the code starts with imports, save them for later.
|
||||||
|
# In particular, the model may return something like:
|
||||||
|
# from datetime import date
|
||||||
|
# def my_column():
|
||||||
|
# ...
|
||||||
|
# We want to return just the function body, but we need to keep the import,
|
||||||
|
# i.e. move it 'inside the function'.
|
||||||
|
imports = ""
|
||||||
|
while stmts and isinstance(stmts[0], (ast.Import, ast.ImportFrom)):
|
||||||
|
imports += atok.get_text(stmts.pop(0)) + "\n"
|
||||||
|
|
||||||
|
# If the non-import code consists only of a function definition, extract the body.
|
||||||
|
if len(stmts) == 1 and isinstance(stmts[0], ast.FunctionDef):
|
||||||
|
func_body_stmts = stmts[0].body
|
||||||
|
if (
|
||||||
|
len(func_body_stmts) > 1 and
|
||||||
|
isinstance(func_body_stmts[0], ast.Expr) and
|
||||||
|
isinstance(func_body_stmts[0].value, ast.Str)
|
||||||
|
):
|
||||||
|
# Skip the docstring.
|
||||||
|
first_stmt = func_body_stmts[1]
|
||||||
|
else:
|
||||||
|
first_stmt = func_body_stmts[0]
|
||||||
|
result_lines = result.splitlines()[first_stmt.lineno - 1:]
|
||||||
|
result = "\n".join(result_lines)
|
||||||
|
result = textwrap.dedent(result)
|
||||||
|
|
||||||
|
if imports:
|
||||||
|
result = imports + "\n" + result
|
||||||
|
|
||||||
|
# Check that we still have valid code.
|
||||||
|
try:
|
||||||
|
ast.parse(result)
|
||||||
|
except SyntaxError:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return result.strip()
|
||||||
|
@ -5,7 +5,7 @@ import test_engine
|
|||||||
import testutil
|
import testutil
|
||||||
|
|
||||||
from formula_prompt import (
|
from formula_prompt import (
|
||||||
values_type, column_type, referenced_tables, get_formula_prompt,
|
values_type, column_type, referenced_tables, get_formula_prompt, convert_completion,
|
||||||
)
|
)
|
||||||
from objtypes import RaisedException
|
from objtypes import RaisedException
|
||||||
from records import Record as BaseRecord, RecordSet as BaseRecordSet
|
from records import Record as BaseRecord, RecordSet as BaseRecordSet
|
||||||
@ -223,3 +223,33 @@ class Table3:
|
|||||||
description here
|
description here
|
||||||
"""
|
"""
|
||||||
''')
|
''')
|
||||||
|
|
||||||
|
def test_convert_completion(self):
|
||||||
|
completion = """
|
||||||
|
Here's some code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from x import (
|
||||||
|
y,
|
||||||
|
z,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def foo():
|
||||||
|
'''This is a docstring'''
|
||||||
|
x = 5
|
||||||
|
return 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Hope you like it!
|
||||||
|
"""
|
||||||
|
self.assertEqual(convert_completion(completion), """\
|
||||||
|
import os
|
||||||
|
from x import (
|
||||||
|
y,
|
||||||
|
z,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = 5
|
||||||
|
return 1""")
|
||||||
|
1
test/formula-dataset/.gitignore
vendored
1
test/formula-dataset/.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
data/templates
|
data/templates
|
||||||
data/cache
|
data/cache
|
||||||
|
data/results
|
||||||
|
@ -67,7 +67,7 @@ const SEEMS_CHATTY = (process.env.COMPLETION_MODEL || '').includes('turbo');
|
|||||||
const SIMULATE_CONVERSATION = SEEMS_CHATTY;
|
const SIMULATE_CONVERSATION = SEEMS_CHATTY;
|
||||||
|
|
||||||
export async function runCompletion() {
|
export async function runCompletion() {
|
||||||
ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600000;
|
ActiveDocDeps.ACTIVEDOC_TIMEOUT = 600;
|
||||||
|
|
||||||
// if template directory not exists, make it
|
// if template directory not exists, make it
|
||||||
if (!fs.existsSync(path.join(PATH_TO_DOC))) {
|
if (!fs.existsSync(path.join(PATH_TO_DOC))) {
|
||||||
|
Loading…
Reference in New Issue
Block a user