diff --git a/docs/cli-reference.rst b/docs/cli-reference.rst index db0ef359..6231dbf2 100644 --- a/docs/cli-reference.rst +++ b/docs/cli-reference.rst @@ -615,11 +615,13 @@ See :ref:`cli_convert`. The following common operations are available as recipe functions: - r.jsonsplit(value, delimiter=',', type=) + r.jsonsplit(value: 'str', delimiter: 'str' = ',', type: 'Callable[[str], + object]' = ) -> 'str' Convert a string like a,b,c into a JSON array ["a", "b", "c"] - r.parsedate(value, dayfirst=False, yearfirst=False, errors=None) + r.parsedate(value: 'str', dayfirst: 'bool' = False, yearfirst: 'bool' = False, + errors: 'Optional[object]' = None) -> 'Optional[str]' Parse a date and convert it to ISO date format: yyyy-mm-dd @@ -628,7 +630,8 @@ See :ref:`cli_convert`. - errors=r.IGNORE to ignore values that cannot be parsed - errors=r.SET_NULL to set values that cannot be parsed to null - r.parsedatetime(value, dayfirst=False, yearfirst=False, errors=None) + r.parsedatetime(value: 'str', dayfirst: 'bool' = False, yearfirst: 'bool' = + False, errors: 'Optional[object]' = None) -> 'Optional[str]' Parse a datetime and convert it to ISO datetime format: yyyy-mm-ddTHH:MM:SS diff --git a/mypy.ini b/mypy.ini index 768d1822..de0dc834 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,35 @@ [mypy] +python_version = 3.10 +warn_return_any = False +warn_unused_configs = True +warn_redundant_casts = False +warn_unused_ignores = False +check_untyped_defs = True +disallow_untyped_defs = False +disallow_incomplete_defs = False +no_implicit_optional = True +strict_equality = True -[mypy-pysqlite3,sqlean,sqlite_dump] -ignore_missing_imports = True \ No newline at end of file +[mypy-sqlite_utils.cli] +ignore_errors = True + +[mypy-pysqlite3.*] +ignore_missing_imports = True + +[mypy-sqlean.*] +ignore_missing_imports = True + +[mypy-sqlite_dump.*] +ignore_missing_imports = True + +[mypy-sqlite_fts4.*] +ignore_missing_imports = True + +[mypy-pandas.*] +ignore_missing_imports = True + +[mypy-numpy.*] +ignore_missing_imports = True + +[mypy-tests.*] +ignore_errors = True diff --git a/sqlite_utils/cli.py b/sqlite_utils/cli.py index 54de2655..9b9ee20e 100644 --- a/sqlite_utils/cli.py +++ b/sqlite_utils/cli.py @@ -1051,7 +1051,7 @@ def insert_upsert_implementation( csv_reader_args["delimiter"] = delimiter if quotechar: csv_reader_args["quotechar"] = quotechar - reader = csv_std.reader(decoded, **csv_reader_args) + reader = csv_std.reader(decoded, **csv_reader_args) # type: ignore first_row = next(reader) if no_headers: headers = ["untitled_{}".format(i + 1) for i in range(len(first_row))] @@ -2988,7 +2988,7 @@ def _generate_convert_help(): n for n in dir(recipes) if not n.startswith("_") - and n not in ("json", "parser") + and n not in ("json", "parser", "Callable", "Optional") and callable(getattr(recipes, n)) ] for name in recipe_names: diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index 0f720eff..aacdc893 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -32,6 +32,8 @@ Generator, Iterable, Sequence, + Set, + Type, Union, Optional, List, @@ -287,7 +289,7 @@ class DescIndex(str): class BadMultiValues(Exception): "With multi=True code must return a Python dictionary" - def __init__(self, values): + def __init__(self, values: object) -> None: self.values = values @@ -386,7 +388,12 @@ def __init__( def __enter__(self) -> "Database": return self - def __exit__(self, exc_type, exc_val, exc_tb) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[object], + ) -> None: self.close() def close(self) -> None: @@ -394,7 +401,7 @@ def close(self) -> None: self.conn.close() @contextlib.contextmanager - def ensure_autocommit_off(self): + def ensure_autocommit_off(self) -> Generator[None, None, None]: """ Ensure autocommit is off for this database connection. @@ -413,7 +420,9 @@ def ensure_autocommit_off(self): self.conn.isolation_level = old_isolation_level @contextlib.contextmanager - def tracer(self, tracer: Optional[Callable] = None): + def tracer( + self, tracer: Optional[Callable[[str, Optional[Sequence]], None]] = None + ) -> Generator["Database", None, None]: """ Context manager to temporarily set a tracer function - all executed SQL queries will be passed to this. @@ -456,7 +465,7 @@ def register_function( deterministic: bool = False, replace: bool = False, name: Optional[str] = None, - ): + ) -> Optional[Callable[[Callable], Callable]]: """ ``fn`` will be made available as a function within SQL, with the same name and number of arguments. Can be used as a decorator:: @@ -479,12 +488,12 @@ def upper(value): :param name: name of the SQLite function - if not specified, the Python function name will be used """ - def register(fn): - fn_name = name or fn.__name__ + def register(fn: Callable) -> Callable: + fn_name = name or fn.__name__ # type: ignore arity = len(inspect.signature(fn).parameters) if not replace and (fn_name, arity) in self._registered_functions: return fn - kwargs = {} + kwargs: Dict[str, bool] = {} registered = False if deterministic: # Try this, but fall back if sqlite3.NotSupportedError @@ -504,12 +513,13 @@ def register(fn): return register else: register(fn) + return None - def register_fts4_bm25(self): + def register_fts4_bm25(self) -> None: "Register the ``rank_bm25(match_info)`` function used for calculating relevance with SQLite FTS4." self.register_function(rank_bm25, deterministic=True, replace=True) - def attach(self, alias: str, filepath: Union[str, pathlib.Path]): + def attach(self, alias: str, filepath: Union[str, pathlib.Path]) -> None: """ Attach another SQLite database file to this connection with the specified alias, equivalent to:: @@ -567,7 +577,7 @@ def executescript(self, sql: str) -> sqlite3.Cursor: self._tracer(sql, None) return self.conn.executescript(sql) - def table(self, table_name: str, **kwargs) -> "Table": + def table(self, table_name: str, **kwargs: Any) -> "Table": """ Return a table object, optionally configured with default options. @@ -766,7 +776,7 @@ def journal_mode(self) -> str: """ return self.execute("PRAGMA journal_mode;").fetchone()[0] - def enable_wal(self): + def enable_wal(self) -> None: """ Sets ``journal_mode`` to ``'wal'`` to enable Write-Ahead Log mode. """ @@ -774,17 +784,17 @@ def enable_wal(self): with self.ensure_autocommit_off(): self.execute("PRAGMA journal_mode=wal;") - def disable_wal(self): + def disable_wal(self) -> None: "Sets ``journal_mode`` back to ``'delete'`` to disable Write-Ahead Log mode." if self.journal_mode != "delete": with self.ensure_autocommit_off(): self.execute("PRAGMA journal_mode=delete;") - def _ensure_counts_table(self): + def _ensure_counts_table(self) -> None: with self.conn: self.execute(_COUNTS_TABLE_CREATE_SQL.format(self._counts_table_name)) - def enable_counts(self): + def enable_counts(self) -> None: """ Enable trigger-based count caching for every table in the database, see :ref:`python_api_cached_table_counts`. @@ -814,7 +824,7 @@ def cached_counts(self, tables: Optional[Iterable[str]] = None) -> Dict[str, int except OperationalError: return {} - def reset_counts(self): + def reset_counts(self) -> None: "Re-calculate cached counts for tables." tables = [table for table in self.tables if table.has_counts_triggers] with self.conn: @@ -1159,7 +1169,7 @@ def create_table( hash_id_columns=hash_id_columns, ) - def rename_table(self, name: str, new_name: str): + def rename_table(self, name: str, new_name: str) -> None: """ Rename a table. @@ -1174,7 +1184,7 @@ def rename_table(self, name: str, new_name: str): def create_view( self, name: str, sql: str, ignore: bool = False, replace: bool = False - ): + ) -> "Database": """ Create a new SQL view with the specified name - ``sql`` should start with ``SELECT ...``. @@ -1220,7 +1230,9 @@ def m2m_table_candidates(self, table: str, other_table: str) -> List[str]: candidates.append(table_obj.name) return candidates - def add_foreign_keys(self, foreign_keys: Iterable[Tuple[str, str, str, str]]): + def add_foreign_keys( + self, foreign_keys: Iterable[Tuple[str, str, str, str]] + ) -> None: """ See :ref:`python_api_add_foreign_keys`. @@ -1272,7 +1284,7 @@ def add_foreign_keys(self, foreign_keys: Iterable[Tuple[str, str, str, str]]): self.vacuum() - def index_foreign_keys(self): + def index_foreign_keys(self) -> None: "Create indexes for every foreign key column on every table in the database." for table_name in self.table_names(): table = self.table(table_name) @@ -1283,11 +1295,11 @@ def index_foreign_keys(self): if fk.column not in existing_indexes: table.create_index([fk.column], find_unique_name=True) - def vacuum(self): + def vacuum(self) -> None: "Run a SQLite ``VACUUM`` against the database." self.execute("VACUUM;") - def analyze(self, name=None): + def analyze(self, name: Optional[str] = None) -> None: """ Run ``ANALYZE`` against the entire database or a named table or index. @@ -1355,18 +1367,21 @@ def init_spatialite(self, path: Optional[str] = None) -> bool: class Queryable: + db: "Database" + name: str + def exists(self) -> bool: "Does this table or view exist yet?" return False - def __init__(self, db, name): + def __init__(self, db: "Database", name: str) -> None: self.db = db self.name = name def count_where( self, where: Optional[str] = None, - where_args: Optional[Union[Iterable, dict]] = None, + where_args: Optional[Union[Sequence, Dict[str, Any]]] = None, ) -> int: """ Executes ``SELECT count(*) FROM table WHERE ...`` and returns a count. @@ -1380,7 +1395,7 @@ def count_where( sql += " where " + where return self.db.execute(sql, where_args or []).fetchone()[0] - def execute_count(self): + def execute_count(self) -> int: # Backwards compatibility, see https://github.com/simonw/sqlite-utils/issues/305#issuecomment-890713185 return self.count_where() @@ -1390,19 +1405,19 @@ def count(self) -> int: return self.count_where() @property - def rows(self) -> Generator[dict, None, None]: + def rows(self) -> Generator[Dict[str, Any], None, None]: "Iterate over every dictionaries for each row in this table or view." return self.rows_where() def rows_where( self, where: Optional[str] = None, - where_args: Optional[Union[Iterable, dict]] = None, + where_args: Optional[Union[Sequence, Dict[str, Any]]] = None, order_by: Optional[str] = None, select: str = "*", limit: Optional[int] = None, offset: Optional[int] = None, - ) -> Generator[dict, None, None]: + ) -> Generator[Dict[str, Any], None, None]: """ Iterate over every row in this table or view that matches the specified where clause. @@ -1435,11 +1450,11 @@ def rows_where( def pks_and_rows_where( self, where: Optional[str] = None, - where_args: Optional[Union[Iterable, dict]] = None, + where_args: Optional[Union[Sequence, Dict[str, Any]]] = None, order_by: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, - ) -> Generator[Tuple[Any, Dict], None, None]: + ) -> Generator[Tuple[Any, Dict[str, Any]], None, None]: """ Like ``.rows_where()`` but returns ``(pk, row)`` pairs - ``pk`` can be a single value or tuple. @@ -1804,18 +1819,18 @@ def create( self.name, columns, pk=pk, - foreign_keys=foreign_keys, - column_order=column_order, - not_null=not_null, - defaults=defaults, - hash_id=hash_id, - hash_id_columns=hash_id_columns, - extracts=extracts, + foreign_keys=foreign_keys, # type: ignore[arg-type] + column_order=column_order, # type: ignore[arg-type] + not_null=not_null, # type: ignore[arg-type] + defaults=defaults, # type: ignore[arg-type] + hash_id=hash_id, # type: ignore[arg-type] + hash_id_columns=hash_id_columns, # type: ignore[arg-type] + extracts=extracts, # type: ignore[arg-type] if_not_exists=if_not_exists, replace=replace, ignore=ignore, transform=transform, - strict=strict, + strict=strict, # type: ignore[arg-type] ) return self @@ -1833,7 +1848,7 @@ def duplicate(self, new_name: str) -> "Table": quote_identifier(self.name), ) self.db.execute(sql) - return self.db[new_name] + return self.db.table(new_name) def transform( self, @@ -2138,7 +2153,7 @@ def extract( ) ) table = table or "_".join(columns) - lookup_table = self.db[table] + lookup_table = self.db.table(table) fk_column = fk_column or "{}_id".format(table) magic_lookup_column = "{}_{}".format(fk_column, os.urandom(6).hex()) @@ -2350,7 +2365,7 @@ def add_column( self.add_foreign_key(col_name, fk, fk_col) return self - def drop(self, ignore: bool = False): + def drop(self, ignore: bool = False) -> None: """ Drop this table. @@ -2394,7 +2409,7 @@ def guess_foreign_table(self, column: str) -> str: ) ) - def guess_foreign_column(self, other_table: str): + def guess_foreign_column(self, other_table: str) -> str: pks = [c for c in self.db[other_table].columns if c.is_pk] if len(pks) != 1: raise BadPrimaryKey( @@ -2453,7 +2468,7 @@ def add_foreign_key( self.db.add_foreign_keys([(self.name, column, other_table, other_column)]) return self - def enable_counts(self): + def enable_counts(self) -> None: """ Set up triggers to update a cache of the count of rows in this table. @@ -2665,7 +2680,7 @@ def disable_fts(self) -> "Table": ) return self - def rebuild_fts(self): + def rebuild_fts(self) -> "Table": "Run the ``rebuild`` operation against the associated full-text search index table." fts_table = self.detect_fts() if fts_table is None: @@ -2752,7 +2767,7 @@ def search_sql( self.name ) fts_table_quoted = quote_identifier(fts_table) - virtual_table_using = self.db[fts_table].virtual_table_using + virtual_table_using = self.db.table(fts_table).virtual_table_using sql = textwrap.dedent( """ with {original} as ( @@ -2849,7 +2864,7 @@ def search( for row in cursor: yield dict(zip(columns, row)) - def value_or_default(self, key, value): + def value_or_default(self, key: str, value: Any) -> Any: return self._defaults[key] if value is DEFAULT else value def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table": @@ -2872,7 +2887,7 @@ def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table": def delete_where( self, where: Optional[str] = None, - where_args: Optional[Union[Iterable, dict]] = None, + where_args: Optional[Union[Sequence, Dict[str, Any]]] = None, analyze: bool = False, ) -> "Table": """ @@ -2963,9 +2978,9 @@ def convert( drop: bool = False, multi: bool = False, where: Optional[str] = None, - where_args: Optional[Union[Iterable, dict]] = None, + where_args: Optional[Union[Sequence, Dict[str, Any]]] = None, show_progress: bool = False, - ): + ) -> "Table": """ Apply conversion function ``fn`` to every value in the specified columns. @@ -3038,7 +3053,7 @@ def _convert_multi( ): # First we execute the function pk_to_values = {} - new_column_types = {} + new_column_types: Dict[str, Set[type]] = {} pks = [column.name for column in self.columns if column.is_pk] if not pks: pks = ["rowid"] @@ -3128,7 +3143,7 @@ def build_insert_queries_and_params( if has_extracts: for i, key in enumerate(all_columns): if key in extracts: - record_values[i] = self.db[extracts[key]].lookup( + record_values[i] = self.db.table(extracts[key]).lookup( {"value": record_values[i]} ) values.append(record_values) @@ -3149,7 +3164,7 @@ def build_insert_queries_and_params( ) if key in extracts: extract_table = extracts[key] - value = self.db[extract_table].lookup({"value": value}) + value = self.db.table(extract_table).lookup({"value": value}) record_values.append(value) values.append(record_values) @@ -3544,7 +3559,7 @@ def insert_all( chunk_as_dicts = [dict(zip(column_names, row)) for row in chunk] column_types = suggest_column_types(chunk_as_dicts) else: - column_types = suggest_column_types(chunk) + column_types = suggest_column_types(chunk) # type: ignore[arg-type] if extracts: for col in extracts: if col in column_types: @@ -3570,9 +3585,9 @@ def insert_all( if hash_id: all_columns.insert(0, hash_id) else: - all_columns_set = set() + all_columns_set: Set[str] = set() for record in chunk: - all_columns_set.update(record.keys()) + all_columns_set.update(record.keys()) # type: ignore[union-attr] all_columns = list(sorted(all_columns_set)) if hash_id: all_columns.insert(0, hash_id) @@ -3795,7 +3810,7 @@ def lookup( ) ) try: - return rows[0][pk] + return rows[0][pk] # type: ignore[index] except IndexError: return self.insert( combined_values, @@ -3859,7 +3874,7 @@ def m2m( already exists. """ if isinstance(other_table, str): - other_table = cast(Table, self.db.table(other_table, pk=pk)) + other_table = self.db.table(other_table, pk=pk) our_id = self.last_pk if lookup is not None: assert record_or_iterable is None, "Provide lookup= or record, not both" @@ -3913,7 +3928,7 @@ def m2m( ) return self - def analyze(self): + def analyze(self) -> None: "Run ANALYZE against this table" self.db.analyze(self.name) @@ -4105,7 +4120,7 @@ def create_spatial_index(self, column_name) -> bool: class View(Queryable): - def exists(self): + def exists(self) -> bool: return True def __repr__(self) -> str: @@ -4113,7 +4128,7 @@ def __repr__(self) -> str: self.name, ", ".join(c.name for c in self.columns) ) - def drop(self, ignore=False): + def drop(self, ignore: bool = False) -> None: """ Drop this view. @@ -4126,14 +4141,14 @@ def drop(self, ignore=False): if not ignore: raise - def enable_fts(self, *args, **kwargs): + def enable_fts(self, *args: object, **kwargs: object) -> None: "``enable_fts()`` is supported on tables but not on views." raise NotImplementedError( "enable_fts() is supported on tables but not on views" ) -def jsonify_if_needed(value): +def jsonify_if_needed(value: object) -> object: if isinstance(value, decimal.Decimal): return float(value) if isinstance(value, (dict, list, tuple)): @@ -4158,7 +4173,7 @@ def resolve_extracts( return extracts -def _decode_default_value(value): +def _decode_default_value(value: str) -> object: if value.startswith("'") and value.endswith("'"): # It's a string return value[1:-1] diff --git a/sqlite_utils/hookspecs.py b/sqlite_utils/hookspecs.py index 83466bec..a746619d 100644 --- a/sqlite_utils/hookspecs.py +++ b/sqlite_utils/hookspecs.py @@ -1,3 +1,6 @@ +import sqlite3 + +import click from pluggy import HookimplMarker from pluggy import HookspecMarker @@ -6,10 +9,10 @@ @hookspec -def register_commands(cli): +def register_commands(cli: click.Group) -> None: """Register additional CLI commands, e.g. 'sqlite-utils mycommand ...'""" @hookspec -def prepare_connection(conn): +def prepare_connection(conn: sqlite3.Connection) -> None: """Modify SQLite connection in some way e.g. register custom SQL functions""" diff --git a/sqlite_utils/plugins.py b/sqlite_utils/plugins.py index 8d6fb856..457a9071 100644 --- a/sqlite_utils/plugins.py +++ b/sqlite_utils/plugins.py @@ -1,8 +1,10 @@ +from typing import Dict, List, Union + import pluggy import sys from . import hookspecs -pm = pluggy.PluginManager("sqlite_utils") +pm: pluggy.PluginManager = pluggy.PluginManager("sqlite_utils") pm.add_hookspecs(hookspecs) if not getattr(sys, "_called_from_test", False): @@ -10,12 +12,12 @@ pm.load_setuptools_entrypoints("sqlite_utils") -def get_plugins(): - plugins = [] +def get_plugins() -> List[Dict[str, Union[str, List[str]]]]: + plugins: List[Dict[str, Union[str, List[str]]]] = [] plugin_to_distinfo = dict(pm.list_plugin_distinfo()) for plugin in pm.get_plugins(): hookcallers = pm.get_hookcallers(plugin) or [] - plugin_info = { + plugin_info: Dict[str, Union[str, List[str]]] = { "name": plugin.__name__, "hooks": [h.name for h in hookcallers], } diff --git a/sqlite_utils/recipes.py b/sqlite_utils/recipes.py index e9c0a6e2..55b55a41 100644 --- a/sqlite_utils/recipes.py +++ b/sqlite_utils/recipes.py @@ -1,11 +1,20 @@ +from __future__ import annotations + +from typing import Callable, Optional + from dateutil import parser import json -IGNORE = object() -SET_NULL = object() +IGNORE: object = object() +SET_NULL: object = object() -def parsedate(value, dayfirst=False, yearfirst=False, errors=None): +def parsedate( + value: str, + dayfirst: bool = False, + yearfirst: bool = False, + errors: Optional[object] = None, +) -> Optional[str]: """ Parse a date and convert it to ISO date format: yyyy-mm-dd \b @@ -31,7 +40,12 @@ def parsedate(value, dayfirst=False, yearfirst=False, errors=None): raise -def parsedatetime(value, dayfirst=False, yearfirst=False, errors=None): +def parsedatetime( + value: str, + dayfirst: bool = False, + yearfirst: bool = False, + errors: Optional[object] = None, +) -> Optional[str]: """ Parse a datetime and convert it to ISO datetime format: yyyy-mm-ddTHH:MM:SS \b @@ -53,7 +67,9 @@ def parsedatetime(value, dayfirst=False, yearfirst=False, errors=None): raise -def jsonsplit(value, delimiter=",", type=str): +def jsonsplit( + value: str, delimiter: str = ",", type: Callable[[str], object] = str +) -> str: """ Convert a string like a,b,c into a JSON array ["a", "b", "c"] """ diff --git a/sqlite_utils/utils.py b/sqlite_utils/utils.py index 7761415a..05e9a511 100644 --- a/sqlite_utils/utils.py +++ b/sqlite_utils/utils.py @@ -8,7 +8,23 @@ import json import os import sys -from typing import Dict, cast, BinaryIO, Iterable, Iterator, Optional, Tuple, Type +from typing import ( + Any, + BinaryIO, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) import click @@ -43,18 +59,24 @@ # Mainly so we can restore it if needed in the tests: ORIGINAL_CSV_FIELD_SIZE_LIMIT = csv.field_size_limit() +# Type alias for row dictionaries - values can be various SQLite-compatible types +RowValue = Union[None, int, float, str, bytes, bool] +Row = Dict[str, RowValue] + +T = TypeVar("T") + -class _CloseableIterator(Iterator[dict]): +class _CloseableIterator(Iterator[Row]): """Iterator wrapper that closes a file when iteration is complete.""" - def __init__(self, iterator: Iterator[dict], closeable: io.IOBase): + def __init__(self, iterator: Iterator[Row], closeable: io.IOBase) -> None: self._iterator = iterator self._closeable = closeable def __iter__(self) -> "_CloseableIterator": return self - def __next__(self) -> dict: + def __next__(self) -> Row: try: return next(self._iterator) except StopIteration: @@ -65,7 +87,7 @@ def close(self) -> None: self._closeable.close() -def maximize_csv_field_size_limit(): +def maximize_csv_field_size_limit() -> None: """ Increase the CSV field size limit to the maximum possible. """ @@ -108,20 +130,25 @@ def find_spatialite() -> Optional[str]: return None -def suggest_column_types(records): - all_column_types = {} +def suggest_column_types( + records: Iterable[Dict[str, Any]], +) -> Dict[str, type]: + all_column_types: Dict[str, Set[type]] = {} for record in records: for key, value in record.items(): all_column_types.setdefault(key, set()).add(type(value)) return types_for_column_types(all_column_types) -def types_for_column_types(all_column_types): - column_types = {} +def types_for_column_types( + all_column_types: Dict[str, Set[type]], +) -> Dict[str, type]: + column_types: Dict[str, type] = {} for key, types in all_column_types.items(): # Ignore null values if at least one other type present: if len(types) > 1: types.discard(None.__class__) + t: type if {None.__class__} == types: t = str elif len(types) == 1: @@ -143,7 +170,7 @@ def types_for_column_types(all_column_types): return column_types -def column_affinity(column_type): +def column_affinity(column_type: str) -> type: # Implementation of SQLite affinity rules from # https://www.sqlite.org/datatype3.html#determination_of_column_affinity assert isinstance(column_type, str) @@ -162,38 +189,42 @@ def column_affinity(column_type): return float -def decode_base64_values(doc): +def decode_base64_values(doc: Dict[str, Any]) -> Dict[str, Any]: # Looks for '{"$base64": true..., "encoded": ...}' values and decodes them to_fix = [ k for k in doc if isinstance(doc[k], dict) - and doc[k].get("$base64") is True - and "encoded" in doc[k] + and cast(dict, doc[k]).get("$base64") is True + and "encoded" in cast(dict, doc[k]) ] if not to_fix: return doc - return dict(doc, **{k: base64.b64decode(doc[k]["encoded"]) for k in to_fix}) + return dict( + doc, **{k: base64.b64decode(cast(dict, doc[k])["encoded"]) for k in to_fix} + ) class UpdateWrapper: - def __init__(self, wrapped, update): + def __init__(self, wrapped: io.IOBase, update: Callable[[int], None]) -> None: self._wrapped = wrapped self._update = update - def __iter__(self): + def __iter__(self) -> Iterator[bytes]: for line in self._wrapped: self._update(len(line)) yield line - def read(self, size=-1): + def read(self, size: int = -1) -> bytes: data = self._wrapped.read(size) self._update(len(data)) return data @contextlib.contextmanager -def file_progress(file, silent=False, **kwargs): +def file_progress( + file: io.IOBase, silent: bool = False, **kwargs: object +) -> Generator[Union[io.IOBase, "UpdateWrapper"], None, None]: if silent: yield file return @@ -206,8 +237,8 @@ def file_progress(file, silent=False, **kwargs): if fileno == 0: # 0 means stdin yield file else: - file_length = os.path.getsize(file.name) - with click.progressbar(length=file_length, **kwargs) as bar: + file_length = os.path.getsize(file.name) # type: ignore + with click.progressbar(length=file_length, **kwargs) as bar: # type: ignore yield UpdateWrapper(file, bar.update) @@ -231,28 +262,30 @@ class RowError(Exception): def _extra_key_strategy( - reader: Iterable[dict], + reader: Iterable[Dict[Optional[str], object]], ignore_extras: Optional[bool] = False, extras_key: Optional[str] = None, -) -> Iterable[dict]: +) -> Iterable[Row]: # Logic for handling CSV rows with more values than there are headings for row in reader: # DictReader adds a 'None' key with extra row values if None not in row: - yield row + yield cast(Row, row) elif ignore_extras: # ignoring row.pop(none) because of this issue: # https://github.com/simonw/sqlite-utils/issues/440#issuecomment-1155358637 - row.pop(None) # type: ignore - yield row + row.pop(None) + yield cast(Row, row) elif not extras_key: - extras = row.pop(None) # type: ignore + extras = row.pop(None) raise RowError( "Row {} contained these extra values: {}".format(row, extras) ) else: - row[extras_key] = row.pop(None) # type: ignore - yield row + extras_value = row.pop(None) + row_out = cast(Row, row) + row_out[extras_key] = extras_value # type: ignore[assignment] + yield row_out def rows_from_file( @@ -262,7 +295,7 @@ def rows_from_file( encoding: Optional[str] = None, ignore_extras: Optional[bool] = False, extras_key: Optional[str] = None, -) -> Tuple[Iterable[dict], Format]: +) -> Tuple[Iterable[Row], Format]: """ Load a sequence of dictionaries from a file-like object containing one of four different formats. @@ -324,10 +357,17 @@ class Format(enum.Enum): rows = _extra_key_strategy(reader, ignore_extras, extras_key) return _CloseableIterator(iter(rows), decoded_fp), Format.CSV elif format == Format.TSV: - rows = rows_from_file( + rows, _ = rows_from_file( fp, format=Format.CSV, dialect=csv.excel_tab, encoding=encoding - )[0] - return _extra_key_strategy(rows, ignore_extras, extras_key), Format.TSV + ) + return ( + _extra_key_strategy( + cast(Iterable[Dict[Optional[str], object]], rows), + ignore_extras, + extras_key, + ), + Format.TSV, + ) elif format is None: # Detect the format, then call this recursively buffered = io.BufferedReader(cast(io.RawIOBase, fp), buffer_size=4096) @@ -349,8 +389,15 @@ class Format(enum.Enum): buffered, format=Format.CSV, dialect=dialect, encoding=encoding ) # Make sure we return the format we detected - format = Format.TSV if dialect.delimiter == "\t" else Format.CSV - return _extra_key_strategy(rows, ignore_extras, extras_key), format + detected_format = Format.TSV if dialect.delimiter == "\t" else Format.CSV + return ( + _extra_key_strategy( + cast(Iterable[Dict[Optional[str], object]], rows), + ignore_extras, + extras_key, + ), + detected_format, + ) else: raise RowsFromFileError("Bad format") @@ -376,10 +423,10 @@ class TypeTracker: db["creatures"].transform(types=tracker.types) """ - def __init__(self): - self.trackers = {} + def __init__(self) -> None: + self.trackers: Dict[str, "ValueTracker"] = {} - def wrap(self, iterator: Iterable[dict]) -> Iterable[dict]: + def wrap(self, iterator: Iterable[Dict[str, Any]]) -> Iterable[Dict[str, Any]]: """ Use this to loop through an existing iterator, tracking the column types as part of the iteration. @@ -402,27 +449,29 @@ def types(self) -> Dict[str, str]: class ValueTracker: - def __init__(self): + couldbe: Dict[str, Callable[[object], bool]] + + def __init__(self) -> None: self.couldbe = {key: getattr(self, "test_" + key) for key in self.get_tests()} @classmethod - def get_tests(cls): + def get_tests(cls) -> List[str]: return [ key.split("test_")[-1] for key in cls.__dict__.keys() if key.startswith("test_") ] - def test_integer(self, value): + def test_integer(self, value: object) -> bool: try: - int(value) + int(value) # type: ignore return True except (ValueError, TypeError): return False - def test_float(self, value): + def test_float(self, value: object) -> bool: try: - float(value) + float(value) # type: ignore[arg-type] return True except (ValueError, TypeError): return False @@ -431,7 +480,7 @@ def __repr__(self) -> str: return self.guessed_type + ": possibilities = " + repr(self.couldbe) @property - def guessed_type(self): + def guessed_type(self) -> str: options = set(self.couldbe.keys()) # Return based on precedence for key in self.get_tests(): @@ -439,10 +488,10 @@ def guessed_type(self): return key return "text" - def evaluate(self, value): + def evaluate(self, value: object) -> None: if not value or not self.couldbe: return - not_these = [] + not_these: List[str] = [] for name, test in self.couldbe.items(): if not test(value): not_these.append(name) @@ -451,45 +500,47 @@ def evaluate(self, value): class NullProgressBar: - def __init__(self, *args): + def __init__(self, *args: Iterable[T]) -> None: self.args = args - def __iter__(self): - yield from self.args[0] + def __iter__(self) -> Iterator[T]: + yield from self.args[0] # type: ignore - def update(self, value): + def update(self, value: int) -> None: pass @contextlib.contextmanager -def progressbar(*args, **kwargs): +def progressbar(*args: Iterable[T], **kwargs: Any) -> Generator[Any, None, None]: silent = kwargs.pop("silent") if silent: yield NullProgressBar(*args) else: - with click.progressbar(*args, **kwargs) as bar: + with click.progressbar(*args, **kwargs) as bar: # type: ignore yield bar -def _compile_code(code, imports, variable="value"): - globals = {"r": recipes, "recipes": recipes} +def _compile_code( + code: str, imports: Iterable[str], variable: str = "value" +) -> Callable[..., Any]: + globals_dict: Dict[str, Any] = {"r": recipes, "recipes": recipes} # Handle imports first so they're available for all approaches for import_ in imports: - globals[import_.split(".")[0]] = __import__(import_) + globals_dict[import_.split(".")[0]] = __import__(import_) # If user defined a convert() function, return that try: - exec(code, globals) - return globals["convert"] + exec(code, globals_dict) + return cast(Callable[..., object], globals_dict["convert"]) except (AttributeError, SyntaxError, NameError, KeyError, TypeError): pass # Check if code is a direct callable reference # e.g. "r.parsedate" instead of "r.parsedate(value)" try: - fn = eval(code, globals) + fn = eval(code, globals_dict) if callable(fn): - return fn + return cast(Callable[..., object], fn) except Exception: pass @@ -514,11 +565,11 @@ def _compile_code(code, imports, variable="value"): if code_o is None: raise SyntaxError("Could not compile code") - exec(code_o, globals) - return globals["fn"] + exec(code_o, globals_dict) + return cast(Callable[..., object], globals_dict["fn"]) -def chunks(sequence: Iterable, size: int) -> Iterable[Iterable]: +def chunks(sequence: Iterable[T], size: int) -> Iterable[Iterable[T]]: """ Iterate over chunks of the sequence of the given size. @@ -530,7 +581,7 @@ def chunks(sequence: Iterable, size: int) -> Iterable[Iterable]: yield itertools.chain([item], itertools.islice(iterator, size - 1)) -def hash_record(record: Dict, keys: Optional[Iterable[str]] = None): +def hash_record(record: Dict[str, Any], keys: Optional[Iterable[str]] = None) -> str: """ ``record`` should be a Python dictionary. Returns a sha1 hash of the keys and values in that record. @@ -551,7 +602,7 @@ def hash_record(record: Dict, keys: Optional[Iterable[str]] = None): :param record: Record to generate a hash for :param keys: Subset of keys to use for that hash """ - to_hash = record + to_hash: Dict[str, Any] = record if keys is not None: to_hash = {key: record[key] for key in keys} return hashlib.sha1( @@ -561,7 +612,7 @@ def hash_record(record: Dict, keys: Optional[Iterable[str]] = None): ).hexdigest() -def _flatten(d): +def _flatten(d: Dict[str, Any]) -> Generator[Tuple[str, Any], None, None]: for key, value in d.items(): if isinstance(value, dict): for key2, value2 in _flatten(value): @@ -570,7 +621,7 @@ def _flatten(d): yield key, value -def flatten(row: dict) -> dict: +def flatten(row: Dict[str, Any]) -> Dict[str, Any]: """ Turn a nested dict e.g. ``{"a": {"b": 1}}`` into a flat dict: ``{"a_b": 1}`` diff --git a/tests/test_docs.py b/tests/test_docs.py index a36b0532..f657416d 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -41,9 +41,9 @@ def test_convert_help(): result = CliRunner().invoke(cli.cli, ["convert", "--help"]) assert result.exit_code == 0 for expected in ( - "r.jsonsplit(value, ", - "r.parsedate(value, ", - "r.parsedatetime(value, ", + "r.jsonsplit(value:", + "r.parsedate(value:", + "r.parsedatetime(value:", ): assert expected in result.output @@ -54,7 +54,7 @@ def test_convert_help(): n for n in dir(recipes) if not n.startswith("_") - and n not in ("json", "parser") + and n not in ("json", "parser", "Callable", "Optional") and callable(getattr(recipes, n)) ], ) diff --git a/tests/test_tracer.py b/tests/test_tracer.py index 9b44fce7..ac490c58 100644 --- a/tests/test_tracer.py +++ b/tests/test_tracer.py @@ -50,7 +50,7 @@ def tracer(sql, params): with db.tracer(tracer): list(dogs.search("Cleopaws")) - assert len(collected) == 5 + assert len(collected) == 4 assert collected == [ ( "SELECT name FROM sqlite_master\n" @@ -70,7 +70,6 @@ def tracer(sql, params): }, ), ("select name from sqlite_master where type = 'view'", None), - ("select name from sqlite_master where type = 'view'", None), ("select sql from sqlite_master where name = ?", ("dogs_fts",)), ( 'with "original" as (\n' @@ -94,4 +93,4 @@ def tracer(sql, params): # Outside the with block collected should not be appended to dogs.insert({"name": "Cleopaws"}) - assert len(collected) == 5 + assert len(collected) == 4