Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
model_validator,
)

from pyiceberg.schema import Schema
from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema, _index_parents
from pyiceberg.transforms import (
BucketTransform,
DayTransform,
Expand Down Expand Up @@ -249,6 +250,36 @@ def partition_to_path(self, data: Record, schema: Schema) -> str:
path = "/".join([field_str + "=" + value_str for field_str, value_str in zip(field_strs, value_strs, strict=True)])
return path

def check_compatible(self, schema: Schema, allow_missing_fields: bool = False) -> None:
# if the underlying field is dropped, we cannot check they are compatible -- continue
schema_fields = schema._lazy_id_to_field
parents = _index_parents(schema)

def validate_parents_are_structs(field_id: int) -> None:
parent_id = parents.get(field_id)
while parent_id:
parent_type = schema.find_type(parent_id)
if not parent_type.is_struct:
raise ValidationError("Invalid partition field parent: %s", parent_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we also use f string here to align with others

parent_id = parents.get(parent_id)

for field in self.fields:
source_field = schema_fields.get(field.source_id)
if allow_missing_fields and source_field:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continue

if not isinstance(field.transform, VoidTransform):
if source_field:
source_type = source_field.field_type
if not source_type.is_primitive:
raise ValidationError(f"Cannot partition by non-primitive source field: {source_type}")
if not field.transform.can_transform(source_type):
raise ValidationError(f"Invalid source type {source_type} for transform: {field.transform}")
# The only valid parent types for a PartitionField are StructTypes. This must be checked recursively
validate_parents_are_structs(field.source_id)
else:
raise ValidationError(f"Cannot find source column for partition field: {field}")


UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)

Expand Down
12 changes: 12 additions & 0 deletions pyiceberg/table/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
model_validator,
)

from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema
from pyiceberg.transforms import IdentityTransform, Transform, parse_transform
from pyiceberg.typedef import IcebergBaseModel
Expand Down Expand Up @@ -169,6 +170,17 @@ def __repr__(self) -> str:
fields = f"{', '.join(repr(column) for column in self.fields)}, " if self.fields else ""
return f"SortOrder({fields}order_id={self.order_id})"

def check_compatible(self, schema: Schema) -> None:
schema_ids = schema._lazy_id_to_field
for field in self.fields:
if source_field := schema_ids.get(field.source_id):
if not source_field.field_type.is_primitive:
raise ValidationError(f"Cannot sort by non-primitive source field: {source_field}")
if not field.transform.can_transform(source_field.field_type):
raise ValidationError(f"Invalid source type {source_field.field_type} for transform: {field.transform}")
else:
raise ValidationError(f"Cannot find source column for sort field: {field}")


UNSORTED_SORT_ORDER_ID = 0
UNSORTED_SORT_ORDER = SortOrder(order_id=UNSORTED_SORT_ORDER_ID)
Expand Down
6 changes: 6 additions & 0 deletions pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,12 @@ def update_table_metadata(
if base_metadata.last_updated_ms == new_metadata.last_updated_ms:
new_metadata = new_metadata.model_copy(update={"last_updated_ms": datetime_to_millis(datetime.now().astimezone())})

# Check correctness of partition spec, and sort order
new_metadata.spec().check_compatible(new_metadata.schema())

if sort_order := new_metadata.sort_order_by_id(new_metadata.default_sort_order_id):
sort_order.check_compatible(new_metadata.schema())

if enforce_validation:
return TableMetadataUtil.parse_obj(new_metadata.model_dump())
else:
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2, TableMetadataV3
from pyiceberg.table.sorting import NullOrder, SortField, SortOrder
from pyiceberg.transforms import DayTransform, IdentityTransform
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -1894,6 +1895,11 @@ def test_partition_spec() -> Schema:
)


@pytest.fixture(scope="session")
def test_sort_order() -> SortOrder:
return SortOrder(SortField(source_id=1, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST))


@pytest.fixture(scope="session")
def generated_manifest_entry_file(
avro_schema_manifest_entry: dict[str, Any], test_schema: Schema, test_partition_spec: PartitionSpec
Expand Down
54 changes: 54 additions & 0 deletions tests/integration/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
NoSuchNamespaceError,
NoSuchTableError,
TableAlreadyExistsError,
ValidationError,
)
from pyiceberg.io import WAREHOUSE
from pyiceberg.partitioning import PartitionField, PartitionSpec
Expand Down Expand Up @@ -601,3 +602,56 @@ def test_register_table_existing(test_catalog: Catalog, table_schema_nested: Sch
# Assert that registering the table again raises TableAlreadyExistsError
with pytest.raises(TableAlreadyExistsError):
test_catalog.register_table(identifier, metadata_location=table.metadata_location)


@pytest.mark.integration
@pytest.mark.parametrize("test_catalog", CATALOGS)
def test_incompatible_partitioned_schema_evolution(
test_catalog: Catalog, test_schema: Schema, test_partition_spec: PartitionSpec, database_name: str, table_name: str
) -> None:
if isinstance(test_catalog, HiveCatalog):
pytest.skip("HiveCatalog does not support schema evolution")

identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
table = test_catalog.create_table(identifier, test_schema, partition_spec=test_partition_spec)
assert test_catalog.table_exists(identifier)

with pytest.raises(ValidationError):
with table.update_schema() as update:
update.delete_column("VendorID")

# Assert column was not dropped
assert "VendorID" in table.schema().column_names

with table.transaction() as transaction:
with transaction.update_spec() as spec_update:
spec_update.remove_field("VendorID")

with transaction.update_schema() as schema_update:
schema_update.delete_column("VendorID")

assert table.spec() == PartitionSpec(PartitionField(2, 1001, DayTransform(), "tpep_pickup_day"), spec_id=1)
assert table.schema() == Schema(NestedField(2, "tpep_pickup_datetime", TimestampType(), False))


@pytest.mark.integration
@pytest.mark.parametrize("test_catalog", CATALOGS)
def test_incompatible_sorted_schema_evolution(
test_catalog: Catalog, test_schema: Schema, test_sort_order: SortOrder, database_name: str, table_name: str
) -> None:
if isinstance(test_catalog, HiveCatalog):
pytest.skip("HiveCatalog does not support schema evolution")

identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
table = test_catalog.create_table(identifier, test_schema, sort_order=test_sort_order)
assert test_catalog.table_exists(identifier)

with pytest.raises(ValidationError):
with table.update_schema() as update:
update.delete_column("VendorID")

assert table.schema() == Schema(
NestedField(1, "VendorID", IntegerType(), False), NestedField(2, "tpep_pickup_datetime", TimestampType(), False)
)
34 changes: 34 additions & 0 deletions tests/table/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pytest

from pyiceberg.exceptions import ValidationError
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.transforms import (
Expand Down Expand Up @@ -259,3 +260,36 @@ def test_deserialize_partition_field_v3() -> None:

field = PartitionField.model_validate_json(json_partition_spec)
assert field == PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=19), name="str_truncate")


def test_incompatible_source_column_not_found() -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

spec = PartitionSpec(PartitionField(3, 1000, IdentityTransform(), "some_partition"))

with pytest.raises(ValidationError) as exc:
spec.check_compatible(schema)

assert "Cannot find source column for partition field: 1000: some_partition: identity(3)" in str(exc.value)


def test_incompatible_non_primitive_type() -> None:
schema = Schema(NestedField(1, "foo", StructType()), NestedField(2, "bar", IntegerType()))

spec = PartitionSpec(PartitionField(1, 1000, IdentityTransform(), "some_partition"))

with pytest.raises(ValidationError) as exc:
spec.check_compatible(schema)

assert "Cannot partition by non-primitive source field: struct<>" in str(exc.value)


def test_incompatible_transform_source_type() -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

spec = PartitionSpec(PartitionField(1, 1000, YearTransform(), "some_partition"))

with pytest.raises(ValidationError) as exc:
spec.check_compatible(schema)

assert "Invalid source type int for transform: year" in str(exc.value)
38 changes: 37 additions & 1 deletion tests/table/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import pytest

from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadataUtil
from pyiceberg.table.sorting import (
UNSORTED_SORT_ORDER,
Expand All @@ -28,7 +30,8 @@
SortField,
SortOrder,
)
from pyiceberg.transforms import BucketTransform, IdentityTransform, VoidTransform
from pyiceberg.transforms import BucketTransform, IdentityTransform, VoidTransform, YearTransform
from pyiceberg.types import IntegerType, NestedField, StructType


@pytest.fixture
Expand Down Expand Up @@ -114,3 +117,36 @@ def test_serialize_sort_field_v3() -> None:
expected = SortField(source_id=19, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST)
payload = '{"source-ids":[19],"transform":"identity","direction":"asc","null-order":"nulls-first"}'
assert SortField.model_validate_json(payload) == expected


def test_incompatible_source_column_not_found(sort_order: SortOrder) -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

with pytest.raises(ValidationError) as exc:
sort_order.check_compatible(schema)

assert "Cannot find source column for sort field: 19 ASC NULLS FIRST" in str(exc.value)


def test_incompatible_non_primitive_type() -> None:
schema = Schema(NestedField(1, "foo", StructType()), NestedField(2, "bar", IntegerType()))

sort_order = SortOrder(SortField(source_id=1, transform=IdentityTransform(), null_order=NullOrder.NULLS_FIRST))

with pytest.raises(ValidationError) as exc:
sort_order.check_compatible(schema)

assert "Cannot sort by non-primitive source field: 1: foo: optional struct<>" in str(exc.value)


def test_incompatible_transform_source_type() -> None:
schema = Schema(NestedField(1, "foo", IntegerType()), NestedField(2, "bar", IntegerType()))

sort_order = SortOrder(
SortField(source_id=1, transform=YearTransform(), null_order=NullOrder.NULLS_FIRST),
)

with pytest.raises(ValidationError) as exc:
sort_order.check_compatible(schema)

assert "Invalid source type int for transform: year" in str(exc.value)