diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index c2d554dfae..cefdd101a0 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -62,7 +62,8 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols """ all_columns = set(source_table.column_names) join_cols_set = set(join_cols) - non_key_cols = all_columns - join_cols_set + + non_key_cols = list(all_columns - join_cols_set) if has_duplicate_rows(target_table, join_cols): raise ValueError("Target table has duplicate rows, aborting upsert") @@ -71,25 +72,51 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # When the target table is empty, there is nothing to update :) return source_table.schema.empty_table() - diff_expr = functools.reduce( - operator.or_, - [ - pc.or_kleene( - pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs")), - pc.is_null(pc.not_equal(pc.field(f"{col}-lhs"), pc.field(f"{col}-rhs"))), - ) - for col in non_key_cols - ], + # We need to compare non_key_cols in Python as PyArrow + # 1. Cannot do a join when non-join columns have complex types + # 2. Cannot compare columns with complex types + # See: https://github.com/apache/arrow/issues/35785 + SOURCE_INDEX_COLUMN_NAME = "__source_index" + TARGET_INDEX_COLUMN_NAME = "__target_index" + + if SOURCE_INDEX_COLUMN_NAME in join_cols or TARGET_INDEX_COLUMN_NAME in join_cols: + raise ValueError( + f"{SOURCE_INDEX_COLUMN_NAME} and {TARGET_INDEX_COLUMN_NAME} are reserved for joining " + f"DataFrames, and cannot be used as column names" + ) from None + + # Step 1: Prepare source index with join keys and a marker index + # Cast to target table schema, so we can do the join + # See: https://github.com/apache/arrow/issues/37542 + source_index = ( + source_table.cast(target_table.schema) + .select(join_cols_set) + .append_column(SOURCE_INDEX_COLUMN_NAME, pa.array(range(len(source_table)))) ) - return ( - source_table - # We already know that the schema is compatible, this is to fix large_ types - .cast(target_table.schema) - .join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs") - .filter(diff_expr) - .drop_columns([f"{col}-rhs" for col in non_key_cols]) - .rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names}) - # Finally cast to the original schema since it doesn't carry nullability: - # https://github.com/apache/arrow/issues/45557 - ).cast(target_table.schema) + # Step 2: Prepare target index with join keys and a marker + target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table)))) + + # Step 3: Perform an inner join to find which rows from source exist in target + matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") + + # Step 4: Compare all rows using Python + to_update_indices = [] + for source_idx, target_idx in zip( + matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist() + ): + source_row = source_table.slice(source_idx, 1) + target_row = target_table.slice(target_idx, 1) + + for key in non_key_cols: + source_val = source_row.column(key)[0].as_py() + target_val = target_row.column(key)[0].as_py() + if source_val != target_val: + to_update_indices.append(source_idx) + break + + # Step 5: Take rows from source table using the indices and cast to target schema + if to_update_indices: + return source_table.take(to_update_indices) + else: + return source_table.schema.empty_table() diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 5de4a61187..70203fd162 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -30,7 +30,7 @@ from pyiceberg.table import UpsertResult from pyiceberg.table.snapshots import Operation from pyiceberg.table.upsert_util import create_match_filter -from pyiceberg.types import IntegerType, NestedField, StringType +from pyiceberg.types import IntegerType, NestedField, StringType, StructType from tests.catalog.test_base import InMemoryCatalog, Table @@ -511,6 +511,163 @@ def test_upsert_without_identifier_fields(catalog: Catalog) -> None: tbl.upsert(df) +def test_upsert_with_struct_field_as_non_join_key(catalog: Catalog) -> None: + identifier = "default.test_upsert_struct_field_fails" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField( + 2, + "nested_type", + StructType( + NestedField(3, "sub1", StringType(), required=True), + NestedField(4, "sub2", StringType(), required=True), + ), + required=False, + ), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field( + "nested_type", + pa.struct( + [ + pa.field("sub1", pa.large_string(), nullable=False), + pa.field("sub2", pa.large_string(), nullable=False), + ] + ), + nullable=True, + ), + ] + ) + + initial_data = pa.Table.from_pylist( + [ + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + } + ], + schema=arrow_schema, + ) + tbl.append(initial_data) + + update_data = pa.Table.from_pylist( + [ + { + "id": 2, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla2"}, + }, + ], + schema=arrow_schema, + ) + + res = tbl.upsert(update_data, join_cols=["id"]) + + expected_updated = 1 + expected_inserted = 1 + + assert_upsert_result(res, expected_updated, expected_inserted) + + update_data = pa.Table.from_pylist( + [ + { + "id": 2, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla2"}, + }, + ], + schema=arrow_schema, + ) + + res = tbl.upsert(update_data, join_cols=["id"]) + + expected_updated = 0 + expected_inserted = 0 + + assert_upsert_result(res, expected_updated, expected_inserted) + + +def test_upsert_with_struct_field_as_join_key(catalog: Catalog) -> None: + identifier = "default.test_upsert_with_struct_field_as_join_key" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField( + 2, + "nested_type", + StructType( + NestedField(3, "sub1", StringType(), required=True), + NestedField(4, "sub2", StringType(), required=True), + ), + required=False, + ), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field( + "nested_type", + pa.struct( + [ + pa.field("sub1", pa.large_string(), nullable=False), + pa.field("sub2", pa.large_string(), nullable=False), + ] + ), + nullable=True, + ), + ] + ) + + initial_data = pa.Table.from_pylist( + [ + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + } + ], + schema=arrow_schema, + ) + tbl.append(initial_data) + + update_data = pa.Table.from_pylist( + [ + { + "id": 2, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + ], + schema=arrow_schema, + ) + + with pytest.raises( + pa.lib.ArrowNotImplementedError, match="Keys of type struct" + ): + _ = tbl.upsert(update_data, join_cols=["nested_type"]) + + def test_upsert_with_nulls(catalog: Catalog) -> None: identifier = "default.test_upsert_with_nulls" _drop_table(catalog, identifier)