|
20 | 20 | from abc import ABC, abstractmethod |
21 | 21 | from datetime import datetime |
22 | 22 | from functools import singledispatch |
23 | | -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union |
| 23 | +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union, cast |
24 | 24 |
|
25 | | -from pydantic import Field, field_validator |
| 25 | +from pydantic import Field, field_validator, model_validator |
26 | 26 | from typing_extensions import Annotated |
27 | 27 |
|
28 | 28 | from pyiceberg.exceptions import CommitFailedException |
@@ -177,8 +177,20 @@ class RemovePropertiesUpdate(IcebergBaseModel): |
177 | 177 |
|
178 | 178 | class SetStatisticsUpdate(IcebergBaseModel): |
179 | 179 | action: Literal["set-statistics"] = Field(default="set-statistics") |
180 | | - snapshot_id: int = Field(alias="snapshot-id") |
181 | 180 | statistics: StatisticsFile |
| 181 | + snapshot_id: Optional[int] = Field( |
| 182 | + None, |
| 183 | + alias="snapshot-id", |
| 184 | + description="snapshot-id is **DEPRECATED for REMOVAL** since it contains redundant information. Use `statistics.snapshot-id` field instead.", |
| 185 | + ) |
| 186 | + |
| 187 | + @model_validator(mode="before") |
| 188 | + def validate_snapshot_id(cls, data: Dict[str, Any]) -> Dict[str, Any]: |
| 189 | + stats = cast(StatisticsFile, data["statistics"]) |
| 190 | + |
| 191 | + data["snapshot_id"] = stats.snapshot_id |
| 192 | + |
| 193 | + return data |
182 | 194 |
|
183 | 195 |
|
184 | 196 | class RemoveStatisticsUpdate(IcebergBaseModel): |
@@ -491,10 +503,7 @@ def _( |
491 | 503 |
|
492 | 504 | @_apply_table_update.register(SetStatisticsUpdate) |
493 | 505 | def _(update: SetStatisticsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: |
494 | | - if update.snapshot_id != update.statistics.snapshot_id: |
495 | | - raise ValueError("Snapshot id in statistics does not match the snapshot id in the update") |
496 | | - |
497 | | - statistics = filter_statistics_by_snapshot_id(base_metadata.statistics, update.snapshot_id) |
| 506 | + statistics = filter_statistics_by_snapshot_id(base_metadata.statistics, update.statistics.snapshot_id) |
498 | 507 | context.add_update(update) |
499 | 508 |
|
500 | 509 | return base_metadata.model_copy(update={"statistics": statistics + [update.statistics]}) |
|
0 commit comments