(core) Ensure formulas return something and don't assign to attributes of rec

Summary: To help with mistakes in formulas, forbid assigning to attributes of `rec` (e.g. `$foo = 1` which should probably be `==`) and ensure that there is at least one `return` in the formula (after maybe adding an implicit one at the end).

Test Plan: Extended Python unit test, updated tests which were missing return.

Reviewers: georgegevoian

Reviewed By: georgegevoian

Subscribers: dsagal

Differential Revision: https://phab.getgrist.com/D3439
This commit is contained in:
Alex Hall
2022-05-23 20:02:39 +02:00
parent fcbad1c887
commit fb575a8b7e
4 changed files with 74 additions and 23 deletions

View File

@@ -1,5 +1,6 @@
import ast
import contextlib
import itertools
import re
import six
@@ -48,13 +49,6 @@ def make_formula_body(formula, default_value, assoc_value=None):
except SyntaxError as e:
return textbuilder.Text(_create_syntax_error_code(tmp_formula, formula, e))
# Parse formula and generate error code on assignment to rec
with use_inferences(InferRecAssignment):
try:
astroid.parse(tmp_formula.get_text())
except SyntaxError as e:
return textbuilder.Text(_create_syntax_error_code(tmp_formula, formula, e))
# Once we have a tree, go through it and create a subset of the dollar patches that are actually
# relevant. E.g. this is where we'll skip the "$foo" patches that appear in strings or comments.
patches = []
@@ -86,6 +80,19 @@ def make_formula_body(formula, default_value, assoc_value=None):
elif last_statement is None:
# If we have an empty body (e.g. just a comment), add a 'pass' at the end.
patches.append(textbuilder.make_patch(formula, len(formula), len(formula), '\npass'))
elif not any(
# Raise an error if the user forgot to return anything. For performance:
# - Use type() instead of isinstance()
# - Check last_statement first to try avoiding walking the tree
type(node) == ast.Return # pylint: disable=unidiomatic-typecheck
for node in itertools.chain([last_statement], ast.walk(atok.tree))
):
message = "No `return` statement, and the last line isn't an expression."
if isinstance(last_statement, ast.Assign):
message += " If you want to check for equality, use `==` instead of `=`."
error = SyntaxError(message,
('<string>', 1, 1, ""))
return textbuilder.Text(_create_syntax_error_code(tmp_formula, formula, error))
# Apply the new set of patches to the original formula to get the real output.
final_formula = textbuilder.Replacer(formula_builder_text, patches)
@@ -93,10 +100,13 @@ def make_formula_body(formula, default_value, assoc_value=None):
# Try parsing again before returning it just in case we have new syntax errors. These are
# possible in cases when a single token ('DOLLARfoo') is valid but an expression ('rec.foo') is
# not, e.g. `foo($bar=1)` or `def $foo()`.
try:
atok = asttokens.ASTTokens(final_formula.get_text(), parse=True)
except SyntaxError as e:
return textbuilder.Text(_create_syntax_error_code(final_formula, formula, e))
# Also check for common mistakes: assigning to `rec` or its attributes (e.g. `$foo = 1`).
with use_inferences(InferRecAssignment, InferRecAttrAssignment):
try:
astroid.parse(final_formula.get_text())
except (astroid.AstroidSyntaxError, SyntaxError) as e:
error = getattr(e, "error", e) # extract SyntaxError from AstroidSyntaxError
return textbuilder.Text(_create_syntax_error_code(final_formula, formula, error))
# We return the text-builder object whose .get_text() is the final formula.
return final_formula
@@ -268,6 +278,23 @@ class InferRecAssignment(InferenceTip):
def infer(cls, node, context):
raise NotImplementedError()
class InferRecAttrAssignment(InferenceTip):
"""
Inference helper to raise exception on assignment to `rec`.
"""
node_class = astroid.nodes.AssignAttr
@classmethod
def filter(cls, node):
if isinstance(node.expr, astroid.nodes.Name) and node.expr.name == 'rec':
raise SyntaxError("You can't assign a value to a column with `=`. "
"If you mean to check for equality, use `==` instead.",
('<string>', node.lineno, node.col_offset, ""))
@classmethod
def infer(cls, node, context):
raise NotImplementedError()
#----------------------------------------------------------------------
def parse_grist_names(builder):