Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions pyiceberg/expressions/iterative_visitors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from functools import singledispatch
from typing import (
List,
Tuple,
TypeVar,
)

from pyiceberg.expressions import (
AlwaysFalse,
AlwaysTrue,
And,
BooleanExpression,
BoundPredicate,
Not,
Or,
UnboundPredicate,
)
from pyiceberg.expressions.visitors import BindVisitor, BooleanExpressionVisitor
from pyiceberg.schema import Schema
from pyiceberg.typedef import L

T = TypeVar("T")


@singledispatch
def _visit_stack(obj: BooleanExpression, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None:
raise NotImplementedError(f"Cannot visit unsupported expression: {obj}")


@_visit_stack.register(AlwaysTrue)
def _(_: AlwaysTrue, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None:
stack.append(visitor.visit_true())


@_visit_stack.register(AlwaysFalse)
def _(_: AlwaysFalse, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None:
stack.append(visitor.visit_false())


@_visit_stack.register(Not)
def _(_: Not, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None:
child_result = stack.pop()
stack.append(visitor.visit_not(child_result))


@_visit_stack.register(And)
def _(_: And, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None:
right_result = stack.pop()
left_result = stack.pop()
stack.append(visitor.visit_and(left_result, right_result))


@_visit_stack.register(UnboundPredicate)
def _(obj: UnboundPredicate[L], stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None:
stack.append(visitor.visit_unbound_predicate(predicate=obj))


@_visit_stack.register(BoundPredicate)
def _(obj: BoundPredicate[L], stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None:
stack.append(visitor.visit_bound_predicate(predicate=obj))


@_visit_stack.register(Or)
def _(_: Or, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None:
right_result = stack.pop()
left_result = stack.pop()
stack.append(visitor.visit_or(left_result, right_result))


def visit_iterative(expression: BooleanExpression, visitor: BooleanExpressionVisitor[T]) -> T:
# Store (node, visited) pairs in the stack of expressions to process
stack: List[Tuple[BooleanExpression, bool]] = [(expression, False)]
# Store the results of the visit in another stack
results_stack: List[T] = []

while stack:
node, visited = stack.pop()
if not visited:
stack.append((node, True))
# TODO: Make this nicer.
if isinstance(node, Not):
stack.append((node.child, False))
elif isinstance(node, And) or isinstance(node, Or):
stack.append((node.right, False))
stack.append((node.left, False))
else:
_visit_stack(node, results_stack, visitor)

return results_stack.pop()


def bind_iterative(schema: Schema, expression: BooleanExpression, case_sensitive: bool) -> BooleanExpression:
"""Traverse iteratively over an expression to bind the predicates to the schema.

Args:
schema (Schema): A schema to use when binding the expression.
expression (BooleanExpression): An expression containing UnboundPredicates that can be bound.
case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True.

Raises:
TypeError: In the case a predicate is already bound.
"""
return visit_iterative(expression, BindVisitor(schema, case_sensitive))
2 changes: 1 addition & 1 deletion pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _(obj: Or, visitor: BooleanExpressionVisitor[T]) -> T:


def bind(schema: Schema, expression: BooleanExpression, case_sensitive: bool) -> BooleanExpression:
"""Travers over an expression to bind the predicates to the schema.
"""Traverse over an expression to bind the predicates to the schema.

Args:
schema (Schema): A schema to use when binding the expression.
Expand Down
6 changes: 3 additions & 3 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@
from pyiceberg.conversions import to_bytes
from pyiceberg.exceptions import ResolveError
from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNaN, BoundIsNull, BoundTerm, Not, Or
from pyiceberg.expressions.iterative_visitors import bind_iterative
from pyiceberg.expressions.literals import Literal
from pyiceberg.expressions.visitors import (
BoundBooleanExpressionVisitor,
bind,
extract_field_ids,
translate_column_names,
)
Expand Down Expand Up @@ -1376,7 +1376,7 @@ def _task_to_record_batches(
pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
bound_file_filter = bind_iterative(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

# Apply column projection rules
Expand Down Expand Up @@ -1512,7 +1512,7 @@ def __init__(
self._table_metadata = table_metadata
self._io = io
self._projected_schema = projected_schema
self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
self._bound_row_filter = bind_iterative(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
self._case_sensitive = case_sensitive
self._limit = limit

Expand Down
4 changes: 3 additions & 1 deletion tests/expressions/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@
StartsWith,
UnboundPredicate,
)

# N.B. Just to test visit_iterative:
from pyiceberg.expressions.iterative_visitors import visit_iterative as visit
from pyiceberg.expressions.literals import Literal, literal
from pyiceberg.expressions.visitors import (
BindVisitor,
Expand All @@ -72,7 +75,6 @@
expression_to_plain_format,
rewrite_not,
rewrite_to_dnf,
visit,
visit_bound_predicate,
)
from pyiceberg.manifest import ManifestFile, PartitionFieldSummary
Expand Down
25 changes: 25 additions & 0 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,31 @@ def test_upsert_into_empty_table(catalog: Catalog) -> None:
assert upd.rows_inserted == 4


def test_large_upsert_into_empty_table(catalog: Catalog) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now-passing test fails on main with

    def __getitem__(self, key):
>       return self.data[ref(key)]
E       RecursionError: maximum recursion depth exceeded in comparison

/<PATH>/weakref.py:416: RecursionError
!!! Recursion error detected, but an error occurred locating the origin of recursion.
  The following exception happened when comparing locals in the stack frame:
    RecursionError: maximum recursion depth exceeded
  Displaying first and last 10 stack frames out of 962.

identifier = "default.test_upsert_large_table"
_drop_table(catalog, identifier)

num_columns = 50
num_rows = 10000
Comment on lines +415 to +416
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, 20 and 1000 is enough to make main fail this test with (default) recursion depth exceeded


schema = Schema(
*[NestedField(i, f"field_{i}", StringType(), required=True) for i in range(1, num_columns + 1)],
identifier_field_ids=[1, 2],
)

tbl = catalog.create_table(identifier, schema=schema)

arrow_schema = pa.schema([pa.field(f"field_{i}", pa.string(), nullable=False) for i in range(1, num_columns + 1)])

data = [{f"field_{i}": f"value_{i}_{j}" for i in range(1, num_columns + 1)} for j in range(num_rows)]

df = pa.Table.from_pylist(data, schema=arrow_schema)
upd = tbl.upsert(df)

assert upd.rows_updated == 0
assert upd.rows_inserted == num_rows


def test_create_match_filter_single_condition() -> None:
"""
Test create_match_filter with a composite key where the source yields exactly one unique key.
Expand Down
Loading