Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
from pyiceberg.utils.concurrent import ExecutorFactory
from pyiceberg.utils.config import Config
from pyiceberg.utils.datetime import millis_to_datetime
from pyiceberg.utils.decimal import unscaled_to_decimal
from pyiceberg.utils.deprecated import deprecation_message
from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int
from pyiceberg.utils.singleton import Singleton
Expand Down Expand Up @@ -1888,7 +1889,7 @@ def visit_fixed(self, fixed_type: FixedType) -> str:
return "FIXED_LEN_BYTE_ARRAY"

def visit_decimal(self, decimal_type: DecimalType) -> str:
return "FIXED_LEN_BYTE_ARRAY"
return "INT32" if decimal_type.precision <= 9 else "INT64" if decimal_type.precision <= 18 else "FIXED_LEN_BYTE_ARRAY"

def visit_boolean(self, boolean_type: BooleanType) -> str:
return "BOOLEAN"
Expand Down Expand Up @@ -2362,8 +2363,13 @@ def data_file_statistics_from_parquet_metadata(
stats_col.iceberg_type, statistics.physical_type, stats_col.mode.length
)

col_aggs[field_id].update_min(statistics.min)
col_aggs[field_id].update_max(statistics.max)
if isinstance(stats_col.iceberg_type, DecimalType) and statistics.physical_type != "FIXED_LEN_BYTE_ARRAY":
scale = stats_col.iceberg_type.scale
col_aggs[field_id].update_min(unscaled_to_decimal(statistics.min_raw, scale))
col_aggs[field_id].update_max(unscaled_to_decimal(statistics.max_raw, scale))
else:
col_aggs[field_id].update_min(statistics.min)
col_aggs[field_id].update_max(statistics.max)

except pyarrow.lib.ArrowNotImplementedError as e:
invalidate_col.add(field_id)
Expand Down
26 changes: 21 additions & 5 deletions tests/io/test_pyarrow_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
timedelta,
timezone,
)
from decimal import Decimal
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -446,6 +447,9 @@ def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[Table
{"id": 10, "name": "strings", "required": False, "type": "string"},
{"id": 11, "name": "uuids", "required": False, "type": "uuid"},
{"id": 12, "name": "binaries", "required": False, "type": "binary"},
{"id": 13, "name": "decimal8", "required": False, "type": "decimal(5, 2)"},
{"id": 14, "name": "decimal16", "required": False, "type": "decimal(16, 6)"},
{"id": 15, "name": "decimal32", "required": False, "type": "decimal(19, 6)"},
],
},
],
Expand All @@ -470,6 +474,9 @@ def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[Table
strings = ["hello", "world"]
uuids = [uuid.uuid3(uuid.NAMESPACE_DNS, "foo").bytes, uuid.uuid3(uuid.NAMESPACE_DNS, "bar").bytes]
binaries = [b"hello", b"world"]
decimal8 = pa.array([Decimal("123.45"), Decimal("678.91")], pa.decimal128(8, 2))
decimal16 = pa.array([Decimal("12345679.123456"), Decimal("67891234.678912")], pa.decimal128(16, 6))
decimal32 = pa.array([Decimal("1234567890123.123456"), Decimal("9876543210703.654321")], pa.decimal128(19, 6))

table = pa.Table.from_pydict(
{
Expand All @@ -485,14 +492,17 @@ def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[Table
"strings": strings,
"uuids": uuids,
"binaries": binaries,
"decimal8": decimal8,
"decimal16": decimal16,
"decimal32": decimal32,
},
schema=arrow_schema,
)

metadata_collector: List[Any] = []

with pa.BufferOutputStream() as f:
with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector) as writer:
with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector, store_decimal_as_integer=True) as writer:
writer.write_table(table)

return metadata_collector[0], table_metadata
Expand All @@ -510,13 +520,13 @@ def test_metrics_primitive_types() -> None:
)
datafile = DataFile(**statistics.to_serialized_dict())

assert len(datafile.value_counts) == 12
assert len(datafile.null_value_counts) == 12
assert len(datafile.value_counts) == 15
assert len(datafile.null_value_counts) == 15
assert len(datafile.nan_value_counts) == 0

tz = timezone(timedelta(seconds=19800))

assert len(datafile.lower_bounds) == 12
assert len(datafile.lower_bounds) == 15
assert datafile.lower_bounds[1] == STRUCT_BOOL.pack(False)
assert datafile.lower_bounds[2] == STRUCT_INT32.pack(23)
assert datafile.lower_bounds[3] == STRUCT_INT64.pack(2)
Expand All @@ -529,8 +539,11 @@ def test_metrics_primitive_types() -> None:
assert datafile.lower_bounds[10] == b"he"
assert datafile.lower_bounds[11] == uuid.uuid3(uuid.NAMESPACE_DNS, "foo").bytes
assert datafile.lower_bounds[12] == b"he"
assert datafile.lower_bounds[13][::-1].ljust(4, b"\x00") == STRUCT_INT32.pack(12345)
assert datafile.lower_bounds[14][::-1].ljust(8, b"\x00") == STRUCT_INT64.pack(12345679123456)
assert str(int.from_bytes(datafile.lower_bounds[15], byteorder="big", signed=True)).encode("utf-8") == b"1234567890123123456"

assert len(datafile.upper_bounds) == 12
assert len(datafile.upper_bounds) == 15
assert datafile.upper_bounds[1] == STRUCT_BOOL.pack(True)
assert datafile.upper_bounds[2] == STRUCT_INT32.pack(89)
assert datafile.upper_bounds[3] == STRUCT_INT64.pack(54)
Expand All @@ -543,6 +556,9 @@ def test_metrics_primitive_types() -> None:
assert datafile.upper_bounds[10] == b"wp"
assert datafile.upper_bounds[11] == uuid.uuid3(uuid.NAMESPACE_DNS, "bar").bytes
assert datafile.upper_bounds[12] == b"wp"
assert datafile.upper_bounds[13][::-1].ljust(4, b"\x00") == STRUCT_INT32.pack(67891)
assert datafile.upper_bounds[14][::-1].ljust(8, b"\x00") == STRUCT_INT64.pack(67891234678912)
assert str(int.from_bytes(datafile.upper_bounds[15], byteorder="big", signed=True)).encode("utf-8") == b"9876543210703654321"


def construct_test_table_invalid_upper_bound() -> Tuple[pq.FileMetaData, Union[TableMetadataV1, TableMetadataV2]]:
Expand Down