diff --git a/pyiceberg/table/statistics.py b/pyiceberg/table/statistics.py index a2e1b149a1..484391efb1 100644 --- a/pyiceberg/table/statistics.py +++ b/pyiceberg/table/statistics.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List, Literal, Optional +from typing import Dict, List, Literal, Optional, Union from pydantic import Field @@ -48,7 +48,7 @@ class PartitionStatisticsFile(StatisticsCommonFields): def filter_statistics_by_snapshot_id( - statistics: List[StatisticsFile], + statistics: List[Union[StatisticsFile, PartitionStatisticsFile]], reject_snapshot_id: int, -) -> List[StatisticsFile]: +) -> List[Union[StatisticsFile, PartitionStatisticsFile]]: return [stat for stat in statistics if stat.snapshot_id != reject_snapshot_id] diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 6653f119f0..3f7d43f0ef 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -36,7 +36,11 @@ SnapshotLogEntry, ) from pyiceberg.table.sorting import SortOrder -from pyiceberg.table.statistics import StatisticsFile, filter_statistics_by_snapshot_id +from pyiceberg.table.statistics import ( + PartitionStatisticsFile, + StatisticsFile, + filter_statistics_by_snapshot_id, +) from pyiceberg.typedef import ( IcebergBaseModel, Properties, @@ -198,6 +202,16 @@ class RemoveStatisticsUpdate(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") +class SetPartitionStatisticsUpdate(IcebergBaseModel): + action: Literal["set-partition-statistics"] = Field(default="set-partition-statistics") + partition_statistics: PartitionStatisticsFile + + +class RemovePartitionStatisticsUpdate(IcebergBaseModel): + action: Literal["remove-partition-statistics"] = Field(default="remove-partition-statistics") + snapshot_id: int = Field(alias="snapshot-id") + + TableUpdate = Annotated[ Union[ AssignUUIDUpdate, @@ -217,6 +231,8 @@ class RemoveStatisticsUpdate(IcebergBaseModel): RemovePropertiesUpdate, SetStatisticsUpdate, RemoveStatisticsUpdate, + SetPartitionStatisticsUpdate, + RemovePartitionStatisticsUpdate, ], Field(discriminator="action"), ] @@ -582,6 +598,29 @@ def _(update: RemoveStatisticsUpdate, base_metadata: TableMetadata, context: _Ta return base_metadata.model_copy(update={"statistics": statistics}) +@_apply_table_update.register(SetPartitionStatisticsUpdate) +def _(update: SetPartitionStatisticsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + partition_statistics = filter_statistics_by_snapshot_id( + base_metadata.partition_statistics, update.partition_statistics.snapshot_id + ) + context.add_update(update) + + return base_metadata.model_copy(update={"partition_statistics": partition_statistics + [update.partition_statistics]}) + + +@_apply_table_update.register(RemovePartitionStatisticsUpdate) +def _( + update: RemovePartitionStatisticsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext +) -> TableMetadata: + if not any(part_stat.snapshot_id == update.snapshot_id for part_stat in base_metadata.partition_statistics): + raise ValueError(f"Partition Statistics with snapshot id {update.snapshot_id} does not exist") + + statistics = filter_statistics_by_snapshot_id(base_metadata.partition_statistics, update.snapshot_id) + context.add_update(update) + + return base_metadata.model_copy(update={"partition_statistics": statistics}) + + def update_table_metadata( base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...], diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 89524a861c..748a77eee0 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -64,7 +64,7 @@ SortField, SortOrder, ) -from pyiceberg.table.statistics import BlobMetadata, StatisticsFile +from pyiceberg.table.statistics import BlobMetadata, PartitionStatisticsFile, StatisticsFile from pyiceberg.table.update import ( AddSnapshotUpdate, AddSortOrderUpdate, @@ -76,11 +76,13 @@ AssertLastAssignedPartitionId, AssertRefSnapshotId, AssertTableUUID, + RemovePartitionStatisticsUpdate, RemovePropertiesUpdate, RemoveSnapshotRefUpdate, RemoveSnapshotsUpdate, RemoveStatisticsUpdate, SetDefaultSortOrderUpdate, + SetPartitionStatisticsUpdate, SetPropertiesUpdate, SetSnapshotRefUpdate, SetStatisticsUpdate, @@ -1359,3 +1361,79 @@ def test_remove_statistics_update(table_v2_with_statistics: Table) -> None: table_v2_with_statistics.metadata, (RemoveStatisticsUpdate(snapshot_id=123456789),), ) + + +def test_set_partition_statistics_update(table_v2_with_statistics: Table) -> None: + snapshot_id = table_v2_with_statistics.metadata.current_snapshot_id + + partition_statistics_file = PartitionStatisticsFile( + snapshot_id=snapshot_id, + statistics_path="s3://bucket/warehouse/stats.puffin", + file_size_in_bytes=124, + ) + + update = SetPartitionStatisticsUpdate( + partition_statistics=partition_statistics_file, + ) + + new_metadata = update_table_metadata( + table_v2_with_statistics.metadata, + (update,), + ) + + expected = """ + { + "snapshot-id": 3055729675574597004, + "statistics-path": "s3://bucket/warehouse/stats.puffin", + "file-size-in-bytes": 124 + }""" + + assert len(new_metadata.partition_statistics) == 1 + + updated_statistics = [stat for stat in new_metadata.partition_statistics if stat.snapshot_id == snapshot_id] + + assert len(updated_statistics) == 1 + assert json.loads(updated_statistics[0].model_dump_json()) == json.loads(expected) + + +def test_remove_partition_statistics_update(table_v2_with_statistics: Table) -> None: + # Add partition statistics file. + snapshot_id = table_v2_with_statistics.metadata.current_snapshot_id + + partition_statistics_file = PartitionStatisticsFile( + snapshot_id=snapshot_id, + statistics_path="s3://bucket/warehouse/stats.puffin", + file_size_in_bytes=124, + ) + + update = SetPartitionStatisticsUpdate( + partition_statistics=partition_statistics_file, + ) + + new_metadata = update_table_metadata( + table_v2_with_statistics.metadata, + (update,), + ) + assert len(new_metadata.partition_statistics) == 1 + + # Remove the same partition statistics file. + remove_update = RemovePartitionStatisticsUpdate(snapshot_id=snapshot_id) + + remove_metadata = update_table_metadata( + new_metadata, + (remove_update,), + ) + + assert len(remove_metadata.partition_statistics) == 0 + + +def test_remove_partition_statistics_update_with_invalid_snapshot_id(table_v2_with_statistics: Table) -> None: + # Remove the same partition statistics file. + with pytest.raises( + ValueError, + match="Partition Statistics with snapshot id 123456789 does not exist", + ): + update_table_metadata( + table_v2_with_statistics.metadata, + (RemovePartitionStatisticsUpdate(snapshot_id=123456789),), + )