Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyiceberg/table/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Literal, Optional
from typing import Dict, List, Literal, Optional, Union

from pydantic import Field

Expand Down Expand Up @@ -48,7 +48,7 @@ class PartitionStatisticsFile(StatisticsCommonFields):


def filter_statistics_by_snapshot_id(
statistics: List[StatisticsFile],
statistics: List[Union[StatisticsFile, PartitionStatisticsFile]],
reject_snapshot_id: int,
) -> List[StatisticsFile]:
) -> List[Union[StatisticsFile, PartitionStatisticsFile]]:
return [stat for stat in statistics if stat.snapshot_id != reject_snapshot_id]
41 changes: 40 additions & 1 deletion pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
SnapshotLogEntry,
)
from pyiceberg.table.sorting import SortOrder
from pyiceberg.table.statistics import StatisticsFile, filter_statistics_by_snapshot_id
from pyiceberg.table.statistics import (
PartitionStatisticsFile,
StatisticsFile,
filter_statistics_by_snapshot_id,
)
from pyiceberg.typedef import (
IcebergBaseModel,
Properties,
Expand Down Expand Up @@ -198,6 +202,16 @@ class RemoveStatisticsUpdate(IcebergBaseModel):
snapshot_id: int = Field(alias="snapshot-id")


class SetPartitionStatisticsUpdate(IcebergBaseModel):
action: Literal["set-partition-statistics"] = Field(default="set-partition-statistics")
partition_statistics: PartitionStatisticsFile


class RemovePartitionStatisticsUpdate(IcebergBaseModel):
action: Literal["remove-partition-statistics"] = Field(default="remove-partition-statistics")
snapshot_id: int = Field(alias="snapshot-id")


TableUpdate = Annotated[
Union[
AssignUUIDUpdate,
Expand All @@ -217,6 +231,8 @@ class RemoveStatisticsUpdate(IcebergBaseModel):
RemovePropertiesUpdate,
SetStatisticsUpdate,
RemoveStatisticsUpdate,
SetPartitionStatisticsUpdate,
RemovePartitionStatisticsUpdate,
],
Field(discriminator="action"),
]
Expand Down Expand Up @@ -582,6 +598,29 @@ def _(update: RemoveStatisticsUpdate, base_metadata: TableMetadata, context: _Ta
return base_metadata.model_copy(update={"statistics": statistics})


@_apply_table_update.register(SetPartitionStatisticsUpdate)
def _(update: SetPartitionStatisticsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
partition_statistics = filter_statistics_by_snapshot_id(
base_metadata.partition_statistics, update.partition_statistics.snapshot_id
)
context.add_update(update)

return base_metadata.model_copy(update={"partition_statistics": partition_statistics + [update.partition_statistics]})


@_apply_table_update.register(RemovePartitionStatisticsUpdate)
def _(
update: RemovePartitionStatisticsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext
) -> TableMetadata:
if not any(part_stat.snapshot_id == update.snapshot_id for part_stat in base_metadata.partition_statistics):
raise ValueError(f"Partition Statistics with snapshot id {update.snapshot_id} does not exist")

statistics = filter_statistics_by_snapshot_id(base_metadata.partition_statistics, update.snapshot_id)
context.add_update(update)

return base_metadata.model_copy(update={"partition_statistics": statistics})


def update_table_metadata(
base_metadata: TableMetadata,
updates: Tuple[TableUpdate, ...],
Expand Down
80 changes: 79 additions & 1 deletion tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
SortField,
SortOrder,
)
from pyiceberg.table.statistics import BlobMetadata, StatisticsFile
from pyiceberg.table.statistics import BlobMetadata, PartitionStatisticsFile, StatisticsFile
from pyiceberg.table.update import (
AddSnapshotUpdate,
AddSortOrderUpdate,
Expand All @@ -76,11 +76,13 @@
AssertLastAssignedPartitionId,
AssertRefSnapshotId,
AssertTableUUID,
RemovePartitionStatisticsUpdate,
RemovePropertiesUpdate,
RemoveSnapshotRefUpdate,
RemoveSnapshotsUpdate,
RemoveStatisticsUpdate,
SetDefaultSortOrderUpdate,
SetPartitionStatisticsUpdate,
SetPropertiesUpdate,
SetSnapshotRefUpdate,
SetStatisticsUpdate,
Expand Down Expand Up @@ -1359,3 +1361,79 @@ def test_remove_statistics_update(table_v2_with_statistics: Table) -> None:
table_v2_with_statistics.metadata,
(RemoveStatisticsUpdate(snapshot_id=123456789),),
)


def test_set_partition_statistics_update(table_v2_with_statistics: Table) -> None:
snapshot_id = table_v2_with_statistics.metadata.current_snapshot_id

partition_statistics_file = PartitionStatisticsFile(
snapshot_id=snapshot_id,
statistics_path="s3://bucket/warehouse/stats.puffin",
file_size_in_bytes=124,
)

update = SetPartitionStatisticsUpdate(
partition_statistics=partition_statistics_file,
)

new_metadata = update_table_metadata(
table_v2_with_statistics.metadata,
(update,),
)

expected = """
{
"snapshot-id": 3055729675574597004,
"statistics-path": "s3://bucket/warehouse/stats.puffin",
"file-size-in-bytes": 124
}"""

assert len(new_metadata.partition_statistics) == 1

updated_statistics = [stat for stat in new_metadata.partition_statistics if stat.snapshot_id == snapshot_id]

assert len(updated_statistics) == 1
assert json.loads(updated_statistics[0].model_dump_json()) == json.loads(expected)


def test_remove_partition_statistics_update(table_v2_with_statistics: Table) -> None:
# Add partition statistics file.
snapshot_id = table_v2_with_statistics.metadata.current_snapshot_id

partition_statistics_file = PartitionStatisticsFile(
snapshot_id=snapshot_id,
statistics_path="s3://bucket/warehouse/stats.puffin",
file_size_in_bytes=124,
)

update = SetPartitionStatisticsUpdate(
partition_statistics=partition_statistics_file,
)

new_metadata = update_table_metadata(
table_v2_with_statistics.metadata,
(update,),
)
assert len(new_metadata.partition_statistics) == 1

# Remove the same partition statistics file.
remove_update = RemovePartitionStatisticsUpdate(snapshot_id=snapshot_id)

remove_metadata = update_table_metadata(
new_metadata,
(remove_update,),
)

assert len(remove_metadata.partition_statistics) == 0


def test_remove_partition_statistics_update_with_invalid_snapshot_id(table_v2_with_statistics: Table) -> None:
# Remove the same partition statistics file.
with pytest.raises(
ValueError,
match="Partition Statistics with snapshot id 123456789 does not exist",
):
update_table_metadata(
table_v2_with_statistics.metadata,
(RemovePartitionStatisticsUpdate(snapshot_id=123456789),),
)