From f77ca0ec0db12a6f8c730132ea3b1a10be539902 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 11 Dec 2025 16:22:22 -0800 Subject: [PATCH 1/5] Database can now work as a context manager, refs #692 --- docs/python-api.rst | 21 +++++++++++++++++++++ sqlite_utils/db.py | 6 ++++++ 2 files changed, 27 insertions(+) diff --git a/docs/python-api.rst b/docs/python-api.rst index f9371d594..267591acb 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -123,6 +123,27 @@ You can pass ``strict=True`` to enable `SQLite STRICT mode "Database": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + def close(self) -> None: "Close the SQLite connection, and the underlying database file" self.conn.close() From 81b0599078b5231b90eddbc2f9e5f6e977161e98 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 11 Dec 2025 16:23:38 -0800 Subject: [PATCH 2/5] Claude Code helped fix a ton of .close() warnings, refs #692 https://gistpreview.github.io/?730f0c5dc38528a1dd0615f330bd5481 --- sqlite_utils/cli.py | 72 ++++++++++++++++++++++++++++++++++-- tests/conftest.py | 8 +++- tests/test_analyze_tables.py | 3 ++ tests/test_cli.py | 14 +++++++ tests/test_cli_memory.py | 4 +- tests/test_hypothesis.py | 48 ++++++++++++------------ 6 files changed, 119 insertions(+), 30 deletions(-) diff --git a/sqlite_utils/cli.py b/sqlite_utils/cli.py index a9244394d..9fe8d0a18 100644 --- a/sqlite_utils/cli.py +++ b/sqlite_utils/cli.py @@ -44,6 +44,28 @@ CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) + +def _register_db_for_cleanup(db): + """Register a database to be closed when the Click context is cleaned up.""" + ctx = click.get_current_context(silent=True) + if ctx is None: + return + if not hasattr(ctx, "_databases_to_close"): + ctx._databases_to_close = [] + ctx.call_on_close(lambda: _close_databases(ctx)) + ctx._databases_to_close.append(db) + + +def _close_databases(ctx): + """Close all databases registered for cleanup.""" + if hasattr(ctx, "_databases_to_close"): + for db in ctx._databases_to_close: + try: + db.close() + except Exception: + pass + + VALID_COLUMN_TYPES = ("INTEGER", "TEXT", "FLOAT", "REAL", "BLOB") UNICODE_ERROR = """ @@ -183,6 +205,7 @@ def tables( sqlite-utils tables trees.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) headers = ["view" if views else "table"] if counts: @@ -309,6 +332,7 @@ def optimize(path, tables, no_vacuum, load_extension): sqlite-utils optimize chickens.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) if not tables: tables = db.table_names(fts4=True) + db.table_names(fts5=True) @@ -336,6 +360,7 @@ def rebuild_fts(path, tables, load_extension): sqlite-utils rebuild-fts chickens.db chickens """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) if not tables: tables = db.table_names(fts4=True) + db.table_names(fts5=True) @@ -360,6 +385,7 @@ def analyze(path, names): sqlite-utils analyze chickens.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) try: if names: for name in names: @@ -384,7 +410,9 @@ def vacuum(path): \b sqlite-utils vacuum chickens.db """ - sqlite_utils.Database(path).vacuum() + db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) + db.vacuum() @cli.command() @@ -403,6 +431,7 @@ def dump(path, load_extension): sqlite-utils dump chickens.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) for line in db.iterdump(): click.echo(line) @@ -464,6 +493,7 @@ def add_column( sqlite-utils add-column chickens.db chickens weight float """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: db[table].add_column( @@ -501,6 +531,7 @@ def add_foreign_key( sqlite-utils add-foreign-key my.db books author_id authors id """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: db[table].add_foreign_key(column, other_table, other_column, ignore=ignore) @@ -528,6 +559,7 @@ def add_foreign_keys(path, foreign_key, load_extension): authors country_id countries id """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) if len(foreign_key) % 4 != 0: raise click.ClickException( @@ -559,6 +591,7 @@ def index_foreign_keys(path, load_extension): sqlite-utils index-foreign-keys chickens.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) db.index_foreign_keys() @@ -603,6 +636,7 @@ def create_index( sqlite-utils create-index chickens.db chickens -- -name """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) # Treat -prefix as descending for columns columns = [] @@ -660,6 +694,7 @@ def enable_fts( fts_version = "FTS4" db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: db[table].enable_fts( @@ -691,6 +726,7 @@ def populate_fts(path, table, column, load_extension): sqlite-utils populate-fts chickens.db chickens name """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) db[table].populate_fts(column) @@ -712,6 +748,7 @@ def disable_fts(path, table, load_extension): sqlite-utils disable-fts chickens.db chickens """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) db[table].disable_fts() @@ -734,6 +771,7 @@ def enable_wal(path, load_extension): """ for path_ in path: db = sqlite_utils.Database(path_) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) db.enable_wal() @@ -756,6 +794,7 @@ def disable_wal(path, load_extension): """ for path_ in path: db = sqlite_utils.Database(path_) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) db.disable_wal() @@ -777,6 +816,7 @@ def enable_counts(path, tables, load_extension): sqlite-utils enable-counts chickens.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) if not tables: db.enable_counts() @@ -805,6 +845,7 @@ def reset_counts(path, load_extension): sqlite-utils reset-counts chickens.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) db.reset_counts() @@ -964,6 +1005,7 @@ def insert_upsert_implementation( strict=False, ): db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) _maybe_register_functions(db, functions) if (delimiter or quotechar or sniff or no_headers) and not tsv: @@ -1480,6 +1522,7 @@ def create_database(path, enable_wal, init_spatialite, load_extension): sqlite-utils create-database trees.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) if enable_wal: db.enable_wal() @@ -1569,6 +1612,7 @@ def create_table( Valid column types are text, integer, float and blob. """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) if len(columns) % 2 == 1: raise click.ClickException( @@ -1620,6 +1664,7 @@ def duplicate(path, table, new_table, ignore, load_extension): Create a duplicate of this table, copying across the schema and all row data. """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: db[table].duplicate(new_table) @@ -1643,6 +1688,7 @@ def rename_table(path, table, new_name, ignore, load_extension): Rename this table. """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: db.rename_table(table, new_name) @@ -1671,6 +1717,7 @@ def drop_table(path, table, ignore, load_extension): sqlite-utils drop-table chickens.db chickens """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: db[table].drop(ignore=ignore) @@ -1707,6 +1754,7 @@ def create_view(path, view, select, ignore, replace, load_extension): 'select * from chickens where weight > 3' """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) # Does view already exist? if view in db.view_names(): @@ -1741,6 +1789,7 @@ def drop_view(path, view, ignore, load_extension): sqlite-utils drop-view chickens.db heavy_chickens """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: db[view].drop(ignore=ignore) @@ -1805,6 +1854,7 @@ def query( -p age 1 """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) for alias, attach_path in attach: db.attach(alias, attach_path) _load_extensions(db, load_extension) @@ -1939,6 +1989,8 @@ def memory( sqlite-utils memory animals.csv --schema """ db = sqlite_utils.Database(memory=True) + if not return_db: + _register_db_for_cleanup(db) # If --dump or --save or --analyze used but no paths detected, assume SQL query is a path: if (dump or save or schema or analyze) and not paths: @@ -2004,6 +2056,7 @@ def memory( if save: db2 = sqlite_utils.Database(save) + _register_db_for_cleanup(db2) for line in db.iterdump(): db2.execute(line) return @@ -2140,6 +2193,7 @@ def search( sqlite-utils search data.db chickens lila """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) # Check table exists table_obj = db[dbtable] @@ -2307,7 +2361,9 @@ def triggers( """ sql = "select name, tbl_name as \"table\", sql from sqlite_master where type = 'trigger'" if tables: - quote = sqlite_utils.Database(memory=True).quote + _quote_db = sqlite_utils.Database(memory=True) + _register_db_for_cleanup(_quote_db) + quote = _quote_db.quote sql += ' and "table" in ({})'.format( ", ".join(quote(table) for table in tables) ) @@ -2372,7 +2428,9 @@ def indexes( sqlite_master.type = 'table' """ if tables: - quote = sqlite_utils.Database(memory=True).quote + _quote_db = sqlite_utils.Database(memory=True) + _register_db_for_cleanup(_quote_db) + quote = _quote_db.quote sql += " and sqlite_master.name in ({})".format( ", ".join(quote(table) for table in tables) ) @@ -2415,6 +2473,7 @@ def schema( sqlite-utils schema trees.db """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) if tables: for table in tables: @@ -2507,6 +2566,7 @@ def transform( --rename column2 column_renamed """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) types = {} kwargs = {} @@ -2590,6 +2650,7 @@ def extract( sqlite-utils extract trees.db Street_Trees species """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) kwargs = dict( columns=columns, @@ -2734,6 +2795,7 @@ def _content_text(p): yield row db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: with db.conn: @@ -2792,6 +2854,7 @@ def analyze_tables( sqlite-utils analyze-tables data.db trees """ db = sqlite_utils.Database(path) + _register_db_for_cleanup(db) _load_extensions(db, load_extension) _analyze(db, tables, columns, save, common_limit, no_most, no_least) @@ -2991,6 +3054,7 @@ def convert( ): sqlite3.enable_callback_tracebacks(True) db = sqlite_utils.Database(db_path) + _register_db_for_cleanup(db) if output is not None and len(columns) > 1: raise click.ClickException("Cannot use --output with more than one column") if multi and len(columns) > 1: @@ -3133,6 +3197,7 @@ def add_geometry_column( By default, this command will try to load the SpatiaLite extension from usual paths. To load it from a specific path, use --load-extension.""" db = sqlite_utils.Database(db_path) + _register_db_for_cleanup(db) if not db[table].exists(): raise click.ClickException( "You must create a table before adding a geometry column" @@ -3165,6 +3230,7 @@ def create_spatial_index(db_path, table, column_name, load_extension): By default, this command will try to load the SpatiaLite extension from usual paths. To load it from a specific path, use --load-extension.""" db = sqlite_utils.Database(db_path) + _register_db_for_cleanup(db) if not db[table].exists(): raise click.ClickException( "You must create a table and add a geometry column before creating a spatial index" diff --git a/tests/conftest.py b/tests/conftest.py index 3932f05c3..a9a47bc5a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,9 @@ def pytest_configure(config): @pytest.fixture def fresh_db(): - return Database(memory=True) + db = Database(memory=True) + yield db + db.close() @pytest.fixture @@ -30,7 +32,8 @@ def existing_db(): INSERT INTO foo (text) values ("three"); """ ) - return database + yield database + database.close() @pytest.fixture @@ -38,4 +41,5 @@ def db_path(tmpdir): path = str(tmpdir / "test.db") db = sqlite3.connect(path) db.executescript(CREATE_TABLES) + db.close() return path diff --git a/tests/test_analyze_tables.py b/tests/test_analyze_tables.py index 9634cfcd3..2867bb5f0 100644 --- a/tests/test_analyze_tables.py +++ b/tests/test_analyze_tables.py @@ -44,6 +44,7 @@ def big_db_to_analyze_path(tmpdir): } ) db["stuff"].insert_all(to_insert) + db.close() return path @@ -137,6 +138,7 @@ def db_to_analyze_path(db_to_analyze, tmpdir): db = sqlite3.connect(path) sql = "\n".join(db_to_analyze.iterdump()) db.executescript(sql) + db.close() return path @@ -311,6 +313,7 @@ def test_analyze_table_validate_columns(tmpdir, args, expected_error): "age": 5, } ) + db.close() result = CliRunner().invoke( cli.cli, ["analyze-tables", path] + args, diff --git a/tests/test_cli.py b/tests/test_cli.py index 4198727e2..822748abb 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -82,6 +82,7 @@ def test_tables_counts_and_columns(db_path): db = Database(db_path) with db.conn: db["lots"].insert_all([{"id": i, "age": i + 1} for i in range(30)]) + db.close() result = CliRunner().invoke(cli.cli, ["tables", "--counts", "--columns", db_path]) assert ( '[{"table": "Gosh", "count": 0, "columns": ["c1", "c2", "c3"]},\n' @@ -117,6 +118,7 @@ def test_tables_counts_and_columns_csv(db_path, format, expected): db = Database(db_path) with db.conn: db["lots"].insert_all([{"id": i, "age": i + 1} for i in range(30)]) + db.close() result = CliRunner().invoke( cli.cli, ["tables", "--counts", "--columns", format, db_path] ) @@ -127,6 +129,7 @@ def test_tables_schema(db_path): db = Database(db_path) with db.conn: db["lots"].insert_all([{"id": i, "age": i + 1} for i in range(30)]) + db.close() result = CliRunner().invoke(cli.cli, ["tables", "--schema", db_path]) assert ( '[{"table": "Gosh", "schema": "CREATE TABLE Gosh (c1 text, c2 text, c3 text)"},\n' @@ -188,6 +191,7 @@ def test_output_table(db_path, options, expected): for i in range(4) ] ) + db.close() result = CliRunner().invoke(cli.cli, ["rows", db_path, "rows"] + options) assert result.exit_code == 0 assert expected == result.output.strip() @@ -243,6 +247,7 @@ def test_create_index(db_path): CliRunner().invoke(cli.cli, create_index_unique_args + [option]).exit_code == 0 ) + db.close() def test_create_index_analyze(db_path): @@ -622,6 +627,7 @@ def test_optimize(db_path, tables): ) db["Gosh"].enable_fts(["c1", "c2", "c3"], fts_version="FTS4") db["Gosh2"].enable_fts(["c1", "c2", "c3"], fts_version="FTS5") + db.close() size_before_optimize = os.stat(db_path).st_size result = CliRunner().invoke(cli.cli, ["optimize", db_path] + tables) assert result.exit_code == 0 @@ -1450,6 +1456,7 @@ def test_drop_table_error(): with runner.isolated_filesystem(): db = Database("test.db") db["t"].create({"pk": int}, pk="pk") + db.close() result = runner.invoke( cli.cli, [ @@ -1474,6 +1481,7 @@ def test_drop_view(): db = Database("test.db") db.create_view("hello", "select 1") assert "hello" in db.view_names() + db.close() result = runner.invoke( cli.cli, [ @@ -1483,7 +1491,9 @@ def test_drop_view(): ], ) assert result.exit_code == 0 + db = Database("test.db") assert "hello" not in db.view_names() + db.close() def test_drop_view_error(): @@ -1491,6 +1501,7 @@ def test_drop_view_error(): with runner.isolated_filesystem(): db = Database("test.db") db["t"].create({"pk": int}, pk="pk") + db.close() result = runner.invoke( cli.cli, [ @@ -1728,10 +1739,13 @@ def test_transform(db_path, args, expected_schema): defaults={"age": 1}, pk="id", ) + db.close() result = CliRunner().invoke(cli.cli, ["transform", db_path, "dogs"] + args) print(result.output) assert result.exit_code == 0 + db = Database(db_path) schema = db["dogs"].schema + db.close() assert schema == expected_schema diff --git a/tests/test_cli_memory.py b/tests/test_cli_memory.py index 8483963bb..3d0e81291 100644 --- a/tests/test_cli_memory.py +++ b/tests/test_cli_memory.py @@ -328,9 +328,11 @@ def test_memory_return_db(tmpdir): from sqlite_utils.cli import cli path = str(tmpdir / "dogs.csv") - open(path, "w").write("id,name\n1,Cleo") + with open(path, "w") as f: + f.write("id,name\n1,Cleo") with click.Context(cli) as ctx: db = ctx.invoke(cli.commands["memory"], paths=(path,), return_db=True) assert db.table_names() == ["dogs"] + db.close() diff --git a/tests/test_hypothesis.py b/tests/test_hypothesis.py index f12f86506..54bdc9db7 100644 --- a/tests/test_hypothesis.py +++ b/tests/test_hypothesis.py @@ -6,39 +6,39 @@ # SQLite integers are -(2^63) to 2^63 - 1 @given(st.integers(-9223372036854775808, 9223372036854775807)) def test_roundtrip_integers(integer): - db = sqlite_utils.Database(memory=True) - row = { - "integer": integer, - } - db["test"].insert(row) - assert list(db["test"].rows) == [row] + with sqlite_utils.Database(memory=True) as db: + row = { + "integer": integer, + } + db["test"].insert(row) + assert list(db["test"].rows) == [row] @given(st.text()) def test_roundtrip_text(text): - db = sqlite_utils.Database(memory=True) - row = { - "text": text, - } - db["test"].insert(row) - assert list(db["test"].rows) == [row] + with sqlite_utils.Database(memory=True) as db: + row = { + "text": text, + } + db["test"].insert(row) + assert list(db["test"].rows) == [row] @given(st.binary(max_size=1024 * 1024)) def test_roundtrip_binary(binary): - db = sqlite_utils.Database(memory=True) - row = { - "binary": binary, - } - db["test"].insert(row) - assert list(db["test"].rows) == [row] + with sqlite_utils.Database(memory=True) as db: + row = { + "binary": binary, + } + db["test"].insert(row) + assert list(db["test"].rows) == [row] @given(st.floats(allow_nan=False)) def test_roundtrip_floats(floats): - db = sqlite_utils.Database(memory=True) - row = { - "floats": floats, - } - db["test"].insert(row) - assert list(db["test"].rows) == [row] + with sqlite_utils.Database(memory=True) as db: + row = { + "floats": floats, + } + db["test"].insert(row) + assert list(db["test"].rows) == [row] From 77359bea300bc1a08a95f6bf538fae463792e399 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 11 Dec 2025 16:30:07 -0800 Subject: [PATCH 3/5] New autouse fixture to help with test warnings Refs https://github.com/simonw/sqlite-utils/issues/692#issuecomment-3644371889 --- tests/conftest.py | 27 ++++++++++++++++---- tests/test_analyze_tables.py | 2 -- tests/test_cli.py | 14 ----------- tests/test_cli_memory.py | 1 - tests/test_hypothesis.py | 48 ++++++++++++++++++------------------ 5 files changed, 46 insertions(+), 46 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a9a47bc5a..4a43dd546 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,11 +14,29 @@ def pytest_configure(config): sys._called_from_test = True +@pytest.fixture(autouse=True) +def close_all_databases(): + """Automatically close all Database objects created during a test.""" + databases = [] + original_init = Database.__init__ + + def tracking_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + databases.append(self) + + Database.__init__ = tracking_init + yield + Database.__init__ = original_init + for db in databases: + try: + db.close() + except Exception: + pass + + @pytest.fixture def fresh_db(): - db = Database(memory=True) - yield db - db.close() + return Database(memory=True) @pytest.fixture @@ -32,8 +50,7 @@ def existing_db(): INSERT INTO foo (text) values ("three"); """ ) - yield database - database.close() + return database @pytest.fixture diff --git a/tests/test_analyze_tables.py b/tests/test_analyze_tables.py index 2867bb5f0..4618eff1f 100644 --- a/tests/test_analyze_tables.py +++ b/tests/test_analyze_tables.py @@ -44,7 +44,6 @@ def big_db_to_analyze_path(tmpdir): } ) db["stuff"].insert_all(to_insert) - db.close() return path @@ -313,7 +312,6 @@ def test_analyze_table_validate_columns(tmpdir, args, expected_error): "age": 5, } ) - db.close() result = CliRunner().invoke( cli.cli, ["analyze-tables", path] + args, diff --git a/tests/test_cli.py b/tests/test_cli.py index 822748abb..4198727e2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -82,7 +82,6 @@ def test_tables_counts_and_columns(db_path): db = Database(db_path) with db.conn: db["lots"].insert_all([{"id": i, "age": i + 1} for i in range(30)]) - db.close() result = CliRunner().invoke(cli.cli, ["tables", "--counts", "--columns", db_path]) assert ( '[{"table": "Gosh", "count": 0, "columns": ["c1", "c2", "c3"]},\n' @@ -118,7 +117,6 @@ def test_tables_counts_and_columns_csv(db_path, format, expected): db = Database(db_path) with db.conn: db["lots"].insert_all([{"id": i, "age": i + 1} for i in range(30)]) - db.close() result = CliRunner().invoke( cli.cli, ["tables", "--counts", "--columns", format, db_path] ) @@ -129,7 +127,6 @@ def test_tables_schema(db_path): db = Database(db_path) with db.conn: db["lots"].insert_all([{"id": i, "age": i + 1} for i in range(30)]) - db.close() result = CliRunner().invoke(cli.cli, ["tables", "--schema", db_path]) assert ( '[{"table": "Gosh", "schema": "CREATE TABLE Gosh (c1 text, c2 text, c3 text)"},\n' @@ -191,7 +188,6 @@ def test_output_table(db_path, options, expected): for i in range(4) ] ) - db.close() result = CliRunner().invoke(cli.cli, ["rows", db_path, "rows"] + options) assert result.exit_code == 0 assert expected == result.output.strip() @@ -247,7 +243,6 @@ def test_create_index(db_path): CliRunner().invoke(cli.cli, create_index_unique_args + [option]).exit_code == 0 ) - db.close() def test_create_index_analyze(db_path): @@ -627,7 +622,6 @@ def test_optimize(db_path, tables): ) db["Gosh"].enable_fts(["c1", "c2", "c3"], fts_version="FTS4") db["Gosh2"].enable_fts(["c1", "c2", "c3"], fts_version="FTS5") - db.close() size_before_optimize = os.stat(db_path).st_size result = CliRunner().invoke(cli.cli, ["optimize", db_path] + tables) assert result.exit_code == 0 @@ -1456,7 +1450,6 @@ def test_drop_table_error(): with runner.isolated_filesystem(): db = Database("test.db") db["t"].create({"pk": int}, pk="pk") - db.close() result = runner.invoke( cli.cli, [ @@ -1481,7 +1474,6 @@ def test_drop_view(): db = Database("test.db") db.create_view("hello", "select 1") assert "hello" in db.view_names() - db.close() result = runner.invoke( cli.cli, [ @@ -1491,9 +1483,7 @@ def test_drop_view(): ], ) assert result.exit_code == 0 - db = Database("test.db") assert "hello" not in db.view_names() - db.close() def test_drop_view_error(): @@ -1501,7 +1491,6 @@ def test_drop_view_error(): with runner.isolated_filesystem(): db = Database("test.db") db["t"].create({"pk": int}, pk="pk") - db.close() result = runner.invoke( cli.cli, [ @@ -1739,13 +1728,10 @@ def test_transform(db_path, args, expected_schema): defaults={"age": 1}, pk="id", ) - db.close() result = CliRunner().invoke(cli.cli, ["transform", db_path, "dogs"] + args) print(result.output) assert result.exit_code == 0 - db = Database(db_path) schema = db["dogs"].schema - db.close() assert schema == expected_schema diff --git a/tests/test_cli_memory.py b/tests/test_cli_memory.py index 3d0e81291..c8be35fda 100644 --- a/tests/test_cli_memory.py +++ b/tests/test_cli_memory.py @@ -335,4 +335,3 @@ def test_memory_return_db(tmpdir): db = ctx.invoke(cli.commands["memory"], paths=(path,), return_db=True) assert db.table_names() == ["dogs"] - db.close() diff --git a/tests/test_hypothesis.py b/tests/test_hypothesis.py index 54bdc9db7..f12f86506 100644 --- a/tests/test_hypothesis.py +++ b/tests/test_hypothesis.py @@ -6,39 +6,39 @@ # SQLite integers are -(2^63) to 2^63 - 1 @given(st.integers(-9223372036854775808, 9223372036854775807)) def test_roundtrip_integers(integer): - with sqlite_utils.Database(memory=True) as db: - row = { - "integer": integer, - } - db["test"].insert(row) - assert list(db["test"].rows) == [row] + db = sqlite_utils.Database(memory=True) + row = { + "integer": integer, + } + db["test"].insert(row) + assert list(db["test"].rows) == [row] @given(st.text()) def test_roundtrip_text(text): - with sqlite_utils.Database(memory=True) as db: - row = { - "text": text, - } - db["test"].insert(row) - assert list(db["test"].rows) == [row] + db = sqlite_utils.Database(memory=True) + row = { + "text": text, + } + db["test"].insert(row) + assert list(db["test"].rows) == [row] @given(st.binary(max_size=1024 * 1024)) def test_roundtrip_binary(binary): - with sqlite_utils.Database(memory=True) as db: - row = { - "binary": binary, - } - db["test"].insert(row) - assert list(db["test"].rows) == [row] + db = sqlite_utils.Database(memory=True) + row = { + "binary": binary, + } + db["test"].insert(row) + assert list(db["test"].rows) == [row] @given(st.floats(allow_nan=False)) def test_roundtrip_floats(floats): - with sqlite_utils.Database(memory=True) as db: - row = { - "floats": floats, - } - db["test"].insert(row) - assert list(db["test"].rows) == [row] + db = sqlite_utils.Database(memory=True) + row = { + "floats": floats, + } + db["test"].insert(row) + assert list(db["test"].rows) == [row] From dc9947a5e17be38f58d1a5491d12f34d465a3c56 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 11 Dec 2025 16:46:05 -0800 Subject: [PATCH 4/5] Fix all remaining resource warnings, refs #693 https://gistpreview.github.io/?0bb8e869b82f6ff0db647de755182502 --- sqlite_utils/cli.py | 52 +++++++++++++------------ sqlite_utils/utils.py | 27 ++++++++++++- tests/test_cli.py | 4 +- tests/test_introspect.py | 10 ++++- tests/test_plugins.py | 4 +- tests/test_recreate.py | 7 +++- tests/test_register_function.py | 67 ++++++++++++++++++--------------- 7 files changed, 110 insertions(+), 61 deletions(-) diff --git a/sqlite_utils/cli.py b/sqlite_utils/cli.py index 9fe8d0a18..78c9f91d3 100644 --- a/sqlite_utils/cli.py +++ b/sqlite_utils/cli.py @@ -905,7 +905,7 @@ def inner(fn): required=True, ), click.argument("table"), - click.argument("file", type=click.File("rb"), required=True), + click.argument("file", type=click.File("rb", lazy=True), required=True), click.option( "--pk", help="Columns to use as the primary key, e.g. id", @@ -2000,6 +2000,7 @@ def memory( for i, path in enumerate(paths): # Path may have a :format suffix fp = None + should_close_fp = False if ":" in path and path.rsplit(":", 1)[-1].upper() in Format.__members__: path, suffix = path.rsplit(":", 1) format = Format[suffix.upper()] @@ -2017,29 +2018,32 @@ def memory( file_table = stem stem_counts[stem] = stem_counts.get(stem, 1) + 1 fp = file_path.open("rb") - rows, format_used = rows_from_file(fp, format=format, encoding=encoding) - tracker = None - if format_used in (Format.CSV, Format.TSV) and not no_detect_types: - tracker = TypeTracker() - rows = tracker.wrap(rows) - if flatten: - rows = (_flatten(row) for row in rows) - - db[file_table].insert_all(rows, alter=True) - if tracker is not None: - db[file_table].transform(types=tracker.types) - # Add convenient t / t1 / t2 views - view_names = ["t{}".format(i + 1)] - if i == 0: - view_names.append("t") - for view_name in view_names: - if not db[view_name].exists(): - db.create_view( - view_name, "select * from {}".format(quote_identifier(file_table)) - ) - - if fp: - fp.close() + should_close_fp = True + try: + rows, format_used = rows_from_file(fp, format=format, encoding=encoding) + tracker = None + if format_used in (Format.CSV, Format.TSV) and not no_detect_types: + tracker = TypeTracker() + rows = tracker.wrap(rows) + if flatten: + rows = (_flatten(row) for row in rows) + + db[file_table].insert_all(rows, alter=True) + if tracker is not None: + db[file_table].transform(types=tracker.types) + # Add convenient t / t1 / t2 views + view_names = ["t{}".format(i + 1)] + if i == 0: + view_names.append("t") + for view_name in view_names: + if not db[view_name].exists(): + db.create_view( + view_name, + "select * from {}".format(quote_identifier(file_table)), + ) + finally: + if should_close_fp and fp: + fp.close() if analyze: _analyze(db, tables=None, columns=None, save=False) diff --git a/sqlite_utils/utils.py b/sqlite_utils/utils.py index 87cf76450..b90329833 100644 --- a/sqlite_utils/utils.py +++ b/sqlite_utils/utils.py @@ -9,7 +9,29 @@ import os import sys from . import recipes -from typing import Dict, cast, BinaryIO, Iterable, Optional, Tuple, Type +from typing import Dict, cast, BinaryIO, Iterable, Iterator, Optional, Tuple, Type + + +class _CloseableIterator(Iterator[dict]): + """Iterator wrapper that closes a file when iteration is complete.""" + + def __init__(self, iterator: Iterator[dict], closeable: io.IOBase): + self._iterator = iterator + self._closeable = closeable + + def __iter__(self) -> "_CloseableIterator": + return self + + def __next__(self) -> dict: + try: + return next(self._iterator) + except StopIteration: + self._closeable.close() + raise + + def close(self) -> None: + self._closeable.close() + import click @@ -299,7 +321,8 @@ class Format(enum.Enum): reader = csv.DictReader(decoded_fp, dialect=dialect) else: reader = csv.DictReader(decoded_fp) - return _extra_key_strategy(reader, ignore_extras, extras_key), Format.CSV + rows = _extra_key_strategy(reader, ignore_extras, extras_key) + return _CloseableIterator(iter(rows), decoded_fp), Format.CSV elif format == Format.TSV: rows = rows_from_file( fp, format=Format.CSV, dialect=csv.excel_tab, encoding=encoding diff --git a/tests/test_cli.py b/tests/test_cli.py index 4198727e2..40c3595c5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -19,9 +19,11 @@ def _supports_pragma_function_list(): db = Database(memory=True) try: db.execute("select * from pragma_function_list()") + return True except Exception: return False - return True + finally: + db.close() def _has_compiled_ext(): diff --git a/tests/test_introspect.py b/tests/test_introspect.py index 3df169d8a..ab61c1587 100644 --- a/tests/test_introspect.py +++ b/tests/test_introspect.py @@ -2,6 +2,14 @@ import pytest +def _check_supports_strict(): + """Check if SQLite supports strict tables without leaking the database.""" + db = Database(memory=True) + result = db.supports_strict + db.close() + return result + + def test_table_names(existing_db): assert ["foo"] == existing_db.table_names() @@ -282,7 +290,7 @@ def test_use_rowid(fresh_db): @pytest.mark.skipif( - not Database(memory=True).supports_strict, + not _check_supports_strict(), reason="Needs SQLite version that supports strict", ) @pytest.mark.parametrize( diff --git a/tests/test_plugins.py b/tests/test_plugins.py index c6c805944..1d459c992 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -9,9 +9,11 @@ def _supports_pragma_function_list(): db = Database(memory=True) try: db.execute("select * from pragma_function_list()") + return True except Exception: return False - return True + finally: + db.close() def test_register_commands(): diff --git a/tests/test_recreate.py b/tests/test_recreate.py index 49dba2bf1..224d18241 100644 --- a/tests/test_recreate.py +++ b/tests/test_recreate.py @@ -14,8 +14,11 @@ def test_recreate_ignored_for_in_memory(): def test_recreate_not_allowed_for_connection(): conn = sqlite3.connect(":memory:") - with pytest.raises(AssertionError): - Database(conn, recreate=True) + try: + with pytest.raises(AssertionError): + Database(conn, recreate=True) + finally: + conn.close() @pytest.mark.parametrize( diff --git a/tests/test_register_function.py b/tests/test_register_function.py index e2591f4e6..618bf1e5f 100644 --- a/tests/test_register_function.py +++ b/tests/test_register_function.py @@ -42,39 +42,46 @@ def to_lower(s): def test_register_function_deterministic_tries_again_if_exception_raised(fresh_db): + # Save the original connection so we can close it later + original_conn = fresh_db.conn fresh_db.conn = MagicMock() fresh_db.conn.create_function = MagicMock() - @fresh_db.register_function(deterministic=True) - def to_lower_2(s): - return s.lower() - - fresh_db.conn.create_function.assert_called_with( - "to_lower_2", 1, to_lower_2, deterministic=True - ) - - first = True - - def side_effect(*args, **kwargs): - # Raise exception only first time this is called - nonlocal first - if first: - first = False - raise sqlite3.NotSupportedError() - - # But if sqlite3.NotSupportedError is raised, it tries again - fresh_db.conn.create_function.reset_mock() - fresh_db.conn.create_function.side_effect = side_effect - - @fresh_db.register_function(deterministic=True) - def to_lower_3(s): - return s.lower() - - # Should have been called once with deterministic=True and once without - assert fresh_db.conn.create_function.call_args_list == [ - call("to_lower_3", 1, to_lower_3, deterministic=True), - call("to_lower_3", 1, to_lower_3), - ] + try: + + @fresh_db.register_function(deterministic=True) + def to_lower_2(s): + return s.lower() + + fresh_db.conn.create_function.assert_called_with( + "to_lower_2", 1, to_lower_2, deterministic=True + ) + + first = True + + def side_effect(*args, **kwargs): + # Raise exception only first time this is called + nonlocal first + if first: + first = False + raise sqlite3.NotSupportedError() + + # But if sqlite3.NotSupportedError is raised, it tries again + fresh_db.conn.create_function.reset_mock() + fresh_db.conn.create_function.side_effect = side_effect + + @fresh_db.register_function(deterministic=True) + def to_lower_3(s): + return s.lower() + + # Should have been called once with deterministic=True and once without + assert fresh_db.conn.create_function.call_args_list == [ + call("to_lower_3", 1, to_lower_3, deterministic=True), + call("to_lower_3", 1, to_lower_3), + ] + finally: + # Close the original connection that was replaced with the mock + original_conn.close() def test_register_function_replace(fresh_db): From 50a979b63da08aebaaffa62873c03afc3c7bf72a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 11 Dec 2025 16:48:23 -0800 Subject: [PATCH 5/5] Fix lint error --- sqlite_utils/utils.py | 46 +++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/sqlite_utils/utils.py b/sqlite_utils/utils.py index b90329833..62826b76a 100644 --- a/sqlite_utils/utils.py +++ b/sqlite_utils/utils.py @@ -8,33 +8,12 @@ import json import os import sys -from . import recipes from typing import Dict, cast, BinaryIO, Iterable, Iterator, Optional, Tuple, Type - -class _CloseableIterator(Iterator[dict]): - """Iterator wrapper that closes a file when iteration is complete.""" - - def __init__(self, iterator: Iterator[dict], closeable: io.IOBase): - self._iterator = iterator - self._closeable = closeable - - def __iter__(self) -> "_CloseableIterator": - return self - - def __next__(self) -> dict: - try: - return next(self._iterator) - except StopIteration: - self._closeable.close() - raise - - def close(self) -> None: - self._closeable.close() - - import click +from . import recipes + try: import pysqlite3 as sqlite3 # noqa: F401 from pysqlite3 import dbapi2 # noqa: F401 @@ -65,6 +44,27 @@ def close(self) -> None: ORIGINAL_CSV_FIELD_SIZE_LIMIT = csv.field_size_limit() +class _CloseableIterator(Iterator[dict]): + """Iterator wrapper that closes a file when iteration is complete.""" + + def __init__(self, iterator: Iterator[dict], closeable: io.IOBase): + self._iterator = iterator + self._closeable = closeable + + def __iter__(self) -> "_CloseableIterator": + return self + + def __next__(self) -> dict: + try: + return next(self._iterator) + except StopIteration: + self._closeable.close() + raise + + def close(self) -> None: + self._closeable.close() + + def maximize_csv_field_size_limit(): """ Increase the CSV field size limit to the maximum possible.