diff --git a/pyiceberg/expressions/iterative_visitors.py b/pyiceberg/expressions/iterative_visitors.py new file mode 100644 index 0000000000..bb11b635ec --- /dev/null +++ b/pyiceberg/expressions/iterative_visitors.py @@ -0,0 +1,103 @@ +from functools import singledispatch +from typing import ( + List, + Tuple, + TypeVar, +) + +from pyiceberg.expressions import ( + AlwaysFalse, + AlwaysTrue, + And, + BooleanExpression, + BoundPredicate, + Not, + Or, + UnboundPredicate, +) +from pyiceberg.expressions.visitors import BindVisitor, BooleanExpressionVisitor +from pyiceberg.schema import Schema +from pyiceberg.typedef import L + +T = TypeVar("T") + + +@singledispatch +def _visit_stack(obj: BooleanExpression, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None: + raise NotImplementedError(f"Cannot visit unsupported expression: {obj}") + + +@_visit_stack.register(AlwaysTrue) +def _(_: AlwaysTrue, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None: + stack.append(visitor.visit_true()) + + +@_visit_stack.register(AlwaysFalse) +def _(_: AlwaysFalse, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None: + stack.append(visitor.visit_false()) + + +@_visit_stack.register(Not) +def _(_: Not, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None: + child_result = stack.pop() + stack.append(visitor.visit_not(child_result)) + + +@_visit_stack.register(And) +def _(_: And, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None: + right_result = stack.pop() + left_result = stack.pop() + stack.append(visitor.visit_and(left_result, right_result)) + + +@_visit_stack.register(UnboundPredicate) +def _(obj: UnboundPredicate[L], stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None: + stack.append(visitor.visit_unbound_predicate(predicate=obj)) + + +@_visit_stack.register(BoundPredicate) +def _(obj: BoundPredicate[L], stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None: + stack.append(visitor.visit_bound_predicate(predicate=obj)) + + +@_visit_stack.register(Or) +def _(_: Or, stack: List[T], visitor: BooleanExpressionVisitor[T]) -> None: + right_result = stack.pop() + left_result = stack.pop() + stack.append(visitor.visit_or(left_result, right_result)) + + +def visit_iterative(expression: BooleanExpression, visitor: BooleanExpressionVisitor[T]) -> T: + # Store (node, visited) pairs in the stack of expressions to process + stack: List[Tuple[BooleanExpression, bool]] = [(expression, False)] + # Store the results of the visit in another stack + results_stack: List[T] = [] + + while stack: + node, visited = stack.pop() + if not visited: + stack.append((node, True)) + # TODO: Make this nicer. + if isinstance(node, Not): + stack.append((node.child, False)) + elif isinstance(node, And) or isinstance(node, Or): + stack.append((node.right, False)) + stack.append((node.left, False)) + else: + _visit_stack(node, results_stack, visitor) + + return results_stack.pop() + + +def bind_iterative(schema: Schema, expression: BooleanExpression, case_sensitive: bool) -> BooleanExpression: + """Traverse iteratively over an expression to bind the predicates to the schema. + + Args: + schema (Schema): A schema to use when binding the expression. + expression (BooleanExpression): An expression containing UnboundPredicates that can be bound. + case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True. + + Raises: + TypeError: In the case a predicate is already bound. + """ + return visit_iterative(expression, BindVisitor(schema, case_sensitive)) diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index abac19bc19..912a9bb1de 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -201,7 +201,7 @@ def _(obj: Or, visitor: BooleanExpressionVisitor[T]) -> T: def bind(schema: Schema, expression: BooleanExpression, case_sensitive: bool) -> BooleanExpression: - """Travers over an expression to bind the predicates to the schema. + """Traverse over an expression to bind the predicates to the schema. Args: schema (Schema): A schema to use when binding the expression. diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 7c8aaaab1b..2b35fcfdfa 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -76,10 +76,10 @@ from pyiceberg.conversions import to_bytes from pyiceberg.exceptions import ResolveError from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNaN, BoundIsNull, BoundTerm, Not, Or +from pyiceberg.expressions.iterative_visitors import bind_iterative from pyiceberg.expressions.literals import Literal from pyiceberg.expressions.visitors import ( BoundBooleanExpressionVisitor, - bind, extract_field_ids, translate_column_names, ) @@ -1376,7 +1376,7 @@ def _task_to_record_batches( pyarrow_filter = None if bound_row_filter is not AlwaysTrue(): translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive) - bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) + bound_file_filter = bind_iterative(file_schema, translated_row_filter, case_sensitive=case_sensitive) pyarrow_filter = expression_to_pyarrow(bound_file_filter) # Apply column projection rules @@ -1512,7 +1512,7 @@ def __init__( self._table_metadata = table_metadata self._io = io self._projected_schema = projected_schema - self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive) + self._bound_row_filter = bind_iterative(table_metadata.schema(), row_filter, case_sensitive=case_sensitive) self._case_sensitive = case_sensitive self._limit = limit diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index 94bfcf076c..795b451471 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -62,6 +62,9 @@ StartsWith, UnboundPredicate, ) + +# N.B. Just to test visit_iterative: +from pyiceberg.expressions.iterative_visitors import visit_iterative as visit from pyiceberg.expressions.literals import Literal, literal from pyiceberg.expressions.visitors import ( BindVisitor, @@ -72,7 +75,6 @@ expression_to_plain_format, rewrite_not, rewrite_to_dnf, - visit, visit_bound_predicate, ) from pyiceberg.manifest import ManifestFile, PartitionFieldSummary diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 7f9e13b5a1..e79de81310 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -408,6 +408,31 @@ def test_upsert_into_empty_table(catalog: Catalog) -> None: assert upd.rows_inserted == 4 +def test_large_upsert_into_empty_table(catalog: Catalog) -> None: + identifier = "default.test_upsert_large_table" + _drop_table(catalog, identifier) + + num_columns = 50 + num_rows = 10000 + + schema = Schema( + *[NestedField(i, f"field_{i}", StringType(), required=True) for i in range(1, num_columns + 1)], + identifier_field_ids=[1, 2], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema([pa.field(f"field_{i}", pa.string(), nullable=False) for i in range(1, num_columns + 1)]) + + data = [{f"field_{i}": f"value_{i}_{j}" for i in range(1, num_columns + 1)} for j in range(num_rows)] + + df = pa.Table.from_pylist(data, schema=arrow_schema) + upd = tbl.upsert(df) + + assert upd.rows_updated == 0 + assert upd.rows_inserted == num_rows + + def test_create_match_filter_single_condition() -> None: """ Test create_match_filter with a composite key where the source yields exactly one unique key.