Skip to content

Commit c2eec90

Browse files
committed
Write small decimals as INTs
Resolves #1979
1 parent eb8756a commit c2eec90

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,13 @@ def visit_fixed(self, fixed_type: FixedType) -> pa.DataType:
638638
return pa.binary(len(fixed_type))
639639

640640
def visit_decimal(self, decimal_type: DecimalType) -> pa.DataType:
641-
return pa.decimal128(decimal_type.precision, decimal_type.scale)
641+
return (
642+
pa.decimal32(decimal_type.precision, decimal_type.scale)
643+
if decimal_type.precision <= 9
644+
else pa.decimal64(decimal_type.precision, decimal_type.scale)
645+
if decimal_type.precision <= 18
646+
else pa.decimal128(decimal_type.precision, decimal_type.scale)
647+
)
642648

643649
def visit_boolean(self, _: BooleanType) -> pa.DataType:
644650
return pa.bool_()
@@ -1749,6 +1755,8 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
17491755
elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}:
17501756
return values.cast(target_type)
17511757
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
1758+
else:
1759+
pass
17521760
return values
17531761

17541762
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
@@ -2437,7 +2445,9 @@ def write_parquet(task: WriteTask) -> DataFile:
24372445
)
24382446
fo = io.new_output(file_path)
24392447
with fo.create(overwrite=True) as fos:
2440-
with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer:
2448+
with pq.ParquetWriter(
2449+
fos, schema=arrow_table.schema, store_decimal_as_integer=True, **parquet_writer_kwargs
2450+
) as writer:
24412451
writer.write(arrow_table, row_group_size=row_group_size)
24422452
statistics = data_file_statistics_from_parquet_metadata(
24432453
parquet_metadata=writer.writer.metadata,

tests/integration/test_writes/test_writes.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import random
2121
import time
2222
from datetime import date, datetime, timedelta
23+
from decimal import Decimal
2324
from pathlib import Path
2425
from typing import Any, Dict
2526
from urllib.parse import urlparse
@@ -50,6 +51,7 @@
5051
from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform
5152
from pyiceberg.types import (
5253
DateType,
54+
DecimalType,
5355
DoubleType,
5456
IntegerType,
5557
ListType,
@@ -1684,3 +1686,66 @@ def test_write_optional_list(session_catalog: Catalog) -> None:
16841686
session_catalog.load_table(identifier).append(df_2)
16851687

16861688
assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4
1689+
1690+
1691+
@pytest.mark.integration
1692+
@pytest.mark.parametrize("format_version", [1, 2])
1693+
def test_evolve_and_write(
1694+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
1695+
) -> None:
1696+
identifier = "default.test_evolve_and_write"
1697+
tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}, schema=Schema())
1698+
other_table = session_catalog.load_table(identifier)
1699+
1700+
numbers = pa.array([1, 2, 3, 4], type=pa.int32())
1701+
1702+
with tbl.update_schema() as upd:
1703+
# This is not known by other_table
1704+
upd.add_column("id", IntegerType())
1705+
1706+
with other_table.transaction() as tx:
1707+
# Refreshes the underlying metadata, and the schema
1708+
other_table.refresh()
1709+
tx.append(
1710+
pa.Table.from_arrays(
1711+
[
1712+
numbers,
1713+
],
1714+
schema=pa.schema(
1715+
[
1716+
pa.field("id", pa.int32(), nullable=True),
1717+
]
1718+
),
1719+
)
1720+
)
1721+
1722+
assert session_catalog.load_table(identifier).scan().to_arrow().column(0).combine_chunks() == numbers
1723+
1724+
1725+
@pytest.mark.integration
1726+
def test_read_write_decimals(session_catalog: Catalog) -> None:
1727+
"""Roundtrip decimal types to make sure that we correctly write them as ints"""
1728+
identifier = "default.test_read_write_decimals"
1729+
1730+
arrow_table = pa.Table.from_pydict(
1731+
{
1732+
"decimal8": pa.array([Decimal("123.45"), Decimal("678.91")], pa.decimal128(8, 2)),
1733+
"decimal16": pa.array([Decimal("12345679.123456"), Decimal("67891234.678912")], pa.decimal128(16, 6)),
1734+
"decimal19": pa.array([Decimal("1234567890123.123456"), Decimal("9876543210703.654321")], pa.decimal128(19, 6)),
1735+
},
1736+
)
1737+
1738+
tbl = _create_table(
1739+
session_catalog,
1740+
identifier,
1741+
properties={"format-version": 2},
1742+
schema=Schema(
1743+
NestedField(1, "decimal8", DecimalType(8, 2)),
1744+
NestedField(2, "decimal16", DecimalType(16, 6)),
1745+
NestedField(3, "decimal19", DecimalType(19, 6)),
1746+
),
1747+
)
1748+
1749+
tbl.append(arrow_table)
1750+
1751+
assert tbl.scan().to_arrow() == arrow_table

0 commit comments

Comments
 (0)