Skip to content

Commit 0e5dc2e

Browse files
Yftach Zurclaude
andcommitted
Fix: Support nested struct field filtering with PyArrow (#953)
Fixes filtering on nested struct fields when using PyArrow for scan operations. ## Problem When filtering on nested struct fields (e.g., `mazeMetadata.run_id == 'value'`), PyArrow would fail with: ``` ArrowInvalid: No match for FieldRef.Name(run_id) in ... ``` The issue occurred because PyArrow requires nested field references as tuples (e.g., `("parent", "child")`) rather than dotted strings (e.g., `"parent.child"`). ## Solution 1. Modified `_ConvertToArrowExpression` to accept an optional `Schema` parameter 2. Added `_get_field_name()` method that converts dotted field paths to tuples for nested struct fields 3. Updated `expression_to_pyarrow()` to accept and pass the schema parameter 4. Updated all call sites to pass the schema when available ## Changes - `pyiceberg/io/pyarrow.py`: - Modified `_ConvertToArrowExpression` class to handle nested field paths - Updated `expression_to_pyarrow()` signature to accept schema - Updated `_expression_to_complementary_pyarrow()` signature - `pyiceberg/table/__init__.py`: - Updated call to `_expression_to_complementary_pyarrow()` to pass schema - Tests: - Added `test_ref_binding_nested_struct_field()` for comprehensive nested field testing - Enhanced `test_nested_fields()` with issue #953 scenarios ## Example ```python # Now works correctly: table.scan(row_filter="mazeMetadata.run_id == 'abc123'").to_polars() ``` The fix converts the field reference from: - ❌ `FieldRef.Name(run_id)` (fails - field not found) - ✅ `FieldRef.Nested(FieldRef.Name(mazeMetadata) FieldRef.Name(run_id))` (works!) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 2a9f2ea commit 0e5dc2e

File tree

4 files changed

+117
-20
lines changed

4 files changed

+117
-20
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -810,51 +810,83 @@ def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar:
810810

811811

812812
class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
813+
"""Convert Iceberg bound expressions to PyArrow expressions.
814+
815+
Args:
816+
schema: Optional Iceberg schema to resolve full field paths for nested fields.
817+
If not provided, only the field name will be used (not dotted path).
818+
"""
819+
820+
_schema: Optional[Schema]
821+
822+
def __init__(self, schema: Optional[Schema] = None):
823+
self._schema = schema
824+
825+
def _get_field_name(self, term: BoundTerm[Any]) -> Union[str, Tuple[str, ...]]:
826+
"""Get the field name or nested field path for a bound term.
827+
828+
For nested struct fields, returns a tuple of field names (e.g., ("mazeMetadata", "run_id")).
829+
For top-level fields, returns just the field name as a string.
830+
831+
PyArrow requires nested field references as tuples, not dotted strings.
832+
"""
833+
if self._schema is not None:
834+
# Use the schema to get the full dotted path for nested fields
835+
full_name = self._schema.find_column_name(term.ref().field.field_id)
836+
if full_name is not None:
837+
# If the field name contains dots, it's a nested field
838+
# Convert "parent.child" to ("parent", "child") for PyArrow
839+
if '.' in full_name:
840+
return tuple(full_name.split('.'))
841+
return full_name
842+
# Fallback to just the field name if schema is not available
843+
return term.ref().field.name
844+
813845
def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
814846
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
815-
return pc.field(term.ref().field.name).isin(pyarrow_literals)
847+
return pc.field(self._get_field_name(term)).isin(pyarrow_literals)
816848

817849
def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
818850
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
819-
return ~pc.field(term.ref().field.name).isin(pyarrow_literals)
851+
return ~pc.field(self._get_field_name(term)).isin(pyarrow_literals)
820852

821853
def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression:
822-
ref = pc.field(term.ref().field.name)
854+
ref = pc.field(self._get_field_name(term))
823855
return pc.is_nan(ref)
824856

825857
def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression:
826-
ref = pc.field(term.ref().field.name)
858+
ref = pc.field(self._get_field_name(term))
827859
return ~pc.is_nan(ref)
828860

829861
def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression:
830-
return pc.field(term.ref().field.name).is_null(nan_is_null=False)
862+
return pc.field(self._get_field_name(term)).is_null(nan_is_null=False)
831863

832864
def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression:
833-
return pc.field(term.ref().field.name).is_valid()
865+
return pc.field(self._get_field_name(term)).is_valid()
834866

835867
def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
836-
return pc.field(term.ref().field.name) == _convert_scalar(literal.value, term.ref().field.field_type)
868+
return pc.field(self._get_field_name(term)) == _convert_scalar(literal.value, term.ref().field.field_type)
837869

838870
def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
839-
return pc.field(term.ref().field.name) != _convert_scalar(literal.value, term.ref().field.field_type)
871+
return pc.field(self._get_field_name(term)) != _convert_scalar(literal.value, term.ref().field.field_type)
840872

841873
def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
842-
return pc.field(term.ref().field.name) >= _convert_scalar(literal.value, term.ref().field.field_type)
874+
return pc.field(self._get_field_name(term)) >= _convert_scalar(literal.value, term.ref().field.field_type)
843875

844876
def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
845-
return pc.field(term.ref().field.name) > _convert_scalar(literal.value, term.ref().field.field_type)
877+
return pc.field(self._get_field_name(term)) > _convert_scalar(literal.value, term.ref().field.field_type)
846878

847879
def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
848-
return pc.field(term.ref().field.name) < _convert_scalar(literal.value, term.ref().field.field_type)
880+
return pc.field(self._get_field_name(term)) < _convert_scalar(literal.value, term.ref().field.field_type)
849881

850882
def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
851-
return pc.field(term.ref().field.name) <= _convert_scalar(literal.value, term.ref().field.field_type)
883+
return pc.field(self._get_field_name(term)) <= _convert_scalar(literal.value, term.ref().field.field_type)
852884

853885
def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
854-
return pc.starts_with(pc.field(term.ref().field.name), literal.value)
886+
return pc.starts_with(pc.field(self._get_field_name(term)), literal.value)
855887

856888
def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
857-
return ~pc.starts_with(pc.field(term.ref().field.name), literal.value)
889+
return ~pc.starts_with(pc.field(self._get_field_name(term)), literal.value)
858890

859891
def visit_true(self) -> pc.Expression:
860892
return pc.scalar(True)
@@ -990,11 +1022,21 @@ def collect(
9901022
boolean_expression_visit(expr, self)
9911023

9921024

993-
def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
994-
return boolean_expression_visit(expr, _ConvertToArrowExpression())
1025+
def expression_to_pyarrow(expr: BooleanExpression, schema: Optional[Schema] = None) -> pc.Expression:
1026+
"""Convert an Iceberg boolean expression to a PyArrow expression.
1027+
1028+
Args:
1029+
expr: The Iceberg boolean expression to convert.
1030+
schema: Optional Iceberg schema to resolve full field paths for nested fields.
1031+
If provided, nested struct fields will use dotted paths (e.g., "parent.child").
1032+
1033+
Returns:
1034+
A PyArrow compute expression.
1035+
"""
1036+
return boolean_expression_visit(expr, _ConvertToArrowExpression(schema))
9951037

9961038

997-
def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression:
1039+
def _expression_to_complementary_pyarrow(expr: BooleanExpression, schema: Optional[Schema] = None) -> pc.Expression:
9981040
"""Complementary filter conversion function of expression_to_pyarrow.
9991041
10001042
Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null.
@@ -1015,7 +1057,7 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expressi
10151057
preserve_expr = Or(preserve_expr, BoundIsNull(term=term))
10161058
for term in nan_unmentioned_bound_terms:
10171059
preserve_expr = Or(preserve_expr, BoundIsNaN(term=term))
1018-
return expression_to_pyarrow(preserve_expr)
1060+
return expression_to_pyarrow(preserve_expr, schema)
10191061

10201062

10211063
@lru_cache
@@ -1550,7 +1592,7 @@ def _task_to_record_batches(
15501592
bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields
15511593
)
15521594
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
1553-
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
1595+
pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema)
15541596

15551597
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
15561598

pyiceberg/table/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def delete(
677677
# Check if there are any files that require an actual rewrite of a data file
678678
if delete_snapshot.rewrites_needed is True:
679679
bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive)
680-
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)
680+
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter, self.table_metadata.schema())
681681

682682
file_scan = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive)
683683
if branch is not None:

tests/expressions/test_expressions.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,58 @@ def test_ref_binding_case_insensitive_failure(table_schema_simple: Schema) -> No
248248
ref.bind(table_schema_simple, case_sensitive=False)
249249

250250

251+
def test_ref_binding_nested_struct_field() -> None:
252+
"""Test binding references to nested struct fields (issue #953)."""
253+
schema = Schema(
254+
NestedField(field_id=1, name="age", field_type=IntegerType(), required=True),
255+
NestedField(
256+
field_id=2,
257+
name="employment",
258+
field_type=StructType(
259+
NestedField(field_id=3, name="status", field_type=StringType(), required=False),
260+
NestedField(field_id=4, name="company", field_type=StringType(), required=False),
261+
),
262+
required=False,
263+
),
264+
NestedField(
265+
field_id=5,
266+
name="contact",
267+
field_type=StructType(
268+
NestedField(field_id=6, name="email", field_type=StringType(), required=False),
269+
),
270+
required=False,
271+
),
272+
schema_id=1,
273+
)
274+
275+
# Test that nested field names are in the index
276+
assert "employment.status" in schema._name_to_id
277+
assert "employment.company" in schema._name_to_id
278+
assert "contact.email" in schema._name_to_id
279+
280+
# Test binding a reference to nested fields
281+
ref = Reference("employment.status")
282+
bound = ref.bind(schema, case_sensitive=True)
283+
assert bound.field.field_id == 3
284+
assert bound.field.name == "status"
285+
286+
# Test with different nested field
287+
ref2 = Reference("contact.email")
288+
bound2 = ref2.bind(schema, case_sensitive=True)
289+
assert bound2.field.field_id == 6
290+
assert bound2.field.name == "email"
291+
292+
# Test case-insensitive binding
293+
ref3 = Reference("EMPLOYMENT.STATUS")
294+
bound3 = ref3.bind(schema, case_sensitive=False)
295+
assert bound3.field.field_id == 3
296+
297+
# Test that binding fails for non-existent nested field
298+
ref4 = Reference("employment.department")
299+
with pytest.raises(ValueError):
300+
ref4.bind(schema, case_sensitive=True)
301+
302+
251303
def test_in_to_eq() -> None:
252304
assert In("x", (34.56,)) == EqualTo("x", 34.56)
253305

tests/expressions/test_parser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ def test_with_function() -> None:
225225
def test_nested_fields() -> None:
226226
assert EqualTo("foo.bar", "data") == parser.parse("foo.bar = 'data'")
227227
assert LessThan("location.x", DecimalLiteral(Decimal(52.00))) == parser.parse("location.x < 52.00")
228+
# Test issue #953 scenario - nested struct field filtering
229+
assert EqualTo("employment.status", "Employed") == parser.parse("employment.status = 'Employed'")
230+
assert EqualTo("contact.email", "test@example.com") == parser.parse("contact.email = 'test@example.com'")
228231

229232

230233
def test_quoted_column_with_dots() -> None:

0 commit comments

Comments
 (0)