diff --git a/tests/integration/test_catalog.py b/tests/integration/test_catalog.py index db4679bc1f..3590d0837e 100644 --- a/tests/integration/test_catalog.py +++ b/tests/integration/test_catalog.py @@ -39,7 +39,7 @@ from pyiceberg.schema import INITIAL_SCHEMA_ID, Schema from pyiceberg.table.metadata import INITIAL_SPEC_ID from pyiceberg.table.sorting import INITIAL_SORT_ORDER_ID, SortField, SortOrder -from pyiceberg.transforms import DayTransform, IdentityTransform +from pyiceberg.transforms import BucketTransform, DayTransform, IdentityTransform from pyiceberg.types import IntegerType, LongType, NestedField, TimestampType, UUIDType from tests.conftest import clean_up @@ -503,6 +503,69 @@ def test_update_namespace_properties(test_catalog: Catalog, database_name: str) assert "updated test description" == test_catalog.load_namespace_properties(database_name)["comment"] +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_update_table_spec(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None: + identifier = (database_name, table_name) + test_catalog.create_namespace(database_name) + table = test_catalog.create_table(identifier, test_schema) + + with table.update_spec() as update: + update.add_field(source_column_name="VendorID", transform=BucketTransform(16), partition_field_name="shard") + + loaded = test_catalog.load_table(identifier) + expected_spec = PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="shard"), spec_id=1 + ) + # The spec ID may not match, so check equality of the fields + assert loaded.spec() == expected_spec + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_update_table_spec_conflict(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None: + identifier = (database_name, table_name) + test_catalog.create_namespace(database_name) + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket")) + table = test_catalog.create_table(identifier, test_schema, partition_spec=spec) + + update = table.update_spec() + update.add_field(source_column_name="tpep_pickup_datetime", transform=BucketTransform(16), partition_field_name="shard") + + # update with conflict + conflict_table = test_catalog.load_table(identifier) + with conflict_table.update_spec() as conflict_update: + conflict_update.remove_field("id_bucket") + + with pytest.raises( + CommitFailedException, match="Requirement failed: default spec id has changed|default partition spec changed" + ): + update.commit() + + loaded = test_catalog.load_table(identifier) + assert loaded.spec() == PartitionSpec(spec_id=1) + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_update_table_spec_then_revert(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None: + identifier = (database_name, table_name) + test_catalog.create_namespace(database_name) + + initial_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket")) + + table = test_catalog.create_table(identifier, test_schema, partition_spec=initial_spec, properties={"format-version": "2"}) + assert table.format_version == 2 + + with table.update_spec() as update: + update.add_identity(source_column_name="tpep_pickup_datetime") + + with table.update_spec() as update: + update.remove_field("tpep_pickup_datetime") + + assert table.spec() == initial_spec + + @pytest.mark.integration @pytest.mark.parametrize("test_catalog", CATALOGS) def test_register_table(test_catalog: Catalog, table_schema_nested: Schema, table_name: str, database_name: str) -> None: