Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
# under the License.
import functools
import operator
from typing import List, cast

import pyarrow as pa
from pyarrow import Table as pyarrow_table
from pyarrow import compute as pc

from pyiceberg.expressions import (
AlwaysFalse,
And,
BooleanExpression,
EqualTo,
Expand All @@ -36,7 +38,16 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre
if len(join_cols) == 1:
return In(join_cols[0], unique_keys[0].to_pylist())
else:
return Or(*[And(*[EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()])
filters: List[BooleanExpression] = [
cast(BooleanExpression, And(*[EqualTo(col, row[col]) for col in join_cols])) for row in unique_keys.to_pylist()
]

if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return functools.reduce(lambda a, b: Or(a, b), filters)


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
Expand Down Expand Up @@ -86,7 +97,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
if rows_to_update:
rows_to_update_table = pa.concat_tables(rows_to_update)
else:
rows_to_update_table = pa.Table.from_arrays([], names=source_table.column_names)
rows_to_update_table = source_table.schema.empty_table()

common_columns = set(source_table.column_names).intersection(set(target_table.column_names))
rows_to_update_table = rows_to_update_table.select(list(common_columns))
Expand Down
22 changes: 22 additions & 0 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import And, EqualTo, Reference
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.schema import Schema
from pyiceberg.table import UpsertResult
from pyiceberg.table.upsert_util import create_match_filter
from pyiceberg.types import IntegerType, NestedField, StringType
from tests.catalog.test_base import InMemoryCatalog, Table

Expand Down Expand Up @@ -366,3 +369,22 @@ def test_upsert_with_identifier_fields(catalog: Catalog) -> None:

assert upd.rows_updated == 1
assert upd.rows_inserted == 1


def test_create_match_filter_single_condition() -> None:
"""
Test create_match_filter with a composite key where the source yields exactly one unique key.
Expected: The function returns the single And condition directly.
"""

data = [
{"order_id": 101, "order_line_id": 1, "extra": "x"},
{"order_id": 101, "order_line_id": 1, "extra": "x"}, # duplicate
]
schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32()), pa.field("extra", pa.string())])
table = pa.Table.from_pylist(data, schema=schema)
expr = create_match_filter(table, ["order_id", "order_line_id"])
assert expr == And(
EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)),
EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)),
)