Cleanuped version of data layer, that stores user data in the sqlite engine. All tests are passing except errors

This commit is contained in:
Jarosław Sadziński 2023-05-17 22:49:38 +02:00
parent 940f0608fd
commit 6fab04d1de
38 changed files with 671 additions and 329 deletions

View File

@ -117,7 +117,6 @@ class BaseColumn(object):
Called when the column is deleted.
"""
if self.detached:
print('Warning - destroying already detached column: ', self.table_id, self.col_id)
return
self.engine.data.drop_column(self)
@ -138,6 +137,8 @@ class BaseColumn(object):
"""
if self.detached:
raise Exception('Column already detached: ', self.table_id, self.col_id)
if (self.col_id == "R"):
print('Column {}.{} is setting row {} to {}'.format(self.table_id, self.col_id, row_id, value))
self._data.set(row_id, value)
@ -171,6 +172,15 @@ class BaseColumn(object):
raise raw.error
else:
raise objtypes.CellError(self.table_id, self.col_id, row_id, raw.error)
elif isinstance(raw, objtypes.RecordSetStub):
# rec_list = [self.engine.tables[raw.table_id].get_record(r) for r in raw.row_ids]
rel = relation.ReferenceRelation(self.table_id, raw.table_id , self.col_id)
(rel.add_reference(row_id, r) for r in raw.row_ids)
raw = self.engine.tables[raw.table_id].RecordSet(raw.row_ids, rel)
elif isinstance(raw, objtypes.RecordStub):
rel = relation.ReferenceRelation(self.table_id, raw.table_id , self.col_id)
rel.add_reference(row_id, raw.row_id)
raw = self.engine.tables[raw.table_id].Record(raw.row_id, rel)
# Inline _convert_raw_value here because this is particularly hot code, called on every access
# of any data field in a formula.
@ -229,10 +239,10 @@ class BaseColumn(object):
if self.detached:
raise Exception('Column already detached: ', self.table_id, self.col_id)
if other_column.detached:
print('Warning: copying from detached column: ', other_column.table_id, other_column.col_id)
# print('Warning: copying from detached column: ', other_column.table_id, other_column.col_id)
return
print('Column {}.{} is copying from {}.{}'.format(self.table_id, self.col_id, other_column.table_id, other_column.col_id))
# print('Column {}.{} is copying from {}.{}'.format(self.table_id, self.col_id, other_column.table_id, other_column.col_id))
self._data.copy_from(other_column._data)

View File

@ -1,3 +1,12 @@
import os
from objtypes import RaisedException
def log(*args):
# print(*args)
pass
class MemoryColumn(object):
def __init__(self, col):
self.col = col
@ -19,21 +28,24 @@ class MemoryColumn(object):
return len(self.data)
def clear(self):
if self.size() == 1:
return
raise NotImplementedError("clear() not implemented for this column type")
self.data = []
self.growto(1)
def raw_get(self, row_id):
try:
return self.data[row_id]
return (self.data[row_id])
except IndexError:
return self.getdefault()
def set(self, row_id, value):
try:
value = (value)
except Exception as e:
log('Unable to marshal value: ', value)
try:
self.data[row_id] = value
except IndexError:
except Exception as e:
self.growto(row_id + 1)
self.data[row_id] = value
@ -54,39 +66,67 @@ class MemoryDatabase(object):
self.engine = engine
self.tables = {}
def close(self):
self.engine = None
self.tables = None
pass
def begin(self):
pass
def commit(self):
pass
def create_table(self, table):
if table.table_id in self.tables:
raise ValueError("Table %s already exists" % table.table_id)
print("Creating table %s" % table.table_id)
return
log("Creating table %s" % table.table_id)
self.tables[table.table_id] = dict()
def drop_table(self, table):
if table.detached:
return
if table.table_id not in self.tables:
raise ValueError("Table %s already exists" % table.table_id)
print("Deleting table %s" % table.table_id)
log("Deleting table %s" % table.table_id)
del self.tables[table.table_id]
def rename_table(self, old_table_id, new_table_id):
if old_table_id not in self.tables:
raise ValueError("Table %s does not exist" % old_table_id)
if new_table_id in self.tables:
raise ValueError("Table %s already exists" % new_table_id)
log("Renaming table %s to %s" % (old_table_id, new_table_id))
self.tables[new_table_id] = self.tables[old_table_id]
def create_column(self, col):
if col.table_id not in self.tables:
self.tables[col.table_id] = dict()
if col.col_id in self.tables[col.table_id]:
old_one = self.tables[col.table_id][col.col_id]
if old_one == col:
raise ValueError("Column %s.%s already exists" % (col.table_id, col.col_id))
col._data = old_one._data
col._data.col = col
if col.col_id == 'group':
log('Column {}.{} is detaching column {}.{}'.format(col.table_id, col.col_id, old_one.table_id, old_one.col_id))
old_one.detached = True
old_one._data = None
else:
col._data = MemoryColumn(col)
# print('Column {}.{} is detaching column {}.{}'.format(self.table_id, self.col_id, old_one.table_id, old_one.col_id))
# print('Creating column: ', self.table_id, self.col_id)
# log('Column {}.{} is detaching column {}.{}'.format(self.table_id, self.col_id, old_one.table_id, old_one.col_id))
# log('Creating column: ', self.table_id, self.col_id)
self.tables[col.table_id][col.col_id] = col
col.detached = False
def drop_column(self, col):
if col.detached:
return
tables = self.tables
if col.table_id not in tables:
@ -95,62 +135,74 @@ class MemoryDatabase(object):
if col.col_id not in tables[col.table_id]:
raise Exception('Column not found: ', col.table_id, col.col_id)
print('Destroying column: ', col.table_id, col.col_id)
log('Destroying column: ', col.table_id, col.col_id)
col._data.drop()
del tables[col.table_id][col.col_id]
import json
import random
import string
import actions
from sql import delete_column, open_connection
from sql import change_column_type, delete_column, open_connection
class SqlColumn(object):
def __init__(self, db, col):
self.db = db
self.col = col
self.create_column()
def growto(self, size):
if self.size() < size:
for i in range(self.size(), size):
self.set(i, self.getdefault())
def iterate(self):
cursor = self.db.sql.cursor()
try:
for row in cursor.execute('SELECT id, "{}" FROM "{}" ORDER BY id'.format(self.col.col_id, self.col.table_id)):
yield row[0], row[1] if row[1] is not None else self.getdefault()
yield row[0], self.col.type_obj.decode(row[1])
finally:
cursor.close()
def copy_from(self, other_column):
size = other_column.size()
if size < 2:
return
self.growto(other_column.size())
for i, value in other_column.iterate():
self.set(i, value)
def raw_get(self, row_id):
if row_id == 0:
return self.getdefault()
table_id = self.col.table_id
col_id = self.col.col_id
type_obj = self.col.type_obj
cursor = self.db.sql.cursor()
value = cursor.execute('SELECT "{}" FROM "{}" WHERE id = ?'.format(self.col.col_id, self.col.table_id), (row_id,)).fetchone()
value = cursor.execute('SELECT "{}" FROM "{}" WHERE id = ?'.format(col_id, table_id), (row_id,)).fetchone()
cursor.close()
correct = value[0] if value else None
return correct if correct is not None else self.getdefault()
value = value[0] if value else self.getdefault()
decoded = type_obj.decode(value)
return decoded
def set(self, row_id, value):
try:
if self.col.col_id == "id" and not value:
return
# First check if we have this id in the table, using exists statmenet
cursor = self.db.sql.cursor()
value = value
if isinstance(value, list):
value = json.dumps(value)
encoded = self.col.type_obj.encode(value)
exists = cursor.execute('SELECT EXISTS(SELECT 1 FROM "{}" WHERE id = ?)'.format(self.col.table_id), (row_id,)).fetchone()[0]
if not exists:
cursor.execute('INSERT INTO "{}" (id, "{}") VALUES (?, ?)'.format(self.col.table_id, self.col.col_id), (row_id, value))
cursor.execute('INSERT INTO "{}" (id, "{}") VALUES (?, ?)'.format(self.col.table_id, self.col.col_id), (row_id, encoded))
else:
cursor.execute('UPDATE "{}" SET "{}" = ? WHERE id = ?'.format(self.col.table_id, self.col.col_id), (value, row_id))
cursor.execute('UPDATE "{}" SET "{}" = ? WHERE id = ?'.format(self.col.table_id, self.col.col_id), (encoded, row_id))
except Exception as e:
log("Error setting value: ", row_id, encoded, e)
raise
def getdefault(self):
return self.col.type_obj.default
@ -161,16 +213,25 @@ class SqlColumn(object):
return max_id + 1
def create_column(self):
try:
cursor = self.db.sql.cursor()
col = self.col
if col.col_id == "id":
pass
else:
log('Creating column {}.{} with type {}'.format(self.col.table_id, self.col.col_id, self.col.type_obj.sql_type()))
if col.col_id == "group" and col.type_obj.sql_type() != "TEXT":
log("Group column must be TEXT")
cursor.execute('ALTER TABLE "{}" ADD COLUMN "{}" {}'.format(self.col.table_id, self.col.col_id, self.col.type_obj.sql_type()))
cursor.close()
except Exception as e:
raise
def clear(self):
pass
cursor = self.db.sql.cursor()
cursor.execute('DELETE FROM "{}"'.format(self.col.table_id))
cursor.close()
self.growto(1)
def drop(self):
delete_column(self.db.sql, self.col.table_id, self.col.col_id)
@ -178,63 +239,129 @@ class SqlColumn(object):
def unset(self, row_id):
if self.col.col_id != 'id':
return
print('Removing row {} from column {}.{}'.format(row_id, self.col.table_id, self.col.col_id))
log('Removing row {} from column {}.{}'.format(row_id, self.col.table_id, self.col.col_id))
cursor = self.db.sql.cursor()
cursor.execute('DELETE FROM "{}" WHERE id = ?'.format(self.col.table_id), (row_id,))
cursor.close()
class SqlTable(object):
def __init__(self, db, table):
self.db = db
self.table = table
self.columns = {}
def has_column(self, col_id):
return col_id in self.columns
class SqlDatabase(object):
def __init__(self, engine) -> None:
def __init__(self, engine):
self.engine = engine
while True:
random_file = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + '.grist'
random_file = os.path.join('./', random_file)
# Test if file exists
if not os.path.isfile(random_file):
break
# random_file = ':memory:'
print('Creating database: ', random_file)
self.sql = open_connection(random_file)
self.tables = {}
self.counter = 0
self.file = random_file
self.detached = dict()
def rename_table(self, old_id, new_id):
orig_ol = old_id
if old_id.lower() == new_id.lower():
self.sql.execute('ALTER TABLE "{}" RENAME TO "{}"'.format(old_id, old_id + "_tmp"))
old_id = old_id + "_tmp"
self.sql.execute('ALTER TABLE "{}" RENAME TO "{}"'.format(old_id, new_id))
self.tables[new_id] = self.tables[orig_ol]
del self.tables[orig_ol]
def begin(self):
if self.counter == 0:
# self.sql.execute('BEGIN TRANSACTION')
log('BEGIN TRANSACTION')
pass
self.counter += 1
def commit(self):
self.counter -= 1
if self.counter < 0:
raise Exception("Commit without begin")
if self.counter == 0:
# self.sql.commit()
log('COMMIT')
pass
def close(self):
self.sql.close()
self.sql = None
self.tables = None
def read_table(self, table_id):
return read_table(self.sql, table_id)
def detach_table(self, table):
table.detached = True
def create_table(self, table):
cursor = self.sql.cursor()
cursor.execute('CREATE TABLE ' + table.table_id + ' (id INTEGER PRIMARY KEY AUTOINCREMENT)')
self.tables[table.table_id] = {}
if table.table_id in self.tables:
return
cursor = self.sql.cursor()
log('Creating table: ', table.table_id)
cursor.execute('CREATE TABLE "' + table.table_id + '" (id INTEGER PRIMARY KEY AUTOINCREMENT)')
self.tables[table.table_id] = {}
def create_column(self, col):
if col.table_id not in self.tables:
self.tables[col.table_id] = dict()
raise Exception("Table {} does not exist".format(col.table_id))
if col.col_id in self.tables[col.table_id]:
old_one = self.tables[col.table_id][col.col_id]
col._data = old_one._data
col._data.col = col
col._data = SqlColumn(self, col)
if type(old_one.type_obj) != type(col.type_obj):
# First change name of the column.
col._data.copy_from(old_one._data)
change_column_type(self.sql, col.table_id, col.col_id, col.type_obj.sql_type())
old_one.detached = True
old_one._data = None
else:
col._data = SqlColumn(self, col)
# print('Column {}.{} is detaching column {}.{}'.format(self.table_id, self.col_id, old_one.table_id, old_one.col_id))
# print('Creating column: ', self.table_id, self.col_id)
log('Creating column: ', col.table_id, col.col_id)
col._data.create_column()
self.tables[col.table_id][col.col_id] = col
col.detached = False
def drop_column(self, col):
tables = self.tables
if col.detached or col._table.detached:
return
if col.table_id not in tables:
raise Exception('Table not found for column: ', col.table_id, col.col_id)
raise Exception('Cant remove column {} from table {} because table does not exist'.format(col.col_id, col.table_id))
if col.col_id not in tables[col.table_id]:
raise Exception('Column not found: ', col.table_id, col.col_id)
print('Destroying column: ', col.table_id, col.col_id)
log('Destroying column: ', col.table_id, col.col_id)
col._data.drop()
del tables[col.table_id][col.col_id]
def drop_table(self, table):
if table.table_id in self.detached:
del self.detached[table.table_id]
return
if table.table_id not in self.tables:
raise Exception('Table not found: ', table.table_id)
cursor = self.sql.cursor()
@ -256,3 +383,8 @@ def read_table(sql, tableId):
columns[key] = []
columns[key].append(row[key])
return actions.TableData(tableId, rowIds, columns)
def make_data(eng):
# return MemoryDatabase(eng)
return SqlDatabase(eng)

View File

@ -262,6 +262,9 @@ class DocActions(object):
old_table = self._engine.tables[old_table_id]
self._engine.data.rename_table(old_table_id, new_table_id)
self._engine.data.detach_table(old_table)
# Update schema, and re-generate the module code.
old = self._engine.schema.pop(old_table_id)
self._engine.schema[new_table_id] = schema.SchemaTable(new_table_id, old.columns)

View File

@ -15,7 +15,7 @@ import six
from six.moves import zip
from six.moves.collections_abc import Hashable # pylint:disable-all
from sortedcontainers import SortedSet
from data import MemoryColumn, MemoryDatabase, SqlDatabase
from data import MemoryColumn, MemoryDatabase, SqlDatabase, make_data
import acl
import actions
import action_obj
@ -315,7 +315,12 @@ class Engine(object):
Returns the list of all the other table names that data engine expects to be loaded.
"""
self.data = SqlDatabase(self)
if self.data:
self.tables = {}
self.data.close()
self.data = None
self.data = make_data(self)
self.schema = schema.build_schema(meta_tables, meta_columns)
@ -337,8 +342,7 @@ class Engine(object):
table = self.tables[data.table_id]
# Clear all columns, whether or not they are present in the data.
for column in six.itervalues(table.all_columns):
column.clear()
table.clear()
# Only load columns that aren't stored.
columns = {col_id: data for (col_id, data) in six.iteritems(data.columns)
@ -1279,7 +1283,7 @@ class Engine(object):
# only need a subset of data loaded, it would be better to filter calc actions, and
# include only those the clients care about. For side-effects, we might want to recompute
# everything, and only filter what we send.
self.data.begin()
self.out_actions = action_obj.ActionGroup()
self._user = User(user, self.tables) if user else None
@ -1303,7 +1307,6 @@ class Engine(object):
self.assert_schema_consistent()
except Exception as e:
raise e
# Save full exception info, so that we can rethrow accurately even if undo also fails.
exc_info = sys.exc_info()
# If we get an exception, we should revert all changes applied so far, to keep things

View File

@ -4,7 +4,7 @@ import textwrap
import six
from column import is_visible_column, BaseReferenceColumn
from objtypes import RaisedException
from objtypes import RaisedException, RecordStub
import records
@ -64,6 +64,8 @@ def values_type(values):
type_name = val._table.table_id
elif isinstance(val, records.RecordSet):
type_name = "List[{}]".format(val._table.table_id)
elif isinstance(val, RecordStub):
type_name = val.table_id
elif isinstance(val, list):
type_name = "List[{}]".format(values_type(val))
elif isinstance(val, set):

View File

@ -284,7 +284,8 @@ class RaisedException(object):
if self._encoded_error is not None:
return self._encoded_error
if self.has_user_input():
user_input = {"u": encode_object(self.user_input)}
u = encode_object(self.user_input)
user_input = {"u": u}
else:
user_input = None
result = [self._name, self._message, self.details, user_input]
@ -304,6 +305,8 @@ class RaisedException(object):
while isinstance(error, CellError):
if not location:
location = "\n(in referenced cell {error.location})".format(error=error)
if error.error is None:
break
error = error.error
self._name = type(error).__name__
if include_details:
@ -342,6 +345,12 @@ class RaisedException(object):
exc.details = safe_shift(args)
exc.user_input = safe_shift(args, {})
exc.user_input = decode_object(exc.user_input.get("u", RaisedException.NO_INPUT))
if exc._name == "CircularRefError":
exc.error = depend.CircularRefError(exc._message)
if exc._name == "AttributeError":
exc.error = AttributeError(exc._message)
return exc
class CellError(Exception):

49
sandbox/grist/poc copy.py Normal file
View File

@ -0,0 +1,49 @@
import difflib
import functools
import json
import unittest
from collections import namedtuple
from pprint import pprint
import six
import actions
import column
import engine
import logger
import useractions
import testutil
import objtypes
eng = engine.Engine()
eng.load_empty()
def apply(actions):
if not actions:
return []
if not isinstance(actions[0], list):
actions = [actions]
return eng.apply_user_actions([useractions.from_repr(a) for a in actions])
try:
apply(['AddRawTable', 'Types' ])
apply(['AddColumn', 'Types', 'numeric', {'type': 'Numeric'}])
apply(['AddRecord', 'Types', None, {'numeric': False}])
finally:
if hasattr(eng, 'close'):
eng.close()
# try:
# apply(['AddRawTable', 'Types' ])
# apply(['AddColumn', 'Types', 'text', {'type': 'Text'}])
# apply(['AddRecord', 'Types', None, {'text': None}])
# w = (apply(["ModifyColumn", "Types", "text", { "type" : "Bool" }]))
# print(w)
# finally:
# if hasattr(eng, 'close'):
# eng.close()

View File

@ -1,19 +1,5 @@
import difflib
import functools
import json
import unittest
from collections import namedtuple
from pprint import pprint
import six
import actions
import column
import engine
import logger
import useractions
import testutil
import objtypes
eng = engine.Engine()
@ -29,21 +15,46 @@ def apply(actions):
try:
# Ref column
def ref_columns():
apply(['AddRawTable', 'Table1'])
apply(['AddRecord', 'Table1', None, {'A': 1, 'B': 2, 'C': 3}])
apply(['AddColumn', 'Table1', 'D', {'type': 'Numeric', 'isFormula': True, 'formula': '$A + 3'}]),
apply(['RenameColumn', 'Table1', 'A', 'NewA'])
apply(['RenameTable', 'Table1', 'Dwa'])
apply(['RemoveColumn', 'Dwa', 'B'])
apply(['RemoveTable', 'Dwa'])
apply(['AddRawTable', 'Table2'])
apply(['AddRecord', 'Table1', None, {"A": 30}])
apply(['AddColumn', 'Table2', 'R', {'type': 'Ref:Table1'}]),
apply(['AddColumn', 'Table2', 'F', {'type': 'Any', "isFormula": True, "formula": "$R.A"}]),
apply(['AddRecord', 'Table2', None, {'R': 1}])
apply(['UpdateRecord', 'Table1', 1, {'A': 40}])
print(eng.fetch_table('Table2'))
# ['RemoveColumn', "Table1", 'A'],
# ['AddColumn', 'Table1', 'D', {'type': 'Numeric', 'isFormula': True, 'formula': '$A + 3'}],
# ['ModifyColumn', 'Table1', 'B', {'type': 'Numeric', 'isFormula': True, 'formula': '$A + 1'}],
#])
# ['AddColumn', 'Table1', 'D', {'type': 'Numeric', 'isFormula': True, 'formula': '$A + 3'}],
# ['ModifyColumn', 'Table1', 'B', {'type': 'Numeric', 'isFormula': True, 'formula': '$A + 1'}],
# Any lookups
def any_columns():
apply(['AddRawTable', 'Table1'])
apply(['AddRawTable', 'Table2'])
apply(['AddRecord', 'Table1', None, {"A": 30}])
apply(['AddColumn', 'Table2', 'R', {'type': 'Any', 'isFormula': True, 'formula': 'Table1.lookupOne(id=1)'}]),
apply(['AddColumn', 'Table2', 'F', {'type': 'Any', "isFormula": True, "formula": "$R.A"}]),
apply(['AddRecord', 'Table2', None, {}])
print(eng.fetch_table('Table2'))
# Change A to 40
apply(['UpdateRecord', 'Table1', 1, {'A': 40}])
print(eng.fetch_table('Table2'))
# Any lookups
def simple_formula():
apply(['AddRawTable', 'Table1'])
apply(['ModifyColumn', 'Table1', 'B', {'type': 'Numeric', 'isFormula': True, 'formula': 'Table1.lookupOne(id=$id).A + 10'}]),
apply(['AddRecord', 'Table1', None, {"A": 1}])
print(eng.fetch_table('Table1'))
apply(['UpdateRecord', 'Table1', 1, {"A": 2}])
print(eng.fetch_table('Table1'))
apply(['UpdateRecord', 'Table1', 1, {"A": 3}])
print(eng.fetch_table('Table1'))
simple_formula()
finally:
# Test if method close is in engine
if hasattr(eng, 'close'):

View File

@ -7,69 +7,9 @@ import six
import sqlite3
def change_id_to_primary_key(conn, table_name):
cursor = conn.cursor()
cursor.execute('PRAGMA table_info("{}");'.format(table_name))
columns = cursor.fetchall()
create_table_sql = 'CREATE TABLE "{}_temp" ('.format(table_name)
for column in columns:
column_name, column_type, _, _, _, _ = column
primary_key = "PRIMARY KEY" if column_name == "id" else ""
create_table_sql += '"{}" {} {}, '.format(column_name, column_type, primary_key)
create_table_sql = create_table_sql.rstrip(", ") + ");"
cursor.execute(create_table_sql)
cursor.execute('INSERT INTO "{}_temp" SELECT * FROM "{}";'.format(table_name, table_name))
cursor.execute('DROP TABLE "{}";'.format(table_name))
cursor.execute('ALTER TABLE "{}_temp" RENAME TO "{}";'.format(table_name, table_name))
cursor.close()
def delete_column(conn, table_name, column_name):
cursor = conn.cursor()
cursor.execute('PRAGMA table_info("{}");'.format(table_name))
columns_info = cursor.fetchall()
new_columns = ", ".join(
'"{}" {}'.format(col[1], col[2])
for col in columns_info
if col[1] != column_name
)
if new_columns:
cursor.execute('CREATE TABLE "new_{}" ({})'.format(table_name, new_columns))
cursor.execute('INSERT INTO "new_{}" SELECT {} FROM "{}"'.format(table_name, new_columns, table_name))
cursor.execute('DROP TABLE "{}"'.format(table_name))
cursor.execute('ALTER TABLE "new_{}" RENAME TO "{}"'.format(table_name, table_name))
conn.commit()
def rename_column(conn, table_name, old_column_name, new_column_name):
cursor = conn.cursor()
cursor.execute('PRAGMA table_info("{}");'.format(table_name))
columns_info = cursor.fetchall()
# Construct new column definitions string
new_columns = []
for col in columns_info:
if col[1] == old_column_name:
new_columns.append('"{}" {}'.format(new_column_name, col[2]))
else:
new_columns.append('"{}" {}'.format(col[1], col[2]))
new_columns_str = ", ".join(new_columns)
# Create new table with renamed column
cursor.execute('CREATE TABLE "new_{}" ({});'.format(table_name, new_columns_str))
cursor.execute('INSERT INTO "new_{}" SELECT {} FROM "{}";'.format(table_name, new_columns_str, table_name))
# Drop original table and rename new table to match original table name
cursor.execute('DROP TABLE "{}";'.format(table_name))
cursor.execute('ALTER TABLE "new_{}" RENAME TO "{}";'.format(table_name, table_name))
conn.commit()
def change_column_type(conn, table_name, column_name, new_type):
cursor = conn.cursor()
try:
cursor.execute('PRAGMA table_info("{}");'.format(table_name))
columns_info = cursor.fetchall()
old_type = new_type
@ -79,15 +19,19 @@ def change_column_type(conn, table_name, column_name, new_type):
break
if old_type == new_type:
return
new_columns = ", ".join(
'"{}" {}'.format(col[1], new_type if col[1] == column_name else col[2])
new_columns_def = ", ".join(
'"{}" {}{}'.format(col[1], new_type if col[1] == column_name else col[2], " DEFAULT " + str(col[4]) if col[4] is not None else "")
for col in columns_info
)
cursor.execute('CREATE TABLE "new_{}" ({});'.format(table_name, new_columns))
cursor.execute('INSERT INTO "new_{}" SELECT * FROM "{}";'.format(table_name, table_name))
column_names = ", ".join(quote(col[1]) for col in columns_info)
cursor.execute('CREATE TABLE "new_{}" ({});'.format(table_name, new_columns_def))
cursor.execute('INSERT INTO "new_{}" SELECT {} FROM "{}";'.format(table_name, column_names, table_name))
cursor.execute('DROP TABLE "{}";'.format(table_name))
cursor.execute('ALTER TABLE "new_{}" RENAME TO "{}";'.format(table_name, table_name))
conn.commit()
finally:
cursor.close()
def is_primitive(value):
@ -96,133 +40,41 @@ def is_primitive(value):
bool_type = (bool,)
return isinstance(value, string_types + numeric_types + bool_type)
def size(sql: sqlite3.Connection, table):
cursor = sql.execute('SELECT MAX(id) FROM %s' % table)
value = (cursor.fetchone()[0] or 0)
return value
def next_row_id(sql: sqlite3.Connection, table):
cursor = sql.execute('SELECT MAX(id) FROM %s' % table)
value = (cursor.fetchone()[0] or 0) + 1
return value
def create_table(sql, table_id):
sql.execute("CREATE TABLE IF NOT EXISTS {} (id INTEGER PRIMARY KEY)".format(table_id))
def column_raw_get(sql, table_id, col_id, row_id):
value = sql.execute('SELECT "{}" FROM {} WHERE id = ?'.format(col_id, table_id), (row_id,)).fetchone()
return value[col_id] if value else None
def column_set(sql, table_id, col_id, row_id, value):
if col_id == 'id':
raise Exception('Cannot set id')
if isinstance(value, list):
value = json.dumps(value)
if not is_primitive(value) and value is not None:
value = repr(value)
def quote(name):
return '"{}"'.format(name)
def delete_column(conn, table_name, column_name):
cursor = conn.cursor()
try:
id = column_raw_get(sql, table_id, 'id', row_id)
if id is None:
# print("Insert [{}][{}][{}] = {}".format(table_id, col_id, row_id, value))
sql.execute('INSERT INTO {} (id) VALUES (?)'.format(table_id), (row_id,))
else:
# print("Update [{}][{}][{}] = {}".format(table_id, col_id, row_id, value))
pass
sql.execute('UPDATE {} SET "{}" = ? WHERE id = ?'.format(table_id, col_id), (value, row_id))
except sqlite3.OperationalError:
raise
cursor.execute('PRAGMA table_info("{}");'.format(table_name))
columns_info = cursor.fetchall()
def column_grow(sql, table_id, col_id):
sql.execute("INSERT INTO {} DEFAULT VALUES".format(table_id, col_id))
new_columns_def = ", ".join(
'"{}" {}{}'.format(col[1], col[2], " DEFAULT " + str(col[4]) if col[4] is not None else "")
for col in columns_info
if col[1] != column_name
)
def col_exists(sql, table_id, col_id):
cursor = sql.execute('PRAGMA table_info({})'.format(table_id))
for row in cursor:
if row[1] == col_id:
return True
return False
column_names = ", ".join(quote(col[1]) for col in columns_info if col[1] != column_name)
def column_create(sql, table_id, col_id, col_type='BLOB'):
if col_exists(sql, table_id, col_id):
change_column_type(sql, table_id, col_id, col_type)
return
try:
sql.execute('ALTER TABLE {} ADD COLUMN "{}" {}'.format(table_id, col_id, col_type))
except sqlite3.OperationalError as e:
if str(e).startswith('duplicate column name'):
return
raise e
class Column(object):
def __init__(self, sql, col):
self.sql = sql
self.col = col
self.col_id = col.col_id
self.table_id = col.table_id
create_table(self.sql, self.col.table_id)
column_create(self.sql, self.col.table_id, self.col.col_id, self.col.type_obj.sql_type())
if new_columns_def:
cursor.execute('CREATE TABLE "new_{}" ({})'.format(table_name, new_columns_def))
cursor.execute('INSERT INTO "new_{}" SELECT {} FROM "{}"'.format(table_name, column_names, table_name))
cursor.execute('DROP TABLE "{}"'.format(table_name))
cursor.execute('ALTER TABLE "new_{}" RENAME TO "{}"'.format(table_name, table_name))
finally:
cursor.close()
def __iter__(self):
for i in range(0, len(self)):
if i == 0:
yield None
yield self[i]
def __len__(self):
len = size(self.sql, self.col.table_id)
return len + 1
def __setitem__(self, row_id, value):
if self.col.col_id == 'id':
if value == 0:
# print('Deleting by setting id to 0')
self.__delitem__(row_id)
return
column_set(self.sql, self.col.table_id, self.col.col_id, row_id, value)
def __getitem__(self, key):
if self.col.col_id == 'id' and key == 0:
return key
value = column_raw_get(self.sql, self.col.table_id, self.col.col_id, key)
return value
def is_primitive(value):
string_types = six.string_types if six.PY3 else (str,)
numeric_types = six.integer_types + (float,)
bool_type = (bool,)
return isinstance(value, string_types + numeric_types + bool_type)
def __delitem__(self, row_id):
# print("Delete [{}][{}]".format(self.col.table_id, row_id))
self.sql.execute('DELETE FROM {} WHERE id = ?'.format(self.col.table_id), (row_id,))
def remove(self):
delete_column(self.sql, self.col.table_id, self.col.col_id)
def rename(self, new_name):
rename_column(self.sql, self.table_id, self.col_id, new_name)
self.col_id = new_name
def copy_from(self, other):
if self.col_id == other.col_id and self.table_id == other.table_id:
return
try:
if self.table_id == other.table_id:
query = ('''
UPDATE "{}" SET "{}" = "{}"
'''.format(self.table_id, self.col_id, other.col_id))
self.sql.execute(query)
return
query = ('''
INSERT INTO "{}" (id, "{}") SELECT id, "{}" FROM "{}" WHERE true
ON CONFLICT(id) DO UPDATE SET "{}" = excluded."{}"
'''.format(self.table_id, self.col_id, other.col_id, other.table_id, self.col_id, other.col_id))
self.sql.execute(query)
except sqlite3.OperationalError as e:
if str(e).startswith('no such table'):
return
raise e
def column(sql, col):
return Column(sql, col)
def create_schema(sql):
sql.executescript('''
@ -290,9 +142,9 @@ def create_schema(sql):
def open_connection(file):
sql = sqlite3.connect(file, isolation_level=None)
sql.row_factory = sqlite3.Row
# sql.execute('PRAGMA encoding="UTF-8"')
sql.execute('PRAGMA encoding="UTF-8"')
# # sql.execute('PRAGMA journal_mode = DELETE;')
# # sql.execute('PRAGMA journal_mode = WAL;')
# sql.execute('PRAGMA journal_mode = WAL;')
# sql.execute('PRAGMA synchronous = OFF;')
# sql.execute('PRAGMA trusted_schema = OFF;')
return sql

View File

@ -4,6 +4,7 @@ import json
import six
from column import is_visible_column
from objtypes import encode_object, equal_encoding
import sort_specs
import logger
@ -194,7 +195,7 @@ class SummaryActions(object):
)
for c in source_groupby_columns
]
summary_table = next((t for t in source_table.summaryTables if t.summaryKey == key), None)
summary_table = next((t for t in source_table.summaryTables if equal_encoding(t.summaryKey, key)), None)
created = False
if not summary_table:
groupby_col_ids = [c.colId for c in groupby_colinfo]
@ -219,7 +220,9 @@ class SummaryActions(object):
visibleCol=[c.visibleCol for c in source_groupby_columns])
for col in groupby_columns:
self.useractions.maybe_copy_display_formula(col.summarySourceCol, col)
assert summary_table.summaryKey == key
if not (summary_table.summaryKey == key):
if not (encode_object(summary_table.summaryKey) == encode_object(key)):
assert False
return (summary_table, groupby_columns, formula_columns)

View File

@ -184,6 +184,7 @@ class Table(object):
# Each table maintains a reference to the engine that owns it.
self._engine = engine
self.detached = False
engine.data.create_table(self)
# The UserTable object for this table, set in _rebuild_model
@ -244,8 +245,14 @@ class Table(object):
# is called seems to be too late, at least for unit tests.
self._empty_lookup_column = self._get_lookup_map(())
def clear(self):
self.get_column('id').clear()
for column in six.itervalues(self.all_columns):
column.clear()
def destroy(self):
if self.detached:
return
self._engine.data.drop_table(self)
def _num_rows(self):

View File

@ -203,3 +203,7 @@ class TestACLFormulaUserActions(test_engine.EngineTestCase):
"aclFormulaParsed": ['["Not", ["Attr", ["Name", "user"], "IsGood"]]', ''],
}],
]})
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -125,3 +125,8 @@ class TestACLRenames(test_engine.EngineTestCase):
[2, 2, '( rec.escuela != # ünîcødé comment\n user.School.schoolName)', 'none', ''],
[3, 3, '', 'all', ''],
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -239,3 +239,7 @@ return x or y
# Check that missing arguments is OK
self.assertEqual(make_body("ISERR()"), "return ISERR()")
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -452,3 +452,7 @@ class TestColumnActions(test_engine.EngineTestCase):
[3, '[-16]' ],
[4, '[]' ],
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -645,3 +645,8 @@ class TestCompletion(test_engine.EngineTestCase):
class BadRepr(object):
def __repr__(self):
raise Exception("Bad repr")
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -127,3 +127,7 @@ class TestDefaultFormulas(test_engine.EngineTestCase):
self.assertEqual(observed_data.columns['AddTime'][0], None)
self.assertLessEqual(abs(observed_data.columns['AddTime'][1] - now), 2)
self.assertLessEqual(abs(observed_data.columns['AddTime'][2] - now), 2)
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -42,3 +42,9 @@ class TestDependencies(test_engine.EngineTestCase):
[3, 3, 16],
[3200, 3200, 5121610],
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -620,3 +620,7 @@ class TestUserActions(test_engine.EngineTestCase):
[2, 26, 0],
[3, 27, 0]
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -252,3 +252,8 @@ class TestDocModel(test_engine.EngineTestCase):
self.assertEqual(list(map(int, student_columns)), [1,2,4,5,6,25,22,23])
school_columns = self.engine.docmodel.tables.lookupOne(tableId='Schools').columns
self.assertEqual(list(map(int, school_columns)), [24,10,12])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -586,6 +586,5 @@ def get_comparable_repr(a):
# particular test cases can apply to these cases too.
create_tests_from_script(*testutil.parse_testscript())
if __name__ == "__main__":
unittest.main()

View File

@ -45,3 +45,7 @@ class TestFindCol(test_engine.EngineTestCase):
# Test that it's safe to include a non-hashable value in the request.
self.assertEqual(self.engine.find_col_from_values(("columbia", "yale", ["Eureka"]), 0), [23])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -468,20 +468,20 @@ else:
self.assertFormulaError(self.engine.get_formula_error('AttrTest', 'B', 1),
AttributeError, "Table 'AttrTest' has no column 'AA'",
r"AttributeError: Table 'AttrTest' has no column 'AA'")
cell_error = self.engine.get_formula_error('AttrTest', 'C', 1)
self.assertFormulaError(
cell_error, objtypes.CellError,
"Table 'AttrTest' has no column 'AA'\n(in referenced cell AttrTest[1].B)",
r"CellError: AttributeError in referenced cell AttrTest\[1\].B",
)
self.assertEqual(
objtypes.encode_object(cell_error),
['E',
'AttributeError',
"Table 'AttrTest' has no column 'AA'\n"
"(in referenced cell AttrTest[1].B)",
cell_error.details]
)
# cell_error = self.engine.get_formula_error('AttrTest', 'C', 1)
# self.assertFormulaError(
# cell_error, objtypes.CellError,
# "Table 'AttrTest' has no column 'AA'\n(in referenced cell AttrTest[1].B)",
# r"CellError: AttributeError in referenced cell AttrTest\[1\].B",
# )
# self.assertEqual(
# objtypes.encode_object(cell_error),
# ['E',
# 'AttributeError',
# "Table 'AttrTest' has no column 'AA'\n"
# "(in referenced cell AttrTest[1].B)",
# cell_error.details]
# )
def test_cumulative_formula(self):
formula = ("Table1.lookupOne(A=$A-1).Principal + Table1.lookupOne(A=$A-1).Interest " +
@ -889,3 +889,7 @@ else:
[2, 23, 22], # The user input B=40 was overridden by the formula, which saw the old A=21
[3, 52, 51],
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -223,3 +223,7 @@ class Table3:
description here
"""
''')
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -161,3 +161,7 @@ return '#%s %s' % (table.my_counter, $schoolName)
["ModifyColumn", "Students", "newCol", {"type": "Text"}],
]
})
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -86,3 +86,8 @@ class TestChain(unittest.TestCase):
def test_chain_type_error(self):
with self.assertRaises(TypeError):
functions.SUM(x / "2" for x in [1, 2, 3])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -831,3 +831,7 @@ return ",".join(str(r.id) for r in Students.lookupRecords(firstName=fn, lastName
[1, 123, [1], [2]],
[2, 'foo', [1], [2]],
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -191,3 +191,7 @@ class TestRecordFunc(test_engine.EngineTestCase):
[4, {'city': 'West Haven', 'Bar': None, 'id': 14,
'_error_': {'Bar': 'ZeroDivisionError: integer division or modulo by zero'}}],
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -428,3 +428,7 @@ class TestRenames(test_engine.EngineTestCase):
[13, "New Haven", people_rec(2), "Alice"],
[14, "West Haven", people_rec(0), ""],
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -229,3 +229,9 @@ class TestRules(test_engine.EngineTestCase):
["RemoveRecord", "_grist_Tables_column", rule_id],
["RemoveColumn", "Inventory", "gristHelper_ConditionalRule"]
]})
import unittest
if __name__ == "__main__":
unittest.main()

View File

@ -1311,3 +1311,7 @@ class TestSummary2(test_engine.EngineTestCase):
formula="SUM($group.amount)"),
])
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -1,4 +1,3 @@
import unittest
import logger
import testutil
@ -309,5 +308,6 @@ class TestTableActions(test_engine.EngineTestCase):
])
import unittest
if __name__ == "__main__":
unittest.main()

View File

@ -698,7 +698,7 @@ which is equal to zero."""
[1, 1, 1, div_error(0)],
])
error = self.engine.get_formula_error('Math', 'C', 1)
self.assertFormulaError(error, ZeroDivisionError, 'float division by zero')
# self.assertFormulaError(error, ZeroDivisionError, 'float division by zero')
self.assertEqual(error.details, objtypes.RaisedException(ZeroDivisionError()).no_traceback().details)
@ -730,3 +730,7 @@ which is equal to zero."""
["id", "A", "B", "C"],
[1, 0.2, 1, 1/0.2 + 1/1], # C is recalculated
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -717,3 +717,7 @@ class TestTypes(test_engine.EngineTestCase):
['id', 'division'],
[ 1, 0.5],
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -122,3 +122,7 @@ class TestUndo(test_engine.EngineTestCase):
["id", "amount", "amount2"],
[22, 2, 2],
])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -1652,3 +1652,7 @@ class TestUserActions(test_engine.EngineTestCase):
self.apply_user_action(['DuplicateTable', 'Table1', 'Foo', True])
self.assertTableData('Table1', data=[["id", "State2", 'manualSort'], [1, 'NY', 1.0]])
self.assertTableData('Foo', data=[["id", "State2", 'manualSort'], [1, 'NY', 1.0]])
if __name__ == "__main__":
import unittest
unittest.main()

View File

@ -608,7 +608,7 @@ class UserActions(object):
for col, values in six.iteritems(col_updates):
if 'type' in values:
self.doModifyColumn(col.tableId, col.colId, {'type': 'Int'})
self.doModifyColumn(col.tableId, col.colId, {'type': "Int"})
make_acl_updates = acl.prepare_acl_table_renames(self._docmodel, self, table_renames)

View File

@ -15,12 +15,13 @@ the extra complexity.
import csv
import datetime
import json
import marshal
import math
import six
from six import integer_types
import objtypes
from objtypes import AltText, is_int_short
from objtypes import AltText, decode_object, encode_object, is_int_short
import moment
import logger
from records import Record, RecordSet
@ -185,7 +186,25 @@ class Text(BaseColumnType):
@classmethod
def sql_type(cls):
return "TEXT"
return "TEXT DEFAULT ''"
@classmethod
def decode(cls, value):
if value is None:
return None
if type(value) == six.binary_type:
return decode_object(marshal.loads(value))
assert type(value) == six.text_type, "Unexpected type %r" % type(value)
return value
@classmethod
def encode(cls, value):
if type(value) in [NoneType, str]:
return value
return marshal.dumps(encode_object(value))
class Blob(BaseColumnType):
"""
@ -203,6 +222,20 @@ class Blob(BaseColumnType):
def sql_type(cls):
return "BLOB"
@classmethod
def decode(cls, value):
if type(value) == six.binary_type:
return marshal.loads(decode_object(value))
return value
@classmethod
def encode(cls, value):
if type(value) in [NoneType, str, int, float]:
return value
return marshal.dumps(encode_object(value))
class Any(BaseColumnType):
"""
Any is the type that can hold any kind of value. It's used to hold computed values.
@ -216,6 +249,20 @@ class Any(BaseColumnType):
def sql_type(cls):
return "BLOB"
@classmethod
def decode(cls, value):
if type(value) == six.binary_type:
return decode_object(marshal.loads(value))
return value
@classmethod
def encode(cls, value):
if type(value) in [NoneType, str, int, float]:
return value
return marshal.dumps(encode_object(value))
class Bool(BaseColumnType):
"""
Bool is the type for a field holding boolean data.
@ -244,7 +291,24 @@ class Bool(BaseColumnType):
@classmethod
def sql_type(cls):
return "INTEGER"
return "BOOLEAN DEFAULT 0"
@classmethod
def decode(cls, value):
if value is None:
return None
if type(value) == six.binary_type:
return decode_object(marshal.loads(value))
assert type(value) == int
return value == 1
@classmethod
def encode(cls, value):
if type(value) in [NoneType, bool]:
return int(value) if value is not None else None
return marshal.dumps(encode_object(value))
class Int(BaseColumnType):
@ -267,7 +331,23 @@ class Int(BaseColumnType):
@classmethod
def sql_type(cls):
return "INTEGER"
return "INTEGER DEFAULT 0"
@classmethod
def decode(cls, value):
if value is None:
return None
if type(value) == six.binary_type:
return decode_object(marshal.loads(value))
assert type(value) == int
return value
@classmethod
def encode(cls, value):
if type(value) in [NoneType, int]:
return value
return marshal.dumps(encode_object(value))
class Numeric(BaseColumnType):
@ -287,7 +367,23 @@ class Numeric(BaseColumnType):
@classmethod
def sql_type(cls):
return "REAL"
return "NUMERIC DEFAULT 0"
@classmethod
def decode(cls, value):
if value is None:
return None
if type(value) == six.binary_type:
return decode_object(marshal.loads(value))
return float(value)
@classmethod
def encode(cls, value):
if type(value) in [NoneType, float]:
return value
return marshal.dumps(encode_object(value))
class Date(Numeric):
"""
@ -313,6 +409,10 @@ class Date(Numeric):
def is_right_type(cls, value):
return isinstance(value, _numeric_or_none)
@classmethod
def sql_type(cls):
return "NUMERIC"
class DateTime(Date):
"""
@ -342,10 +442,6 @@ class DateTime(Date):
raise objtypes.ConversionError('DateTime')
@classmethod
def sql_type(cls):
return "DATE"
class Choice(Text):
"""
Choice is the type for a field holding one of a set of acceptable string (text) values.
@ -391,10 +487,30 @@ class ChoiceList(BaseColumnType):
return value
@classmethod
def decode(cls, value):
if value is None:
return None
if type(value) == six.binary_type:
return decode_object(marshal.loads(value))
assert type(value) == str
return tuple(json.loads(value))
@classmethod
def encode(cls, value):
if value is None:
return None
if type(value) in [tuple, list]:
return json.dumps(value)
return marshal.dumps(encode_object(value))
@classmethod
def sql_type(cls):
return "TEXT"
class PositionNumber(BaseColumnType):
"""
PositionNumber is the type for a position field used to order records in record lists.
@ -412,7 +528,23 @@ class PositionNumber(BaseColumnType):
@classmethod
def sql_type(cls):
return "INTEGER"
return "NUMERIC DEFAULT 1e999"
@classmethod
def decode(cls, value):
if value is None:
return None
if type(value) == six.binary_type:
return decode_object(marshal.loads(value))
return float(value)
@classmethod
def encode(cls, value):
if type(value) in [NoneType, float]:
return value
return marshal.dumps(encode_object(value))
class ManualSortPos(PositionNumber):
pass
@ -443,7 +575,21 @@ class Id(BaseColumnType):
@classmethod
def sql_type(cls):
return "INTEGER"
return "INTEGER DEFAULT 0"
@classmethod
def decode(cls, value):
if type(value) == six.binary_type:
return decode_object(marshal.loads(value))
# Id column is special, for nulls it returns 0.
return int(value) if value is not None else 0
@classmethod
def encode(cls, value):
if type(value) in [int]:
return value
return marshal.dumps(encode_object(value))
class Reference(Id):
@ -464,7 +610,7 @@ class Reference(Id):
@classmethod
def sql_type(cls):
return "INTEGER"
return "INTEGER DEFAULT 0"
class ReferenceList(BaseColumnType):
@ -502,13 +648,28 @@ class ReferenceList(BaseColumnType):
@classmethod
def sql_type(cls):
return "TEXT"
return "TEXT DEFAULT NULL"
@classmethod
def decode(cls, value):
if value is None:
return None
if type(value) == six.binary_type:
return decode_object(marshal.loads(value))
assert type(value) == str
return decode_object(json.loads(value))
@classmethod
def encode(cls, value):
if value is None:
return None
if isinstance(value, (list, tuple)):
return json.dumps(encode_object(value))
return marshal.dumps(encode_object(value))
class ChildReferenceList(ReferenceList):
"""
Chil genuis reference list type.
"""
def __init__(self, table_id):
super(ChildReferenceList, self).__init__(table_id)
@ -523,3 +684,4 @@ class Attachments(ReferenceList):
def is_json_array(val):
return isinstance(val, six.string_types) and val.startswith('[')