Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from abc import ABC, abstractmethod
from datetime import datetime
from functools import singledispatch
from typing import TYPE_CHECKING, Annotated, Any, Dict, Generic, List, Literal, Optional, Set, Tuple, TypeVar, Union, cast
from typing import TYPE_CHECKING, Annotated, Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union, cast

from pydantic import Field, field_validator, model_validator
from pydantic import Field, field_validator, model_serializer, model_validator

from pyiceberg.exceptions import CommitFailedException
from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec
Expand Down Expand Up @@ -52,6 +52,8 @@
from pyiceberg.utils.properties import property_as_int

if TYPE_CHECKING:
from pydantic.functional_serializers import ModelWrapSerializerWithoutInfo

from pyiceberg.table import Transaction

U = TypeVar("U")
Expand Down Expand Up @@ -727,6 +729,12 @@ class AssertRefSnapshotId(ValidatableTableRequirement):
ref: str = Field(...)
snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id")

@model_serializer(mode="wrap")
def serialize_model(self, handler: ModelWrapSerializerWithoutInfo) -> dict[str, Any]:
partial_result = handler(self)
# Ensure "snapshot-id" is always present, even if value is None
return {**partial_result, "snapshot-id": self.snapshot_id}

def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
Expand All @@ -745,13 +753,6 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
elif self.snapshot_id is not None:
raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}")

# override the override method, allowing None to serialize to `null` instead of being omitted.
def model_dump_json(
self, exclude_none: bool = False, exclude: Optional[Set[str]] = None, by_alias: bool = True, **kwargs: Any
) -> str:
# `snapshot-id` is required in json response, even if null
return super().model_dump_json(exclude_none=False)


class AssertLastAssignedFieldId(ValidatableTableRequirement):
"""The table's last assigned column id must match the requirement's `last-assigned-field-id`."""
Expand Down
14 changes: 13 additions & 1 deletion tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
import json
import os
import uuid
from typing import Any, Dict
from typing import Any, Dict, Tuple

import pytest
from pytest_mock import MockFixture

from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import StaticTable
from pyiceberg.table.metadata import TableMetadataV1
from pyiceberg.table.update import AssertRefSnapshotId, TableRequirement
from pyiceberg.typedef import IcebergBaseModel


def test_legacy_current_snapshot_id(
Expand All @@ -48,3 +50,13 @@ def test_legacy_current_snapshot_id(
backwards_compatible_static_table = StaticTable.from_metadata(metadata_location)
assert backwards_compatible_static_table.metadata.current_snapshot_id is None
assert backwards_compatible_static_table.metadata == static_table.metadata


def test_null_serializer_field() -> None:
class ExampleRequest(IcebergBaseModel):
requirements: Tuple[TableRequirement, ...]

request = ExampleRequest(requirements=(AssertRefSnapshotId(ref="main", snapshot_id=None),))
dumped_json = request.model_dump_json()
expected_json = """{"type":"assert-ref-snapshot-id","ref":"main","snapshot-id":null}"""
assert expected_json in dumped_json