diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 723a89aa20..c12351d45c 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -16,12 +16,14 @@ # under the License. import functools import operator +from typing import List, cast import pyarrow as pa from pyarrow import Table as pyarrow_table from pyarrow import compute as pc from pyiceberg.expressions import ( + AlwaysFalse, And, BooleanExpression, EqualTo, @@ -36,7 +38,16 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre if len(join_cols) == 1: return In(join_cols[0], unique_keys[0].to_pylist()) else: - return Or(*[And(*[EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()]) + filters: List[BooleanExpression] = [ + cast(BooleanExpression, And(*[EqualTo(col, row[col]) for col in join_cols])) for row in unique_keys.to_pylist() + ] + + if len(filters) == 0: + return AlwaysFalse() + elif len(filters) == 1: + return filters[0] + else: + return functools.reduce(lambda a, b: Or(a, b), filters) def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: @@ -86,7 +97,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols if rows_to_update: rows_to_update_table = pa.concat_tables(rows_to_update) else: - rows_to_update_table = pa.Table.from_arrays([], names=source_table.column_names) + rows_to_update_table = source_table.schema.empty_table() common_columns = set(source_table.column_names).intersection(set(target_table.column_names)) rows_to_update_table = rows_to_update_table.select(list(common_columns)) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 0cfb0ba609..c97015e650 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -23,8 +23,11 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.expressions import And, EqualTo, Reference +from pyiceberg.expressions.literals import LongLiteral from pyiceberg.schema import Schema from pyiceberg.table import UpsertResult +from pyiceberg.table.upsert_util import create_match_filter from pyiceberg.types import IntegerType, NestedField, StringType from tests.catalog.test_base import InMemoryCatalog, Table @@ -366,3 +369,22 @@ def test_upsert_with_identifier_fields(catalog: Catalog) -> None: assert upd.rows_updated == 1 assert upd.rows_inserted == 1 + + +def test_create_match_filter_single_condition() -> None: + """ + Test create_match_filter with a composite key where the source yields exactly one unique key. + Expected: The function returns the single And condition directly. + """ + + data = [ + {"order_id": 101, "order_line_id": 1, "extra": "x"}, + {"order_id": 101, "order_line_id": 1, "extra": "x"}, # duplicate + ] + schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32()), pa.field("extra", pa.string())]) + table = pa.Table.from_pylist(data, schema=schema) + expr = create_match_filter(table, ["order_id", "order_line_id"]) + assert expr == And( + EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)), + EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)), + )