Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2728,9 +2728,11 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T

for partition, name in zip(spec.fields, partition_fields):
source_field = schema.find_field(partition.source_id)
arrow_table = arrow_table.append_column(
name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name])
)
full_field_name = schema.find_column_name(partition.source_id)
if full_field_name is None:
raise ValueError(f"Could not find column name for field ID: {partition.source_id}")
field_array = _get_field_from_arrow_table(arrow_table, full_field_name)
arrow_table = arrow_table.append_column(name, partition.transform.pyarrow_transform(source_field.field_type)(field_array))

unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])

Expand Down Expand Up @@ -2765,3 +2767,32 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
)

return table_partitions


def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> pa.Array:
"""Get a field from an Arrow table, supporting both literal field names and nested field paths.

This function handles two cases:
1. Literal field names that may contain dots (e.g., "some.id")
2. Nested field paths using dot notation (e.g., "bar.baz" for nested access)

Args:
arrow_table: The Arrow table containing the field
field_path: Field name or dot-separated path

Returns:
The field as a PyArrow Array

Raises:
KeyError: If the field path cannot be resolved
"""
# Try exact column name match (handles field names containing literal dots)
if field_path in arrow_table.column_names:
return arrow_table[field_path]

# If not found as exact name, treat as nested field path
path_parts = field_path.split(".")
# Get the struct column from the table (e.g., "bar" from "bar.baz")
field_array = arrow_table[path_parts[0]]
# Navigate into the struct using the remaining path parts
return pc.struct_field(field_array, path_parts[1:])
98 changes: 97 additions & 1 deletion tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import IdentityTransform
from pyiceberg.transforms import HourTransform, IdentityTransform
from pyiceberg.typedef import UTF8, Properties, Record
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -2350,6 +2350,102 @@ def test_partition_for_demo() -> None:
)


def test_partition_for_nested_field() -> None:
schema = Schema(
NestedField(id=1, name="foo", field_type=StringType(), required=True),
NestedField(
id=2,
name="bar",
field_type=StructType(
NestedField(id=3, name="baz", field_type=TimestampType(), required=False),
NestedField(id=4, name="qux", field_type=IntegerType(), required=False),
),
required=True,
),
)

spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=HourTransform(), name="ts"))

from datetime import datetime

t1 = datetime(2025, 7, 11, 9, 30, 0)
t2 = datetime(2025, 7, 11, 10, 30, 0)

test_data = [
{"foo": "a", "bar": {"baz": t1, "qux": 1}},
{"foo": "b", "bar": {"baz": t2, "qux": 2}},
]

arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
partitions = _determine_partitions(spec, schema, arrow_table)
partition_values = {p.partition_key.partition[0] for p in partitions}

assert partition_values == {486729, 486730}


def test_partition_for_deep_nested_field() -> None:
schema = Schema(
NestedField(
id=1,
name="foo",
field_type=StructType(
NestedField(
id=2,
name="bar",
field_type=StructType(NestedField(id=3, name="baz", field_type=StringType(), required=False)),
required=True,
)
),
required=True,
)
)

spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="qux"))

test_data = [
{"foo": {"bar": {"baz": "data-1"}}},
{"foo": {"bar": {"baz": "data-2"}}},
{"foo": {"bar": {"baz": "data-1"}}},
]

arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
partitions = _determine_partitions(spec, schema, arrow_table)

assert len(partitions) == 2 # 2 unique partitions
partition_values = {p.partition_key.partition[0] for p in partitions}
assert partition_values == {"data-1", "data-2"}


def test_inspect_partition_for_nested_field(catalog: InMemoryCatalog) -> None:
schema = Schema(
NestedField(id=1, name="foo", field_type=StringType(), required=True),
NestedField(
id=2,
name="bar",
field_type=StructType(
NestedField(id=3, name="baz", field_type=StringType(), required=False),
NestedField(id=4, name="qux", field_type=IntegerType(), required=False),
),
required=True,
),
)
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="part"))
catalog.create_namespace("default")
table = catalog.create_table("default.test_partition_in_struct", schema=schema, partition_spec=spec)
test_data = [
{"foo": "a", "bar": {"baz": "data-a", "qux": 1}},
{"foo": "b", "bar": {"baz": "data-b", "qux": 2}},
]

arrow_table = pa.Table.from_pylist(test_data, schema=table.schema().as_arrow())
table.append(arrow_table)
partitions_table = table.inspect.partitions()
partitions = partitions_table["partition"].to_pylist()

assert len(partitions) == 2
assert {part["part"] for part in partitions} == {"data-a", "data-b"}


def test_identity_partition_on_multi_columns() -> None:
test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
test_schema = Schema(
Expand Down