diff --git a/docs/changelog.rst b/docs/changelog.rst index b47c06a35..cd694d619 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,13 @@ Changelog =========== +.. _unreleased: + +Unreleased +---------- + +- The ``table.insert_all()`` and ``table.upsert_all()`` methods can now accept an iterator of lists or tuples as an alternative to dictionaries. The first item should be a list/tuple of column names. See :ref:`python_api_insert_lists` for details. (:issue:`672`) + .. _v4_0a0: 4.0a0 (2025-05-08) diff --git a/docs/python-api.rst b/docs/python-api.rst index 47cb30b19..bb2b3eeba 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -810,6 +810,35 @@ You can delete all the existing rows in the table before inserting the new recor Pass ``analyze=True`` to run ``ANALYZE`` against the table after inserting the new records. +.. _python_api_insert_lists: + +Inserting data from a list or tuple iterator +-------------------------------------------- + +As an alternative to passing an iterator of dictionaries, you can pass an iterator of lists or tuples. The first item yielded by the iterator must be a list or tuple of string column names, and subsequent items should be lists or tuples of values: + +.. code-block:: python + + db["creatures"].insert_all([ + ["name", "species"], + ["Cleo", "dog"], + ["Lila", "chicken"], + ["Bants", "chicken"], + ]) + +This also works with generators: + +.. code-block:: python + + def creatures(): + yield "id", "name", "city" + yield 1, "Cleo", "San Francisco" + yield 2, "Lila", "Los Angeles" + + db["creatures"].insert_all(creatures()) + +Tuples and lists are both supported. + .. _python_api_insert_replace: Insert-replacing data diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index 2be7a6d4d..dd83b4c90 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -31,6 +31,7 @@ Dict, Generator, Iterable, + Sequence, Union, Optional, List, @@ -3010,6 +3011,7 @@ def build_insert_queries_and_params( num_records_processed, replace, ignore, + list_mode=False, ): """ Given a list ``chunk`` of records that should be written to *this* table, @@ -3024,24 +3026,47 @@ def build_insert_queries_and_params( # Build a row-list ready for executemany-style flattening values = [] - for record in chunk: - record_values = [] - for key in all_columns: - value = jsonify_if_needed( - record.get( - key, - ( - None - if key != hash_id - else hash_record(record, hash_id_columns) - ), + if list_mode: + # In list mode, records are already lists of values + num_columns = len(all_columns) + has_extracts = bool(extracts) + for record in chunk: + # Pad short records with None, truncate long ones + record_len = len(record) + if record_len < num_columns: + record_values = [jsonify_if_needed(v) for v in record] + [None] * ( + num_columns - record_len ) - ) - if key in extracts: - extract_table = extracts[key] - value = self.db[extract_table].lookup({"value": value}) - record_values.append(value) - values.append(record_values) + else: + record_values = [jsonify_if_needed(v) for v in record[:num_columns]] + # Only process extracts if there are any + if has_extracts: + for i, key in enumerate(all_columns): + if key in extracts: + record_values[i] = self.db[extracts[key]].lookup( + {"value": record_values[i]} + ) + values.append(record_values) + else: + # Dict mode: original logic + for record in chunk: + record_values = [] + for key in all_columns: + value = jsonify_if_needed( + record.get( + key, + ( + None + if key != hash_id + else hash_record(record, hash_id_columns) + ), + ) + ) + if key in extracts: + extract_table = extracts[key] + value = self.db[extract_table].lookup({"value": value}) + record_values.append(value) + values.append(record_values) columns_sql = ", ".join(f"[{c}]" for c in all_columns) placeholder_expr = ", ".join(conversions.get(c, "?") for c in all_columns) @@ -3157,6 +3182,7 @@ def insert_chunk( num_records_processed, replace, ignore, + list_mode=False, ) -> Optional[sqlite3.Cursor]: queries_and_params = self.build_insert_queries_and_params( extracts, @@ -3171,6 +3197,7 @@ def insert_chunk( num_records_processed, replace, ignore, + list_mode, ) result = None with self.db.conn: @@ -3200,6 +3227,7 @@ def insert_chunk( num_records_processed, replace, ignore, + list_mode, ) result = self.insert_chunk( @@ -3216,6 +3244,7 @@ def insert_chunk( num_records_processed, replace, ignore, + list_mode, ) else: @@ -3293,7 +3322,10 @@ def insert( def insert_all( self, - records, + records: Union[ + Iterable[Dict[str, Any]], + Iterable[Sequence[Any]], + ], pk=DEFAULT, foreign_keys=DEFAULT, column_order=DEFAULT, @@ -3353,17 +3385,54 @@ def insert_all( all_columns = [] first = True num_records_processed = 0 - # Fix up any records with square braces in the column names - records = fix_square_braces(records) - # We can only handle a max of 999 variables in a SQL insert, so - # we need to adjust the batch_size down if we have too many cols - records = iter(records) - # Peek at first record to count its columns: + + # Detect if we're using list-based iteration or dict-based iteration + list_mode = False + column_names: List[str] = [] + + # Fix up any records with square braces in the column names (only for dict mode) + # We'll handle this differently for list mode + records_iter = iter(records) + + # Peek at first record to determine mode: try: - first_record = next(records) + first_record = next(records_iter) except StopIteration: return self # It was an empty list - num_columns = len(first_record.keys()) + + # Check if this is list mode or dict mode + if isinstance(first_record, (list, tuple)): + # List/tuple mode: first record should be column names + list_mode = True + if not all(isinstance(col, str) for col in first_record): + raise ValueError( + "When using list-based iteration, the first yielded value must be a list of column name strings" + ) + column_names = list(first_record) + all_columns = column_names + num_columns = len(column_names) + # Get the actual first data record + try: + first_record = next(records_iter) + except StopIteration: + return self # Only headers, no data + if not isinstance(first_record, (list, tuple)): + raise ValueError( + "After column names list, all subsequent records must also be lists" + ) + else: + # Dict mode: traditional behavior + records_iter = itertools.chain([first_record], records_iter) + records_iter = fix_square_braces( + cast(Iterable[Dict[str, Any]], records_iter) + ) + try: + first_record = next(records_iter) + except StopIteration: + return self + first_record = cast(Dict[str, Any], first_record) + num_columns = len(first_record.keys()) + assert ( num_columns <= SQLITE_MAX_VARS ), "Rows can have a maximum of {} columns".format(SQLITE_MAX_VARS) @@ -3373,13 +3442,18 @@ def insert_all( if truncate and self.exists(): self.db.execute("DELETE FROM [{}];".format(self.name)) result = None - for chunk in chunks(itertools.chain([first_record], records), batch_size): + for chunk in chunks(itertools.chain([first_record], records_iter), batch_size): chunk = list(chunk) num_records_processed += len(chunk) if first: if not self.exists(): # Use the first batch to derive the table names - column_types = suggest_column_types(chunk) + if list_mode: + # Convert list records to dicts for type detection + chunk_as_dicts = [dict(zip(column_names, row)) for row in chunk] + column_types = suggest_column_types(chunk_as_dicts) + else: + column_types = suggest_column_types(chunk) if extracts: for col in extracts: if col in column_types: @@ -3399,17 +3473,24 @@ def insert_all( extracts=extracts, strict=strict, ) - all_columns_set = set() - for record in chunk: - all_columns_set.update(record.keys()) - all_columns = list(sorted(all_columns_set)) - if hash_id: - all_columns.insert(0, hash_id) + if list_mode: + # In list mode, columns are already known + all_columns = list(column_names) + if hash_id: + all_columns.insert(0, hash_id) + else: + all_columns_set = set() + for record in chunk: + all_columns_set.update(record.keys()) + all_columns = list(sorted(all_columns_set)) + if hash_id: + all_columns.insert(0, hash_id) else: - for record in chunk: - all_columns += [ - column for column in record if column not in all_columns - ] + if not list_mode: + for record in chunk: + all_columns += [ + column for column in record if column not in all_columns + ] first = False @@ -3427,6 +3508,7 @@ def insert_all( num_records_processed, replace, ignore, + list_mode, ) # If we only handled a single row populate self.last_pk @@ -3447,14 +3529,29 @@ def insert_all( self.last_pk = self.last_rowid else: # For an upsert use first_record from earlier - if hash_id: - self.last_pk = hash_record(first_record, hash_id_columns) + if list_mode: + # In list mode, look up pk value by column index + first_record_list = cast(Sequence[Any], first_record) + if hash_id: + # hash_id not supported in list mode for last_pk + pass + elif isinstance(pk, str): + pk_index = column_names.index(pk) + self.last_pk = first_record_list[pk_index] + else: + self.last_pk = tuple( + first_record_list[column_names.index(p)] for p in pk + ) else: - self.last_pk = ( - first_record[pk] - if isinstance(pk, str) - else tuple(first_record[p] for p in pk) - ) + first_record_dict = cast(Dict[str, Any], first_record) + if hash_id: + self.last_pk = hash_record(first_record_dict, hash_id_columns) + else: + self.last_pk = ( + first_record_dict[pk] + if isinstance(pk, str) + else tuple(first_record_dict[p] for p in pk) + ) if analyze: self.analyze() @@ -3501,7 +3598,10 @@ def upsert( def upsert_all( self, - records, + records: Union[ + Iterable[Dict[str, Any]], + Iterable[Sequence[Any]], + ], pk=DEFAULT, foreign_keys=DEFAULT, column_order=DEFAULT, diff --git a/tests/test_list_mode.py b/tests/test_list_mode.py new file mode 100644 index 000000000..746c9c176 --- /dev/null +++ b/tests/test_list_mode.py @@ -0,0 +1,288 @@ +""" +Tests for list-based iteration in insert_all and upsert_all +""" + +import pytest +from sqlite_utils import Database + + +def test_insert_all_list_mode_basic(): + """Test basic insert_all with list-based iteration""" + db = Database(memory=True) + + def data_generator(): + # First yield column names + yield ["id", "name", "age"] + # Then yield data rows + yield [1, "Alice", 30] + yield [2, "Bob", 25] + yield [3, "Charlie", 35] + + db["people"].insert_all(data_generator()) + + rows = list(db["people"].rows) + assert len(rows) == 3 + assert rows[0] == {"id": 1, "name": "Alice", "age": 30} + assert rows[1] == {"id": 2, "name": "Bob", "age": 25} + assert rows[2] == {"id": 3, "name": "Charlie", "age": 35} + + +def test_insert_all_list_mode_with_pk(): + """Test insert_all with list mode and primary key""" + db = Database(memory=True) + + def data_generator(): + yield ["id", "name", "score"] + yield [1, "Alice", 95] + yield [2, "Bob", 87] + + db["scores"].insert_all(data_generator(), pk="id") + + assert db["scores"].pks == ["id"] + rows = list(db["scores"].rows) + assert len(rows) == 2 + + +def test_upsert_all_list_mode(): + """Test upsert_all with list-based iteration""" + db = Database(memory=True) + + # Initial insert + def initial_data(): + yield ["id", "name", "value"] + yield [1, "Alice", 100] + yield [2, "Bob", 200] + + db["data"].insert_all(initial_data(), pk="id") + + # Upsert with some updates and new records + def upsert_data(): + yield ["id", "name", "value"] + yield [1, "Alice", 150] # Update existing + yield [3, "Charlie", 300] # Insert new + + db["data"].upsert_all(upsert_data(), pk="id") + + rows = list(db["data"].rows_where(order_by="id")) + assert len(rows) == 3 + assert rows[0] == {"id": 1, "name": "Alice", "value": 150} + assert rows[1] == {"id": 2, "name": "Bob", "value": 200} + assert rows[2] == {"id": 3, "name": "Charlie", "value": 300} + + +def test_list_mode_with_various_types(): + """Test list mode with different data types""" + db = Database(memory=True) + + def data_generator(): + yield ["id", "name", "score", "active"] + yield [1, "Alice", 95.5, True] + yield [2, "Bob", 87.3, False] + yield [3, "Charlie", None, True] + + db["mixed"].insert_all(data_generator()) + + rows = list(db["mixed"].rows) + assert len(rows) == 3 + assert rows[0]["score"] == 95.5 + assert rows[1]["active"] == 0 # SQLite stores boolean as int + assert rows[2]["score"] is None + + +def test_list_mode_error_non_string_columns(): + """Test that non-string column names raise an error""" + db = Database(memory=True) + + def bad_data(): + yield [1, 2, 3] # Non-string column names + yield ["a", "b", "c"] + + with pytest.raises(ValueError, match="must be a list of column name strings"): + db["bad"].insert_all(bad_data()) + + +def test_list_mode_error_mixed_types(): + """Test that mixing list and dict raises an error""" + db = Database(memory=True) + + def bad_data(): + yield ["id", "name"] + yield {"id": 1, "name": "Alice"} # Should be a list, not dict + + with pytest.raises(ValueError, match="must also be lists"): + db["bad"].insert_all(bad_data()) + + +def test_list_mode_empty_after_headers(): + """Test that only headers without data works gracefully""" + db = Database(memory=True) + + def data_generator(): + yield ["id", "name", "age"] + # No data rows + + result = db["people"].insert_all(data_generator()) + assert result is not None + assert not db["people"].exists() + + +def test_list_mode_batch_processing(): + """Test list mode with large dataset requiring batching""" + db = Database(memory=True) + + def large_data(): + yield ["id", "value"] + for i in range(1000): + yield [i, f"value_{i}"] + + db["large"].insert_all(large_data(), batch_size=100) + + count = db.execute("SELECT COUNT(*) as c FROM large").fetchone()[0] + assert count == 1000 + + +def test_list_mode_shorter_rows(): + """Test that rows shorter than column list get NULL values""" + db = Database(memory=True) + + def data_generator(): + yield ["id", "name", "age", "city"] + yield [1, "Alice", 30, "NYC"] + yield [2, "Bob"] # Missing age and city + yield [3, "Charlie", 35] # Missing city + + db["people"].insert_all(data_generator()) + + rows = list(db["people"].rows_where(order_by="id")) + assert rows[0] == {"id": 1, "name": "Alice", "age": 30, "city": "NYC"} + assert rows[1] == {"id": 2, "name": "Bob", "age": None, "city": None} + assert rows[2] == {"id": 3, "name": "Charlie", "age": 35, "city": None} + + +def test_backwards_compatibility_dict_mode(): + """Ensure dict mode still works (backward compatibility)""" + db = Database(memory=True) + + # Traditional dict-based insert + data = [ + {"id": 1, "name": "Alice", "age": 30}, + {"id": 2, "name": "Bob", "age": 25}, + ] + + db["people"].insert_all(data) + + rows = list(db["people"].rows) + assert len(rows) == 2 + assert rows[0] == {"id": 1, "name": "Alice", "age": 30} + + +def test_insert_all_tuple_mode_basic(): + """Test basic insert_all with tuple-based iteration""" + db = Database(memory=True) + + def data_generator(): + # First yield column names as tuple + yield ("id", "name", "age") + # Then yield data rows as tuples + yield (1, "Alice", 30) + yield (2, "Bob", 25) + yield (3, "Charlie", 35) + + db["people"].insert_all(data_generator()) + + rows = list(db["people"].rows) + assert len(rows) == 3 + assert rows[0] == {"id": 1, "name": "Alice", "age": 30} + assert rows[1] == {"id": 2, "name": "Bob", "age": 25} + assert rows[2] == {"id": 3, "name": "Charlie", "age": 35} + + +def test_insert_all_mixed_list_tuple(): + """Test insert_all with mixed lists and tuples for data rows""" + db = Database(memory=True) + + def data_generator(): + # Column names as list + yield ["id", "name", "age"] + # Mix of list and tuple data rows + yield [1, "Alice", 30] + yield (2, "Bob", 25) + yield [3, "Charlie", 35] + yield (4, "Diana", 40) + + db["people"].insert_all(data_generator()) + + rows = list(db["people"].rows) + assert len(rows) == 4 + assert rows[0] == {"id": 1, "name": "Alice", "age": 30} + assert rows[1] == {"id": 2, "name": "Bob", "age": 25} + assert rows[2] == {"id": 3, "name": "Charlie", "age": 35} + assert rows[3] == {"id": 4, "name": "Diana", "age": 40} + + +def test_upsert_all_tuple_mode(): + """Test upsert_all with tuple-based iteration""" + db = Database(memory=True) + + # Initial insert with tuples + def initial_data(): + yield ("id", "name", "value") + yield (1, "Alice", 100) + yield (2, "Bob", 200) + + db["data"].insert_all(initial_data(), pk="id") + + # Upsert with tuples + def upsert_data(): + yield ("id", "name", "value") + yield (1, "Alice", 150) # Update existing + yield (3, "Charlie", 300) # Insert new + + db["data"].upsert_all(upsert_data(), pk="id") + + rows = list(db["data"].rows_where(order_by="id")) + assert len(rows) == 3 + assert rows[0] == {"id": 1, "name": "Alice", "value": 150} + assert rows[1] == {"id": 2, "name": "Bob", "value": 200} + assert rows[2] == {"id": 3, "name": "Charlie", "value": 300} + + +def test_tuple_mode_shorter_rows(): + """Test that tuple rows shorter than column list get NULL values""" + db = Database(memory=True) + + def data_generator(): + yield "id", "name", "age", "city" + yield 1, "Alice", 30, "NYC" + yield 2, "Bob" # Missing age and city + yield 3, "Charlie", 35 # Missing city + + db["people"].insert_all(data_generator()) + + rows = list(db["people"].rows_where(order_by="id")) + assert rows[0] == {"id": 1, "name": "Alice", "age": 30, "city": "NYC"} + assert rows[1] == {"id": 2, "name": "Bob", "age": None, "city": None} + assert rows[2] == {"id": 3, "name": "Charlie", "age": 35, "city": None} + + +def test_list_mode_single_record_upsert_last_pk(): + """Test that last_pk is populated correctly for single-record upserts in list mode""" + db = Database(memory=True) + + # Create table first + db["data"].insert({"id": 1, "name": "Alice", "value": 100}, pk="id") + + # Now upsert a single record using list mode + def upsert_data(): + yield ["id", "name", "value"] + yield [1, "Alice", 150] # Update existing + + table = db["data"] + table.upsert_all(upsert_data(), pk="id") + + # Verify the data was updated + rows = list(db["data"].rows) + assert rows == [{"id": 1, "name": "Alice", "value": 150}] + + # Verify last_pk is populated correctly + assert table.last_pk == 1 diff --git a/tests/test_recipes.py b/tests/test_recipes.py index ff0422531..a7c7ef7e6 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -62,27 +62,37 @@ def test_dayfirst_yearfirst(fresh_db, recipe, kwargs, expected): ] -@pytest.mark.parametrize("fn", ("parsedate", "parsedatetime")) -@pytest.mark.parametrize("errors", (None, recipes.SET_NULL, recipes.IGNORE)) @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") -def test_dateparse_errors(fresh_db, fn, errors): +@pytest.mark.parametrize("fn", ("parsedate", "parsedatetime")) +def test_dateparse_errors_raises(fresh_db, fn): + """Test that invalid dates raise errors when errors=None""" fresh_db["example"].insert_all( [ {"id": 1, "dt": "invalid"}, ], pk="id", ) - if errors is None: - # Should raise an error - with pytest.raises(sqlite3.OperationalError): - fresh_db["example"].convert("dt", lambda value: getattr(recipes, fn)(value)) - else: - fresh_db["example"].convert( - "dt", lambda value: getattr(recipes, fn)(value, errors=errors) - ) - rows = list(fresh_db["example"].rows) - expected = [{"id": 1, "dt": None if errors is recipes.SET_NULL else "invalid"}] - assert rows == expected + # Exception in SQLite callback surfaces as OperationalError + with pytest.raises(sqlite3.OperationalError): + fresh_db["example"].convert("dt", lambda value: getattr(recipes, fn)(value)) + + +@pytest.mark.parametrize("fn", ("parsedate", "parsedatetime")) +@pytest.mark.parametrize("errors", (recipes.SET_NULL, recipes.IGNORE)) +def test_dateparse_errors_handled(fresh_db, fn, errors): + """Test error handling modes for invalid dates""" + fresh_db["example"].insert_all( + [ + {"id": 1, "dt": "invalid"}, + ], + pk="id", + ) + fresh_db["example"].convert( + "dt", lambda value: getattr(recipes, fn)(value, errors=errors) + ) + rows = list(fresh_db["example"].rows) + expected = [{"id": 1, "dt": None if errors is recipes.SET_NULL else "invalid"}] + assert rows == expected @pytest.mark.parametrize("delimiter", [None, ";", "-"])