diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 8bf2b817d9..3bb0290008 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -32,7 +32,8 @@ model_validator, ) -from pyiceberg.schema import Schema +from pyiceberg.exceptions import ValidationError +from pyiceberg.schema import Schema, _index_parents from pyiceberg.transforms import ( BucketTransform, DayTransform, @@ -249,6 +250,36 @@ def partition_to_path(self, data: Record, schema: Schema) -> str: path = "/".join([field_str + "=" + value_str for field_str, value_str in zip(field_strs, value_strs, strict=True)]) return path + def check_compatible(self, schema: Schema, allow_missing_fields: bool = False) -> None: + # if the underlying field is dropped, we cannot check they are compatible -- continue + schema_fields = schema._lazy_id_to_field + parents = _index_parents(schema) + + def validate_parents_are_structs(field_id: int) -> None: + parent_id = parents.get(field_id) + while parent_id: + parent_type = schema.find_type(parent_id) + if not parent_type.is_struct: + raise ValidationError("Invalid partition field parent: %s", parent_type) + parent_id = parents.get(parent_id) + + for field in self.fields: + source_field = schema_fields.get(field.source_id) + if allow_missing_fields and source_field: + continue + + if not isinstance(field.transform, VoidTransform): + if source_field: + source_type = source_field.field_type + if not source_type.is_primitive: + raise ValidationError(f"Cannot partition by non-primitive source field: {source_type}") + if not field.transform.can_transform(source_type): + raise ValidationError(f"Invalid source type {source_type} for transform: {field.transform}") + # The only valid parent types for a PartitionField are StructTypes. This must be checked recursively + validate_parents_are_structs(field.source_id) + else: + raise ValidationError(f"Cannot find source column for partition field: {field}") + UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0) diff --git a/pyiceberg/table/sorting.py b/pyiceberg/table/sorting.py index 5243d7b184..1dd41edbb0 100644 --- a/pyiceberg/table/sorting.py +++ b/pyiceberg/table/sorting.py @@ -27,6 +27,7 @@ model_validator, ) +from pyiceberg.exceptions import ValidationError from pyiceberg.schema import Schema from pyiceberg.transforms import IdentityTransform, Transform, parse_transform from pyiceberg.typedef import IcebergBaseModel @@ -169,6 +170,17 @@ def __repr__(self) -> str: fields = f"{', '.join(repr(column) for column in self.fields)}, " if self.fields else "" return f"SortOrder({fields}order_id={self.order_id})" + def check_compatible(self, schema: Schema) -> None: + schema_ids = schema._lazy_id_to_field + for field in self.fields: + if source_field := schema_ids.get(field.source_id): + if not source_field.field_type.is_primitive: + raise ValidationError(f"Cannot sort by non-primitive source field: {source_field}") + if not field.transform.can_transform(source_field.field_type): + raise ValidationError(f"Invalid source type {source_field.field_type} for transform: {field.transform}") + else: + raise ValidationError(f"Cannot find source column for sort field: {field}") + UNSORTED_SORT_ORDER_ID = 0 UNSORTED_SORT_ORDER = SortOrder(order_id=UNSORTED_SORT_ORDER_ID) diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index a79e2cb468..a6f9e3ed04 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -700,6 +700,12 @@ def update_table_metadata( if base_metadata.last_updated_ms == new_metadata.last_updated_ms: new_metadata = new_metadata.model_copy(update={"last_updated_ms": datetime_to_millis(datetime.now().astimezone())}) + # Check correctness of partition spec, and sort order + new_metadata.spec().check_compatible(new_metadata.schema()) + + if sort_order := new_metadata.sort_order_by_id(new_metadata.default_sort_order_id): + sort_order.check_compatible(new_metadata.schema()) + if enforce_validation: return TableMetadataUtil.parse_obj(new_metadata.model_dump()) else: diff --git a/tests/conftest.py b/tests/conftest.py index 85c15d3e0b..36afbcf506 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,6 +70,7 @@ from pyiceberg.serializers import ToOutputFile from pyiceberg.table import FileScanTask, Table from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2, TableMetadataV3 +from pyiceberg.table.sorting import NullOrder, SortField, SortOrder from pyiceberg.transforms import DayTransform, IdentityTransform from pyiceberg.types import ( BinaryType, @@ -1894,6 +1895,11 @@ def test_partition_spec() -> Schema: ) +@pytest.fixture(scope="session") +def test_sort_order() -> SortOrder: + return SortOrder(SortField(source_id=1, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST)) + + @pytest.fixture(scope="session") def generated_manifest_entry_file( avro_schema_manifest_entry: dict[str, Any], test_schema: Schema, test_partition_spec: PartitionSpec diff --git a/tests/integration/test_catalog.py b/tests/integration/test_catalog.py index 0c77666568..fa20b8c2f8 100644 --- a/tests/integration/test_catalog.py +++ b/tests/integration/test_catalog.py @@ -33,6 +33,7 @@ NoSuchNamespaceError, NoSuchTableError, TableAlreadyExistsError, + ValidationError, ) from pyiceberg.io import WAREHOUSE from pyiceberg.partitioning import PartitionField, PartitionSpec @@ -601,3 +602,56 @@ def test_register_table_existing(test_catalog: Catalog, table_schema_nested: Sch # Assert that registering the table again raises TableAlreadyExistsError with pytest.raises(TableAlreadyExistsError): test_catalog.register_table(identifier, metadata_location=table.metadata_location) + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_incompatible_partitioned_schema_evolution( + test_catalog: Catalog, test_schema: Schema, test_partition_spec: PartitionSpec, database_name: str, table_name: str +) -> None: + if isinstance(test_catalog, HiveCatalog): + pytest.skip("HiveCatalog does not support schema evolution") + + identifier = (database_name, table_name) + test_catalog.create_namespace(database_name) + table = test_catalog.create_table(identifier, test_schema, partition_spec=test_partition_spec) + assert test_catalog.table_exists(identifier) + + with pytest.raises(ValidationError): + with table.update_schema() as update: + update.delete_column("VendorID") + + # Assert column was not dropped + assert "VendorID" in table.schema().column_names + + with table.transaction() as transaction: + with transaction.update_spec() as spec_update: + spec_update.remove_field("VendorID") + + with transaction.update_schema() as schema_update: + schema_update.delete_column("VendorID") + + assert table.spec() == PartitionSpec(PartitionField(2, 1001, DayTransform(), "tpep_pickup_day"), spec_id=1) + assert table.schema() == Schema(NestedField(2, "tpep_pickup_datetime", TimestampType(), False)) + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_incompatible_sorted_schema_evolution( + test_catalog: Catalog, test_schema: Schema, test_sort_order: SortOrder, database_name: str, table_name: str +) -> None: + if isinstance(test_catalog, HiveCatalog): + pytest.skip("HiveCatalog does not support schema evolution") + + identifier = (database_name, table_name) + test_catalog.create_namespace(database_name) + table = test_catalog.create_table(identifier, test_schema, sort_order=test_sort_order) + assert test_catalog.table_exists(identifier) + + with pytest.raises(ValidationError): + with table.update_schema() as update: + update.delete_column("VendorID") + + assert table.schema() == Schema( + NestedField(1, "VendorID", IntegerType(), False), NestedField(2, "tpep_pickup_datetime", TimestampType(), False) + ) diff --git a/tests/table/test_partitioning.py b/tests/table/test_partitioning.py index 576297c6f2..284040794f 100644 --- a/tests/table/test_partitioning.py +++ b/tests/table/test_partitioning.py @@ -21,6 +21,7 @@ import pytest +from pyiceberg.exceptions import ValidationError from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.transforms import ( @@ -259,3 +260,36 @@ def test_deserialize_partition_field_v3() -> None: field = PartitionField.model_validate_json(json_partition_spec) assert field == PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=19), name="str_truncate") + + +def test_incompatible_source_column_not_found() -> None: + schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType())) + + spec = PartitionSpec(PartitionField(3, 1000, IdentityTransform(), "some_partition")) + + with pytest.raises(ValidationError) as exc: + spec.check_compatible(schema) + + assert "Cannot find source column for partition field: 1000: some_partition: identity(3)" in str(exc.value) + + +def test_incompatible_non_primitive_type() -> None: + schema = Schema(NestedField(1, "foo", StructType()), NestedField(2, "bar", IntegerType())) + + spec = PartitionSpec(PartitionField(1, 1000, IdentityTransform(), "some_partition")) + + with pytest.raises(ValidationError) as exc: + spec.check_compatible(schema) + + assert "Cannot partition by non-primitive source field: struct<>" in str(exc.value) + + +def test_incompatible_transform_source_type() -> None: + schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType())) + + spec = PartitionSpec(PartitionField(1, 1000, YearTransform(), "some_partition")) + + with pytest.raises(ValidationError) as exc: + spec.check_compatible(schema) + + assert "Invalid source type int for transform: year" in str(exc.value) diff --git a/tests/table/test_sorting.py b/tests/table/test_sorting.py index cb7a2c187a..c1ce5c04d7 100644 --- a/tests/table/test_sorting.py +++ b/tests/table/test_sorting.py @@ -20,6 +20,8 @@ import pytest +from pyiceberg.exceptions import ValidationError +from pyiceberg.schema import Schema from pyiceberg.table.metadata import TableMetadataUtil from pyiceberg.table.sorting import ( UNSORTED_SORT_ORDER, @@ -28,7 +30,8 @@ SortField, SortOrder, ) -from pyiceberg.transforms import BucketTransform, IdentityTransform, VoidTransform +from pyiceberg.transforms import BucketTransform, IdentityTransform, VoidTransform, YearTransform +from pyiceberg.types import IntegerType, NestedField, StructType @pytest.fixture @@ -114,3 +117,36 @@ def test_serialize_sort_field_v3() -> None: expected = SortField(source_id=19, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST) payload = '{"source-ids":[19],"transform":"identity","direction":"asc","null-order":"nulls-first"}' assert SortField.model_validate_json(payload) == expected + + +def test_incompatible_source_column_not_found(sort_order: SortOrder) -> None: + schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType())) + + with pytest.raises(ValidationError) as exc: + sort_order.check_compatible(schema) + + assert "Cannot find source column for sort field: 19 ASC NULLS FIRST" in str(exc.value) + + +def test_incompatible_non_primitive_type() -> None: + schema = Schema(NestedField(1, "foo", StructType()), NestedField(2, "bar", IntegerType())) + + sort_order = SortOrder(SortField(source_id=1, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST)) + + with pytest.raises(ValidationError) as exc: + sort_order.check_compatible(schema) + + assert "Cannot sort by non-primitive source field: 1: foo: optional struct<>" in str(exc.value) + + +def test_incompatible_transform_source_type() -> None: + schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType())) + + sort_order = SortOrder( + SortField(source_id=1, transform=YearTransform(), null_order=NullOrder.NULLS_FIRST), + ) + + with pytest.raises(ValidationError) as exc: + sort_order.check_compatible(schema) + + assert "Invalid source type int for transform: year" in str(exc.value)