Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ jobs:
run: mypy sqlite_utils tests
- name: run flake8
run: flake8
- name: run ty
if: matrix.os != 'windows-latest'
run: |
pip install uv
uv run ty check sqlite_utils
- name: Check formatting
run: black . --check
- name: Check if cog needs to be run
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def linkcode_resolve(domain, info):
#
# The short X.Y version.
pipe = Popen("git describe --tags --always", stdout=PIPE, shell=True)
git_version = pipe.stdout.read().decode("utf8")
git_version = pipe.stdout.read().decode("utf8") if pipe.stdout else ""

if git_version:
version = git_version.rsplit("-", 1)[0]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dev = [
# flake8
"flake8",
"flake8-pyproject",
"ty",
]
docs = [
"beanbag-docutils>=2.0",
Expand Down
86 changes: 47 additions & 39 deletions sqlite_utils/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
from typing import Any
import click
from click_default_group import DefaultGroup # type: ignore
from datetime import datetime, timezone
Expand Down Expand Up @@ -50,20 +51,19 @@ def _register_db_for_cleanup(db):
ctx = click.get_current_context(silent=True)
if ctx is None:
return
if not hasattr(ctx, "_databases_to_close"):
ctx._databases_to_close = []
if "_databases_to_close" not in ctx.meta:
ctx.meta["_databases_to_close"] = []
ctx.call_on_close(lambda: _close_databases(ctx))
ctx._databases_to_close.append(db)
ctx.meta["_databases_to_close"].append(db)


def _close_databases(ctx):
"""Close all databases registered for cleanup."""
if hasattr(ctx, "_databases_to_close"):
for db in ctx._databases_to_close:
try:
db.close()
except Exception:
pass
for db in ctx.meta.get("_databases_to_close", []):
try:
db.close()
except Exception:
pass


VALID_COLUMN_TYPES = ("INTEGER", "TEXT", "FLOAT", "REAL", "BLOB")
Expand Down Expand Up @@ -294,6 +294,7 @@ def views(
\b
sqlite-utils views trees.db
"""
assert tables.callback is not None
tables.callback(
path=path,
fts4=False,
Expand Down Expand Up @@ -338,7 +339,7 @@ def optimize(path, tables, no_vacuum, load_extension):
tables = db.table_names(fts4=True) + db.table_names(fts5=True)
with db.conn:
for table in tables:
db[table].optimize()
db.table(table).optimize()
if not no_vacuum:
db.vacuum()

Expand Down Expand Up @@ -366,7 +367,7 @@ def rebuild_fts(path, tables, load_extension):
tables = db.table_names(fts4=True) + db.table_names(fts5=True)
with db.conn:
for table in tables:
db[table].rebuild_fts()
db.table(table).rebuild_fts()


@cli.command()
Expand All @@ -393,7 +394,7 @@ def analyze(path, names):
else:
db.analyze()
except OperationalError as e:
raise click.ClickException(e)
raise click.ClickException(str(e))


@cli.command()
Expand Down Expand Up @@ -496,7 +497,7 @@ def add_column(
_register_db_for_cleanup(db)
_load_extensions(db, load_extension)
try:
db[table].add_column(
db.table(table).add_column(
col_name, col_type, fk=fk, fk_col=fk_col, not_null_default=not_null_default
)
except OperationalError as ex:
Expand Down Expand Up @@ -534,9 +535,11 @@ def add_foreign_key(
_register_db_for_cleanup(db)
_load_extensions(db, load_extension)
try:
db[table].add_foreign_key(column, other_table, other_column, ignore=ignore)
db.table(table).add_foreign_key(
column, other_table, other_column, ignore=ignore
)
except AlterError as e:
raise click.ClickException(e)
raise click.ClickException(str(e))


@cli.command(name="add-foreign-keys")
Expand Down Expand Up @@ -571,7 +574,7 @@ def add_foreign_keys(path, foreign_key, load_extension):
try:
db.add_foreign_keys(tuples)
except AlterError as e:
raise click.ClickException(e)
raise click.ClickException(str(e))


@cli.command(name="index-foreign-keys")
Expand Down Expand Up @@ -644,7 +647,7 @@ def create_index(
if col.startswith("-"):
col = DescIndex(col[1:])
columns.append(col)
db[table].create_index(
db.table(table).create_index(
columns,
index_name=name,
unique=unique,
Expand Down Expand Up @@ -705,7 +708,7 @@ def enable_fts(
replace=replace,
)
except OperationalError as ex:
raise click.ClickException(ex)
raise click.ClickException(str(ex))


@cli.command(name="populate-fts")
Expand All @@ -728,7 +731,7 @@ def populate_fts(path, table, column, load_extension):
db = sqlite_utils.Database(path)
_register_db_for_cleanup(db)
_load_extensions(db, load_extension)
db[table].populate_fts(column)
db.table(table).populate_fts(column)


@cli.command(name="disable-fts")
Expand All @@ -750,7 +753,7 @@ def disable_fts(path, table, load_extension):
db = sqlite_utils.Database(path)
_register_db_for_cleanup(db)
_load_extensions(db, load_extension)
db[table].disable_fts()
db.table(table).disable_fts()


@cli.command(name="enable-wal")
Expand Down Expand Up @@ -826,7 +829,7 @@ def enable_counts(path, tables, load_extension):
if bad_tables:
raise click.ClickException("Invalid tables: {}".format(bad_tables))
for table in tables:
db[table].enable_counts()
db.table(table).enable_counts()


@cli.command(name="reset-counts")
Expand Down Expand Up @@ -1036,13 +1039,14 @@ def insert_upsert_implementation(
if csv or tsv:
if sniff:
# Read first 2048 bytes and use that to detect
assert sniff_buffer is not None
first_bytes = sniff_buffer.peek(2048)
dialect = csv_std.Sniffer().sniff(
first_bytes.decode(encoding, "ignore")
)
else:
dialect = "excel-tab" if tsv else "excel"
csv_reader_args = {"dialect": dialect}
csv_reader_args: dict[str, Any] = {"dialect": dialect}
if delimiter:
csv_reader_args["delimiter"] = delimiter
if quotechar:
Expand Down Expand Up @@ -1146,7 +1150,7 @@ def insert_upsert_implementation(
return

try:
db[table].insert_all(
db.table(table).insert_all(
docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs
)
except Exception as e:
Expand All @@ -1173,7 +1177,7 @@ def insert_upsert_implementation(
else:
raise
if tracker is not None:
db[table].transform(types=tracker.types)
db.table(table).transform(types=tracker.types)

# Clean up open file-like objects
if sniff_buffer:
Expand Down Expand Up @@ -1636,7 +1640,7 @@ def create_table(
table
)
)
db[table].create(
db.table(table).create(
coltypes,
pk=pks[0] if len(pks) == 1 else pks,
not_null=not_null,
Expand Down Expand Up @@ -1667,7 +1671,7 @@ def duplicate(path, table, new_table, ignore, load_extension):
_register_db_for_cleanup(db)
_load_extensions(db, load_extension)
try:
db[table].duplicate(new_table)
db.table(table).duplicate(new_table)
except NoTable:
if not ignore:
raise click.ClickException('Table "{}" does not exist'.format(table))
Expand Down Expand Up @@ -2028,9 +2032,9 @@ def memory(
if flatten:
rows = (_flatten(row) for row in rows)

db[file_table].insert_all(rows, alter=True)
db.table(file_table).insert_all(rows, alter=True)
if tracker is not None:
db[file_table].transform(types=tracker.types)
db.table(file_table).transform(types=tracker.types)
# Add convenient t / t1 / t2 views
view_names = ["t{}".format(i + 1)]
if i == 0:
Expand Down Expand Up @@ -2119,7 +2123,8 @@ def _execute_query(
else:
headers = [c[0] for c in cursor.description]
if raw:
data = cursor.fetchone()[0]
row = cursor.fetchone() # type: ignore[union-attr]
data = row[0] if row else None
if isinstance(data, bytes):
sys.stdout.buffer.write(data)
else:
Expand Down Expand Up @@ -2200,7 +2205,7 @@ def search(
_register_db_for_cleanup(db)
_load_extensions(db, load_extension)
# Check table exists
table_obj = db[dbtable]
table_obj = db.table(dbtable)
if not table_obj.exists():
raise click.ClickException("Table '{}' does not exist".format(dbtable))
if not table_obj.detect_fts():
Expand Down Expand Up @@ -2612,10 +2617,10 @@ def transform(
kwargs["add_foreign_keys"] = add_foreign_keys

if sql:
for line in db[table].transform_sql(**kwargs):
for line in db.table(table).transform_sql(**kwargs):
click.echo(line)
else:
db[table].transform(**kwargs)
db.table(table).transform(**kwargs)


@cli.command()
Expand Down Expand Up @@ -2656,13 +2661,13 @@ def extract(
db = sqlite_utils.Database(path)
_register_db_for_cleanup(db)
_load_extensions(db, load_extension)
kwargs = dict(
kwargs: dict[str, Any] = dict(
columns=columns,
table=other_table,
fk_column=fk_column,
rename=dict(rename),
)
db[table].extract(**kwargs)
db.table(table).extract(**kwargs)


@cli.command(name="insert-files")
Expand Down Expand Up @@ -2803,7 +2808,7 @@ def _content_text(p):
_load_extensions(db, load_extension)
try:
with db.conn:
db[table].insert_all(
db.table(table).insert_all(
to_insert(),
pk=pks[0] if len(pks) == 1 else pks,
alter=alter,
Expand Down Expand Up @@ -3122,7 +3127,7 @@ def wrapped_fn(value):

fn = wrapped_fn
try:
db[table].convert(
db.table(table).convert(
columns,
fn,
where=where,
Expand Down Expand Up @@ -3212,7 +3217,7 @@ def add_geometry_column(
_load_extensions(db, load_extension)
db.init_spatialite()

if db[table].add_geometry_column(
if db.table(table).add_geometry_column(
column_name, geometry_type, srid, coord_dimension, not_null
):
click.echo(f"Added {geometry_type} column {column_name} to {table}")
Expand Down Expand Up @@ -3250,7 +3255,7 @@ def create_spatial_index(db_path, table, column_name, load_extension):
"You must add a geometry column before creating a spatial index"
)

db[table].create_spatial_index(column_name)
db.table(table).create_spatial_index(column_name)


@cli.command(name="plugins")
Expand Down Expand Up @@ -3361,7 +3366,10 @@ def _load_extensions(db, load_extension):
db.conn.enable_load_extension(True)
for ext in load_extension:
if ext == "spatialite" and not os.path.exists(ext):
ext = find_spatialite()
found = find_spatialite()
if found is None:
raise click.ClickException("Could not find SpatiaLite extension")
ext = found
if ":" in ext:
path, _, entrypoint = ext.partition(":")
db.conn.execute("SELECT load_extension(?, ?)", [path, entrypoint])
Expand Down
Loading
Loading