Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from functools import cached_property, reduce
from functools import cached_property
from typing import (
Any,
Callable,
Generic,
Iterable,
Sequence,
Set,
Tuple,
Type,
Expand Down Expand Up @@ -79,6 +81,45 @@ def __or__(self, other: BooleanExpression) -> BooleanExpression:
return Or(self, other)


def _build_balanced_tree(
operator_: Callable[[BooleanExpression, BooleanExpression], BooleanExpression], items: Sequence[BooleanExpression]
) -> BooleanExpression:
"""
Recursively constructs a balanced binary tree of BooleanExpressions using the provided binary operator.

This function is a safer and more scalable alternative to:
reduce(operator_, items)

Using `reduce` creates a deeply nested, unbalanced tree (e.g., operator_(a, operator_(b, operator_(c, ...)))),
which grows linearly with the number of items. This can lead to RecursionError exceptions in Python
when the number of expressions is large (e.g., >1000).

In contrast, this function builds a balanced binary tree with logarithmic depth (O(log n)),
helping avoid recursion issues and ensuring that expression trees remain stable, predictable,
and safe to traverse — especially in tools like PyIceberg that operate on large logical trees.

Parameters:
operator_ (Callable): A binary operator function (e.g., pyiceberg.expressions.Or, And) that takes two
BooleanExpressions and returns a combined BooleanExpression.
items (Sequence[BooleanExpression]): A sequence of BooleanExpression objects to combine.

Returns:
BooleanExpression: The balanced combination of all input BooleanExpressions.

Raises:
ValueError: If the input sequence is empty.
"""
if not items:
raise ValueError("No expressions to combine")
if len(items) == 1:
return items[0]
mid = len(items) // 2

left = _build_balanced_tree(operator_, items[:mid])
right = _build_balanced_tree(operator_, items[mid:])
return operator_(left, right)


class Term(Generic[L], ABC):
"""A simple expression that evaluates to a value."""

Expand Down Expand Up @@ -214,7 +255,7 @@ class And(BooleanExpression):

def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> BooleanExpression: # type: ignore
if rest:
return reduce(And, (left, right, *rest))
return _build_balanced_tree(And, (left, right, *rest))
if left is AlwaysFalse() or right is AlwaysFalse():
return AlwaysFalse()
elif left is AlwaysTrue():
Expand Down Expand Up @@ -257,7 +298,7 @@ class Or(BooleanExpression):

def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> BooleanExpression: # type: ignore
if rest:
return reduce(Or, (left, right, *rest))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also want to apply this in the And situation :)

return _build_balanced_tree(Or, (left, right, *rest))
if left is AlwaysTrue() or right is AlwaysTrue():
return AlwaysTrue()
elif left is AlwaysFalse():
Expand Down
8 changes: 7 additions & 1 deletion pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
BooleanExpression,
EqualTo,
In,
Or,
)


Expand All @@ -39,7 +40,12 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
]

return AlwaysFalse() if len(filters) == 0 else functools.reduce(operator.or_, filters)
if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,11 @@ def test_negate(lhs: BooleanExpression, rhs: BooleanExpression) -> None:
[
(
And(ExpressionA(), ExpressionB(), ExpressionA()),
And(And(ExpressionA(), ExpressionB()), ExpressionA()),
And(ExpressionA(), And(ExpressionB(), ExpressionA())),
),
(
Or(ExpressionA(), ExpressionB(), ExpressionA()),
Or(Or(ExpressionA(), ExpressionB()), ExpressionA()),
Or(ExpressionA(), Or(ExpressionB(), ExpressionA())),
),
(Not(Not(ExpressionA())), ExpressionA()),
],
Expand Down
48 changes: 24 additions & 24 deletions tests/expressions/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,14 @@ def test_boolean_expression_visitor() -> None:
"NOT",
"OR",
"EQUALTO",
"OR",
"NOTEQUALTO",
"OR",
"OR",
"EQUALTO",
"NOT",
"AND",
"NOTEQUALTO",
"AND",
"AND",
]


Expand Down Expand Up @@ -335,28 +335,28 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch
),
),
And(
And(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("bar"), literal("baz")},
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("bar"), literal("baz")},
),
And(
BoundEqualTo[int](
BoundReference(
field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
accessor=Accessor(position=1, inner=None),
),
literal(1),
),
),
BoundEqualTo(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
BoundEqualTo(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
literal("baz"),
),
literal("baz"),
),
),
),
Expand Down Expand Up @@ -408,28 +408,28 @@ def test_and_expression_binding(
),
),
Or(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("bar"), literal("baz")},
),
Or(
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("bar"), literal("baz")},
{literal("bar")},
),
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
{literal("bar")},
),
),
BoundIn(
BoundReference(
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
{literal("baz")},
),
{literal("baz")},
),
),
),
Expand Down
2 changes: 1 addition & 1 deletion tests/io/test_pyarrow_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,5 +836,5 @@ def test_expression_to_complementary_pyarrow(
# Notice an isNan predicate on a str column is automatically converted to always false and removed from Or and thus will not appear in the pc.expr.
assert (
repr(result)
== """<pyarrow.compute.Expression (((invert((((((string_field == "hello") and (float_field > 100)) or (is_nan(float_field) and (double_field == 0))) or (float_field > 100)) and invert(is_null(double_field, {nan_is_null=false})))) or is_null(float_field, {nan_is_null=false})) or is_null(string_field, {nan_is_null=false})) or is_nan(double_field))>"""
== """<pyarrow.compute.Expression (((invert(((((string_field == "hello") and (float_field > 100)) or ((is_nan(float_field) and (double_field == 0)) or (float_field > 100))) and invert(is_null(double_field, {nan_is_null=false})))) or is_null(float_field, {nan_is_null=false})) or is_null(string_field, {nan_is_null=false})) or is_nan(double_field))>"""
)