diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 2f0e6c13d2..c30d960d38 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -21,9 +21,9 @@ from abc import ABC, abstractmethod from datetime import datetime from functools import singledispatch -from typing import TYPE_CHECKING, Annotated, Any, Dict, Generic, List, Literal, Optional, Set, Tuple, TypeVar, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union, cast -from pydantic import Field, field_validator, model_validator +from pydantic import Field, field_validator, model_serializer, model_validator from pyiceberg.exceptions import CommitFailedException from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec @@ -52,6 +52,8 @@ from pyiceberg.utils.properties import property_as_int if TYPE_CHECKING: + from pydantic.functional_serializers import ModelWrapSerializerWithoutInfo + from pyiceberg.table import Transaction U = TypeVar("U") @@ -727,6 +729,12 @@ class AssertRefSnapshotId(ValidatableTableRequirement): ref: str = Field(...) snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id") + @model_serializer(mode="wrap") + def serialize_model(self, handler: ModelWrapSerializerWithoutInfo) -> dict[str, Any]: + partial_result = handler(self) + # Ensure "snapshot-id" is always present, even if value is None + return {**partial_result, "snapshot-id": self.snapshot_id} + def validate(self, base_metadata: Optional[TableMetadata]) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") @@ -745,13 +753,6 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: elif self.snapshot_id is not None: raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}") - # override the override method, allowing None to serialize to `null` instead of being omitted. - def model_dump_json( - self, exclude_none: bool = False, exclude: Optional[Set[str]] = None, by_alias: bool = True, **kwargs: Any - ) -> str: - # `snapshot-id` is required in json response, even if null - return super().model_dump_json(exclude_none=False) - class AssertLastAssignedFieldId(ValidatableTableRequirement): """The table's last assigned column id must match the requirement's `last-assigned-field-id`.""" diff --git a/tests/test_serializers.py b/tests/test_serializers.py index ad40ea08e0..3f2bd73e48 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -18,7 +18,7 @@ import json import os import uuid -from typing import Any, Dict +from typing import Any, Dict, Tuple import pytest from pytest_mock import MockFixture @@ -26,6 +26,8 @@ from pyiceberg.serializers import ToOutputFile from pyiceberg.table import StaticTable from pyiceberg.table.metadata import TableMetadataV1 +from pyiceberg.table.update import AssertRefSnapshotId, TableRequirement +from pyiceberg.typedef import IcebergBaseModel def test_legacy_current_snapshot_id( @@ -48,3 +50,13 @@ def test_legacy_current_snapshot_id( backwards_compatible_static_table = StaticTable.from_metadata(metadata_location) assert backwards_compatible_static_table.metadata.current_snapshot_id is None assert backwards_compatible_static_table.metadata == static_table.metadata + + +def test_null_serializer_field() -> None: + class ExampleRequest(IcebergBaseModel): + requirements: Tuple[TableRequirement, ...] + + request = ExampleRequest(requirements=(AssertRefSnapshotId(ref="main", snapshot_id=None),)) + dumped_json = request.model_dump_json() + expected_json = """{"type":"assert-ref-snapshot-id","ref":"main","snapshot-id":null}""" + assert expected_json in dumped_json