diff --git a/docs/python-api.rst b/docs/python-api.rst index c6bf77620..2e81fa3f1 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -927,6 +927,13 @@ An ``upsert_all()`` method is also available, which behaves like ``insert_all()` .. note:: ``.upsert()`` and ``.upsert_all()`` in sqlite-utils 1.x worked like ``.insert(..., replace=True)`` and ``.insert_all(..., replace=True)`` do in 2.x. See `issue #66 `__ for details of this change. +.. _python_api_old_upsert: + +Alternative upserts using INSERT OR IGNORE +------------------------------------------ + +Upserts use ``INSERT INTO ... ON CONFLICT SET``. Prior to ``sqlite-utils 4.0`` these used a sequence of ``INSERT OR IGNORE`` followed by an ``UPDATE``. This older method is still used for SQLite 3.23.1 and earlier. You can force the older implementation by passing ``use_old_upsert=True`` to the ``Database()`` constructor. + .. _python_api_convert: Converting data in columns diff --git a/sqlite_utils/cli.py b/sqlite_utils/cli.py index 795e30de0..a19ee4306 100644 --- a/sqlite_utils/cli.py +++ b/sqlite_utils/cli.py @@ -1098,7 +1098,9 @@ def insert_upsert_implementation( if ( isinstance(e, OperationalError) and e.args - and "has no column named" in e.args[0] + and ( + "has no column named" in e.args[0] or "no such column" in e.args[0] + ) ): raise click.ClickException( "{}\n\nTry using --alter to add additional columns".format( diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index 144330a1b..363b069d8 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -304,6 +304,8 @@ class Database: ``sql, parameters`` every time a SQL query is executed :param use_counts_table: set to ``True`` to use a cached counts table, if available. See :ref:`python_api_cached_table_counts` + :param use_old_upsert: set to ``True`` to force the older upsert implementation. See + :ref:`python_api_old_upsert` :param strict: Apply STRICT mode to all created tables (unless overridden) """ @@ -320,10 +322,12 @@ def __init__( tracer: Optional[Callable] = None, use_counts_table: bool = False, execute_plugins: bool = True, + use_old_upsert: bool = False, strict: bool = False, ): self.memory_name = None self.memory = False + self.use_old_upsert = use_old_upsert assert (filename_or_conn is not None and (not memory and not memory_name)) or ( filename_or_conn is None and (memory or memory_name) ), "Either specify a filename_or_conn or pass memory=True" @@ -671,16 +675,46 @@ def schema(self) -> str: @property def supports_strict(self) -> bool: "Does this database support STRICT mode?" - try: + if not hasattr(self, "_supports_strict"): + try: + table_name = "t{}".format(secrets.token_hex(16)) + with self.conn: + self.conn.execute( + "create table {} (name text) strict".format(table_name) + ) + self.conn.execute("drop table {}".format(table_name)) + self._supports_strict = True + except Exception: + self._supports_strict = False + return self._supports_strict + + @property + def supports_on_conflict(self) -> bool: + # SQLite's upsert is implemented as INSERT INTO ... ON CONFLICT DO ... + if not hasattr(self, "_supports_on_conflict"): table_name = "t{}".format(secrets.token_hex(16)) - with self.conn: - self.conn.execute( - "create table {} (name text) strict".format(table_name) - ) - self.conn.execute("drop table {}".format(table_name)) - return True - except Exception: - return False + try: + with self.conn: + self.conn.execute( + "create table {} (id integer primary key, name text)".format( + table_name + ) + ) + self.conn.execute( + "insert into {} (id, name) values (1, 'one')".format(table_name) + ) + self.conn.execute( + ( + "insert into {} (id, name) values (1, 'two') " + "on conflict do update set name = 'two'" + ).format(table_name) + ) + self._supports_on_conflict = True + except Exception: + self._supports_on_conflict = False + finally: + self.conn.execute("drop table if exists {}".format(table_name)) + return self._supports_on_conflict @property def sqlite_version(self) -> Tuple[int, ...]: @@ -2966,13 +3000,19 @@ def build_insert_queries_and_params( replace, ignore, ): - # values is the list of insert data that is passed to the - # .execute() method - but some of them may be replaced by - # new primary keys if we are extracting any columns. - values = [] + """ + Given a list ``chunk`` of records that should be written to *this* table, + return a list of ``(sql, parameters)`` 2-tuples which, when executed in + order, perform the desired INSERT / UPSERT / REPLACE operation. + """ if hash_id_columns and hash_id is None: hash_id = "id" + extracts = resolve_extracts(extracts) + + # Build a row-list ready for executemany-style flattening + values = [] + for record in chunk: record_values = [] for key in all_columns: @@ -2992,76 +3032,103 @@ def build_insert_queries_and_params( record_values.append(value) values.append(record_values) - queries_and_params = [] - if upsert: - if isinstance(pk, str): - pks = [pk] - else: - pks = pk - self.last_pk = None - for record_values in values: - record = dict(zip(all_columns, record_values)) - placeholders = list(pks) - # Need to populate not-null columns too, or INSERT OR IGNORE ignores - # them since it ignores the resulting integrity errors - if not_null: - placeholders.extend(not_null) - sql = "INSERT OR IGNORE INTO [{table}]({cols}) VALUES({placeholders});".format( - table=self.name, - cols=", ".join(["[{}]".format(p) for p in placeholders]), - placeholders=", ".join(["?" for p in placeholders]), - ) - queries_and_params.append( - (sql, [record[col] for col in pks] + ["" for _ in (not_null or [])]) - ) - # UPDATE [book] SET [name] = 'Programming' WHERE [id] = 1001; - set_cols = [col for col in all_columns if col not in pks] - if set_cols: - sql2 = "UPDATE [{table}] SET {pairs} WHERE {wheres}".format( - table=self.name, - pairs=", ".join( - "[{}] = {}".format(col, conversions.get(col, "?")) - for col in set_cols - ), - wheres=" AND ".join("[{}] = ?".format(pk) for pk in pks), - ) - queries_and_params.append( - ( - sql2, - [record[col] for col in set_cols] - + [record[pk] for pk in pks], + columns_sql = ", ".join(f"[{c}]" for c in all_columns) + placeholder_expr = ", ".join(conversions.get(c, "?") for c in all_columns) + row_placeholders_sql = ", ".join(f"({placeholder_expr})" for _ in values) + flat_params = list(itertools.chain.from_iterable(values)) + + # replace=True mean INSERT OR REPLACE INTO + if replace: + sql = ( + f"INSERT OR REPLACE INTO [{self.name}] " + f"({columns_sql}) VALUES {row_placeholders_sql}" + ) + return [(sql, flat_params)] + + # If not an upsert it's an INSERT, maybe with OR IGNORE + if not upsert: + or_ignore = "" + if ignore: + or_ignore = " OR IGNORE" + sql = ( + f"INSERT{or_ignore} INTO [{self.name}] " + f"({columns_sql}) VALUES {row_placeholders_sql}" + ) + return [(sql, flat_params)] + + # Everything from here on is for upsert=True + pk_cols = [pk] if isinstance(pk, str) else list(pk) + non_pk_cols = [c for c in all_columns if c not in pk_cols] + conflict_sql = ", ".join(f"[{c}]" for c in pk_cols) + + if self.db.supports_on_conflict and not self.db.use_old_upsert: + if non_pk_cols: + # DO UPDATE + assignments = [] + for c in non_pk_cols: + if c in conversions: + assignments.append( + f"[{c}] = {conversions[c].replace('?', f'excluded.[{c}]')}" ) - ) - # We can populate .last_pk right here - if num_records_processed == 1: - self.last_pk = tuple(record[pk] for pk in pks) - if len(self.last_pk) == 1: - self.last_pk = self.last_pk[0] + else: + assignments.append(f"[{c}] = excluded.[{c}]") + do_clause = "DO UPDATE SET " + ", ".join(assignments) + else: + # All columns are in the PK – nothing to update. + do_clause = "DO NOTHING" + + sql = ( + f"INSERT INTO [{self.name}] ({columns_sql}) " + f"VALUES {row_placeholders_sql} " + f"ON CONFLICT({conflict_sql}) {do_clause}" + ) + return [(sql, flat_params)] + # At this point we need compatibility UPSERT for SQLite < 3.24.0 + # (INSERT OR IGNORE + second UPDATE stage) + queries_and_params = [] + if isinstance(pk, str): + pks = [pk] else: - or_what = "" - if replace: - or_what = "OR REPLACE " - elif ignore: - or_what = "OR IGNORE " - sql = """ - INSERT {or_what}INTO [{table}] ({columns}) VALUES {rows}; - """.strip().format( - or_what=or_what, + pks = pk + self.last_pk = None + for record_values in values: + record = dict(zip(all_columns, record_values)) + placeholders = list(pks) + # Need to populate not-null columns too, or INSERT OR IGNORE ignores + # them since it ignores the resulting integrity errors + if not_null: + placeholders.extend(not_null) + sql = "INSERT OR IGNORE INTO [{table}]({cols}) VALUES({placeholders});".format( table=self.name, - columns=", ".join("[{}]".format(c) for c in all_columns), - rows=", ".join( - "({placeholders})".format( - placeholders=", ".join( - [conversions.get(col, "?") for col in all_columns] - ) - ) - for record in chunk - ), + cols=", ".join(["[{}]".format(p) for p in placeholders]), + placeholders=", ".join(["?" for p in placeholders]), ) - flat_values = list(itertools.chain(*values)) - queries_and_params = [(sql, flat_values)] - + queries_and_params.append( + (sql, [record[col] for col in pks] + ["" for _ in (not_null or [])]) + ) + # UPDATE [book] SET [name] = 'Programming' WHERE [id] = 1001; + set_cols = [col for col in all_columns if col not in pks] + if set_cols: + sql2 = "UPDATE [{table}] SET {pairs} WHERE {wheres}".format( + table=self.name, + pairs=", ".join( + "[{}] = {}".format(col, conversions.get(col, "?")) + for col in set_cols + ), + wheres=" AND ".join("[{}] = ?".format(pk) for pk in pks), + ) + queries_and_params.append( + ( + sql2, + [record[col] for col in set_cols] + [record[pk] for pk in pks], + ) + ) + # We can populate .last_pk right here + if num_records_processed == 1: + self.last_pk = tuple(record[pk] for pk in pks) + if len(self.last_pk) == 1: + self.last_pk = self.last_pk[0] return queries_and_params def insert_chunk( @@ -3079,7 +3146,7 @@ def insert_chunk( num_records_processed, replace, ignore, - ): + ) -> Optional[sqlite3.Cursor]: queries_and_params = self.build_insert_queries_and_params( extracts, chunk, @@ -3094,9 +3161,8 @@ def insert_chunk( replace, ignore, ) - + result = None with self.db.conn: - result = None for query, params in queries_and_params: try: result = self.db.execute(query, params) @@ -3125,7 +3191,7 @@ def insert_chunk( ignore, ) - self.insert_chunk( + result = self.insert_chunk( alter, extracts, second_half, @@ -3143,20 +3209,7 @@ def insert_chunk( else: raise - if num_records_processed == 1 and not upsert: - self.last_rowid = result.lastrowid - self.last_pk = self.last_rowid - # self.last_rowid will be 0 if a "INSERT OR IGNORE" happened - if (hash_id or pk) and self.last_rowid: - row = list(self.rows_where("rowid = ?", [self.last_rowid]))[0] - if hash_id: - self.last_pk = row[hash_id] - elif isinstance(pk, str): - self.last_pk = row[pk] - else: - self.last_pk = tuple(row[p] for p in pk) - - return + return result def insert( self, @@ -3276,6 +3329,7 @@ def insert_all( if upsert and (not pk and not hash_id): raise PrimaryKeyRequired("upsert() requires a pk") + assert not (hash_id and pk), "Use either pk= or hash_id=" if hash_id_columns and (hash_id is None): hash_id = "id" @@ -3307,6 +3361,7 @@ def insert_all( self.last_pk = None if truncate and self.exists(): self.db.execute("DELETE FROM [{}];".format(self.name)) + result = None for chunk in chunks(itertools.chain([first_record], records), batch_size): chunk = list(chunk) num_records_processed += len(chunk) @@ -3314,6 +3369,12 @@ def insert_all( if not self.exists(): # Use the first batch to derive the table names column_types = suggest_column_types(chunk) + if extracts: + for col in extracts: + if col in column_types: + column_types[col] = ( + int # This will be an integer foreign key + ) column_types.update(columns or {}) self.create( column_types, @@ -3341,7 +3402,7 @@ def insert_all( first = False - self.insert_chunk( + result = self.insert_chunk( alter, extracts, chunk, @@ -3357,6 +3418,33 @@ def insert_all( ignore, ) + # If we only handled a single row populate self.last_pk + if num_records_processed == 1: + # For an insert we need to use result.lastrowid + if not upsert and result is not None: + self.last_rowid = result.lastrowid + if (hash_id or pk) and self.last_rowid: + # Set self.last_pk to the pk(s) for that rowid + row = list(self.rows_where("rowid = ?", [self.last_rowid]))[0] + if hash_id: + self.last_pk = row[hash_id] + elif isinstance(pk, str): + self.last_pk = row[pk] + else: + self.last_pk = tuple(row[p] for p in pk) + else: + self.last_pk = self.last_rowid + else: + # For an upsert use first_record from earlier + if hash_id: + self.last_pk = hash_record(first_record, hash_id_columns) + else: + self.last_pk = ( + first_record[pk] + if isinstance(pk, str) + else tuple(first_record[p] for p in pk) + ) + if analyze: self.analyze() diff --git a/tests/test_cli.py b/tests/test_cli.py index 4033af641..17e5ed6f6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1116,11 +1116,8 @@ def test_upsert_alter(db_path, tmpdir): cli.cli, ["upsert", db_path, "dogs", json_path, "--pk", "id"] ) assert result.exit_code == 1 - assert ( - "Error: no such column: age\n\n" - "sql = UPDATE [dogs] SET [age] = ? WHERE [id] = ?\n" - "parameters = [5, 1]" - ) == result.output.strip() + # Could be one of two errors depending on SQLite version + assert ("Try using --alter to add additional columns") in result.output.strip() # Should succeed with --alter result = CliRunner().invoke( cli.cli, ["upsert", db_path, "dogs", json_path, "--pk", "id", "--alter"] @@ -2248,7 +2245,7 @@ def test_integer_overflow_error(tmpdir): assert result.exit_code == 1 assert result.output == ( "Error: Python int too large to convert to SQLite INTEGER\n\n" - "sql = INSERT INTO [items] ([bignumber]) VALUES (?);\n" + "sql = INSERT INTO [items] ([bignumber]) VALUES (?)\n" "parameters = [34223049823094832094802398430298048240]\n" ) diff --git a/tests/test_create.py b/tests/test_create.py index 7825198ac..a04c8b3d3 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -173,14 +173,16 @@ def test_create_table_from_example_with_compound_primary_keys(fresh_db): @pytest.mark.parametrize( "method_name", ("insert", "upsert", "insert_all", "upsert_all") ) -def test_create_table_with_custom_columns(fresh_db, method_name): - table = fresh_db["dogs"] +@pytest.mark.parametrize("use_old_upsert", (False, True)) +def test_create_table_with_custom_columns(method_name, use_old_upsert): + db = Database(memory=True, use_old_upsert=use_old_upsert) + table = db["dogs"] method = getattr(table, method_name) record = {"id": 1, "name": "Cleo", "age": "5"} if method_name.endswith("_all"): record = [record] method(record, pk="id", columns={"age": int, "weight": float}) - assert ["dogs"] == fresh_db.table_names() + assert ["dogs"] == db.table_names() expected_columns = [ {"name": "id", "type": "INTEGER"}, {"name": "name", "type": "TEXT"}, diff --git a/tests/test_tracer.py b/tests/test_tracer.py index 26318ae06..9dfb490d6 100644 --- a/tests/test_tracer.py +++ b/tests/test_tracer.py @@ -18,7 +18,7 @@ def test_tracer(): ("select name from sqlite_master where type = 'view'", None), ("CREATE TABLE [dogs] (\n [name] TEXT\n);\n ", None), ("select name from sqlite_master where type = 'view'", None), - ("INSERT INTO [dogs] ([name]) VALUES (?);", ["Cleopaws"]), + ("INSERT INTO [dogs] ([name]) VALUES (?)", ["Cleopaws"]), ("select name from sqlite_master where type = 'view'", None), ( "CREATE VIRTUAL TABLE [dogs_fts] USING FTS5 (\n [name],\n content=[dogs]\n)", diff --git a/tests/test_upsert.py b/tests/test_upsert.py index 8a34d4ebe..de0e362b5 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -1,9 +1,12 @@ from sqlite_utils.db import PrimaryKeyRequired +from sqlite_utils import Database import pytest -def test_upsert(fresh_db): - table = fresh_db["table"] +@pytest.mark.parametrize("use_old_upsert", (False, True)) +def test_upsert(use_old_upsert): + db = Database(memory=True, use_old_upsert=use_old_upsert) + table = db["table"] table.insert({"id": 1, "name": "Cleo"}, pk="id") table.upsert({"id": 1, "age": 5}, pk="id", alter=True) assert list(table.rows) == [{"id": 1, "name": "Cleo", "age": 5}]