Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
visit,
visit_with_partner,
)
from pyiceberg.table import TableProperties
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties
from pyiceberg.table.locations import load_location_provider
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
Expand Down Expand Up @@ -1487,17 +1487,20 @@ def _task_to_record_batches(
name_mapping: Optional[NameMapping] = None,
partition_spec: Optional[PartitionSpec] = None,
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
downcast_ns_timestamp_to_us: Optional[bool] = None,
) -> Iterator[pa.RecordBatch]:
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with io.new_input(task.file.file_path).open() as fin:
fragment = arrow_format.make_fragment(fin)
physical_schema = fragment.physical_schema
# In V1 and V2 table formats, we only support Timestamp 'us' in Iceberg Schema
# Hence it is reasonable to always cast 'ns' timestamp to 'us' on read.
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
# the table format version.

# For V1 and V2, we only support Timestamp 'us' in Iceberg Schema, therefore it is reasonable to always cast 'ns' timestamp to 'us' on read.
# For V3 this has to set explicitly to avoid nanosecond timestamp to be down-casted by default
downcast_ns_timestamp_to_us = (
downcast_ns_timestamp_to_us if downcast_ns_timestamp_to_us is not None else format_version <= 2
)
file_schema = pyarrow_to_schema(
physical_schema, name_mapping, downcast_ns_timestamp_to_us=True, format_version=format_version
physical_schema, name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, format_version=format_version
)

# Apply column projection rules: https://iceberg.apache.org/spec/#column-projection
Expand Down Expand Up @@ -1555,7 +1558,7 @@ def _task_to_record_batches(
projected_schema,
file_project_schema,
current_batch,
downcast_ns_timestamp_to_us=True,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
projected_missing_fields=projected_missing_fields,
)

Expand Down Expand Up @@ -1586,6 +1589,7 @@ class ArrowScan:
_bound_row_filter: BooleanExpression
_case_sensitive: bool
_limit: Optional[int]
_downcast_ns_timestamp_to_us: Optional[bool]
"""Scan the Iceberg Table and create an Arrow construct.

Attributes:
Expand All @@ -1612,6 +1616,7 @@ def __init__(
self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
self._case_sensitive = case_sensitive
self._limit = limit
self._downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check the format version for downcasting? (We have the table_metadata already, so we have access to it)


@property
def _projected_field_ids(self) -> Set[int]:
Expand Down Expand Up @@ -1728,6 +1733,7 @@ def _record_batches_from_scan_tasks_and_deletes(
self._table_metadata.name_mapping(),
self._table_metadata.specs().get(task.file.spec_id),
self._table_metadata.format_version,
self._downcast_ns_timestamp_to_us,
)
for batch in batches:
if self._limit is not None:
Expand Down
75 changes: 69 additions & 6 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tempfile
import uuid
import warnings
from datetime import date
from datetime import date, datetime, timezone
from typing import Any, List, Optional
from unittest.mock import MagicMock, patch
from uuid import uuid4
Expand Down Expand Up @@ -61,6 +61,7 @@
from pyiceberg.io import S3_RETRY_STRATEGY_IMPL, InputStream, OutputStream, load_file_io
from pyiceberg.io.pyarrow import (
ICEBERG_SCHEMA,
PYARROW_PARQUET_FIELD_ID_KEY,
ArrowScan,
PyArrowFile,
PyArrowFileIO,
Expand All @@ -70,6 +71,7 @@
_determine_partitions,
_primitive_to_physical,
_read_deletes,
_task_to_record_batches,
_to_requested_schema,
bin_pack_arrow_table,
compute_statistics_plan,
Expand All @@ -85,7 +87,7 @@
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import HourTransform, IdentityTransform
from pyiceberg.typedef import UTF8, Properties, Record
from pyiceberg.typedef import UTF8, Properties, Record, TableVersion
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand All @@ -102,6 +104,7 @@
PrimitiveType,
StringType,
StructType,
TimestampNanoType,
TimestampType,
TimestamptzType,
TimeType,
Expand Down Expand Up @@ -873,6 +876,18 @@ def _write_table_to_file(filepath: str, schema: pa.Schema, table: pa.Table) -> s
return filepath


def _write_table_to_data_file(filepath: str, schema: pa.Schema, table: pa.Table) -> DataFile:
filepath = _write_table_to_file(filepath, schema, table)
return DataFile.from_args(
content=DataFileContent.DATA,
file_path=filepath,
file_format=FileFormat.PARQUET,
partition={},
record_count=len(table),
file_size_in_bytes=22, # This is not relevant for now
)


@pytest.fixture
def file_int(schema_int: Schema, tmpdir: str) -> str:
pyarrow_schema = schema_to_pyarrow(schema_int, metadata={ICEBERG_SCHEMA: bytes(schema_int.model_dump_json(), UTF8)})
Expand Down Expand Up @@ -2411,8 +2426,6 @@ def test_partition_for_nested_field() -> None:

spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=HourTransform(), name="ts"))

from datetime import datetime

t1 = datetime(2025, 7, 11, 9, 30, 0)
t2 = datetime(2025, 7, 11, 10, 30, 0)

Expand Down Expand Up @@ -2551,8 +2564,6 @@ def test_initial_value() -> None:


def test__to_requested_schema_timestamp_to_timestamptz_projection() -> None:
from datetime import datetime, timezone

# file is written with timestamp without timezone
file_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
batch = pa.record_batch(
Expand Down Expand Up @@ -2722,3 +2733,55 @@ def test_retry_strategy_not_found() -> None:
io = PyArrowFileIO(properties={S3_RETRY_STRATEGY_IMPL: "pyiceberg.DoesNotExist"})
with pytest.warns(UserWarning, match="Could not initialize S3 retry strategy: pyiceberg.DoesNotExist"):
io.new_input("s3://bucket/path/to/file")


@pytest.mark.parametrize("format_version", [1, 2, 3])
def test_task_to_record_batches_nanos(format_version: TableVersion, tmpdir: str) -> None:
arrow_table = pa.table(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0),
datetime(2025, 8, 14, 13, 0, 0),
],
type=pa.timestamp("ns"),
)
],
pa.schema((pa.field("ts_field", pa.timestamp("ns"), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}),)),
)

data_file = _write_table_to_data_file(f"{tmpdir}/test_task_to_record_batches_nanos.parquet", arrow_table.schema, arrow_table)

if format_version <= 2:
table_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
else:
table_schema = Schema(NestedField(1, "ts_field", TimestampNanoType(), required=False))

actual_result = list(
_task_to_record_batches(
PyArrowFileIO(),
FileScanTask(data_file),
bound_row_filter=AlwaysTrue(),
projected_schema=table_schema,
projected_field_ids={1},
positional_deletes=None,
case_sensitive=True,
format_version=format_version,
)
)[0]

def _expected_batch(unit: str) -> pa.RecordBatch:
return pa.record_batch(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0),
datetime(2025, 8, 14, 13, 0, 0),
],
type=pa.timestamp(unit),
)
],
names=["ts_field"],
)

assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result)