Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,36 @@ def primitive(self, primitive: PrimitiveType) -> T:


class PreOrderSchemaVisitor(Generic[T], ABC):
def before_field(self, field: NestedField) -> None:
"""Override this method to perform an action immediately before visiting a field."""

def after_field(self, field: NestedField) -> None:
"""Override this method to perform an action immediately after visiting a field."""

def before_list_element(self, element: NestedField) -> None:
"""Override this method to perform an action immediately before visiting an element within a ListType."""
self.before_field(element)

def after_list_element(self, element: NestedField) -> None:
"""Override this method to perform an action immediately after visiting an element within a ListType."""
self.after_field(element)

def before_map_key(self, key: NestedField) -> None:
"""Override this method to perform an action immediately before visiting a key within a MapType."""
self.before_field(key)

def after_map_key(self, key: NestedField) -> None:
"""Override this method to perform an action immediately after visiting a key within a MapType."""
self.after_field(key)

def before_map_value(self, value: NestedField) -> None:
"""Override this method to perform an action immediately before visiting a value within a MapType."""
self.before_field(value)

def after_map_value(self, value: NestedField) -> None:
"""Override this method to perform an action immediately after visiting a value within a MapType."""
self.after_field(value)

@abstractmethod
def schema(self, schema: Schema, struct_result: Callable[[], T]) -> T:
"""Visit a Schema."""
Expand Down Expand Up @@ -851,9 +881,7 @@ def _(obj: PrimitiveType, visitor: SchemaVisitor[T]) -> T:
def pre_order_visit(obj: Union[Schema, IcebergType], visitor: PreOrderSchemaVisitor[T]) -> T:
"""Apply a schema visitor to any point within a schema.

The function traverses the schema in pre-order fashion. This is a slimmed down version
compared to the post-order traversal (missing before and after methods), mostly
because we don't use the pre-order traversal much.
The function traverses the schema in pre-order fashion.

Args:
obj (Union[Schema, IcebergType]): An instance of a Schema or an IcebergType.
Expand All @@ -874,28 +902,39 @@ def _(obj: Schema, visitor: PreOrderSchemaVisitor[T]) -> T:
@pre_order_visit.register(StructType)
def _(obj: StructType, visitor: PreOrderSchemaVisitor[T]) -> T:
"""Visit a StructType with a concrete PreOrderSchemaVisitor."""
return visitor.struct(
obj,
[
partial(
lambda field: visitor.field(field, partial(lambda field: pre_order_visit(field.field_type, visitor), field)),
field,
)
for field in obj.fields
],
)
results = []

for field in obj.fields:
visitor.before_field(field)
result = pre_order_visit(field.field_type, visitor)
visitor.after_field(field)
results.append(visitor.field(field, result))

return visitor.struct(obj, results)


@pre_order_visit.register(ListType)
def _(obj: ListType, visitor: PreOrderSchemaVisitor[T]) -> T:
"""Visit a ListType with a concrete PreOrderSchemaVisitor."""
return visitor.list(obj, lambda: pre_order_visit(obj.element_type, visitor))
visitor.before_list_element(obj.element_field)
result = pre_order_visit(obj.element_type, visitor)
visitor.after_list_element(obj.element_field)

return visitor.list(obj, result)


@pre_order_visit.register(MapType)
def _(obj: MapType, visitor: PreOrderSchemaVisitor[T]) -> T:
"""Visit a MapType with a concrete PreOrderSchemaVisitor."""
return visitor.map(obj, lambda: pre_order_visit(obj.key_type, visitor), lambda: pre_order_visit(obj.value_type, visitor))
visitor.before_map_key(obj.key_field)
key_result = pre_order_visit(obj.key_type, visitor)
visitor.after_map_key(obj.key_field)

visitor.before_map_value(obj.value_field)
value_result = pre_order_visit(obj.value_type, visitor)
visitor.after_map_value(obj.value_field)

return visitor.map(obj, key_result, value_result)


@pre_order_visit.register(PrimitiveType)
Expand Down