Skip to content

Commit e57b252

Browse files
committed
Use a balanced tree instead of unbalanced one to prevent recursion error in create_match_filter
1 parent c06e320 commit e57b252

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

pyiceberg/table/upsert_util.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
import functools
1818
import operator
19+
from typing import Callable, List, TypeVar
1920

2021
import pyarrow as pa
2122
from pyarrow import Table as pyarrow_table
@@ -28,6 +29,43 @@
2829
In,
2930
)
3031

32+
T = TypeVar("T")
33+
34+
35+
def build_balanced_tree(operator_: Callable[[T, T], T], items: List[T]) -> T:
36+
"""
37+
Recursively constructs a balanced binary tree of expressions using the provided binary operator.
38+
39+
This function is a safer and more scalable alternative to:
40+
reduce(operator_, items)
41+
42+
Using reduce creates a deeply nested, unbalanced tree (e.g., operator_(a, operator_(b, operator_(c, ...)))),
43+
which grows linearly with the number of items. This can lead to RecursionError exceptions in Python
44+
when the number of expressions is large (e.g., >1000).
45+
46+
In contrast, this function builds a balanced binary tree with logarithmic depth (O(log n)),
47+
helping avoid recursion issues and ensuring that expression trees remain stable, predictable,
48+
and safe to traverse — especially in tools like PyIceberg that operate on large logical trees.
49+
50+
Parameters:
51+
operator_ (Callable[[T, T], T]): A binary operator function (e.g., pyiceberg.expressions.Or, And).
52+
items (List[T]): A list of expression objects to combine.
53+
54+
Returns:
55+
T: An expression object representing the balanced combination of all input expressions.
56+
57+
Raises:
58+
ValueError: If the input list is empty.
59+
"""
60+
if not items:
61+
raise ValueError("No expressions to combine")
62+
if len(items) == 1:
63+
return items[0]
64+
mid = len(items) // 2
65+
left = build_balanced_tree(operator_, items[:mid])
66+
right = build_balanced_tree(operator_, items[mid:])
67+
return operator_(left, right)
68+
3169

3270
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
3371
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
@@ -39,7 +77,7 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre
3977
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
4078
]
4179

42-
return AlwaysFalse() if len(filters) == 0 else functools.reduce(operator.or_, filters)
80+
return AlwaysFalse() if len(filters) == 0 else build_balanced_tree(operator.or_, filters)
4381

4482

4583
def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:

0 commit comments

Comments
 (0)