Skip to content

Commit 71cb247

Browse files
authored
Support Filters on Top-Level Struct Fields (#1832)
Closes #1778. # Rationale for this change Current, filters that are applied to the top-level struct column do not work. For example, given a table of schema: ``` table { 2: id: optional int 1: data: required string 3: location: struct<5: latitude: optional float, 6: longitude: optional float> } ``` We want to support applying filters to field `location`, such as `location is not null`. Note that filters like `location == {"latitude": ..., "longitude": ...}` wont work right now, but can be equivalently rewritten to `location.latitude == ... and location.longitude == ...`. # Are these changes tested? Yes, tests were added at both the schema level and table reads. # Are there any user-facing changes? Support some basic filters on struct columns at the top-level.
1 parent 9d19ef7 commit 71cb247

File tree

5 files changed

+43
-5
lines changed

5 files changed

+43
-5
lines changed

dev/provision.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@
328328
CREATE TABLE {catalog_name}.default.test_table_empty_list_and_map (
329329
col_list array<int>,
330330
col_map map<int, int>,
331+
col_struct struct<test:int>,
331332
col_list_with_struct array<struct<test:int>>
332333
)
333334
USING iceberg
@@ -340,8 +341,8 @@
340341
spark.sql(
341342
f"""
342343
INSERT INTO {catalog_name}.default.test_table_empty_list_and_map
343-
VALUES (null, null, null),
344-
(array(), map(), array(struct(1)))
344+
VALUES (null, null, null, null),
345+
(array(), map(), struct(1), array(struct(1)))
345346
"""
346347
)
347348

pyiceberg/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,7 @@ class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]):
12401240
... 1: Accessor(position=1, inner=None),
12411241
... 5: Accessor(position=2, inner=Accessor(position=0, inner=None)),
12421242
... 6: Accessor(position=2, inner=Accessor(position=1, inner=None))
1243+
... 3: Accessor(position=2, inner=None),
12431244
... }
12441245
>>> result == expected
12451246
True
@@ -1255,8 +1256,7 @@ def struct(self, struct: StructType, field_results: List[Dict[Position, Accessor
12551256
if field_results[position]:
12561257
for inner_field_id, acc in field_results[position].items():
12571258
result[inner_field_id] = Accessor(position, inner=acc)
1258-
else:
1259-
result[field.field_id] = Accessor(position)
1259+
result[field.field_id] = Accessor(position)
12601260

12611261
return result
12621262

tests/expressions/test_expressions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,23 @@ def test_notnull_bind_required() -> None:
168168
assert NotNull(Reference("a")).bind(schema) == AlwaysTrue()
169169

170170

171+
def test_notnull_bind_top_struct() -> None:
172+
schema = Schema(
173+
NestedField(
174+
3,
175+
"struct_col",
176+
required=False,
177+
field_type=StructType(
178+
NestedField(1, "id", IntegerType(), required=True),
179+
NestedField(2, "cost", DecimalType(38, 18), required=False),
180+
),
181+
),
182+
schema_id=1,
183+
)
184+
bound = BoundNotNull(BoundReference(schema.find_field(3), schema.accessor_for_field(3)))
185+
assert NotNull(Reference("struct_col")).bind(schema) == bound
186+
187+
171188
def test_isnan_inverse() -> None:
172189
assert ~IsNaN(Reference("f")) == NotNaN(Reference("f"))
173190

tests/integration/test_reads.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
LessThan,
4242
NotEqualTo,
4343
NotNaN,
44+
NotNull,
4445
)
4546
from pyiceberg.io import PYARROW_USE_LARGE_TYPES_ON_READ
4647
from pyiceberg.io.pyarrow import (
@@ -667,6 +668,24 @@ def test_filter_case_insensitive(catalog: Catalog) -> None:
667668
assert arrow_table["b"].to_pylist() == ["2"]
668669

669670

671+
@pytest.mark.integration
672+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
673+
def test_filters_on_top_level_struct(catalog: Catalog) -> None:
674+
test_empty_struct = catalog.load_table("default.test_table_empty_list_and_map")
675+
676+
arrow_table = test_empty_struct.scan().to_arrow()
677+
assert None in arrow_table["col_struct"].to_pylist()
678+
679+
arrow_table = test_empty_struct.scan(row_filter=NotNull("col_struct")).to_arrow()
680+
assert arrow_table["col_struct"].to_pylist() == [{"test": 1}]
681+
682+
arrow_table = test_empty_struct.scan(row_filter="col_struct is not null", case_sensitive=False).to_arrow()
683+
assert arrow_table["col_struct"].to_pylist() == [{"test": 1}]
684+
685+
arrow_table = test_empty_struct.scan(row_filter="COL_STRUCT is null", case_sensitive=False).to_arrow()
686+
assert arrow_table["col_struct"].to_pylist() == [None]
687+
688+
670689
@pytest.mark.integration
671690
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
672691
def test_upgrade_table_version(catalog: Catalog) -> None:

tests/test_schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def test_build_position_accessors(table_schema_nested: Schema) -> None:
398398
4: Accessor(position=3, inner=None),
399399
6: Accessor(position=4, inner=None),
400400
11: Accessor(position=5, inner=None),
401+
15: Accessor(position=6, inner=None),
401402
16: Accessor(position=6, inner=Accessor(position=0, inner=None)),
402403
17: Accessor(position=6, inner=Accessor(position=1, inner=None)),
403404
}
@@ -925,7 +926,7 @@ def primitive_fields() -> List[NestedField]:
925926
]
926927

927928

928-
def test_add_top_level_primitives(primitive_fields: NestedField) -> None:
929+
def test_add_top_level_primitives(primitive_fields: List[NestedField]) -> None:
929930
for primitive_field in primitive_fields:
930931
new_schema = Schema(primitive_field)
931932
applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore

0 commit comments

Comments
 (0)