Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 26 additions & 43 deletions sqlite_minutils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from typing import ( cast, Any, Callable, Dict, Generator, Iterable, Union, Optional, List, Tuple,Iterator)
from functools import cache
import uuid
import apsw.ext
import apsw.bestpractice

apsw.bestpractice.apply(apsw.bestpractice.recommended)

try: from sqlite_dump import iterdump
except ImportError: iterdump = None
Expand Down Expand Up @@ -238,24 +242,21 @@ def __init__(
), "Either specify a filename_or_conn or pass memory=True"
if memory_name:
uri = "file:{}?mode=memory&cache=shared".format(memory_name)
self.conn = sqlite3.connect(
uri,
uri=True,
check_same_thread=False,
isolation_level=None
self.conn = sqlite3.Connection(
uri
)
elif memory or filename_or_conn == ":memory:":
self.conn = sqlite3.connect(":memory:", isolation_level=None)
self.conn = sqlite3.Connection(":memory:")
elif isinstance(filename_or_conn, (str, pathlib.Path)):
if recreate and os.path.exists(filename_or_conn):
try:
os.remove(filename_or_conn)
except OSError:
# Avoid mypy and __repr__ errors, see:
# https://github.com/simonw/sqlite-utils/issues/503
self.conn = sqlite3.connect(":memory:", isolation_level=None)
self.conn = sqlite3.Connection(":memory:")
raise
self.conn = sqlite3.connect(str(filename_or_conn), check_same_thread=False, isolation_level=None)
self.conn = sqlite3.Connection(str(filename_or_conn))
else:
assert not recreate, "recreate cannot be used with connections, only paths"
self.conn = filename_or_conn
Expand All @@ -268,6 +269,8 @@ def __init__(
self._registered_functions: set = set()
self.use_counts_table = use_counts_table
self.strict = strict
# self.execute('PRAGMA foreign_keys=on;')


def close(self):
"Close the SQLite connection, and the underlying database file"
Expand All @@ -278,25 +281,6 @@ def get_last_rowid(self):
if res is None: return None
return int(res[0])

@contextlib.contextmanager
def ensure_autocommit_off(self):
"""
Ensure autocommit is off for this database connection.

Example usage::

with db.ensure_autocommit_off():
# do stuff here

This will reset to the previous autocommit state at the end of the block.
"""
old_isolation_level = self.conn.isolation_level
try:
self.conn.isolation_level = None
yield
finally:
self.conn.isolation_level = old_isolation_level

@contextlib.contextmanager
def tracer(self, tracer: Optional[Callable] = None):
"""
Expand Down Expand Up @@ -421,10 +405,10 @@ def query(
parameters, or a dictionary for ``where id = :id``
"""
cursor = self.execute(sql, tuple(params or tuple()))
if cursor.description is None: return []
keys = [d[0] for d in cursor.description]
try: columns = [c[0] for c in cursor.description]
except apsw.ExecutionCompleteError: return []
for row in cursor:
yield dict(zip(keys, row))
yield dict(zip(columns, row))

def execute(
self, sql: str, parameters: Optional[Union[Iterable, dict]] = None
Expand Down Expand Up @@ -643,14 +627,12 @@ def enable_wal(self):
Sets ``journal_mode`` to ``'wal'`` to enable Write-Ahead Log mode.
"""
if self.journal_mode != "wal":
with self.ensure_autocommit_off():
self.execute("PRAGMA journal_mode=wal;")
self.execute("PRAGMA journal_mode=wal;")

def disable_wal(self):
"Sets ``journal_mode`` back to ``'delete'`` to disable Write-Ahead Log mode."
if self.journal_mode != "delete":
with self.ensure_autocommit_off():
self.execute("PRAGMA journal_mode=delete;")
self.execute("PRAGMA journal_mode=delete;")

def _ensure_counts_table(self):
self.execute(_COUNTS_TABLE_CREATE_SQL.format(self._counts_table_name))
Expand Down Expand Up @@ -1292,7 +1274,9 @@ def rows_where(
if offset is not None:
sql += f" offset {offset}"
cursor = self.db.execute(sql, where_args or [])
columns = [c[0] for c in cursor.description]
# If no records found, raise a NotFoundError
try: columns = [c[0] for c in cursor.description]
except apsw.ExecutionCompleteError: return []
for row in cursor:
yield dict(zip(columns, row))

Expand Down Expand Up @@ -1740,15 +1724,15 @@ def transform(
]
try:
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_keys=0;")
self.db.execute("PRAGMA foreign_keys=off;")
for sql in sqls:
self.db.execute(sql)
# Run the foreign_key_check before we commit
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_key_check;")
finally:
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_keys=1;")
self.db.execute("PRAGMA foreign_keys=on;")
return self

def transform_sql(
Expand Down Expand Up @@ -2211,7 +2195,7 @@ def drop(self, ignore: bool = False):
"""
try:
self.db.execute("DROP TABLE [{}]".format(self.name))
except sqlite3.OperationalError:
except apsw.SQLError:
if not ignore:
raise

Expand Down Expand Up @@ -2769,12 +2753,10 @@ def update(
if alter and (" column" in e.args[0]):
# Attempt to add any missing columns, then try again
self.add_missing_columns([updates])
rowcount = self.db.execute(sql, args).rowcount
self.db.execute(sql, args)
else:
raise

# TODO: Test this works (rolls back) - use better exception:
# assert rowcount == 1
self.last_pk = pk_values[0] if len(pks) == 1 else pk_values
self.result = records
return self
Expand Down Expand Up @@ -2929,7 +2911,8 @@ def insert_chunk(
for query, params in queries_and_params:
try:
cursor = self.db.execute(query, tuple(params))
if cursor.description is None: continue
try: columns = [c[0] for c in cursor.description]
except apsw.ExecutionCompleteError: continue
columns = [d[0] for d in cursor.description]
for row in cursor:
records.append(dict(zip(columns, row)))
Expand Down Expand Up @@ -3671,7 +3654,7 @@ def drop(self, ignore=False):

try:
self.db.execute("DROP VIEW [{}]".format(self.name))
except sqlite3.OperationalError:
except apsw.SQLError:
if not ignore:
raise

Expand Down
34 changes: 18 additions & 16 deletions sqlite_minutils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@
import json
from typing import Dict, cast, BinaryIO, Iterable, Optional, Tuple, Type

try:
import pysqlite3 as sqlite3 # noqa: F401
from pysqlite3 import dbapi2 # noqa: F401

OperationalError = dbapi2.OperationalError
except ImportError:
try:
import sqlean as sqlite3 # noqa: F401
from sqlean import dbapi2 # noqa: F401

OperationalError = dbapi2.OperationalError
except ImportError:
import sqlite3 # noqa: F401
from sqlite3 import dbapi2 # noqa: F401

OperationalError = dbapi2.OperationalError
# try:
# import pysqlite3 as sqlite3 # noqa: F401
# from pysqlite3 import dbapi2 # noqa: F401

# OperationalError = dbapi2.OperationalError
# except ImportError:
# try:
# import sqlean as sqlite3 # noqa: F401
# from sqlean import dbapi2 # noqa: F401

# OperationalError = dbapi2.OperationalError
# except ImportError:
# import sqlite3 # noqa: F401
# from sqlite3 import dbapi2 # noqa: F401

# OperationalError = dbapi2.OperationalError
import apsw as sqlite3
OperationalError = sqlite3.Error


SPATIALITE_PATHS = (
Expand Down
14 changes: 7 additions & 7 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pathlib
import pytest
import uuid

import apsw

try:
import pandas as pd # type: ignore
Expand Down Expand Up @@ -1077,7 +1077,7 @@ def test_drop_view(fresh_db):


def test_drop_ignore(fresh_db):
with pytest.raises(sqlite3.OperationalError):
with pytest.raises(apsw.SQLError):
fresh_db["does_not_exist"].drop()
fresh_db["does_not_exist"].drop(ignore=True)
# Testing view is harder, we need to create it in order
Expand All @@ -1086,7 +1086,7 @@ def test_drop_ignore(fresh_db):
view = fresh_db["foo_view"]
assert isinstance(view, View)
view.drop()
with pytest.raises(sqlite3.OperationalError):
with pytest.raises(apsw.SQLError):
view.drop()
view.drop(ignore=True)

Expand Down Expand Up @@ -1198,7 +1198,7 @@ def test_create(fresh_db):
def test_create_if_not_exists(fresh_db):
fresh_db["t"].create({"id": int})
# This should error
with pytest.raises(sqlite3.OperationalError):
with pytest.raises(apsw.SQLError):
fresh_db["t"].create({"id": int})
# This should not
fresh_db["t"].create({"id": int}, if_not_exists=True)
Expand All @@ -1213,7 +1213,7 @@ def test_create_if_no_columns(fresh_db):
def test_create_ignore(fresh_db):
fresh_db["t"].create({"id": int})
# This should error
with pytest.raises(sqlite3.OperationalError):
with pytest.raises(apsw.SQLError):
fresh_db["t"].create({"id": int})
# This should not
fresh_db["t"].create({"id": int}, ignore=True)
Expand All @@ -1222,7 +1222,7 @@ def test_create_ignore(fresh_db):
def test_create_replace(fresh_db):
fresh_db["t"].create({"id": int})
# This should error
with pytest.raises(sqlite3.OperationalError):
with pytest.raises(apsw.SQLError):
fresh_db["t"].create({"id": int})
# This should not
fresh_db["t"].create({"name": str}, replace=True)
Expand Down Expand Up @@ -1312,7 +1312,7 @@ def test_rename_table(fresh_db):
assert ["renamed"] == fresh_db.table_names()
assert [{"foo": "bar"}] == list(fresh_db["renamed"].rows)
# Should error if table does not exist:
with pytest.raises(sqlite3.OperationalError):
with pytest.raises(sqlite3.SQLError):
fresh_db.rename_table("does_not_exist", "renamed")


Expand Down
13 changes: 12 additions & 1 deletion tests/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ def test_get_primary_key(fresh_db):

@pytest.mark.parametrize(
"argument,expected_msg",
[(100, None), (None, None), ((1, 2), "Need 1 primary key value"), ("2", None)],
[
(100, None),
(None, None),
((1, 2), "Need 1 primary key value"),
("2", None)
],
)
def test_get_not_found(argument, expected_msg, fresh_db):
fresh_db["dogs"].insert(
Expand All @@ -29,3 +34,9 @@ def test_get_not_found(argument, expected_msg, fresh_db):
fresh_db["dogs"].get(argument)
if expected_msg is not None:
assert expected_msg == excinfo.value.args[0]

def test_get_success(fresh_db):
fresh_db["dogs"].insert(
{"id": 1, "name": "Cleo", "age": 4, "is_good": True}, pk="id"
)
assert fresh_db["dogs"].get(1)['name'] == 'Cleo'
28 changes: 11 additions & 17 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlite_minutils.utils import OperationalError
from sqlite3 import IntegrityError
import pytest
import apsw


@pytest.mark.parametrize(
Expand Down Expand Up @@ -119,12 +120,12 @@ def tracer(sql, params):
dogs.transform(**params)
# If use_pragma_foreign_keys, check that we did the right thing
if use_pragma_foreign_keys:
assert ("PRAGMA foreign_keys=0;", None) in captured
assert ("PRAGMA foreign_keys=off;", None) in captured
assert captured[-2] == ("PRAGMA foreign_key_check;", None)
assert captured[-1] == ("PRAGMA foreign_keys=1;", None)
assert captured[-1] == ("PRAGMA foreign_keys=on;", None)
else:
assert ("PRAGMA foreign_keys=0;", None) not in captured
assert ("PRAGMA foreign_keys=1;", None) not in captured
assert ("PRAGMA foreign_keys=off;", None) not in captured
assert ("PRAGMA foreign_keys=on;", None) not in captured


@pytest.mark.parametrize(
Expand Down Expand Up @@ -172,32 +173,25 @@ def tracer(sql, params):
),
],
)
@pytest.mark.parametrize("use_pragma_foreign_keys", [False, True])
def test_transform_sql_table_with_no_primary_key(
fresh_db, params, expected_sql, use_pragma_foreign_keys
fresh_db, params, expected_sql
):
captured = []

def tracer(sql, params):
return captured.append((sql, params))

dogs = fresh_db["dogs"]
if use_pragma_foreign_keys:
fresh_db.conn.execute("PRAGMA foreign_keys=ON")
dogs.insert({"id": 1, "name": "Cleo", "age": "5"})
sql = dogs.transform_sql(**{**params, **{"tmp_suffix": "suffix"}})
assert sql == expected_sql
# Check that .transform() runs without exceptions:
with fresh_db.tracer(tracer):
dogs.transform(**params)
# If use_pragma_foreign_keys, check that we did the right thing
if use_pragma_foreign_keys:
assert ("PRAGMA foreign_keys=0;", None) in captured
assert captured[-2] == ("PRAGMA foreign_key_check;", None)
assert captured[-1] == ("PRAGMA foreign_keys=1;", None)
else:
assert ("PRAGMA foreign_keys=0;", None) not in captured
assert ("PRAGMA foreign_keys=1;", None) not in captured
# We always use foreign keys
assert ("PRAGMA foreign_keys=off;", None) in captured
assert captured[-2] == ("PRAGMA foreign_key_check;", None)
assert captured[-1] == ("PRAGMA foreign_keys=on;", None)


def test_transform_sql_with_no_primary_key_to_primary_key_of_id(fresh_db):
Expand Down Expand Up @@ -392,7 +386,7 @@ def test_transform_verify_foreign_keys(fresh_db):
try:
fresh_db["authors"].transform(rename={"id": "id2"})
fresh_db.commit()
except IntegrityError:
except apsw.ConstraintError:
fresh_db.rollback()

# This should have rolled us back
Expand Down