Skip to content

Commit 895dff5

Browse files
author
redpheonixx
committed
changes_as_per_pr1799_comments
1 parent 71cb247 commit 895dff5

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@
177177
from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int
178178
from pyiceberg.utils.singleton import Singleton
179179
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string
180+
from decimal import Decimal
180181

181182
if TYPE_CHECKING:
182183
from pyiceberg.table import FileScanTask, WriteTask
@@ -1876,7 +1877,11 @@ def visit_fixed(self, fixed_type: FixedType) -> str:
18761877
return "FIXED_LEN_BYTE_ARRAY"
18771878

18781879
def visit_decimal(self, decimal_type: DecimalType) -> str:
1879-
return "FIXED_LEN_BYTE_ARRAY"
1880+
return (
1881+
"INT32" if decimal_type.precision <= 9
1882+
else "INT64" if decimal_type.precision <= 18
1883+
else "FIXED_LEN_BYTE_ARRAY"
1884+
)
18801885

18811886
def visit_boolean(self, boolean_type: BooleanType) -> str:
18821887
return "BOOLEAN"
@@ -2350,8 +2355,15 @@ def data_file_statistics_from_parquet_metadata(
23502355
stats_col.iceberg_type, statistics.physical_type, stats_col.mode.length
23512356
)
23522357

2353-
col_aggs[field_id].update_min(statistics.min)
2354-
col_aggs[field_id].update_max(statistics.max)
2358+
if isinstance(stats_col.iceberg_type, DecimalType) and statistics.physical_type != "FIXED_LEN_BYTE_ARRAY":
2359+
precision= stats_col.iceberg_type.precision
2360+
scale = stats_col.iceberg_type.scale
2361+
decimal_type = pa.decimal128(precision, scale)
2362+
col_aggs[field_id].update_min(pa.array([Decimal(statistics.min_raw)/ (10 ** scale)], decimal_type)[0].as_py())
2363+
col_aggs[field_id].update_max(pa.array([Decimal(statistics.max_raw)/ (10 ** scale)], decimal_type)[0].as_py())
2364+
else:
2365+
col_aggs[field_id].update_min(statistics.min)
2366+
col_aggs[field_id].update_max(statistics.max)
23552367

23562368
except pyarrow.lib.ArrowNotImplementedError as e:
23572369
invalidate_col.add(field_id)

tests/io/test_pyarrow_stats.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
StringType,
7373
)
7474
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros
75-
75+
from decimal import Decimal
7676

7777
@dataclass(frozen=True)
7878
class TestStruct:
@@ -446,6 +446,9 @@ def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[Table
446446
{"id": 10, "name": "strings", "required": False, "type": "string"},
447447
{"id": 11, "name": "uuids", "required": False, "type": "uuid"},
448448
{"id": 12, "name": "binaries", "required": False, "type": "binary"},
449+
{"id": 13, "name": "decimal8", "required": False, "type": "decimal(5, 2)"},
450+
{"id": 14, "name": "decimal16", "required": False, "type": "decimal(16, 6)"},
451+
{"id": 15, "name": "decimal32", "required": False, "type": "decimal(19, 6)"},
449452
],
450453
},
451454
],
@@ -470,6 +473,9 @@ def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[Table
470473
strings = ["hello", "world"]
471474
uuids = [uuid.uuid3(uuid.NAMESPACE_DNS, "foo").bytes, uuid.uuid3(uuid.NAMESPACE_DNS, "bar").bytes]
472475
binaries = [b"hello", b"world"]
476+
decimal8 = pa.array([Decimal('123.45'), Decimal('678.91')], pa.decimal128(8, 2))
477+
decimal16 = pa.array([Decimal("12345679.123456"), Decimal("67891234.678912")], pa.decimal128(16, 6))
478+
decimal32 = pa.array([Decimal("1234567890123.123456"), Decimal("9876543210703.654321")], pa.decimal128(19, 6))
473479

474480
table = pa.Table.from_pydict(
475481
{
@@ -485,14 +491,17 @@ def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[Table
485491
"strings": strings,
486492
"uuids": uuids,
487493
"binaries": binaries,
494+
"decimal8": decimal8,
495+
"decimal16": decimal16,
496+
"decimal32": decimal32,
488497
},
489498
schema=arrow_schema,
490499
)
491500

492501
metadata_collector: List[Any] = []
493502

494503
with pa.BufferOutputStream() as f:
495-
with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector) as writer:
504+
with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector, store_decimal_as_integer=True) as writer:
496505
writer.write_table(table)
497506

498507
return metadata_collector[0], table_metadata
@@ -510,13 +519,13 @@ def test_metrics_primitive_types() -> None:
510519
)
511520
datafile = DataFile(**statistics.to_serialized_dict())
512521

513-
assert len(datafile.value_counts) == 12
514-
assert len(datafile.null_value_counts) == 12
522+
assert len(datafile.value_counts) == 15
523+
assert len(datafile.null_value_counts) == 15
515524
assert len(datafile.nan_value_counts) == 0
516525

517526
tz = timezone(timedelta(seconds=19800))
518527

519-
assert len(datafile.lower_bounds) == 12
528+
assert len(datafile.lower_bounds) == 15
520529
assert datafile.lower_bounds[1] == STRUCT_BOOL.pack(False)
521530
assert datafile.lower_bounds[2] == STRUCT_INT32.pack(23)
522531
assert datafile.lower_bounds[3] == STRUCT_INT64.pack(2)
@@ -529,8 +538,11 @@ def test_metrics_primitive_types() -> None:
529538
assert datafile.lower_bounds[10] == b"he"
530539
assert datafile.lower_bounds[11] == uuid.uuid3(uuid.NAMESPACE_DNS, "foo").bytes
531540
assert datafile.lower_bounds[12] == b"he"
541+
assert datafile.lower_bounds[13][::-1].ljust(4, b'\x00') == STRUCT_INT32.pack(12345)
542+
assert datafile.lower_bounds[14][::-1].ljust(8, b'\x00') == STRUCT_INT64.pack(12345679123456)
543+
assert str(int.from_bytes(datafile.lower_bounds[15], byteorder='big', signed=True)).encode('utf-8')== b"1234567890123123456"
532544

533-
assert len(datafile.upper_bounds) == 12
545+
assert len(datafile.upper_bounds) == 15
534546
assert datafile.upper_bounds[1] == STRUCT_BOOL.pack(True)
535547
assert datafile.upper_bounds[2] == STRUCT_INT32.pack(89)
536548
assert datafile.upper_bounds[3] == STRUCT_INT64.pack(54)
@@ -543,6 +555,9 @@ def test_metrics_primitive_types() -> None:
543555
assert datafile.upper_bounds[10] == b"wp"
544556
assert datafile.upper_bounds[11] == uuid.uuid3(uuid.NAMESPACE_DNS, "bar").bytes
545557
assert datafile.upper_bounds[12] == b"wp"
558+
assert datafile.upper_bounds[13][::-1].ljust(4, b'\x00')== STRUCT_INT32.pack(67891)
559+
assert datafile.upper_bounds[14][::-1].ljust(8, b'\x00')== STRUCT_INT64.pack(67891234678912)
560+
assert str(int.from_bytes(datafile.upper_bounds[15], byteorder='big', signed=True)).encode('utf-8')== b"9876543210703654321"
546561

547562

548563
def construct_test_table_invalid_upper_bound() -> Tuple[pq.FileMetaData, Union[TableMetadataV1, TableMetadataV2]]:

0 commit comments

Comments
 (0)