diff --git a/docs/python-api.rst b/docs/python-api.rst index f9371d59..267591ac 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -123,6 +123,27 @@ You can pass ``strict=True`` to enable `SQLite STRICT mode 3' """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) # Does view already exist? if view in db.view_names(): @@ -1741,6 +1789,7 @@ def drop_view(path, view, ignore, load_extension): sqlite-utils drop-view chickens.db heavy_chickens """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: db[view].drop(ignore=ignore) @@ -1805,6 +1854,7 @@ def query( -p age 1 """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) for alias, attach_path in attach: db.attach(alias, attach_path) _load_extensions(db, load_extension) @@ -1939,6 +1989,8 @@ def memory( sqlite-utils memory animals.csv --schema """ db = sqlite_utils.Database(memory=True) + if not return_db: + _register_db_for_cleanup(db) # If --dump or --save or --analyze used but no paths detected, assume SQL query is a path: if (dump or save or schema or analyze) and not paths: @@ -1948,6 +2000,7 @@ def memory( for i, path in enumerate(paths): # Path may have a :format suffix fp = None + should_close_fp = False if ":" in path and path.rsplit(":", 1)[-1].upper() in Format.__members__: path, suffix = path.rsplit(":", 1) format = Format[suffix.upper()] @@ -1965,29 +2018,32 @@ def memory( file_table = stem stem_counts[stem] = stem_counts.get(stem, 1) + 1 fp = file_path.open("rb") - rows, format_used = rows_from_file(fp, format=format, encoding=encoding) - tracker = None - if format_used in (Format.CSV, Format.TSV) and not no_detect_types: - tracker = TypeTracker() - rows = tracker.wrap(rows) - if flatten: - rows = (_flatten(row) for row in rows) - - db[file_table].insert_all(rows, alter=True) - if tracker is not None: - db[file_table].transform(types=tracker.types) - # Add convenient t / t1 / t2 views - view_names = ["t{}".format(i + 1)] - if i == 0: - view_names.append("t") - for view_name in view_names: - if not db[view_name].exists(): - db.create_view( - view_name, "select * from {}".format(quote_identifier(file_table)) - ) - - if fp: - fp.close() + should_close_fp = True + try: + rows, format_used = rows_from_file(fp, format=format, encoding=encoding) + tracker = None + if format_used in (Format.CSV, Format.TSV) and not no_detect_types: + tracker = TypeTracker() + rows = tracker.wrap(rows) + if flatten: + rows = (_flatten(row) for row in rows) + + db[file_table].insert_all(rows, alter=True) + if tracker is not None: + db[file_table].transform(types=tracker.types) + # Add convenient t / t1 / t2 views + view_names = ["t{}".format(i + 1)] + if i == 0: + view_names.append("t") + for view_name in view_names: + if not db[view_name].exists(): + db.create_view( + view_name, + "select * from {}".format(quote_identifier(file_table)), + ) + finally: + if should_close_fp and fp: + fp.close() if analyze: _analyze(db, tables=None, columns=None, save=False) @@ -2004,6 +2060,7 @@ def memory( if save: db2 = sqlite_utils.Database(save) + _register_db_for_cleanup(db2) for line in db.iterdump(): db2.execute(line) return @@ -2140,6 +2197,7 @@ def search( sqlite-utils search data.db chickens lila """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) # Check table exists table_obj = db[dbtable] @@ -2307,7 +2365,9 @@ def triggers( """ sql = "select name, tbl_name as \"table\", sql from sqlite_master where type = 'trigger'" if tables: - quote = sqlite_utils.Database(memory=True).quote + _quote_db = sqlite_utils.Database(memory=True) + _register_db_for_cleanup(_quote_db) + quote = _quote_db.quote sql += ' and "table" in ({})'.format( ", ".join(quote(table) for table in tables) ) @@ -2372,7 +2432,9 @@ def indexes( sqlite_master.type = 'table' """ if tables: - quote = sqlite_utils.Database(memory=True).quote + _quote_db = sqlite_utils.Database(memory=True) + _register_db_for_cleanup(_quote_db) + quote = _quote_db.quote sql += " and sqlite_master.name in ({})".format( ", ".join(quote(table) for table in tables) ) @@ -2415,6 +2477,7 @@ def schema( sqlite-utils schema trees.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) if tables: for table in tables: @@ -2507,6 +2570,7 @@ def transform( --rename column2 column_renamed """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) types = {} kwargs = {} @@ -2590,6 +2654,7 @@ def extract( sqlite-utils extract trees.db Street_Trees species """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) kwargs = dict( columns=columns, @@ -2734,6 +2799,7 @@ def _content_text(p): yield row db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: with db.conn: @@ -2792,6 +2858,7 @@ def analyze_tables( sqlite-utils analyze-tables data.db trees """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) _analyze(db, tables, columns, save, common_limit, no_most, no_least) @@ -2991,6 +3058,7 @@ def convert( ): sqlite3.enable_callback_tracebacks(True) db = sqlite_utils.Database(db_path) + _register_db_for_cleanup(db) if output is not None and len(columns) > 1: raise click.ClickException("Cannot use --output with more than one column") if multi and len(columns) > 1: @@ -3133,6 +3201,7 @@ def add_geometry_column( By default, this command will try to load the SpatiaLite extension from usual paths. To load it from a specific path, use --load-extension.""" db = sqlite_utils.Database(db_path) + _register_db_for_cleanup(db) if not db[table].exists(): raise click.ClickException( "You must create a table before adding a geometry column" @@ -3165,6 +3234,7 @@ def create_spatial_index(db_path, table, column_name, load_extension): By default, this command will try to load the SpatiaLite extension from usual paths. To load it from a specific path, use --load-extension.""" db = sqlite_utils.Database(db_path) + _register_db_for_cleanup(db) if not db[table].exists(): raise click.ClickException( "You must create a table and add a geometry column before creating a spatial index" diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index c057eabe..f6fe5edb 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -382,6 +382,12 @@ def __init__( pm.hook.prepare_connection(conn=self.conn) self.strict = strict + def __enter__(self) -> "Database": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + def close(self) -> None: "Close the SQLite connection, and the underlying database file" self.conn.close() diff --git a/sqlite_utils/utils.py b/sqlite_utils/utils.py index 87cf7645..62826b76 100644 --- a/sqlite_utils/utils.py +++ b/sqlite_utils/utils.py @@ -8,11 +8,12 @@ import json import os import sys -from . import recipes -from typing import Dict, cast, BinaryIO, Iterable, Optional, Tuple, Type +from typing import Dict, cast, BinaryIO, Iterable, Iterator, Optional, Tuple, Type import click +from . import recipes + try: import pysqlite3 as sqlite3 # noqa: F401 from pysqlite3 import dbapi2 # noqa: F401 @@ -43,6 +44,27 @@ ORIGINAL_CSV_FIELD_SIZE_LIMIT = csv.field_size_limit() +class _CloseableIterator(Iterator[dict]): + """Iterator wrapper that closes a file when iteration is complete.""" + + def __init__(self, iterator: Iterator[dict], closeable: io.IOBase): + self._iterator = iterator + self._closeable = closeable + + def __iter__(self) -> "_CloseableIterator": + return self + + def __next__(self) -> dict: + try: + return next(self._iterator) + except StopIteration: + self._closeable.close() + raise + + def close(self) -> None: + self._closeable.close() + + def maximize_csv_field_size_limit(): """ Increase the CSV field size limit to the maximum possible. @@ -299,7 +321,8 @@ class Format(enum.Enum): reader = csv.DictReader(decoded_fp, dialect=dialect) else: reader = csv.DictReader(decoded_fp) - return _extra_key_strategy(reader, ignore_extras, extras_key), Format.CSV + rows = _extra_key_strategy(reader, ignore_extras, extras_key) + return _CloseableIterator(iter(rows), decoded_fp), Format.CSV elif format == Format.TSV: rows = rows_from_file( fp, format=Format.CSV, dialect=csv.excel_tab, encoding=encoding diff --git a/tests/conftest.py b/tests/conftest.py index 3932f05c..4a43dd54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,26 @@ def pytest_configure(config): sys._called_from_test = True +@pytest.fixture(autouse=True) +def close_all_databases(): + """Automatically close all Database objects created during a test.""" + databases = [] + original_init = Database.__init__ + + def tracking_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + databases.append(self) + + Database.__init__ = tracking_init + yield + Database.__init__ = original_init + for db in databases: + try: + db.close() + except Exception: + pass + + @pytest.fixture def fresh_db(): return Database(memory=True) @@ -38,4 +58,5 @@ def db_path(tmpdir): path = str(tmpdir / "test.db") db = sqlite3.connect(path) db.executescript(CREATE_TABLES) + db.close() return path diff --git a/tests/test_analyze_tables.py b/tests/test_analyze_tables.py index 9634cfcd..4618eff1 100644 --- a/tests/test_analyze_tables.py +++ b/tests/test_analyze_tables.py @@ -137,6 +137,7 @@ def db_to_analyze_path(db_to_analyze, tmpdir): db = sqlite3.connect(path) sql = "\n".join(db_to_analyze.iterdump()) db.executescript(sql) + db.close() return path diff --git a/tests/test_cli.py b/tests/test_cli.py index 4198727e..40c3595c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -19,9 +19,11 @@ def _supports_pragma_function_list(): db = Database(memory=True) try: db.execute("select * from pragma_function_list()") + return True except Exception: return False - return True + finally: + db.close() def _has_compiled_ext(): diff --git a/tests/test_cli_memory.py b/tests/test_cli_memory.py index 8483963b..c8be35fd 100644 --- a/tests/test_cli_memory.py +++ b/tests/test_cli_memory.py @@ -328,7 +328,8 @@ def test_memory_return_db(tmpdir): from sqlite_utils.cli import cli path = str(tmpdir / "dogs.csv") - open(path, "w").write("id,name\n1,Cleo") + with open(path, "w") as f: + f.write("id,name\n1,Cleo") with click.Context(cli) as ctx: db = ctx.invoke(cli.commands["memory"], paths=(path,), return_db=True) diff --git a/tests/test_introspect.py b/tests/test_introspect.py index 3df169d8..ab61c158 100644 --- a/tests/test_introspect.py +++ b/tests/test_introspect.py @@ -2,6 +2,14 @@ import pytest +def _check_supports_strict(): + """Check if SQLite supports strict tables without leaking the database.""" + db = Database(memory=True) + result = db.supports_strict + db.close() + return result + + def test_table_names(existing_db): assert ["foo"] == existing_db.table_names() @@ -282,7 +290,7 @@ def test_use_rowid(fresh_db): @pytest.mark.skipif( - not Database(memory=True).supports_strict, + not _check_supports_strict(), reason="Needs SQLite version that supports strict", ) @pytest.mark.parametrize( diff --git a/tests/test_plugins.py b/tests/test_plugins.py index c6c80594..1d459c99 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -9,9 +9,11 @@ def _supports_pragma_function_list(): db = Database(memory=True) try: db.execute("select * from pragma_function_list()") + return True except Exception: return False - return True + finally: + db.close() def test_register_commands(): diff --git a/tests/test_recreate.py b/tests/test_recreate.py index 49dba2bf..224d1824 100644 --- a/tests/test_recreate.py +++ b/tests/test_recreate.py @@ -14,8 +14,11 @@ def test_recreate_ignored_for_in_memory(): def test_recreate_not_allowed_for_connection(): conn = sqlite3.connect(":memory:") - with pytest.raises(AssertionError): - Database(conn, recreate=True) + try: + with pytest.raises(AssertionError): + Database(conn, recreate=True) + finally: + conn.close() @pytest.mark.parametrize( diff --git a/tests/test_register_function.py b/tests/test_register_function.py index e2591f4e..618bf1e5 100644 --- a/tests/test_register_function.py +++ b/tests/test_register_function.py @@ -42,39 +42,46 @@ def to_lower(s): def test_register_function_deterministic_tries_again_if_exception_raised(fresh_db): + # Save the original connection so we can close it later + original_conn = fresh_db.conn fresh_db.conn = MagicMock() fresh_db.conn.create_function = MagicMock() - @fresh_db.register_function(deterministic=True) - def to_lower_2(s): - return s.lower() - - fresh_db.conn.create_function.assert_called_with( - "to_lower_2", 1, to_lower_2, deterministic=True - ) - - first = True - - def side_effect(*args, **kwargs): - # Raise exception only first time this is called - nonlocal first - if first: - first = False - raise sqlite3.NotSupportedError() - - # But if sqlite3.NotSupportedError is raised, it tries again - fresh_db.conn.create_function.reset_mock() - fresh_db.conn.create_function.side_effect = side_effect - - @fresh_db.register_function(deterministic=True) - def to_lower_3(s): - return s.lower() - - # Should have been called once with deterministic=True and once without - assert fresh_db.conn.create_function.call_args_list == [ - call("to_lower_3", 1, to_lower_3, deterministic=True), - call("to_lower_3", 1, to_lower_3), - ] + try: + + @fresh_db.register_function(deterministic=True) + def to_lower_2(s): + return s.lower() + + fresh_db.conn.create_function.assert_called_with( + "to_lower_2", 1, to_lower_2, deterministic=True + ) + + first = True + + def side_effect(*args, **kwargs): + # Raise exception only first time this is called + nonlocal first + if first: + first = False + raise sqlite3.NotSupportedError() + + # But if sqlite3.NotSupportedError is raised, it tries again + fresh_db.conn.create_function.reset_mock() + fresh_db.conn.create_function.side_effect = side_effect + + @fresh_db.register_function(deterministic=True) + def to_lower_3(s): + return s.lower() + + # Should have been called once with deterministic=True and once without + assert fresh_db.conn.create_function.call_args_list == [ + call("to_lower_3", 1, to_lower_3, deterministic=True), + call("to_lower_3", 1, to_lower_3), + ] + finally: + # Close the original connection that was replaced with the mock + original_conn.close() def test_register_function_replace(fresh_db):