diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index c2d554dfae..edec13c3e7 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -57,8 +57,8 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols """ Return a table with rows that need to be updated in the target table based on the join columns. - The table is joined on the identifier columns, and then checked if there are any updated rows. - Those are selected and everything is renamed correctly. + When a row is matched, an additional scan is done to evaluate the non-key columns to detect if an actual change has occurred. + Only matched rows that have an actual change to a non-key column value will be returned in the final output. """ all_columns = set(source_table.column_names) join_cols_set = set(join_cols) @@ -71,25 +71,39 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # When the target table is empty, there is nothing to update :) return source_table.schema.empty_table() - diff_expr = functools.reduce( - operator.or_, - [ - pc.or_kleene( - pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs")), - pc.is_null(pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs"))), - ) - for col in non_key_cols - ], - ) - - return ( - source_table - # We already know that the schema is compatible, this is to fix large_ types - .cast(target_table.schema) - .join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs") - .filter(diff_expr) - .drop_columns([f"{col}-rhs" for col in non_key_cols]) - .rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names}) - # Finally cast to the original schema since it doesn't carry nullability: - # https://github.com/apache/arrow/issues/45557 - ).cast(target_table.schema) + match_expr = functools.reduce(operator.and_, [pc.field(col).isin(target_table.column(col).to_pylist()) for col in join_cols]) + + matching_source_rows = source_table.filter(match_expr) + + rows_to_update = [] + + for index in range(matching_source_rows.num_rows): + source_row = matching_source_rows.slice(index, 1) + + target_filter = functools.reduce(operator.and_, [pc.field(col) == source_row.column(col)[0].as_py() for col in join_cols]) + + matching_target_row = target_table.filter(target_filter) + + if matching_target_row.num_rows > 0: + needs_update = False + + for non_key_col in non_key_cols: + source_value = source_row.column(non_key_col)[0].as_py() + target_value = matching_target_row.column(non_key_col)[0].as_py() + + if source_value != target_value: + needs_update = True + break + + if needs_update: + rows_to_update.append(source_row) + + if rows_to_update: + rows_to_update_table = pa.concat_tables(rows_to_update) + else: + rows_to_update_table = source_table.schema.empty_table() + + common_columns = set(source_table.column_names).intersection(set(target_table.column_names)) + rows_to_update_table = rows_to_update_table.select(list(common_columns)) + + return rows_to_update_table diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 5de4a61187..9773f5f0b0 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -30,7 +30,7 @@ from pyiceberg.table import UpsertResult from pyiceberg.table.snapshots import Operation from pyiceberg.table.upsert_util import create_match_filter -from pyiceberg.types import IntegerType, NestedField, StringType +from pyiceberg.types import IntegerType, NestedField, StringType, StructType from tests.catalog.test_base import InMemoryCatalog, Table @@ -552,3 +552,137 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: ], schema=schema, ) + + +def test_upsert_with_struct_field(catalog: Catalog) -> None: + identifier = "default.test_upsert_with_struct_field" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField( + 2, + "nested_type", + StructType( + NestedField(3, "sub1", StringType(), required=True), + NestedField(4, "sub2", StringType(), required=True), + ), + required=False, + ), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field( + "nested_type", + pa.struct( + [ + pa.field("sub1", pa.large_string(), nullable=False), + pa.field("sub2", pa.large_string(), nullable=False), + ] + ), + nullable=True, + ), + ] + ) + + initial_data = pa.Table.from_pylist( + [ + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + } + ], + schema=arrow_schema, + ) + tbl.append(initial_data) + + update_data = pa.Table.from_pylist( + [ + { + "id": 2, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + ], + schema=arrow_schema, + ) + + upd = tbl.upsert(update_data, join_cols=["id"]) + + assert upd.rows_updated == 0 + assert upd.rows_inserted == 1 + + +def test_upsert_with_struct_field_as_join_key(catalog: Catalog) -> None: + identifier = "default.test_upsert_with_struct_field_as_join_key" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField( + 2, + "nested_type", + StructType( + NestedField(3, "sub1", StringType(), required=True), + NestedField(4, "sub2", StringType(), required=True), + ), + required=False, + ), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field( + "nested_type", + pa.struct( + [ + pa.field("sub1", pa.large_string(), nullable=False), + pa.field("sub2", pa.large_string(), nullable=False), + ] + ), + nullable=True, + ), + ] + ) + + initial_data = pa.Table.from_pylist( + [ + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + } + ], + schema=arrow_schema, + ) + tbl.append(initial_data) + + update_data = pa.Table.from_pylist( + [ + { + "id": 2, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + ], + schema=arrow_schema, + ) + + with pytest.raises( + pa.lib.ArrowNotImplementedError, match="Keys of type struct" + ): + _ = tbl.upsert(update_data, join_cols=["nested_type"])