diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 8b006a28f1..2adf898fea 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -18,11 +18,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from functools import cached_property, reduce +from functools import cached_property from typing import ( Any, + Callable, Generic, Iterable, + Sequence, Set, Tuple, Type, @@ -79,6 +81,45 @@ def __or__(self, other: BooleanExpression) -> BooleanExpression: return Or(self, other) +def _build_balanced_tree( + operator_: Callable[[BooleanExpression, BooleanExpression], BooleanExpression], items: Sequence[BooleanExpression] +) -> BooleanExpression: + """ + Recursively constructs a balanced binary tree of BooleanExpressions using the provided binary operator. + + This function is a safer and more scalable alternative to: + reduce(operator_, items) + + Using `reduce` creates a deeply nested, unbalanced tree (e.g., operator_(a, operator_(b, operator_(c, ...)))), + which grows linearly with the number of items. This can lead to RecursionError exceptions in Python + when the number of expressions is large (e.g., >1000). + + In contrast, this function builds a balanced binary tree with logarithmic depth (O(log n)), + helping avoid recursion issues and ensuring that expression trees remain stable, predictable, + and safe to traverse — especially in tools like PyIceberg that operate on large logical trees. + + Parameters: + operator_ (Callable): A binary operator function (e.g., pyiceberg.expressions.Or, And) that takes two + BooleanExpressions and returns a combined BooleanExpression. + items (Sequence[BooleanExpression]): A sequence of BooleanExpression objects to combine. + + Returns: + BooleanExpression: The balanced combination of all input BooleanExpressions. + + Raises: + ValueError: If the input sequence is empty. + """ + if not items: + raise ValueError("No expressions to combine") + if len(items) == 1: + return items[0] + mid = len(items) // 2 + + left = _build_balanced_tree(operator_, items[:mid]) + right = _build_balanced_tree(operator_, items[mid:]) + return operator_(left, right) + + class Term(Generic[L], ABC): """A simple expression that evaluates to a value.""" @@ -214,7 +255,7 @@ class And(BooleanExpression): def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> BooleanExpression: # type: ignore if rest: - return reduce(And, (left, right, *rest)) + return _build_balanced_tree(And, (left, right, *rest)) if left is AlwaysFalse() or right is AlwaysFalse(): return AlwaysFalse() elif left is AlwaysTrue(): @@ -257,7 +298,7 @@ class Or(BooleanExpression): def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> BooleanExpression: # type: ignore if rest: - return reduce(Or, (left, right, *rest)) + return _build_balanced_tree(Or, (left, right, *rest)) if left is AlwaysTrue() or right is AlwaysTrue(): return AlwaysTrue() elif left is AlwaysFalse(): diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index d2bd48bc99..e67f6c0232 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -26,6 +26,7 @@ BooleanExpression, EqualTo, In, + Or, ) @@ -39,7 +40,12 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist() ] - return AlwaysFalse() if len(filters) == 0 else functools.reduce(operator.or_, filters) + if len(filters) == 0: + return AlwaysFalse() + elif len(filters) == 1: + return filters[0] + else: + return Or(*filters) def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 12d9ff95a9..858a9ff852 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -591,11 +591,11 @@ def test_negate(lhs: BooleanExpression, rhs: BooleanExpression) -> None: [ ( And(ExpressionA(), ExpressionB(), ExpressionA()), - And(And(ExpressionA(), ExpressionB()), ExpressionA()), + And(ExpressionA(), And(ExpressionB(), ExpressionA())), ), ( Or(ExpressionA(), ExpressionB(), ExpressionA()), - Or(Or(ExpressionA(), ExpressionB()), ExpressionA()), + Or(ExpressionA(), Or(ExpressionB(), ExpressionA())), ), (Not(Not(ExpressionA())), ExpressionA()), ], diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index 94bfcf076c..586ba9f5d4 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -230,14 +230,14 @@ def test_boolean_expression_visitor() -> None: "NOT", "OR", "EQUALTO", - "OR", "NOTEQUALTO", "OR", + "OR", "EQUALTO", "NOT", - "AND", "NOTEQUALTO", "AND", + "AND", ] @@ -335,14 +335,14 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch ), ), And( - And( - BoundIn( - BoundReference( - field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - accessor=Accessor(position=0, inner=None), - ), - {literal("bar"), literal("baz")}, + BoundIn( + BoundReference( + field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + accessor=Accessor(position=0, inner=None), ), + {literal("bar"), literal("baz")}, + ), + And( BoundEqualTo[int]( BoundReference( field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -350,13 +350,13 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch ), literal(1), ), - ), - BoundEqualTo( - BoundReference( - field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - accessor=Accessor(position=0, inner=None), + BoundEqualTo( + BoundReference( + field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + accessor=Accessor(position=0, inner=None), + ), + literal("baz"), ), - literal("baz"), ), ), ), @@ -408,28 +408,28 @@ def test_and_expression_binding( ), ), Or( + BoundIn( + BoundReference( + field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + accessor=Accessor(position=0, inner=None), + ), + {literal("bar"), literal("baz")}, + ), Or( BoundIn( BoundReference( field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), accessor=Accessor(position=0, inner=None), ), - {literal("bar"), literal("baz")}, + {literal("bar")}, ), BoundIn( BoundReference( field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), accessor=Accessor(position=0, inner=None), ), - {literal("bar")}, - ), - ), - BoundIn( - BoundReference( - field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - accessor=Accessor(position=0, inner=None), + {literal("baz")}, ), - {literal("baz")}, ), ), ), diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 9f5aff3f70..9d5772d01c 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -836,5 +836,5 @@ def test_expression_to_complementary_pyarrow( # Notice an isNan predicate on a str column is automatically converted to always false and removed from Or and thus will not appear in the pc.expr. assert ( repr(result) - == """ 100)) or (is_nan(float_field) and (double_field == 0))) or (float_field > 100)) and invert(is_null(double_field, {nan_is_null=false})))) or is_null(float_field, {nan_is_null=false})) or is_null(string_field, {nan_is_null=false})) or is_nan(double_field))>""" + == """ 100)) or ((is_nan(float_field) and (double_field == 0)) or (float_field > 100))) and invert(is_null(double_field, {nan_is_null=false})))) or is_null(float_field, {nan_is_null=false})) or is_null(string_field, {nan_is_null=false})) or is_nan(double_field))>""" )