mirror of
				https://github.com/gristlabs/grist-core.git
				synced 2025-06-13 20:53:59 +00:00 
			
		
		
		
	(core) Prevent the AI assistant response from including class definitions
Summary: Sometimes the model repeats the classes given in the prompt which would mess up extracting the actual formula. This diff solves this by: 1. Changes the generated Python schema so that (a) the thing that needs completing is a plain top level function instead of a property/method inside the class and (2) the classes are fully valid syntax, which makes it easier to 2. Remove classes from the parsed Python code when converting the completion to a formula. 3. Tweak the prompt wording to discourage including classes in general, especially because sometimes the model tries to solve the problem by defining extra methods/attributes/classes. While I was at it, I changed type hints to use builtins (e.g. `list` instead of `List`) to prevent `from typing import List` which was happening sometimes and would look weird in a formula. Similarly I removed `@dataclass` since that also implies an import, and this also fits with the tweaked wording that the classes are fake. Test Plan: Added a new test case to the formula dataset which triggers the unwanted behaviour. The factors that seem to trigger the problem are (1) a small schema so the classes are easier to repeat and (2) the need to import modules, which the model wants to place before all other code. The case failed before this diff and succeeded after. The tweaked wording reduces the chances of repeating the classes but didn't eliminate it, so forcibly removing the classes in Python was needed. There were also a couple of other existing cases where repeating the classes was observed before but not after. Overall the score increased from 49 to 51 out of 69 (including the new case). At one point the score was 53, but changes in whitespace were enough to make it drop again. Reviewers: georgegevoian Reviewed By: georgegevoian Differential Revision: https://phab.getgrist.com/D4000
This commit is contained in:
		
							parent
							
								
									999d723d14
								
							
						
					
					
						commit
						a1d31e41ad
					
				| @ -144,20 +144,18 @@ export class OpenAIAssistant implements Assistant { | |||||||
|       newMessages.push({ |       newMessages.push({ | ||||||
|         role: 'system', |         role: 'system', | ||||||
|         content: 'You are a helpful assistant for a user of software called Grist. ' + |         content: 'You are a helpful assistant for a user of software called Grist. ' + | ||||||
|           'Below are one or more Python classes. ' + |           "Below are one or more fake Python classes representing the structure of the user's data. " + | ||||||
|           'The last method needs completing. ' + |           'The function at the end needs completing. ' + | ||||||
|           "The user will probably give a description of what they want the method (a 'formula') to return. " + |           "The user will probably give a description of what they want the function (a 'formula') to return. " + | ||||||
|           'If so, your response should include the method body as Python code in a markdown block. ' + |           'If so, your response should include the function BODY as Python code in a markdown block. ' + | ||||||
|           'Do not include the class or method signature, just the method body. ' + |           "Your response will be automatically concatenated to the code below, so you mustn't repeat any of it. " + | ||||||
|           'If your code starts with `class`, `@dataclass`, or `def` it will fail. Only give the method body. ' + |           'You cannot change the function signature or define additional functions or classes. ' + | ||||||
|           'You can import modules inside the method body if needed. ' + |           'It should be a pure function that performs some computation and returns a result. ' + | ||||||
|           'You cannot define additional functions or methods. ' + |  | ||||||
|           'The method should be a pure function that performs some computation and returns a result. ' + |  | ||||||
|           'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' + |           'It CANNOT perform any side effects such as adding/removing/modifying rows/columns/cells/tables/etc. ' + | ||||||
|           'It CANNOT interact with files/databases/networks/etc. ' + |           'It CANNOT interact with files/databases/networks/etc. ' + | ||||||
|           'It CANNOT display images/charts/graphs/maps/etc. ' + |           'It CANNOT display images/charts/graphs/maps/etc. ' + | ||||||
|           'If the user asks for these things, tell them that you cannot help. ' + |           'If the user asks for these things, tell them that you cannot help. ' + | ||||||
|           'The method uses `rec` instead of `self` as the first parameter.\n\n' + |           "\n\n" + | ||||||
|           '```python\n' + |           '```python\n' + | ||||||
|           await makeSchemaPromptV1(optSession, doc, request) + |           await makeSchemaPromptV1(optSession, doc, request) + | ||||||
|           '\n```', |           '\n```', | ||||||
| @ -198,6 +196,10 @@ export class OpenAIAssistant implements Assistant { | |||||||
| 
 | 
 | ||||||
|     const userIdHash = getUserHash(optSession); |     const userIdHash = getUserHash(optSession); | ||||||
|     const completion: string = await this._getCompletion(messages, userIdHash); |     const completion: string = await this._getCompletion(messages, userIdHash); | ||||||
|  | 
 | ||||||
|  |     // It's nice to have this ready to uncomment for debugging.
 | ||||||
|  |     // console.log(completion);
 | ||||||
|  | 
 | ||||||
|     const response = await completionToResponse(doc, request, completion); |     const response = await completionToResponse(doc, request, completion); | ||||||
|     if (response.suggestedFormula) { |     if (response.suggestedFormula) { | ||||||
|       // Show the tweaked version of the suggested formula to the user (i.e. the one that's
 |       // Show the tweaked version of the suggested formula to the user (i.e. the one that's
 | ||||||
|  | |||||||
| @ -22,11 +22,11 @@ def column_type(engine, table_id, col_id): | |||||||
|   if parts[0] == "Ref": |   if parts[0] == "Ref": | ||||||
|     return parts[1] |     return parts[1] | ||||||
|   elif parts[0] == "RefList": |   elif parts[0] == "RefList": | ||||||
|     return "List[{}]".format(parts[1]) |     return "list[{}]".format(parts[1]) | ||||||
|   elif typ == "Choice": |   elif typ == "Choice": | ||||||
|     return choices(col_rec) |     return choices(col_rec) | ||||||
|   elif typ == "ChoiceList": |   elif typ == "ChoiceList": | ||||||
|     return "Tuple[{}, ...]".format(choices(col_rec)) |     return "tuple[{}, ...]".format(choices(col_rec)) | ||||||
|   elif typ == "Any": |   elif typ == "Any": | ||||||
|     table = engine.tables[table_id] |     table = engine.tables[table_id] | ||||||
|     col = table.get_column(col_id) |     col = table.get_column(col_id) | ||||||
| @ -69,15 +69,15 @@ def values_type(values): | |||||||
|   if isinstance(val, records.Record): |   if isinstance(val, records.Record): | ||||||
|     type_name = val._table.table_id |     type_name = val._table.table_id | ||||||
|   elif isinstance(val, records.RecordSet): |   elif isinstance(val, records.RecordSet): | ||||||
|     type_name = "List[{}]".format(val._table.table_id) |     type_name = "list[{}]".format(val._table.table_id) | ||||||
|   elif isinstance(val, list): |   elif isinstance(val, list): | ||||||
|     type_name = "List[{}]".format(values_type(val)) |     type_name = "list[{}]".format(values_type(val)) | ||||||
|   elif isinstance(val, set): |   elif isinstance(val, set): | ||||||
|     type_name = "Set[{}]".format(values_type(val)) |     type_name = "set[{}]".format(values_type(val)) | ||||||
|   elif isinstance(val, tuple): |   elif isinstance(val, tuple): | ||||||
|     type_name = "Tuple[{}, ...]".format(values_type(val)) |     type_name = "tuple[{}, ...]".format(values_type(val)) | ||||||
|   elif isinstance(val, dict): |   elif isinstance(val, dict): | ||||||
|     type_name = "Dict[{}, {}]".format(values_type(val.keys()), values_type(val.values())) |     type_name = "dict[{}, {}]".format(values_type(val.keys()), values_type(val.values())) | ||||||
|   else: |   else: | ||||||
|     type_name = typ.__name__ |     type_name = typ.__name__ | ||||||
| 
 | 
 | ||||||
| @ -115,7 +115,7 @@ def visible_columns(engine, table_id): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def class_schema(engine, table_id, exclude_col_id=None, lookups=False): | def class_schema(engine, table_id, exclude_col_id=None, lookups=False): | ||||||
|   result = "@dataclass\nclass {}:\n".format(table_id) |   result = "class {}:\n".format(table_id) | ||||||
| 
 | 
 | ||||||
|   if lookups: |   if lookups: | ||||||
| 
 | 
 | ||||||
| @ -135,8 +135,8 @@ def class_schema(engine, table_id, exclude_col_id=None, lookups=False): | |||||||
|     result += "    def __len__(self):\n" |     result += "    def __len__(self):\n" | ||||||
|     result += "        return len(%s.lookupRecords())\n" % table_id |     result += "        return len(%s.lookupRecords())\n" % table_id | ||||||
|     result += "    @staticmethod\n" |     result += "    @staticmethod\n" | ||||||
|     result += "    def lookupRecords(%s) -> List[%s]:\n" % (lookupOne_args_line, table_id) |     result += "    def lookupRecords(%s) -> list[%s]:\n" % (lookupOne_args_line, table_id) | ||||||
|     result += "       # ...\n" |     result += "       ...\n" | ||||||
|     result += "    @staticmethod\n" |     result += "    @staticmethod\n" | ||||||
|     result += "    def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id) |     result += "    def lookupOne(%s) -> %s:\n" % (lookupOne_args_line, table_id) | ||||||
|     result += "       '''\n" |     result += "       '''\n" | ||||||
| @ -165,10 +165,7 @@ def get_formula_prompt(engine, table_id, col_id, _description, | |||||||
|   result += class_schema(engine, table_id, col_id, lookups) |   result += class_schema(engine, table_id, col_id, lookups) | ||||||
| 
 | 
 | ||||||
|   return_type = column_type(engine, table_id, col_id) |   return_type = column_type(engine, table_id, col_id) | ||||||
|   result += "    @property\n" |   result += "def {}(rec: {}) -> {}:\n".format(col_id, table_id, return_type) | ||||||
|   result += "    # rec is alias for self\n" |  | ||||||
|   result += "    def {}(rec) -> {}:\n".format(col_id, return_type) |  | ||||||
|   result += "        # Please fill in code only after this line, not the `def`\n" |  | ||||||
|   return result |   return result | ||||||
| 
 | 
 | ||||||
| def indent(text, prefix, predicate=None): | def indent(text, prefix, predicate=None): | ||||||
| @ -213,7 +210,10 @@ def convert_completion(completion): | |||||||
|   while stmts and isinstance(stmts[0], (ast.Import, ast.ImportFrom)): |   while stmts and isinstance(stmts[0], (ast.Import, ast.ImportFrom)): | ||||||
|     imports += atok.get_text(stmts.pop(0)) + "\n" |     imports += atok.get_text(stmts.pop(0)) + "\n" | ||||||
| 
 | 
 | ||||||
|   # If the non-import code consists only of a function definition, extract the body. |   # Sometimes the model repeats the provided classes, remove them. | ||||||
|  |   stmts = [stmt for stmt in stmts if not isinstance(stmt, ast.ClassDef)] | ||||||
|  | 
 | ||||||
|  |   # If the remaining code consists only of a function definition, extract the body. | ||||||
|   if len(stmts) == 1 and isinstance(stmts[0], ast.FunctionDef): |   if len(stmts) == 1 and isinstance(stmts[0], ast.FunctionDef): | ||||||
|     func_body_stmts = stmts[0].body |     func_body_stmts = stmts[0].body | ||||||
|     if ( |     if ( | ||||||
|  | |||||||
| @ -58,29 +58,29 @@ class TestFormulaPrompt(test_engine.EngineTestCase): | |||||||
|     self.assertEqual(values_type([ |     self.assertEqual(values_type([ | ||||||
|       fake_table.RecordSet(None), |       fake_table.RecordSet(None), | ||||||
|       fake_table.RecordSet(None), |       fake_table.RecordSet(None), | ||||||
|     ]), "List[Table1]") |     ]), "list[Table1]") | ||||||
|     self.assertEqual(values_type([ |     self.assertEqual(values_type([ | ||||||
|       fake_table.RecordSet(None), |       fake_table.RecordSet(None), | ||||||
|       fake_table.RecordSet(None), |       fake_table.RecordSet(None), | ||||||
|       None, |       None, | ||||||
|     ]), "Optional[List[Table1]]") |     ]), "Optional[list[Table1]]") | ||||||
| 
 | 
 | ||||||
|     self.assertEqual(values_type([[1, 2, 3]]), "List[int]") |     self.assertEqual(values_type([[1, 2, 3]]), "list[int]") | ||||||
|     self.assertEqual(values_type([[1, 2, 3], None]), "Optional[List[int]]") |     self.assertEqual(values_type([[1, 2, 3], None]), "Optional[list[int]]") | ||||||
|     self.assertEqual(values_type([[1, 2, None]]), "List[Optional[int]]") |     self.assertEqual(values_type([[1, 2, None]]), "list[Optional[int]]") | ||||||
|     self.assertEqual(values_type([[1, 2, None], None]), "Optional[List[Optional[int]]]") |     self.assertEqual(values_type([[1, 2, None], None]), "Optional[list[Optional[int]]]") | ||||||
|     self.assertEqual(values_type([[1, 2, "3"]]), "List[Any]") |     self.assertEqual(values_type([[1, 2, "3"]]), "list[Any]") | ||||||
| 
 | 
 | ||||||
|     self.assertEqual(values_type([{1, 2, 3}]), "Set[int]") |     self.assertEqual(values_type([{1, 2, 3}]), "set[int]") | ||||||
|     self.assertEqual(values_type([(1, 2, 3)]), "Tuple[int, ...]") |     self.assertEqual(values_type([(1, 2, 3)]), "tuple[int, ...]") | ||||||
|     self.assertEqual(values_type([{1: ["2"]}]), "Dict[int, List[str]]") |     self.assertEqual(values_type([{1: ["2"]}]), "dict[int, list[str]]") | ||||||
| 
 | 
 | ||||||
|   def assert_column_type(self, col_id, expected_type): |   def assert_column_type(self, col_id, expected_type): | ||||||
|     self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type) |     self.assertEqual(column_type(self.engine, "Table2", col_id), expected_type) | ||||||
| 
 | 
 | ||||||
|   def assert_prompt(self, table_name, col_id, expected_prompt): |   def assert_prompt(self, table_name, col_id, expected_prompt, lookups=False): | ||||||
|     prompt = get_formula_prompt(self.engine, table_name, col_id, "description here", |     prompt = get_formula_prompt(self.engine, table_name, col_id, "description here", | ||||||
|                                 include_all_tables=False, lookups=False) |                                 include_all_tables=False, lookups=lookups) | ||||||
|     # print(prompt) |     # print(prompt) | ||||||
|     self.assertEqual(prompt, expected_prompt) |     self.assertEqual(prompt, expected_prompt) | ||||||
| 
 | 
 | ||||||
| @ -122,9 +122,9 @@ class TestFormulaPrompt(test_engine.EngineTestCase): | |||||||
|     self.assert_column_type("datetime", "datetime.datetime") |     self.assert_column_type("datetime", "datetime.datetime") | ||||||
|     self.assert_column_type("attachments", "Any") |     self.assert_column_type("attachments", "Any") | ||||||
|     self.assert_column_type("ref", "Table2") |     self.assert_column_type("ref", "Table2") | ||||||
|     self.assert_column_type("reflist", "List[Table2]") |     self.assert_column_type("reflist", "list[Table2]") | ||||||
|     self.assert_column_type("choice", "Literal['a', 'b', 'c']") |     self.assert_column_type("choice", "Literal['a', 'b', 'c']") | ||||||
|     self.assert_column_type("choicelist", "Tuple[Literal['x', 'y', 'z'], ...]") |     self.assert_column_type("choicelist", "tuple[Literal['x', 'y', 'z'], ...]") | ||||||
|     self.assert_column_type("ref_formula", "Optional[Table2]") |     self.assert_column_type("ref_formula", "Optional[Table2]") | ||||||
|     self.assert_column_type("numeric_formula", "float") |     self.assert_column_type("numeric_formula", "float") | ||||||
| 
 | 
 | ||||||
| @ -132,7 +132,6 @@ class TestFormulaPrompt(test_engine.EngineTestCase): | |||||||
| 
 | 
 | ||||||
|     self.assert_prompt("Table2", "new_formula", |     self.assert_prompt("Table2", "new_formula", | ||||||
|       '''\ |       '''\ | ||||||
| @dataclass |  | ||||||
| class Table2: | class Table2: | ||||||
|     text: str |     text: str | ||||||
|     numeric: float |     numeric: float | ||||||
| @ -142,16 +141,13 @@ class Table2: | |||||||
|     datetime: datetime.datetime |     datetime: datetime.datetime | ||||||
|     attachments: Any |     attachments: Any | ||||||
|     ref: Table2 |     ref: Table2 | ||||||
|     reflist: List[Table2] |     reflist: list[Table2] | ||||||
|     choice: Literal['a', 'b', 'c'] |     choice: Literal['a', 'b', 'c'] | ||||||
|     choicelist: Tuple[Literal['x', 'y', 'z'], ...] |     choicelist: tuple[Literal['x', 'y', 'z'], ...] | ||||||
|     ref_formula: Optional[Table2] |     ref_formula: Optional[Table2] | ||||||
|     numeric_formula: float |     numeric_formula: float | ||||||
| 
 | 
 | ||||||
|     @property | def new_formula(rec: Table2) -> float: | ||||||
|     # rec is alias for self |  | ||||||
|     def new_formula(rec) -> float: |  | ||||||
|         # Please fill in code only after this line, not the `def` |  | ||||||
| ''') | ''') | ||||||
| 
 | 
 | ||||||
|   def test_get_formula_prompt(self): |   def test_get_formula_prompt(self): | ||||||
| @ -175,45 +171,52 @@ class Table2: | |||||||
|     self.assertEqual(referenced_tables(self.engine, "Table1"), set()) |     self.assertEqual(referenced_tables(self.engine, "Table1"), set()) | ||||||
| 
 | 
 | ||||||
|     self.assert_prompt("Table1", "text", '''\ |     self.assert_prompt("Table1", "text", '''\ | ||||||
| @dataclass |  | ||||||
| class Table1: | class Table1: | ||||||
| 
 | 
 | ||||||
|     @property | def text(rec: Table1) -> str: | ||||||
|     # rec is alias for self |  | ||||||
|     def text(rec) -> str: |  | ||||||
|         # Please fill in code only after this line, not the `def` |  | ||||||
| ''') | ''') | ||||||
| 
 | 
 | ||||||
|  |     # Test the same thing but include the lookup methods as in a real case, | ||||||
|  |     # just to show that the table class would never actually be empty | ||||||
|  |     # (which would be invalid Python and might confuse the model). | ||||||
|  |     self.assert_prompt("Table1", "text", """\ | ||||||
|  | class Table1: | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(Table1.lookupRecords()) | ||||||
|  |     @staticmethod | ||||||
|  |     def lookupRecords(sort_by=None) -> list[Table1]: | ||||||
|  |        ... | ||||||
|  |     @staticmethod | ||||||
|  |     def lookupOne(sort_by=None) -> Table1: | ||||||
|  |        ''' | ||||||
|  |        Filter for one result matching the keys provided. | ||||||
|  |        To control order, use e.g. `sort_by='Key' or `sort_by='-Key'`. | ||||||
|  |        ''' | ||||||
|  |        return Table1.lookupRecords(sort_by=sort_by)[0] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def text(rec: Table1) -> str: | ||||||
|  | """, lookups=True) | ||||||
|  | 
 | ||||||
|     self.assert_prompt("Table2", "ref", '''\ |     self.assert_prompt("Table2", "ref", '''\ | ||||||
| @dataclass |  | ||||||
| class Table1: | class Table1: | ||||||
|     text: str |     text: str | ||||||
| 
 | 
 | ||||||
| @dataclass |  | ||||||
| class Table2: | class Table2: | ||||||
| 
 | 
 | ||||||
|     @property | def ref(rec: Table2) -> Table1: | ||||||
|     # rec is alias for self |  | ||||||
|     def ref(rec) -> Table1: |  | ||||||
|         # Please fill in code only after this line, not the `def` |  | ||||||
| ''') | ''') | ||||||
| 
 | 
 | ||||||
|     self.assert_prompt("Table3", "reflist", '''\ |     self.assert_prompt("Table3", "reflist", '''\ | ||||||
| @dataclass |  | ||||||
| class Table1: | class Table1: | ||||||
|     text: str |     text: str | ||||||
| 
 | 
 | ||||||
| @dataclass |  | ||||||
| class Table2: | class Table2: | ||||||
|     ref: Table1 |     ref: Table1 | ||||||
| 
 | 
 | ||||||
| @dataclass |  | ||||||
| class Table3: | class Table3: | ||||||
| 
 | 
 | ||||||
|     @property | def reflist(rec: Table3) -> list[Table2]: | ||||||
|     # rec is alias for self |  | ||||||
|     def reflist(rec) -> List[Table2]: |  | ||||||
|         # Please fill in code only after this line, not the `def` |  | ||||||
| ''') | ''') | ||||||
| 
 | 
 | ||||||
|   def test_convert_completion(self): |   def test_convert_completion(self): | ||||||
| @ -227,6 +230,9 @@ from x import ( | |||||||
|   z, |   z, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | class Foo: | ||||||
|  |     bar: Bar | ||||||
|  | 
 | ||||||
| @property | @property | ||||||
| def foo(rec): | def foo(rec): | ||||||
|     '''This is a docstring''' |     '''This is a docstring''' | ||||||
|  | |||||||
| @ -1,4 +1,5 @@ | |||||||
| no_formula,table_id,col_id,doc_id,Description | no_formula,table_id,col_id,doc_id,Description | ||||||
|  | 0,Encrypt,Encrypted,n2se5cBJty1GyWdougSD2T,"Encrypt with a simple Caeser cipher: Convert all letters to uppercase, then circular shift them forward by 6 positions. Leave all other characters unchanged. For example, 'abc xyz!' becomes 'GHI DEF!'. Use the `string` module." | ||||||
| 0,Contacts,Send_Email,hQHXqAQXceeQBPvRw5sSs1,"Link to compose an email, if there is one" | 0,Contacts,Send_Email,hQHXqAQXceeQBPvRw5sSs1,"Link to compose an email, if there is one" | ||||||
| 0,Contacts,No_Notes,hQHXqAQXceeQBPvRw5sSs1,"Number of notes for this contact" | 0,Contacts,No_Notes,hQHXqAQXceeQBPvRw5sSs1,"Number of notes for this contact" | ||||||
| 0,Category,Contains_archived_project_,hQHXqAQXceeQBPvRw5sSs1,"Whether any projects in this category are archived" | 0,Category,Contains_archived_project_,hQHXqAQXceeQBPvRw5sSs1,"Whether any projects in this category are archived" | ||||||
|  | |||||||
| 
 | 
		Loading…
	
		Reference in New Issue
	
	Block a user