From 7abfee903267e0c62f72df0fe25e6fe6aff572c4 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Wed, 19 Mar 2025 13:49:31 +0100 Subject: [PATCH 01/19] Move actual implementation of upsert from Table to Transaction --- pyiceberg/table/__init__.py | 179 +++++++++++++++++++++++------------- 1 file changed, 113 insertions(+), 66 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 9e9de52dee..d57b6463af 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -695,6 +695,115 @@ def delete( if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed: warnings.warn("Delete operation did not match any records") + def upsert( + self, + df: pa.Table, + join_cols: Optional[List[str]] = None, + when_matched_update_all: bool = True, + when_not_matched_insert_all: bool = True, + case_sensitive: bool = True, + ) -> UpsertResult: + """Shorthand API for performing an upsert to an iceberg table. + + Args: + + df: The input dataframe to upsert with the table's data. + join_cols: Columns to join on, if not provided, it will use the identifier-field-ids. + when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing + when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table + case_sensitive: Bool indicating if the match should be case-sensitive + + To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids + + Example Use Cases: + Case 1: Both Parameters = True (Full Upsert) + Existing row found → Update it + New row found → Insert it + + Case 2: when_matched_update_all = False, when_not_matched_insert_all = True + Existing row found → Do nothing (no updates) + New row found → Insert it + + Case 3: when_matched_update_all = True, when_not_matched_insert_all = False + Existing row found → Update it + New row found → Do nothing (no inserts) + + Case 4: Both Parameters = False (No Merge Effect) + Existing row found → Do nothing + New row found → Do nothing + (Function effectively does nothing) + + + Returns: + An UpsertResult class (contains details of rows updated and inserted) + """ + try: + import pyarrow as pa # noqa: F401 + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + from pyiceberg.io.pyarrow import expression_to_pyarrow + from pyiceberg.table import upsert_util + + if join_cols is None: + join_cols = [] + for field_id in df.schema.identifier_field_ids: + col = df.schema.find_column_name(field_id) + if col is not None: + join_cols.append(col) + else: + raise ValueError(f"Field-ID could not be found: {join_cols}") + + if len(join_cols) == 0: + raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") + + if not when_matched_update_all and not when_not_matched_insert_all: + raise ValueError("no upsert options selected...exiting") + + if upsert_util.has_duplicate_rows(df, join_cols): + raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") + + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible + + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + _check_pyarrow_schema_compatible( + df.schema, provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) + + # get list of rows that exist so we don't have to load the entire target table + matched_predicate = upsert_util.create_match_filter(df, join_cols) + matched_iceberg_table = df.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() + + update_row_cnt = 0 + insert_row_cnt = 0 + + if when_matched_update_all: + # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed + # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed + # this extra step avoids unnecessary IO and writes + rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols) + + update_row_cnt = len(rows_to_update) + + if len(rows_to_update) > 0: + # build the match predicate filter + overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) + + self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate) + + if when_not_matched_insert_all: + expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) + expr_match_bound = bind(df.schema, expr_match, case_sensitive=case_sensitive) + expr_match_arrow = expression_to_pyarrow(expr_match_bound) + rows_to_insert = df.filter(~expr_match_arrow) + + insert_row_cnt = len(rows_to_insert) + + if insert_row_cnt > 0: + self.append(rows_to_insert) + + return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + def add_files( self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True ) -> None: @@ -1159,73 +1268,11 @@ def upsert( Returns: An UpsertResult class (contains details of rows updated and inserted) """ - try: - import pyarrow as pa # noqa: F401 - except ModuleNotFoundError as e: - raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - - from pyiceberg.io.pyarrow import expression_to_pyarrow - from pyiceberg.table import upsert_util - - if join_cols is None: - join_cols = [] - for field_id in self.schema().identifier_field_ids: - col = self.schema().find_column_name(field_id) - if col is not None: - join_cols.append(col) - else: - raise ValueError(f"Field-ID could not be found: {join_cols}") - - if len(join_cols) == 0: - raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") - - if not when_matched_update_all and not when_not_matched_insert_all: - raise ValueError("no upsert options selected...exiting") - - if upsert_util.has_duplicate_rows(df, join_cols): - raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") - - from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible - - downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - _check_pyarrow_schema_compatible( - self.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us - ) - - # get list of rows that exist so we don't have to load the entire target table - matched_predicate = upsert_util.create_match_filter(df, join_cols) - matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() - - update_row_cnt = 0 - insert_row_cnt = 0 - with self.transaction() as tx: - if when_matched_update_all: - # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed - # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed - # this extra step avoids unnecessary IO and writes - rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols) - - update_row_cnt = len(rows_to_update) - - if len(rows_to_update) > 0: - # build the match predicate filter - overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) - - tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate) - - if when_not_matched_insert_all: - expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) - expr_match_bound = bind(self.schema(), expr_match, case_sensitive=case_sensitive) - expr_match_arrow = expression_to_pyarrow(expr_match_bound) - rows_to_insert = df.filter(~expr_match_arrow) - - insert_row_cnt = len(rows_to_insert) - - if insert_row_cnt > 0: - tx.append(rows_to_insert) - - return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + return tx.upsert( + df=df, join_cols=join_cols, when_matched_update_all=when_matched_update_all, when_not_matched_insert_all=when_not_matched_insert_all, + case_sensitive=case_sensitive + ) def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: """ From db334ae4bfeefff1ab7373ea8c5e55f18fdace84 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Wed, 19 Mar 2025 14:32:19 +0100 Subject: [PATCH 02/19] Fix some incorrect usage of schema --- pyiceberg/table/__init__.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index d57b6463af..99f2fee388 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -747,8 +747,8 @@ def upsert( if join_cols is None: join_cols = [] - for field_id in df.schema.identifier_field_ids: - col = df.schema.find_column_name(field_id) + for field_id in self.table_metadata.schema().identifier_field_ids: + col = self.table_metadata.schema().find_column_name(field_id) if col is not None: join_cols.append(col) else: @@ -767,12 +767,12 @@ def upsert( downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( - df.schema, provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) # get list of rows that exist so we don't have to load the entire target table matched_predicate = upsert_util.create_match_filter(df, join_cols) - matched_iceberg_table = df.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() + matched_iceberg_table = self._table.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() update_row_cnt = 0 insert_row_cnt = 0 @@ -793,7 +793,7 @@ def upsert( if when_not_matched_insert_all: expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) - expr_match_bound = bind(df.schema, expr_match, case_sensitive=case_sensitive) + expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) expr_match_arrow = expression_to_pyarrow(expr_match_bound) rows_to_insert = df.filter(~expr_match_arrow) @@ -1270,8 +1270,11 @@ def upsert( """ with self.transaction() as tx: return tx.upsert( - df=df, join_cols=join_cols, when_matched_update_all=when_matched_update_all, when_not_matched_insert_all=when_not_matched_insert_all, - case_sensitive=case_sensitive + df=df, + join_cols=join_cols, + when_matched_update_all=when_matched_update_all, + when_not_matched_insert_all=when_not_matched_insert_all, + case_sensitive=case_sensitive, ) def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: From cebfda373efd0ea17460ff73f419062027d093da Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 25 Mar 2025 15:41:37 +0100 Subject: [PATCH 03/19] Write a test for upsert transaction --- tests/table/test_upsert.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 70203fd162..429c78091c 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -23,7 +23,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import And, EqualTo, Reference +from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference from pyiceberg.expressions.literals import LongLiteral from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema @@ -709,3 +709,26 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: ], schema=schema, ) + + +def test_transaction(catalog: Catalog) -> None: + """Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is + rolled back.""" + identifier = "default.test_merge_source_dups" + _drop_table(catalog, identifier) + + ctx = SessionContext() + + table = gen_target_iceberg_table(1, 10, False, ctx, catalog, identifier) + df_before_transaction = table.scan().to_arrow() + + source_df = gen_source_dataset(5, 15, False, True, ctx) + + with pytest.raises(Exception, match="Duplicate rows found in source dataset based on the key columns. No upsert executed"): + with table.transaction() as tx: + tx.delete(delete_filter=AlwaysTrue()) + tx.upsert(df=source_df, join_cols=["order_id"]) + + df = table.scan().to_arrow() + + assert df_before_transaction == df From 52fd35eebb81dd351be5923bd6f2aeb8341d6671 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Thu, 27 Mar 2025 21:14:46 +0100 Subject: [PATCH 04/19] Add failing test for multiple upserts in same transaction --- tests/table/test_upsert.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 429c78091c..553b1ef5b3 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -732,3 +732,39 @@ def test_transaction(catalog: Catalog) -> None: df = table.scan().to_arrow() assert df_before_transaction == df + + +def test_transaction_multiple_upserts(catalog: Catalog) -> None: + identifier = "default.test_multi_upsert" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "name", StringType(), required=True), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + # Define exact schema: required int32 and required string + arrow_schema = pa.schema([ + pa.field("id", pa.int32(), nullable=False), + pa.field("name", pa.string(), nullable=False), + ]) + + tbl.append(pa.Table.from_pylist([{"id": 1, "name": "Alice"}], schema=arrow_schema)) + + df = pa.Table.from_pylist([{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alicia"}], schema=arrow_schema) + + with tbl.transaction() as txn: + # This should read the uncommitted changes? + txn.upsert(df, join_cols=["id"]) + + txn.upsert(df, join_cols=["id"]) + + result = tbl.scan().to_arrow().to_pylist() + assert sorted(result, key=lambda x: x["id"]) == [ + {"id": 1, "name": "Alicia"}, + {"id": 2, "name": "Bob"}, + ] + From f336c0b6b92a2516f34b6f52cb332fda1602c6a7 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Wed, 2 Apr 2025 22:02:32 +0200 Subject: [PATCH 05/19] Fix test --- tests/table/test_upsert.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 553b1ef5b3..2e65c97259 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -734,6 +734,7 @@ def test_transaction(catalog: Catalog) -> None: assert df_before_transaction == df +@pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction") def test_transaction_multiple_upserts(catalog: Catalog) -> None: identifier = "default.test_multi_upsert" _drop_table(catalog, identifier) @@ -747,24 +748,28 @@ def test_transaction_multiple_upserts(catalog: Catalog) -> None: tbl = catalog.create_table(identifier, schema=schema) # Define exact schema: required int32 and required string - arrow_schema = pa.schema([ - pa.field("id", pa.int32(), nullable=False), - pa.field("name", pa.string(), nullable=False), - ]) + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field("name", pa.string(), nullable=False), + ] + ) tbl.append(pa.Table.from_pylist([{"id": 1, "name": "Alice"}], schema=arrow_schema)) df = pa.Table.from_pylist([{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alicia"}], schema=arrow_schema) with tbl.transaction() as txn: + txn.append(df) + txn.delete(delete_filter="id = 1") + txn.append(df) # This should read the uncommitted changes? txn.upsert(df, join_cols=["id"]) - txn.upsert(df, join_cols=["id"]) + # txn.upsert(df, join_cols=["id"]) result = tbl.scan().to_arrow().to_pylist() assert sorted(result, key=lambda x: x["id"]) == [ {"id": 1, "name": "Alicia"}, {"id": 2, "name": "Bob"}, ] - From 07890ac4d73cc051dd92e27cff7981f3e134534d Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 13 May 2025 10:47:34 +0200 Subject: [PATCH 06/19] Add failing test --- tests/table/test_upsert.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 2e65c97259..85315a81db 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -734,7 +734,7 @@ def test_transaction(catalog: Catalog) -> None: assert df_before_transaction == df -@pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction") +# @pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction") def test_transaction_multiple_upserts(catalog: Catalog) -> None: identifier = "default.test_multi_upsert" _drop_table(catalog, identifier) @@ -760,13 +760,12 @@ def test_transaction_multiple_upserts(catalog: Catalog) -> None: df = pa.Table.from_pylist([{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alicia"}], schema=arrow_schema) with tbl.transaction() as txn: - txn.append(df) txn.delete(delete_filter="id = 1") txn.append(df) - # This should read the uncommitted changes? - txn.upsert(df, join_cols=["id"]) - # txn.upsert(df, join_cols=["id"]) + # This should read the uncommitted changes + # TODO: currently fails because it only reads {"id": 1, "name": "Alice"} + txn.upsert(df, join_cols=["id"]) result = tbl.scan().to_arrow().to_pylist() assert sorted(result, key=lambda x: x["id"]) == [ From ae0e60fa4ca725d42a8b1832cdbb4fec936ddc60 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 13 May 2025 10:56:08 +0200 Subject: [PATCH 07/19] Use Transaction.table_metadata when doing the data scan in upsert --- pyiceberg/table/__init__.py | 9 ++++++++- tests/table/test_upsert.py | 1 - 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 99f2fee388..78676a774a 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -772,7 +772,14 @@ def upsert( # get list of rows that exist so we don't have to load the entire target table matched_predicate = upsert_util.create_match_filter(df, join_cols) - matched_iceberg_table = self._table.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() + + # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. + matched_iceberg_table = DataScan( + table_metadata=self.table_metadata, + io=self._table.io, + row_filter=matched_predicate, + case_sensitive=case_sensitive, + ).to_arrow() update_row_cnt = 0 insert_row_cnt = 0 diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 85315a81db..10593ea62e 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -764,7 +764,6 @@ def test_transaction_multiple_upserts(catalog: Catalog) -> None: txn.append(df) # This should read the uncommitted changes - # TODO: currently fails because it only reads {"id": 1, "name": "Alice"} txn.upsert(df, join_cols=["id"]) result = tbl.scan().to_arrow().to_pylist() From ce8d9efc72110e29330355d84e011ba6707d5802 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 13 May 2025 10:58:45 +0200 Subject: [PATCH 08/19] Remove as it's resolved --- tests/table/test_upsert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 10593ea62e..9fecbbb7bb 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -734,7 +734,6 @@ def test_transaction(catalog: Catalog) -> None: assert df_before_transaction == df -# @pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction") def test_transaction_multiple_upserts(catalog: Catalog) -> None: identifier = "default.test_multi_upsert" _drop_table(catalog, identifier) From 5bdb0b8aa50b0185c8dcaacd259b4aa15fb84cd9 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 13 May 2025 12:19:13 +0200 Subject: [PATCH 09/19] Use to_arrow_batch_reader instead of to_arrow in upsert to prevent OOM when updating large tables --- pyiceberg/table/__init__.py | 62 +++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 78676a774a..5f38828024 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -16,7 +16,9 @@ # under the License. from __future__ import annotations +import functools import itertools +import operator import os import uuid import warnings @@ -774,39 +776,59 @@ def upsert( matched_predicate = upsert_util.create_match_filter(df, join_cols) # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. - matched_iceberg_table = DataScan( + matched_iceberg_record_batches = DataScan( table_metadata=self.table_metadata, io=self._table.io, row_filter=matched_predicate, case_sensitive=case_sensitive, - ).to_arrow() + ).to_arrow_batch_reader() - update_row_cnt = 0 - insert_row_cnt = 0 + batches_to_overwrite = [] + overwrite_predicates = [] + insert_filters = [] - if when_matched_update_all: - # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed - # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed - # this extra step avoids unnecessary IO and writes - rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols) + for batch in matched_iceberg_record_batches: + rows = pa.Table.from_batches([batch]) - update_row_cnt = len(rows_to_update) + if when_matched_update_all: + # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed + # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed + # this extra step avoids unnecessary IO and writes + rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols) + + if len(rows_to_update) > 0: + # build the match predicate filter + overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) - if len(rows_to_update) > 0: - # build the match predicate filter - overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) + batches_to_overwrite.append(rows_to_update) + overwrite_predicates.append(overwrite_mask_predicate) - self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate) + if when_not_matched_insert_all: + expr_match = upsert_util.create_match_filter(rows, join_cols) + expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) + expr_match_arrow = expression_to_pyarrow(expr_match_bound) + + insert_filters.append(~expr_match_arrow) + + update_row_cnt = 0 + insert_row_cnt = 0 + + if batches_to_overwrite: + rows_to_update = pa.concat_tables(batches_to_overwrite) + update_row_cnt = len(rows_to_update) + self.overwrite( + rows_to_update, + overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0], + ) if when_not_matched_insert_all: - expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) - expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) - expr_match_arrow = expression_to_pyarrow(expr_match_bound) - rows_to_insert = df.filter(~expr_match_arrow) + if insert_filters: + rows_to_insert = df.filter(functools.reduce(operator.and_, insert_filters)) + else: + rows_to_insert = df insert_row_cnt = len(rows_to_insert) - - if insert_row_cnt > 0: + if rows_to_insert: self.append(rows_to_insert) return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) From 65fe36d10944516e40cce472e65234d9fd928657 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 13 May 2025 13:02:41 +0200 Subject: [PATCH 10/19] Filter rows to insert on each iteration instead of keeping a list of all filter expressions. Prevents memory pressure due to large filters --- pyiceberg/table/__init__.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5f38828024..ab41138fd2 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -16,9 +16,7 @@ # under the License. from __future__ import annotations -import functools import itertools -import operator import os import uuid import warnings @@ -785,7 +783,7 @@ def upsert( batches_to_overwrite = [] overwrite_predicates = [] - insert_filters = [] + rows_to_insert = df for batch in matched_iceberg_record_batches: rows = pa.Table.from_batches([batch]) @@ -808,7 +806,8 @@ def upsert( expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) expr_match_arrow = expression_to_pyarrow(expr_match_bound) - insert_filters.append(~expr_match_arrow) + # Filter rows per batch. + rows_to_insert = rows_to_insert.filter(~expr_match_arrow) update_row_cnt = 0 insert_row_cnt = 0 @@ -822,11 +821,6 @@ def upsert( ) if when_not_matched_insert_all: - if insert_filters: - rows_to_insert = df.filter(functools.reduce(operator.and_, insert_filters)) - else: - rows_to_insert = df - insert_row_cnt = len(rows_to_insert) if rows_to_insert: self.append(rows_to_insert) From 88a4ad2ebc00ac5f7c7d392a3472d3d9c87185fa Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Mon, 2 Jun 2025 23:32:05 +0200 Subject: [PATCH 11/19] Accept concurrent_tasks when fetching record_batches --- pyiceberg/io/pyarrow.py | 19 +++++++++++++++++-- pyiceberg/table/__init__.py | 4 ++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 1aaab32dbe..160671cd63 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1625,7 +1625,9 @@ def _table_from_scan_task(task: FileScanTask) -> pa.Table: return result - def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]: + def to_record_batches( + self, tasks: Iterable[FileScanTask], concurrent_tasks: Optional[int] = None + ) -> Iterator[pa.RecordBatch]: """Scan the Iceberg table and return an Iterator[pa.RecordBatch]. Returns an Iterator of pa.RecordBatch with data from the Iceberg table @@ -1634,6 +1636,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record Args: tasks: FileScanTasks representing the data files and delete files to read from. + concurrent_tasks: number of concurrent tasks Returns: An Iterator of PyArrow RecordBatches. @@ -1643,8 +1646,20 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ + from concurrent.futures import ThreadPoolExecutor + deletes_per_file = _read_all_delete_files(self._io, tasks) - return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file) + + if concurrent_tasks is not None: + with ThreadPoolExecutor(max_workers=concurrent_tasks) as pool: + for batches in pool.map( + lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks + ): + for batch in batches: + yield batch + + else: + return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file) def _record_batches_from_scan_tasks_and_deletes( self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, List[ChunkedArray]] diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ab41138fd2..46bda600a4 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1864,7 +1864,7 @@ def to_arrow(self) -> pa.Table: self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit ).to_table(self.plan_files()) - def to_arrow_batch_reader(self) -> pa.RecordBatchReader: + def to_arrow_batch_reader(self, concurrent_tasks: Optional[int] = None) -> pa.RecordBatchReader: """Return an Arrow RecordBatchReader from this DataScan. For large results, using a RecordBatchReader requires less memory than @@ -1882,7 +1882,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files()) + ).to_record_batches(self.plan_files(), concurrent_tasks=concurrent_tasks) return pa.RecordBatchReader.from_batches( target_schema, From f8acdb05a4e81c1e961b6d499f9722f903f6f1ab Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 3 Jun 2025 07:11:08 +0200 Subject: [PATCH 12/19] Use ExecutorFactory --- pyiceberg/io/pyarrow.py | 4 +--- pyiceberg/utils/concurrent.py | 4 ++++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 160671cd63..12dfc51e1f 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1646,12 +1646,10 @@ def to_record_batches( ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - from concurrent.futures import ThreadPoolExecutor - deletes_per_file = _read_all_delete_files(self._io, tasks) if concurrent_tasks is not None: - with ThreadPoolExecutor(max_workers=concurrent_tasks) as pool: + with ExecutorFactory.create(max_workers=concurrent_tasks) as pool: for batches in pool.map( lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks ): diff --git a/pyiceberg/utils/concurrent.py b/pyiceberg/utils/concurrent.py index 805599bf41..8b112a1147 100644 --- a/pyiceberg/utils/concurrent.py +++ b/pyiceberg/utils/concurrent.py @@ -38,3 +38,7 @@ def get_or_create() -> Executor: def max_workers() -> Optional[int]: """Return the max number of workers configured.""" return Config().get_int("max-workers") + + @staticmethod + def create(max_workers: int) -> Executor: + return ThreadPoolExecutor(max_workers=max_workers) From 119d92fb4397bcee04cca2201a217a15bc9edc92 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 3 Jun 2025 11:46:00 +0200 Subject: [PATCH 13/19] Simplify to_arrow to use the optimized to_record_batches --- pyiceberg/io/pyarrow.py | 79 +++++++++++++---------------------- pyiceberg/table/__init__.py | 4 +- pyiceberg/utils/concurrent.py | 4 -- 3 files changed, 31 insertions(+), 56 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 12dfc51e1f..15838b9236 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1570,47 +1570,17 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - deletes_per_file = _read_all_delete_files(self._io, tasks) - executor = ExecutorFactory.get_or_create() - - def _table_from_scan_task(task: FileScanTask) -> pa.Table: - batches = list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)) - if len(batches) > 0: - return pa.Table.from_batches(batches) - else: - return None - - futures = [ - executor.submit( - _table_from_scan_task, - task, - ) - for task in tasks - ] - total_row_count = 0 - # for consistent ordering, we need to maintain future order - futures_index = {f: i for i, f in enumerate(futures)} - completed_futures: SortedList[Future[pa.Table]] = SortedList(iterable=[], key=lambda f: futures_index[f]) - for future in concurrent.futures.as_completed(futures): - completed_futures.add(future) - if table_result := future.result(): - total_row_count += len(table_result) - # stop early if limit is satisfied - if self._limit is not None and total_row_count >= self._limit: - break - - # by now, we've either completed all tasks or satisfied the limit - if self._limit is not None: - _ = [f.cancel() for f in futures if not f.done()] - - tables = [f.result() for f in completed_futures if f.result()] arrow_schema = schema_to_pyarrow(self._projected_schema, include_field_ids=False) - if len(tables) < 1: + batches = self.to_record_batches(tasks) + try: + first_batch = next(batches) + except StopIteration: + # Empty return pa.Table.from_batches([], schema=arrow_schema) - result = pa.concat_tables(tables, promote_options="permissive") + result = pa.Table.from_batches(itertools.chain([first_batch], batches)) if property_as_bool(self._io.properties, PYARROW_USE_LARGE_TYPES_ON_READ, False): deprecation_message( @@ -1620,13 +1590,10 @@ def _table_from_scan_task(task: FileScanTask) -> pa.Table: ) result = result.cast(arrow_schema) - if self._limit is not None: - return result.slice(0, self._limit) - return result def to_record_batches( - self, tasks: Iterable[FileScanTask], concurrent_tasks: Optional[int] = None + self, tasks: Iterable[FileScanTask] ) -> Iterator[pa.RecordBatch]: """Scan the Iceberg table and return an Iterator[pa.RecordBatch]. @@ -1636,7 +1603,6 @@ def to_record_batches( Args: tasks: FileScanTasks representing the data files and delete files to read from. - concurrent_tasks: number of concurrent tasks Returns: An Iterator of PyArrow RecordBatches. @@ -1648,16 +1614,29 @@ def to_record_batches( """ deletes_per_file = _read_all_delete_files(self._io, tasks) - if concurrent_tasks is not None: - with ExecutorFactory.create(max_workers=concurrent_tasks) as pool: - for batches in pool.map( - lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks - ): - for batch in batches: - yield batch + total_row_count = 0 + executor = ExecutorFactory.get_or_create() - else: - return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file) + with executor as pool: + should_stop = False + for batches in pool.map( + lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks + ): + for batch in batches: + current_batch_size = len(batch) + if self._limit is not None: + if total_row_count + current_batch_size >= self._limit: + yield batch.slice(0, self._limit - total_row_count) + + # This break will also cancel all tasks in the Pool + should_stop = True + break + + yield batch + total_row_count += current_batch_size + + if should_stop: + break def _record_batches_from_scan_tasks_and_deletes( self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, List[ChunkedArray]] diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 46bda600a4..ab41138fd2 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1864,7 +1864,7 @@ def to_arrow(self) -> pa.Table: self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit ).to_table(self.plan_files()) - def to_arrow_batch_reader(self, concurrent_tasks: Optional[int] = None) -> pa.RecordBatchReader: + def to_arrow_batch_reader(self) -> pa.RecordBatchReader: """Return an Arrow RecordBatchReader from this DataScan. For large results, using a RecordBatchReader requires less memory than @@ -1882,7 +1882,7 @@ def to_arrow_batch_reader(self, concurrent_tasks: Optional[int] = None) -> pa.Re target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_record_batches(self.plan_files(), concurrent_tasks=concurrent_tasks) + ).to_record_batches(self.plan_files()) return pa.RecordBatchReader.from_batches( target_schema, diff --git a/pyiceberg/utils/concurrent.py b/pyiceberg/utils/concurrent.py index 8b112a1147..805599bf41 100644 --- a/pyiceberg/utils/concurrent.py +++ b/pyiceberg/utils/concurrent.py @@ -38,7 +38,3 @@ def get_or_create() -> Executor: def max_workers() -> Optional[int]: """Return the max number of workers configured.""" return Config().get_int("max-workers") - - @staticmethod - def create(max_workers: int) -> Executor: - return ThreadPoolExecutor(max_workers=max_workers) From 5d3a6aa012fb1f24988ea3756d875655075b3155 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 3 Jun 2025 11:59:13 +0200 Subject: [PATCH 14/19] Fix for shutdown pool after doing a map --- pyiceberg/io/pyarrow.py | 41 +++++++++++++++++------------------------ 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 15838b9236..4c86780252 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -25,7 +25,6 @@ from __future__ import annotations -import concurrent.futures import fnmatch import functools import itertools @@ -36,7 +35,6 @@ import uuid import warnings from abc import ABC, abstractmethod -from concurrent.futures import Future from copy import copy from dataclasses import dataclass from enum import Enum @@ -71,7 +69,6 @@ FileType, FSSpecHandler, ) -from sortedcontainers import SortedList from pyiceberg.conversions import to_bytes from pyiceberg.exceptions import ResolveError @@ -1570,7 +1567,6 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - arrow_schema = schema_to_pyarrow(self._projected_schema, include_field_ids=False) batches = self.to_record_batches(tasks) @@ -1592,9 +1588,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: return result - def to_record_batches( - self, tasks: Iterable[FileScanTask] - ) -> Iterator[pa.RecordBatch]: + def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]: """Scan the Iceberg table and return an Iterator[pa.RecordBatch]. Returns an Iterator of pa.RecordBatch with data from the Iceberg table @@ -1617,26 +1611,25 @@ def to_record_batches( total_row_count = 0 executor = ExecutorFactory.get_or_create() - with executor as pool: - should_stop = False - for batches in pool.map( - lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks - ): - for batch in batches: - current_batch_size = len(batch) - if self._limit is not None: - if total_row_count + current_batch_size >= self._limit: - yield batch.slice(0, self._limit - total_row_count) + limit_reached = False + for batches in executor.map( + lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks + ): + for batch in batches: + current_batch_size = len(batch) + if self._limit is not None: + if total_row_count + current_batch_size >= self._limit: + yield batch.slice(0, self._limit - total_row_count) - # This break will also cancel all tasks in the Pool - should_stop = True - break + # This break will also cancel all tasks in the Pool + limit_reached = True + break - yield batch - total_row_count += current_batch_size + yield batch + total_row_count += current_batch_size - if should_stop: - break + if limit_reached: + break def _record_batches_from_scan_tasks_and_deletes( self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, List[ChunkedArray]] From 445845d7a7c5a7d894a00a0e881f51e5f2478dc5 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 3 Jun 2025 12:02:33 +0200 Subject: [PATCH 15/19] minor --- pyiceberg/io/pyarrow.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 4c86780252..47e92cf1b0 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1617,13 +1617,12 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record ): for batch in batches: current_batch_size = len(batch) - if self._limit is not None: - if total_row_count + current_batch_size >= self._limit: - yield batch.slice(0, self._limit - total_row_count) + if self._limit is not None and total_row_count + current_batch_size >= self._limit: + yield batch.slice(0, self._limit - total_row_count) - # This break will also cancel all tasks in the Pool - limit_reached = True - break + # This break will also cancel all running tasks + limit_reached = True + break yield batch total_row_count += current_batch_size From 6b8daceced28001be0b1c6513ee86911d9661e36 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 3 Jun 2025 16:22:33 +0200 Subject: [PATCH 16/19] Improve comments in to_record_batches --- pyiceberg/io/pyarrow.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 47e92cf1b0..968dbce540 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1611,16 +1611,19 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record total_row_count = 0 executor = ExecutorFactory.get_or_create() + def batches_for_task(task: FileScanTask) -> List[pa.RecordBatch]: + # Materialize the iterator here to ensure execution happens within the executor. + # Otherwise, the iterator would be lazily consumed later (in the main thread), + # defeating the purpose of using executor.map. + return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)) + limit_reached = False - for batches in executor.map( - lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks - ): + for batches in executor.map(batches_for_task, tasks): for batch in batches: current_batch_size = len(batch) if self._limit is not None and total_row_count + current_batch_size >= self._limit: yield batch.slice(0, self._limit - total_row_count) - # This break will also cancel all running tasks limit_reached = True break @@ -1628,6 +1631,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record total_row_count += current_batch_size if limit_reached: + # This break will also cancel all running tasks in the executor break def _record_batches_from_scan_tasks_and_deletes( From 2cd21373586a107ae6ba0067a42eb83e4e52f08b Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Sat, 14 Jun 2025 22:23:51 +0200 Subject: [PATCH 17/19] Make sure the 'infer the types when reading (#1669)' works again --- pyiceberg/io/pyarrow.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 968dbce540..1f56dddb8a 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1576,7 +1576,11 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: # Empty return pa.Table.from_batches([], schema=arrow_schema) - result = pa.Table.from_batches(itertools.chain([first_batch], batches)) + # Note: cannot use pa.Table.from_batches(itertools.chain([first_batch], batches))) + # as different batches can use different schema's (due to large_ types) + result = pa.concat_tables( + (pa.Table.from_batches([batch]) for batch in itertools.chain([first_batch], batches)), promote_options="permissive" + ) if property_as_bool(self._io.properties, PYARROW_USE_LARGE_TYPES_ON_READ, False): deprecation_message( From 8e6f5e901a125be51f3a42d0e33f824209bea774 Mon Sep 17 00:00:00 2001 From: koenvo Date: Fri, 20 Jun 2025 17:08:32 +0200 Subject: [PATCH 18/19] Update pyiceberg/io/pyarrow.py Use arrow_schema.empty_table() Co-authored-by: Fokko Driesprong --- pyiceberg/io/pyarrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 1f56dddb8a..5755220631 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1574,7 +1574,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: first_batch = next(batches) except StopIteration: # Empty - return pa.Table.from_batches([], schema=arrow_schema) + return arrow_schema.empty_table() # Note: cannot use pa.Table.from_batches(itertools.chain([first_batch], batches))) # as different batches can use different schema's (due to large_ types) From 68130619d771a74b5e74ccc0f8add58fea036b4c Mon Sep 17 00:00:00 2001 From: koenvo Date: Fri, 20 Jun 2025 17:08:46 +0200 Subject: [PATCH 19/19] Update pyiceberg/io/pyarrow.py Co-authored-by: Fokko Driesprong --- pyiceberg/io/pyarrow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 5755220631..62e6ed8af1 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1630,9 +1630,9 @@ def batches_for_task(task: FileScanTask) -> List[pa.RecordBatch]: limit_reached = True break - - yield batch - total_row_count += current_batch_size + else: + yield batch + total_row_count += current_batch_size if limit_reached: # This break will also cancel all running tasks in the executor