diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 763c6ad9..6eb7e389 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,6 +45,11 @@ jobs: run: mypy sqlite_utils tests - name: run flake8 run: flake8 + - name: run ty + if: matrix.os != 'windows-latest' + run: | + pip install uv + uv run ty check sqlite_utils - name: Check formatting run: black . --check - name: Check if cog needs to be run diff --git a/docs/conf.py b/docs/conf.py index 859c6b90..04a23014 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -79,7 +79,7 @@ def linkcode_resolve(domain, info): # # The short X.Y version. pipe = Popen("git describe --tags --always", stdout=PIPE, shell=True) -git_version = pipe.stdout.read().decode("utf8") +git_version = pipe.stdout.read().decode("utf8") if pipe.stdout else "" if git_version: version = git_version.rsplit("-", 1)[0] diff --git a/pyproject.toml b/pyproject.toml index 30ceeb0e..a50a3a8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dev = [ # flake8 "flake8", "flake8-pyproject", + "ty", ] docs = [ "beanbag-docutils>=2.0", diff --git a/sqlite_utils/cli.py b/sqlite_utils/cli.py index 78c9f91d..54de2655 100644 --- a/sqlite_utils/cli.py +++ b/sqlite_utils/cli.py @@ -1,4 +1,5 @@ import base64 +from typing import Any import click from click_default_group import DefaultGroup # type: ignore from datetime import datetime, timezone @@ -50,20 +51,19 @@ def _register_db_for_cleanup(db): ctx = click.get_current_context(silent=True) if ctx is None: return - if not hasattr(ctx, "_databases_to_close"): - ctx._databases_to_close = [] + if "_databases_to_close" not in ctx.meta: + ctx.meta["_databases_to_close"] = [] ctx.call_on_close(lambda: _close_databases(ctx)) - ctx._databases_to_close.append(db) + ctx.meta["_databases_to_close"].append(db) def _close_databases(ctx): """Close all databases registered for cleanup.""" - if hasattr(ctx, "_databases_to_close"): - for db in ctx._databases_to_close: - try: - db.close() - except Exception: - pass + for db in ctx.meta.get("_databases_to_close", []): + try: + db.close() + except Exception: + pass VALID_COLUMN_TYPES = ("INTEGER", "TEXT", "FLOAT", "REAL", "BLOB") @@ -294,6 +294,7 @@ def views( \b sqlite-utils views trees.db """ + assert tables.callback is not None tables.callback( path=path, fts4=False, @@ -338,7 +339,7 @@ def optimize(path, tables, no_vacuum, load_extension): tables = db.table_names(fts4=True) + db.table_names(fts5=True) with db.conn: for table in tables: - db[table].optimize() + db.table(table).optimize() if not no_vacuum: db.vacuum() @@ -366,7 +367,7 @@ def rebuild_fts(path, tables, load_extension): tables = db.table_names(fts4=True) + db.table_names(fts5=True) with db.conn: for table in tables: - db[table].rebuild_fts() + db.table(table).rebuild_fts() @cli.command() @@ -393,7 +394,7 @@ def analyze(path, names): else: db.analyze() except OperationalError as e: - raise click.ClickException(e) + raise click.ClickException(str(e)) @cli.command() @@ -496,7 +497,7 @@ def add_column( _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: - db[table].add_column( + db.table(table).add_column( col_name, col_type, fk=fk, fk_col=fk_col, not_null_default=not_null_default ) except OperationalError as ex: @@ -534,9 +535,11 @@ def add_foreign_key( _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: - db[table].add_foreign_key(column, other_table, other_column, ignore=ignore) + db.table(table).add_foreign_key( + column, other_table, other_column, ignore=ignore + ) except AlterError as e: - raise click.ClickException(e) + raise click.ClickException(str(e)) @cli.command(name="add-foreign-keys") @@ -571,7 +574,7 @@ def add_foreign_keys(path, foreign_key, load_extension): try: db.add_foreign_keys(tuples) except AlterError as e: - raise click.ClickException(e) + raise click.ClickException(str(e)) @cli.command(name="index-foreign-keys") @@ -644,7 +647,7 @@ def create_index( if col.startswith("-"): col = DescIndex(col[1:]) columns.append(col) - db[table].create_index( + db.table(table).create_index( columns, index_name=name, unique=unique, @@ -705,7 +708,7 @@ def enable_fts( replace=replace, ) except OperationalError as ex: - raise click.ClickException(ex) + raise click.ClickException(str(ex)) @cli.command(name="populate-fts") @@ -728,7 +731,7 @@ def populate_fts(path, table, column, load_extension): db = sqlite_utils.Database(path) _register_db_for_cleanup(db) _load_extensions(db, load_extension) - db[table].populate_fts(column) + db.table(table).populate_fts(column) @cli.command(name="disable-fts") @@ -750,7 +753,7 @@ def disable_fts(path, table, load_extension): db = sqlite_utils.Database(path) _register_db_for_cleanup(db) _load_extensions(db, load_extension) - db[table].disable_fts() + db.table(table).disable_fts() @cli.command(name="enable-wal") @@ -826,7 +829,7 @@ def enable_counts(path, tables, load_extension): if bad_tables: raise click.ClickException("Invalid tables: {}".format(bad_tables)) for table in tables: - db[table].enable_counts() + db.table(table).enable_counts() @cli.command(name="reset-counts") @@ -1036,13 +1039,14 @@ def insert_upsert_implementation( if csv or tsv: if sniff: # Read first 2048 bytes and use that to detect + assert sniff_buffer is not None first_bytes = sniff_buffer.peek(2048) dialect = csv_std.Sniffer().sniff( first_bytes.decode(encoding, "ignore") ) else: dialect = "excel-tab" if tsv else "excel" - csv_reader_args = {"dialect": dialect} + csv_reader_args: dict[str, Any] = {"dialect": dialect} if delimiter: csv_reader_args["delimiter"] = delimiter if quotechar: @@ -1146,7 +1150,7 @@ def insert_upsert_implementation( return try: - db[table].insert_all( + db.table(table).insert_all( docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs ) except Exception as e: @@ -1173,7 +1177,7 @@ def insert_upsert_implementation( else: raise if tracker is not None: - db[table].transform(types=tracker.types) + db.table(table).transform(types=tracker.types) # Clean up open file-like objects if sniff_buffer: @@ -1636,7 +1640,7 @@ def create_table( table ) ) - db[table].create( + db.table(table).create( coltypes, pk=pks[0] if len(pks) == 1 else pks, not_null=not_null, @@ -1667,7 +1671,7 @@ def duplicate(path, table, new_table, ignore, load_extension): _register_db_for_cleanup(db) _load_extensions(db, load_extension) try: - db[table].duplicate(new_table) + db.table(table).duplicate(new_table) except NoTable: if not ignore: raise click.ClickException('Table "{}" does not exist'.format(table)) @@ -2028,9 +2032,9 @@ def memory( if flatten: rows = (_flatten(row) for row in rows) - db[file_table].insert_all(rows, alter=True) + db.table(file_table).insert_all(rows, alter=True) if tracker is not None: - db[file_table].transform(types=tracker.types) + db.table(file_table).transform(types=tracker.types) # Add convenient t / t1 / t2 views view_names = ["t{}".format(i + 1)] if i == 0: @@ -2119,7 +2123,8 @@ def _execute_query( else: headers = [c[0] for c in cursor.description] if raw: - data = cursor.fetchone()[0] + row = cursor.fetchone() # type: ignore[union-attr] + data = row[0] if row else None if isinstance(data, bytes): sys.stdout.buffer.write(data) else: @@ -2200,7 +2205,7 @@ def search( _register_db_for_cleanup(db) _load_extensions(db, load_extension) # Check table exists - table_obj = db[dbtable] + table_obj = db.table(dbtable) if not table_obj.exists(): raise click.ClickException("Table '{}' does not exist".format(dbtable)) if not table_obj.detect_fts(): @@ -2612,10 +2617,10 @@ def transform( kwargs["add_foreign_keys"] = add_foreign_keys if sql: - for line in db[table].transform_sql(**kwargs): + for line in db.table(table).transform_sql(**kwargs): click.echo(line) else: - db[table].transform(**kwargs) + db.table(table).transform(**kwargs) @cli.command() @@ -2656,13 +2661,13 @@ def extract( db = sqlite_utils.Database(path) _register_db_for_cleanup(db) _load_extensions(db, load_extension) - kwargs = dict( + kwargs: dict[str, Any] = dict( columns=columns, table=other_table, fk_column=fk_column, rename=dict(rename), ) - db[table].extract(**kwargs) + db.table(table).extract(**kwargs) @cli.command(name="insert-files") @@ -2803,7 +2808,7 @@ def _content_text(p): _load_extensions(db, load_extension) try: with db.conn: - db[table].insert_all( + db.table(table).insert_all( to_insert(), pk=pks[0] if len(pks) == 1 else pks, alter=alter, @@ -3122,7 +3127,7 @@ def wrapped_fn(value): fn = wrapped_fn try: - db[table].convert( + db.table(table).convert( columns, fn, where=where, @@ -3212,7 +3217,7 @@ def add_geometry_column( _load_extensions(db, load_extension) db.init_spatialite() - if db[table].add_geometry_column( + if db.table(table).add_geometry_column( column_name, geometry_type, srid, coord_dimension, not_null ): click.echo(f"Added {geometry_type} column {column_name} to {table}") @@ -3250,7 +3255,7 @@ def create_spatial_index(db_path, table, column_name, load_extension): "You must add a geometry column before creating a spatial index" ) - db[table].create_spatial_index(column_name) + db.table(table).create_spatial_index(column_name) @cli.command(name="plugins") @@ -3361,7 +3366,10 @@ def _load_extensions(db, load_extension): db.conn.enable_load_extension(True) for ext in load_extension: if ext == "spatialite" and not os.path.exists(ext): - ext = find_spatialite() + found = find_spatialite() + if found is None: + raise click.ClickException("Could not find SpatiaLite extension") + ext = found if ":" in ext: path, _, entrypoint = ext.partition(":") db.conn.execute("SELECT load_extension(?, ?)", [path, entrypoint]) diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index f6fe5edb..0f720eff 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -41,7 +41,7 @@ from sqlite_utils.plugins import pm try: - from sqlite_dump import iterdump + from sqlite_dump import iterdump # type: ignore[import-not-found] except ImportError: iterdump = None @@ -328,6 +328,7 @@ class Database: _counts_table_name = "_counts" use_counts_table = False + conn: sqlite3.Connection def __init__( self, @@ -525,7 +526,7 @@ def attach(self, alias: str, filepath: Union[str, pathlib.Path]): self.execute(attach_sql) def query( - self, sql: str, params: Optional[Union[Iterable, dict]] = None + self, sql: str, params: Optional[Union[Sequence, Dict[str, Any]]] = None ) -> Generator[dict, None, None]: """ Execute ``sql`` and return an iterable of dictionaries representing each row. @@ -540,7 +541,7 @@ def query( yield dict(zip(keys, row)) def execute( - self, sql: str, parameters: Optional[Union[Iterable, dict]] = None + self, sql: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] = None ) -> sqlite3.Cursor: """ Execute SQL query and return a ``sqlite3.Cursor``. @@ -805,10 +806,11 @@ def cached_counts(self, tables: Optional[Iterable[str]] = None) -> Dict[str, int :param tables: Subset list of tables to return counts for. """ sql = 'select "table", count from {}'.format(self._counts_table_name) - if tables: - sql += ' where "table" in ({})'.format(", ".join("?" for table in tables)) + tables_list = list(tables) if tables else None + if tables_list: + sql += ' where "table" in ({})'.format(", ".join("?" for _ in tables_list)) try: - return {r[0]: r[1] for r in self.execute(sql, tables).fetchall()} + return {r[0]: r[1] for r in self.execute(sql, tables_list).fetchall()} except OperationalError: return {} @@ -817,7 +819,7 @@ def reset_counts(self): tables = [table for table in self.tables if table.has_counts_triggers] with self.conn: self._ensure_counts_table() - counts_table = self[self._counts_table_name] + counts_table = self.table(self._counts_table_name) counts_table.delete_where() counts_table.insert_all( {"table": table.name, "count": table.execute_count()} @@ -825,7 +827,7 @@ def reset_counts(self): ) def execute_returning_dicts( - self, sql: str, params: Optional[Union[Iterable, dict]] = None + self, sql: str, params: Optional[Union[Sequence, Dict[str, Any]]] = None ) -> List[dict]: return list(self.query(sql, params)) @@ -1273,7 +1275,7 @@ def add_foreign_keys(self, foreign_keys: Iterable[Tuple[str, str, str, str]]): def index_foreign_keys(self): "Create indexes for every foreign key column on every table in the database." for table_name in self.table_names(): - table = self[table_name] + table = self.table(table_name) existing_indexes = { i.columns[0] for i in table.indexes if len(i.columns) == 1 } @@ -1339,6 +1341,8 @@ def init_spatialite(self, path: Optional[str] = None) -> bool: """ if path is None: path = find_spatialite() + if path is None: + raise OSError("Could not find SpatiaLite extension") self.conn.enable_load_extension(True) self.conn.load_extension(path) @@ -3005,7 +3009,7 @@ def convert_value(v): bar.update(1) return jsonify_if_needed(fn(v)) - fn_name = fn.__name__ + fn_name = getattr(fn, "__name__", "fn") if fn_name == "": fn_name = f"lambda_{abs(hash(fn))}" self.db.register_function(convert_value, name=fn_name) @@ -3250,9 +3254,11 @@ def build_insert_queries_and_params( ) # We can populate .last_pk right here if num_records_processed == 1: - self.last_pk = tuple(record[pk] for pk in pks) - if len(self.last_pk) == 1: - self.last_pk = self.last_pk[0] + pk_values = tuple(record[pk] for pk in pks) + if len(pk_values) == 1: + self.last_pk = pk_values[0] + else: + self.last_pk = pk_values return queries_and_params def insert_chunk( diff --git a/sqlite_utils/plugins.py b/sqlite_utils/plugins.py index 1e45e623..8d6fb856 100644 --- a/sqlite_utils/plugins.py +++ b/sqlite_utils/plugins.py @@ -14,9 +14,10 @@ def get_plugins(): plugins = [] plugin_to_distinfo = dict(pm.list_plugin_distinfo()) for plugin in pm.get_plugins(): + hookcallers = pm.get_hookcallers(plugin) or [] plugin_info = { "name": plugin.__name__, - "hooks": [h.name for h in pm.get_hookcallers(plugin)], + "hooks": [h.name for h in hookcallers], } distinfo = plugin_to_distinfo.get(plugin) if distinfo: diff --git a/sqlite_utils/utils.py b/sqlite_utils/utils.py index 62826b76..7761415a 100644 --- a/sqlite_utils/utils.py +++ b/sqlite_utils/utils.py @@ -15,14 +15,14 @@ from . import recipes try: - import pysqlite3 as sqlite3 # noqa: F401 - from pysqlite3 import dbapi2 # noqa: F401 + import pysqlite3 as sqlite3 # type: ignore[import-not-found] # noqa: F401 + from pysqlite3 import dbapi2 # type: ignore[import-not-found] # noqa: F401 OperationalError = dbapi2.OperationalError except ImportError: try: - import sqlean as sqlite3 # noqa: F401 - from sqlean import dbapi2 # noqa: F401 + import sqlean as sqlite3 # type: ignore[import-not-found] # noqa: F401 + from sqlean import dbapi2 # type: ignore[import-not-found] # noqa: F401 OperationalError = dbapi2.OperationalError except ImportError: diff --git a/tests/conftest.py b/tests/conftest.py index 4a43dd54..3990d76e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ def pytest_configure(config): import sys - sys._called_from_test = True + sys._called_from_test = True # type: ignore[attr-defined] @pytest.fixture(autouse=True) @@ -24,9 +24,9 @@ def tracking_init(self, *args, **kwargs): original_init(self, *args, **kwargs) databases.append(self) - Database.__init__ = tracking_init + Database.__init__ = tracking_init # type: ignore[method-assign] yield - Database.__init__ = original_init + Database.__init__ = original_init # type: ignore[method-assign] for db in databases: try: db.close() diff --git a/tests/test_cli_convert.py b/tests/test_cli_convert.py index 6d1292be..387b3181 100644 --- a/tests/test_cli_convert.py +++ b/tests/test_cli_convert.py @@ -535,7 +535,7 @@ def test_convert_where(test_db_and_path): "id = :id", "-p", "id", - 2, + "2", ], ) assert result.exit_code == 0, result.output @@ -564,7 +564,7 @@ def test_convert_where_multi(fresh_db_and_path): "id = :id", "-p", "id", - 2, + "2", "--multi", ], ) diff --git a/tests/test_cli_memory.py b/tests/test_cli_memory.py index c8be35fd..cf619855 100644 --- a/tests/test_cli_memory.py +++ b/tests/test_cli_memory.py @@ -331,7 +331,7 @@ def test_memory_return_db(tmpdir): with open(path, "w") as f: f.write("id,name\n1,Cleo") - with click.Context(cli) as ctx: + with click.Context(cli) as ctx: # type: ignore[attr-defined] db = ctx.invoke(cli.commands["memory"], paths=(path,), return_db=True) assert db.table_names() == ["dogs"] diff --git a/tests/test_fts.py b/tests/test_fts.py index f7219bd5..9c635ffa 100644 --- a/tests/test_fts.py +++ b/tests/test_fts.py @@ -424,7 +424,7 @@ def test_enable_fts_error_message_on_views(): db = Database(memory=True) db.create_view("hello", "select 1 + 1") with pytest.raises(NotImplementedError) as e: - db["hello"].enable_fts() + db["hello"].enable_fts() # type: ignore[call-arg] assert e.value.args[0] == "enable_fts() is supported on tables but not on views" diff --git a/tests/test_gis.py b/tests/test_gis.py index a4ee75ec..1b5ed704 100644 --- a/tests/test_gis.py +++ b/tests/test_gis.py @@ -7,7 +7,7 @@ from sqlite_utils.utils import find_spatialite, sqlite3 try: - import sqlean + import sqlean # type: ignore[import-not-found] except ImportError: sqlean = None @@ -50,7 +50,7 @@ def test_add_geometry_column(): column_name="geometry", geometry_type="Point", srid=4326, - coord_dimension=2, + coord_dimension="XY", ) assert db["geometry_columns"].get(["locations", "geometry"]) == { diff --git a/tests/test_rows_from_file.py b/tests/test_rows_from_file.py index 5316b86a..a19fed6e 100644 --- a/tests/test_rows_from_file.py +++ b/tests/test_rows_from_file.py @@ -48,7 +48,7 @@ def test_rows_from_file_extra_fields_strategies(ignore_extras, extras_key, expec def test_rows_from_file_error_on_string_io(): with pytest.raises(TypeError) as ex: - rows_from_file(StringIO("id,name\r\n1,Cleo")) + rows_from_file(StringIO("id,name\r\n1,Cleo")) # type: ignore[arg-type] assert ex.value.args == ( "rows_from_file() requires a file-like object that supports peek(), such as io.BytesIO", )