|
20 | 20 | import random |
21 | 21 | import time |
22 | 22 | from datetime import date, datetime, timedelta |
| 23 | +from decimal import Decimal |
23 | 24 | from pathlib import Path |
24 | 25 | from typing import Any, Dict |
25 | 26 | from urllib.parse import urlparse |
|
50 | 51 | from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform |
51 | 52 | from pyiceberg.types import ( |
52 | 53 | DateType, |
| 54 | + DecimalType, |
53 | 55 | DoubleType, |
54 | 56 | IntegerType, |
55 | 57 | ListType, |
@@ -1684,3 +1686,66 @@ def test_write_optional_list(session_catalog: Catalog) -> None: |
1684 | 1686 | session_catalog.load_table(identifier).append(df_2) |
1685 | 1687 |
|
1686 | 1688 | assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4 |
| 1689 | + |
| 1690 | + |
| 1691 | +@pytest.mark.integration |
| 1692 | +@pytest.mark.parametrize("format_version", [1, 2]) |
| 1693 | +def test_evolve_and_write( |
| 1694 | + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| 1695 | +) -> None: |
| 1696 | + identifier = "default.test_evolve_and_write" |
| 1697 | + tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}, schema=Schema()) |
| 1698 | + other_table = session_catalog.load_table(identifier) |
| 1699 | + |
| 1700 | + numbers = pa.array([1, 2, 3, 4], type=pa.int32()) |
| 1701 | + |
| 1702 | + with tbl.update_schema() as upd: |
| 1703 | + # This is not known by other_table |
| 1704 | + upd.add_column("id", IntegerType()) |
| 1705 | + |
| 1706 | + with other_table.transaction() as tx: |
| 1707 | + # Refreshes the underlying metadata, and the schema |
| 1708 | + other_table.refresh() |
| 1709 | + tx.append( |
| 1710 | + pa.Table.from_arrays( |
| 1711 | + [ |
| 1712 | + numbers, |
| 1713 | + ], |
| 1714 | + schema=pa.schema( |
| 1715 | + [ |
| 1716 | + pa.field("id", pa.int32(), nullable=True), |
| 1717 | + ] |
| 1718 | + ), |
| 1719 | + ) |
| 1720 | + ) |
| 1721 | + |
| 1722 | + assert session_catalog.load_table(identifier).scan().to_arrow().column(0).combine_chunks() == numbers |
| 1723 | + |
| 1724 | + |
| 1725 | +@pytest.mark.integration |
| 1726 | +def test_read_write_decimals(session_catalog: Catalog) -> None: |
| 1727 | + """Roundtrip decimal types to make sure that we correctly write them as ints""" |
| 1728 | + identifier = "default.test_read_write_decimals" |
| 1729 | + |
| 1730 | + arrow_table = pa.Table.from_pydict( |
| 1731 | + { |
| 1732 | + "decimal8": pa.array([Decimal("123.45"), Decimal("678.91")], pa.decimal128(8, 2)), |
| 1733 | + "decimal16": pa.array([Decimal("12345679.123456"), Decimal("67891234.678912")], pa.decimal128(16, 6)), |
| 1734 | + "decimal19": pa.array([Decimal("1234567890123.123456"), Decimal("9876543210703.654321")], pa.decimal128(19, 6)), |
| 1735 | + }, |
| 1736 | + ) |
| 1737 | + |
| 1738 | + tbl = _create_table( |
| 1739 | + session_catalog, |
| 1740 | + identifier, |
| 1741 | + properties={"format-version": 2}, |
| 1742 | + schema=Schema( |
| 1743 | + NestedField(1, "decimal8", DecimalType(8, 2)), |
| 1744 | + NestedField(2, "decimal16", DecimalType(16, 6)), |
| 1745 | + NestedField(3, "decimal19", DecimalType(19, 6)), |
| 1746 | + ), |
| 1747 | + ) |
| 1748 | + |
| 1749 | + tbl.append(arrow_table) |
| 1750 | + |
| 1751 | + assert tbl.scan().to_arrow() == arrow_table |
0 commit comments