diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 360e3c43cc..1d7db412aa 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2728,9 +2728,11 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T for partition, name in zip(spec.fields, partition_fields): source_field = schema.find_field(partition.source_id) - arrow_table = arrow_table.append_column( - name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name]) - ) + full_field_name = schema.find_column_name(partition.source_id) + if full_field_name is None: + raise ValueError(f"Could not find column name for field ID: {partition.source_id}") + field_array = _get_field_from_arrow_table(arrow_table, full_field_name) + arrow_table = arrow_table.append_column(name, partition.transform.pyarrow_transform(source_field.field_type)(field_array)) unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([]) @@ -2765,3 +2767,32 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T ) return table_partitions + + +def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> pa.Array: + """Get a field from an Arrow table, supporting both literal field names and nested field paths. + + This function handles two cases: + 1. Literal field names that may contain dots (e.g., "some.id") + 2. Nested field paths using dot notation (e.g., "bar.baz" for nested access) + + Args: + arrow_table: The Arrow table containing the field + field_path: Field name or dot-separated path + + Returns: + The field as a PyArrow Array + + Raises: + KeyError: If the field path cannot be resolved + """ + # Try exact column name match (handles field names containing literal dots) + if field_path in arrow_table.column_names: + return arrow_table[field_path] + + # If not found as exact name, treat as nested field path + path_parts = field_path.split(".") + # Get the struct column from the table (e.g., "bar" from "bar.baz") + field_array = arrow_table[path_parts[0]] + # Navigate into the struct using the remaining path parts + return pc.struct_field(field_array, path_parts[1:]) diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index db4f04dedf..4f121ba3bc 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -84,7 +84,7 @@ from pyiceberg.table import FileScanTask, TableProperties from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.table.name_mapping import create_mapping_from_schema -from pyiceberg.transforms import IdentityTransform +from pyiceberg.transforms import HourTransform, IdentityTransform from pyiceberg.typedef import UTF8, Properties, Record from pyiceberg.types import ( BinaryType, @@ -2350,6 +2350,102 @@ def test_partition_for_demo() -> None: ) +def test_partition_for_nested_field() -> None: + schema = Schema( + NestedField(id=1, name="foo", field_type=StringType(), required=True), + NestedField( + id=2, + name="bar", + field_type=StructType( + NestedField(id=3, name="baz", field_type=TimestampType(), required=False), + NestedField(id=4, name="qux", field_type=IntegerType(), required=False), + ), + required=True, + ), + ) + + spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=HourTransform(), name="ts")) + + from datetime import datetime + + t1 = datetime(2025, 7, 11, 9, 30, 0) + t2 = datetime(2025, 7, 11, 10, 30, 0) + + test_data = [ + {"foo": "a", "bar": {"baz": t1, "qux": 1}}, + {"foo": "b", "bar": {"baz": t2, "qux": 2}}, + ] + + arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow()) + partitions = _determine_partitions(spec, schema, arrow_table) + partition_values = {p.partition_key.partition[0] for p in partitions} + + assert partition_values == {486729, 486730} + + +def test_partition_for_deep_nested_field() -> None: + schema = Schema( + NestedField( + id=1, + name="foo", + field_type=StructType( + NestedField( + id=2, + name="bar", + field_type=StructType(NestedField(id=3, name="baz", field_type=StringType(), required=False)), + required=True, + ) + ), + required=True, + ) + ) + + spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="qux")) + + test_data = [ + {"foo": {"bar": {"baz": "data-1"}}}, + {"foo": {"bar": {"baz": "data-2"}}}, + {"foo": {"bar": {"baz": "data-1"}}}, + ] + + arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow()) + partitions = _determine_partitions(spec, schema, arrow_table) + + assert len(partitions) == 2 # 2 unique partitions + partition_values = {p.partition_key.partition[0] for p in partitions} + assert partition_values == {"data-1", "data-2"} + + +def test_inspect_partition_for_nested_field(catalog: InMemoryCatalog) -> None: + schema = Schema( + NestedField(id=1, name="foo", field_type=StringType(), required=True), + NestedField( + id=2, + name="bar", + field_type=StructType( + NestedField(id=3, name="baz", field_type=StringType(), required=False), + NestedField(id=4, name="qux", field_type=IntegerType(), required=False), + ), + required=True, + ), + ) + spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="part")) + catalog.create_namespace("default") + table = catalog.create_table("default.test_partition_in_struct", schema=schema, partition_spec=spec) + test_data = [ + {"foo": "a", "bar": {"baz": "data-a", "qux": 1}}, + {"foo": "b", "bar": {"baz": "data-b", "qux": 2}}, + ] + + arrow_table = pa.Table.from_pylist(test_data, schema=table.schema().as_arrow()) + table.append(arrow_table) + partitions_table = table.inspect.partitions() + partitions = partitions_table["partition"].to_pylist() + + assert len(partitions) == 2 + assert {part["part"] for part in partitions} == {"data-a", "data-b"} + + def test_identity_partition_on_multi_columns() -> None: test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) test_schema = Schema(