From 35c9c92193943e9c82fc0841ef6256330050c7d7 Mon Sep 17 00:00:00 2001 From: Daniel Roy Greenfeld Date: Fri, 22 Nov 2024 13:41:59 +0000 Subject: [PATCH 1/5] Early pass at apsw integration Co-authored-by: Audrey Roy Greenfeld --- sqlite_minutils/db.py | 18 +++++++++--------- sqlite_minutils/utils.py | 34 ++++++++++++++++++---------------- tests/test_get.py | 13 ++++++++++++- 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/sqlite_minutils/db.py b/sqlite_minutils/db.py index e262d07..ca33731 100644 --- a/sqlite_minutils/db.py +++ b/sqlite_minutils/db.py @@ -9,6 +9,7 @@ from typing import ( cast, Any, Callable, Dict, Generator, Iterable, Union, Optional, List, Tuple,Iterator) from functools import cache import uuid +import apsw.ext try: from sqlite_dump import iterdump except ImportError: iterdump = None @@ -238,14 +239,11 @@ def __init__( ), "Either specify a filename_or_conn or pass memory=True" if memory_name: uri = "file:{}?mode=memory&cache=shared".format(memory_name) - self.conn = sqlite3.connect( - uri, - uri=True, - check_same_thread=False, - isolation_level=None + self.conn = sqlite3.Connection( + uri ) elif memory or filename_or_conn == ":memory:": - self.conn = sqlite3.connect(":memory:", isolation_level=None) + self.conn = sqlite3.Connection(":memory:") elif isinstance(filename_or_conn, (str, pathlib.Path)): if recreate and os.path.exists(filename_or_conn): try: @@ -253,9 +251,9 @@ def __init__( except OSError: # Avoid mypy and __repr__ errors, see: # https://github.com/simonw/sqlite-utils/issues/503 - self.conn = sqlite3.connect(":memory:", isolation_level=None) + self.conn = sqlite3.Connection(":memory:") raise - self.conn = sqlite3.connect(str(filename_or_conn), check_same_thread=False, isolation_level=None) + self.conn = sqlite3.Connection(str(filename_or_conn)) else: assert not recreate, "recreate cannot be used with connections, only paths" self.conn = filename_or_conn @@ -1292,7 +1290,9 @@ def rows_where( if offset is not None: sql += f" offset {offset}" cursor = self.db.execute(sql, where_args or []) - columns = [c[0] for c in cursor.description] + # If no records found, raise a NotFoundError + try: columns = [c[0] for c in cursor.description] + except apsw.ExecutionCompleteError: raise NotFoundError for row in cursor: yield dict(zip(columns, row)) diff --git a/sqlite_minutils/utils.py b/sqlite_minutils/utils.py index d6a016d..842fb66 100644 --- a/sqlite_minutils/utils.py +++ b/sqlite_minutils/utils.py @@ -12,22 +12,24 @@ import json from typing import Dict, cast, BinaryIO, Iterable, Optional, Tuple, Type -try: - import pysqlite3 as sqlite3 # noqa: F401 - from pysqlite3 import dbapi2 # noqa: F401 - - OperationalError = dbapi2.OperationalError -except ImportError: - try: - import sqlean as sqlite3 # noqa: F401 - from sqlean import dbapi2 # noqa: F401 - - OperationalError = dbapi2.OperationalError - except ImportError: - import sqlite3 # noqa: F401 - from sqlite3 import dbapi2 # noqa: F401 - - OperationalError = dbapi2.OperationalError +# try: +# import pysqlite3 as sqlite3 # noqa: F401 +# from pysqlite3 import dbapi2 # noqa: F401 + +# OperationalError = dbapi2.OperationalError +# except ImportError: +# try: +# import sqlean as sqlite3 # noqa: F401 +# from sqlean import dbapi2 # noqa: F401 + +# OperationalError = dbapi2.OperationalError +# except ImportError: +# import sqlite3 # noqa: F401 +# from sqlite3 import dbapi2 # noqa: F401 + +# OperationalError = dbapi2.OperationalError +import apsw as sqlite3 +OperationalError = sqlite3.Error SPATIALITE_PATHS = ( diff --git a/tests/test_get.py b/tests/test_get.py index 304e37d..23dec42 100644 --- a/tests/test_get.py +++ b/tests/test_get.py @@ -19,7 +19,12 @@ def test_get_primary_key(fresh_db): @pytest.mark.parametrize( "argument,expected_msg", - [(100, None), (None, None), ((1, 2), "Need 1 primary key value"), ("2", None)], + [ + (100, None), + (None, None), + ((1, 2), "Need 1 primary key value"), + ("2", None) + ], ) def test_get_not_found(argument, expected_msg, fresh_db): fresh_db["dogs"].insert( @@ -29,3 +34,9 @@ def test_get_not_found(argument, expected_msg, fresh_db): fresh_db["dogs"].get(argument) if expected_msg is not None: assert expected_msg == excinfo.value.args[0] + +def test_get_success(fresh_db): + fresh_db["dogs"].insert( + {"id": 1, "name": "Cleo", "age": 4, "is_good": True}, pk="id" + ) + assert fresh_db["dogs"].get(1)['name'] == 'Cleo' From 7d59e65f9c668b62593be7682ced56ebe63b18f5 Mon Sep 17 00:00:00 2001 From: Daniel Roy Greenfeld Date: Fri, 22 Nov 2024 15:23:27 +0000 Subject: [PATCH 2/5] More tests for apsw Co-authored-by: Audrey Roy Greenfeld --- sqlite_minutils/db.py | 17 +++++++++++------ tests/test_create.py | 14 +++++++------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/sqlite_minutils/db.py b/sqlite_minutils/db.py index ca33731..8c83081 100644 --- a/sqlite_minutils/db.py +++ b/sqlite_minutils/db.py @@ -10,6 +10,9 @@ from functools import cache import uuid import apsw.ext +import apsw.bestpractice + +# apsw.bestpractice.apply(apsw.bestpractice.connection_enable_foreign_keys) try: from sqlite_dump import iterdump except ImportError: iterdump = None @@ -267,6 +270,7 @@ def __init__( self.use_counts_table = use_counts_table self.strict = strict + def close(self): "Close the SQLite connection, and the underlying database file" self.conn.close() @@ -419,10 +423,10 @@ def query( parameters, or a dictionary for ``where id = :id`` """ cursor = self.execute(sql, tuple(params or tuple())) - if cursor.description is None: return [] - keys = [d[0] for d in cursor.description] + try: columns = [c[0] for c in cursor.description] + except apsw.ExecutionCompleteError: return [] for row in cursor: - yield dict(zip(keys, row)) + yield dict(zip(columns, row)) def execute( self, sql: str, parameters: Optional[Union[Iterable, dict]] = None @@ -2211,7 +2215,7 @@ def drop(self, ignore: bool = False): """ try: self.db.execute("DROP TABLE [{}]".format(self.name)) - except sqlite3.OperationalError: + except apsw.SQLError: if not ignore: raise @@ -2929,7 +2933,8 @@ def insert_chunk( for query, params in queries_and_params: try: cursor = self.db.execute(query, tuple(params)) - if cursor.description is None: continue + try: columns = [c[0] for c in cursor.description] + except apsw.ExecutionCompleteError: continue columns = [d[0] for d in cursor.description] for row in cursor: records.append(dict(zip(columns, row))) @@ -3671,7 +3676,7 @@ def drop(self, ignore=False): try: self.db.execute("DROP VIEW [{}]".format(self.name)) - except sqlite3.OperationalError: + except apsw.SQLError: if not ignore: raise diff --git a/tests/test_create.py b/tests/test_create.py index 11b7d46..690b02d 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -17,7 +17,7 @@ import pathlib import pytest import uuid - +import apsw try: import pandas as pd # type: ignore @@ -1077,7 +1077,7 @@ def test_drop_view(fresh_db): def test_drop_ignore(fresh_db): - with pytest.raises(sqlite3.OperationalError): + with pytest.raises(apsw.SQLError): fresh_db["does_not_exist"].drop() fresh_db["does_not_exist"].drop(ignore=True) # Testing view is harder, we need to create it in order @@ -1086,7 +1086,7 @@ def test_drop_ignore(fresh_db): view = fresh_db["foo_view"] assert isinstance(view, View) view.drop() - with pytest.raises(sqlite3.OperationalError): + with pytest.raises(apsw.SQLError): view.drop() view.drop(ignore=True) @@ -1198,7 +1198,7 @@ def test_create(fresh_db): def test_create_if_not_exists(fresh_db): fresh_db["t"].create({"id": int}) # This should error - with pytest.raises(sqlite3.OperationalError): + with pytest.raises(apsw.SQLError): fresh_db["t"].create({"id": int}) # This should not fresh_db["t"].create({"id": int}, if_not_exists=True) @@ -1213,7 +1213,7 @@ def test_create_if_no_columns(fresh_db): def test_create_ignore(fresh_db): fresh_db["t"].create({"id": int}) # This should error - with pytest.raises(sqlite3.OperationalError): + with pytest.raises(apsw.SQLError): fresh_db["t"].create({"id": int}) # This should not fresh_db["t"].create({"id": int}, ignore=True) @@ -1222,7 +1222,7 @@ def test_create_ignore(fresh_db): def test_create_replace(fresh_db): fresh_db["t"].create({"id": int}) # This should error - with pytest.raises(sqlite3.OperationalError): + with pytest.raises(apsw.SQLError): fresh_db["t"].create({"id": int}) # This should not fresh_db["t"].create({"name": str}, replace=True) @@ -1312,7 +1312,7 @@ def test_rename_table(fresh_db): assert ["renamed"] == fresh_db.table_names() assert [{"foo": "bar"}] == list(fresh_db["renamed"].rows) # Should error if table does not exist: - with pytest.raises(sqlite3.OperationalError): + with pytest.raises(sqlite3.SQLError): fresh_db.rename_table("does_not_exist", "renamed") From c7ce1ce54fd21b78412351466eb52566e8e46fef Mon Sep 17 00:00:00 2001 From: Daniel Roy Greenfeld Date: Fri, 22 Nov 2024 15:39:12 +0000 Subject: [PATCH 3/5] Remove unnecessary rowcount in update method Co-authored-by: Audrey Roy Greenfeld --- sqlite_minutils/db.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sqlite_minutils/db.py b/sqlite_minutils/db.py index 8c83081..d0f89fa 100644 --- a/sqlite_minutils/db.py +++ b/sqlite_minutils/db.py @@ -2773,12 +2773,10 @@ def update( if alter and (" column" in e.args[0]): # Attempt to add any missing columns, then try again self.add_missing_columns([updates]) - rowcount = self.db.execute(sql, args).rowcount + self.db.execute(sql, args) else: raise - # TODO: Test this works (rolls back) - use better exception: - # assert rowcount == 1 self.last_pk = pk_values[0] if len(pks) == 1 else pk_values self.result = records return self From f3a60451ed0efe3d9d46295677aefbe1998e2bc6 Mon Sep 17 00:00:00 2001 From: Daniel Roy Greenfeld Date: Fri, 22 Nov 2024 15:47:27 +0000 Subject: [PATCH 4/5] Remove last vestige of autocommit off Co-authored-by: Audrey Roy Greenfeld --- sqlite_minutils/db.py | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/sqlite_minutils/db.py b/sqlite_minutils/db.py index d0f89fa..499590c 100644 --- a/sqlite_minutils/db.py +++ b/sqlite_minutils/db.py @@ -280,25 +280,6 @@ def get_last_rowid(self): if res is None: return None return int(res[0]) - @contextlib.contextmanager - def ensure_autocommit_off(self): - """ - Ensure autocommit is off for this database connection. - - Example usage:: - - with db.ensure_autocommit_off(): - # do stuff here - - This will reset to the previous autocommit state at the end of the block. - """ - old_isolation_level = self.conn.isolation_level - try: - self.conn.isolation_level = None - yield - finally: - self.conn.isolation_level = old_isolation_level - @contextlib.contextmanager def tracer(self, tracer: Optional[Callable] = None): """ @@ -645,14 +626,12 @@ def enable_wal(self): Sets ``journal_mode`` to ``'wal'`` to enable Write-Ahead Log mode. """ if self.journal_mode != "wal": - with self.ensure_autocommit_off(): - self.execute("PRAGMA journal_mode=wal;") + self.execute("PRAGMA journal_mode=wal;") def disable_wal(self): "Sets ``journal_mode`` back to ``'delete'`` to disable Write-Ahead Log mode." if self.journal_mode != "delete": - with self.ensure_autocommit_off(): - self.execute("PRAGMA journal_mode=delete;") + self.execute("PRAGMA journal_mode=delete;") def _ensure_counts_table(self): self.execute(_COUNTS_TABLE_CREATE_SQL.format(self._counts_table_name)) @@ -1296,7 +1275,7 @@ def rows_where( cursor = self.db.execute(sql, where_args or []) # If no records found, raise a NotFoundError try: columns = [c[0] for c in cursor.description] - except apsw.ExecutionCompleteError: raise NotFoundError + except apsw.ExecutionCompleteError: return [] for row in cursor: yield dict(zip(columns, row)) From c31c3bad5f1ab4405dd2070b93824bbf43a4e3c9 Mon Sep 17 00:00:00 2001 From: Daniel Roy Greenfeld Date: Fri, 22 Nov 2024 16:18:07 +0000 Subject: [PATCH 5/5] Move to using apsw best practices Co-authored-by: Audrey Roy Greenfeld --- sqlite_minutils/db.py | 7 ++++--- tests/test_transform.py | 28 +++++++++++----------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/sqlite_minutils/db.py b/sqlite_minutils/db.py index 499590c..9091625 100644 --- a/sqlite_minutils/db.py +++ b/sqlite_minutils/db.py @@ -12,7 +12,7 @@ import apsw.ext import apsw.bestpractice -# apsw.bestpractice.apply(apsw.bestpractice.connection_enable_foreign_keys) +apsw.bestpractice.apply(apsw.bestpractice.recommended) try: from sqlite_dump import iterdump except ImportError: iterdump = None @@ -269,6 +269,7 @@ def __init__( self._registered_functions: set = set() self.use_counts_table = use_counts_table self.strict = strict + # self.execute('PRAGMA foreign_keys=on;') def close(self): @@ -1723,7 +1724,7 @@ def transform( ] try: if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_keys=0;") + self.db.execute("PRAGMA foreign_keys=off;") for sql in sqls: self.db.execute(sql) # Run the foreign_key_check before we commit @@ -1731,7 +1732,7 @@ def transform( self.db.execute("PRAGMA foreign_key_check;") finally: if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_keys=1;") + self.db.execute("PRAGMA foreign_keys=on;") return self def transform_sql( diff --git a/tests/test_transform.py b/tests/test_transform.py index 93bed9c..d2fa8da 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -2,6 +2,7 @@ from sqlite_minutils.utils import OperationalError from sqlite3 import IntegrityError import pytest +import apsw @pytest.mark.parametrize( @@ -119,12 +120,12 @@ def tracer(sql, params): dogs.transform(**params) # If use_pragma_foreign_keys, check that we did the right thing if use_pragma_foreign_keys: - assert ("PRAGMA foreign_keys=0;", None) in captured + assert ("PRAGMA foreign_keys=off;", None) in captured assert captured[-2] == ("PRAGMA foreign_key_check;", None) - assert captured[-1] == ("PRAGMA foreign_keys=1;", None) + assert captured[-1] == ("PRAGMA foreign_keys=on;", None) else: - assert ("PRAGMA foreign_keys=0;", None) not in captured - assert ("PRAGMA foreign_keys=1;", None) not in captured + assert ("PRAGMA foreign_keys=off;", None) not in captured + assert ("PRAGMA foreign_keys=on;", None) not in captured @pytest.mark.parametrize( @@ -172,9 +173,8 @@ def tracer(sql, params): ), ], ) -@pytest.mark.parametrize("use_pragma_foreign_keys", [False, True]) def test_transform_sql_table_with_no_primary_key( - fresh_db, params, expected_sql, use_pragma_foreign_keys + fresh_db, params, expected_sql ): captured = [] @@ -182,22 +182,16 @@ def tracer(sql, params): return captured.append((sql, params)) dogs = fresh_db["dogs"] - if use_pragma_foreign_keys: - fresh_db.conn.execute("PRAGMA foreign_keys=ON") dogs.insert({"id": 1, "name": "Cleo", "age": "5"}) sql = dogs.transform_sql(**{**params, **{"tmp_suffix": "suffix"}}) assert sql == expected_sql # Check that .transform() runs without exceptions: with fresh_db.tracer(tracer): dogs.transform(**params) - # If use_pragma_foreign_keys, check that we did the right thing - if use_pragma_foreign_keys: - assert ("PRAGMA foreign_keys=0;", None) in captured - assert captured[-2] == ("PRAGMA foreign_key_check;", None) - assert captured[-1] == ("PRAGMA foreign_keys=1;", None) - else: - assert ("PRAGMA foreign_keys=0;", None) not in captured - assert ("PRAGMA foreign_keys=1;", None) not in captured + # We always use foreign keys + assert ("PRAGMA foreign_keys=off;", None) in captured + assert captured[-2] == ("PRAGMA foreign_key_check;", None) + assert captured[-1] == ("PRAGMA foreign_keys=on;", None) def test_transform_sql_with_no_primary_key_to_primary_key_of_id(fresh_db): @@ -392,7 +386,7 @@ def test_transform_verify_foreign_keys(fresh_db): try: fresh_db["authors"].transform(rename={"id": "id2"}) fresh_db.commit() - except IntegrityError: + except apsw.ConstraintError: fresh_db.rollback() # This should have rolled us back