Skip to content

Commit 0c1e4c7

Browse files
committed
refactor: consolidate snapshot expiration into MaintenanceTable
- Move ExpireSnapshots functionality from standalone class to MaintenanceTable - Replace fluent API (table.expire_snapshots().method().commit()) with direct execution (table.maintenance.method()) - Remove ExpireSnapshots class and integrate logic into maintenance operations - Update all tests to use new unified maintenance API - Maintain all existing validation and protection logic for snapshots This change consolidates table maintenance operations under a single interface and simplifies the API by removing the need for explicit commit calls. Breaking change: table.expire_snapshots() API is replaced with table.maintenance.expire_*() methods
1 parent f4d98d2 commit 0c1e4c7

File tree

4 files changed

+99
-119
lines changed

4 files changed

+99
-119
lines changed

pyiceberg/table/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
update_table_metadata,
117117
)
118118
from pyiceberg.table.update.schema import UpdateSchema
119-
from pyiceberg.table.update.snapshot import ExpireSnapshots, ManageSnapshots, UpdateSnapshot, _FastAppendFiles
119+
from pyiceberg.table.update.snapshot import ManageSnapshots, UpdateSnapshot, _FastAppendFiles
120120
from pyiceberg.table.update.spec import UpdateSpec
121121
from pyiceberg.table.update.statistics import UpdateStatistics
122122
from pyiceberg.transforms import IdentityTransform
@@ -1220,10 +1220,6 @@ def manage_snapshots(self) -> ManageSnapshots:
12201220
"""
12211221
return ManageSnapshots(transaction=Transaction(self, autocommit=True))
12221222

1223-
def expire_snapshots(self) -> ExpireSnapshots:
1224-
"""Shorthand to run expire snapshots by id or by a timestamp."""
1225-
return ExpireSnapshots(transaction=Transaction(self, autocommit=True))
1226-
12271223
def update_statistics(self) -> UpdateStatistics:
12281224
"""
12291225
Shorthand to run statistics management operations like add statistics and remove statistics.

pyiceberg/table/maintenance.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
from datetime import datetime, timedelta, timezone
2121
from functools import reduce
22-
from typing import TYPE_CHECKING, Optional, Set
22+
from typing import TYPE_CHECKING, List, Optional, Set
2323

2424
from pyiceberg.utils.concurrent import ExecutorFactory
2525

@@ -117,3 +117,88 @@ def _delete(file: str) -> None:
117117
logger.warning(f"Files:\n{failed_to_delete_files}")
118118
else:
119119
logger.info(f"No orphaned files found at {location}!")
120+
121+
def expire_snapshot_by_id(self, snapshot_id: int) -> None:
122+
"""Expire a single snapshot by its ID.
123+
124+
Args:
125+
snapshot_id: The ID of the snapshot to expire.
126+
127+
Raises:
128+
ValueError: If the snapshot does not exist or is protected.
129+
"""
130+
with self.tbl.transaction() as txn:
131+
# Check if snapshot exists
132+
if txn.table_metadata.snapshot_by_id(snapshot_id) is None:
133+
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
134+
135+
# Check if snapshot is protected
136+
protected_ids = self._get_protected_snapshot_ids(txn.table_metadata)
137+
if snapshot_id in protected_ids:
138+
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")
139+
140+
# Remove the snapshot
141+
from pyiceberg.table.update import RemoveSnapshotsUpdate
142+
txn._apply((RemoveSnapshotsUpdate(snapshot_ids=[snapshot_id]),))
143+
144+
def expire_snapshots_by_ids(self, snapshot_ids: List[int]) -> None:
145+
"""Expire multiple snapshots by their IDs.
146+
147+
Args:
148+
snapshot_ids: List of snapshot IDs to expire.
149+
150+
Raises:
151+
ValueError: If any snapshot does not exist or is protected.
152+
"""
153+
with self.tbl.transaction() as txn:
154+
protected_ids = self._get_protected_snapshot_ids(txn.table_metadata)
155+
156+
# Validate all snapshots before expiring any
157+
for snapshot_id in snapshot_ids:
158+
if txn.table_metadata.snapshot_by_id(snapshot_id) is None:
159+
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
160+
if snapshot_id in protected_ids:
161+
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")
162+
163+
# Remove all snapshots
164+
from pyiceberg.table.update import RemoveSnapshotsUpdate
165+
txn._apply((RemoveSnapshotsUpdate(snapshot_ids=snapshot_ids),))
166+
167+
def expire_snapshots_older_than(self, timestamp_ms: int) -> None:
168+
"""Expire all unprotected snapshots with a timestamp older than a given value.
169+
170+
Args:
171+
timestamp_ms: Only snapshots with timestamp_ms < this value will be expired.
172+
"""
173+
# First check if there are any snapshots to expire to avoid unnecessary transactions
174+
protected_ids = self._get_protected_snapshot_ids(self.tbl.metadata)
175+
snapshots_to_expire = []
176+
177+
for snapshot in self.tbl.metadata.snapshots:
178+
if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids:
179+
snapshots_to_expire.append(snapshot.snapshot_id)
180+
181+
if snapshots_to_expire:
182+
with self.tbl.transaction() as txn:
183+
from pyiceberg.table.update import RemoveSnapshotsUpdate
184+
txn._apply((RemoveSnapshotsUpdate(snapshot_ids=snapshots_to_expire),))
185+
186+
def _get_protected_snapshot_ids(self, table_metadata) -> Set[int]:
187+
"""Get the IDs of protected snapshots.
188+
189+
These are the HEAD snapshots of all branches and all tagged snapshots.
190+
These ids are to be excluded from expiration.
191+
192+
Args:
193+
table_metadata: The table metadata to check for protected snapshots.
194+
195+
Returns:
196+
Set of protected snapshot IDs to exclude from expiration.
197+
"""
198+
from pyiceberg.table.refs import SnapshotRefType
199+
200+
protected_ids: Set[int] = set()
201+
for ref in table_metadata.refs.values():
202+
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]:
203+
protected_ids.add(ref.snapshot_id)
204+
return protected_ids

pyiceberg/table/update/snapshot.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -858,103 +858,3 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
858858
This for method chaining
859859
"""
860860
return self._remove_ref_snapshot(ref_name=branch_name)
861-
862-
863-
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
864-
"""
865-
Expire snapshots by ID.
866-
867-
Use table.expire_snapshots().<operation>().commit() to run a specific operation.
868-
Use table.expire_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
869-
Pending changes are applied on commit.
870-
"""
871-
872-
_snapshot_ids_to_expire: Set[int] = set()
873-
_updates: Tuple[TableUpdate, ...] = ()
874-
_requirements: Tuple[TableRequirement, ...] = ()
875-
876-
def _commit(self) -> UpdatesAndRequirements:
877-
"""
878-
Commit the staged updates and requirements.
879-
880-
This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
881-
882-
Returns:
883-
Tuple of updates and requirements to be committed,
884-
as required by the calling parent apply functions.
885-
"""
886-
# Remove any protected snapshot IDs from the set to expire, just in case
887-
protected_ids = self._get_protected_snapshot_ids()
888-
self._snapshot_ids_to_expire -= protected_ids
889-
update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire)
890-
self._updates += (update,)
891-
return self._updates, self._requirements
892-
893-
def _get_protected_snapshot_ids(self) -> Set[int]:
894-
"""
895-
Get the IDs of protected snapshots.
896-
897-
These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration.
898-
899-
Returns:
900-
Set of protected snapshot IDs to exclude from expiration.
901-
"""
902-
protected_ids: Set[int] = set()
903-
904-
for ref in self._transaction.table_metadata.refs.values():
905-
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]:
906-
protected_ids.add(ref.snapshot_id)
907-
908-
return protected_ids
909-
910-
def expire_snapshot_by_id(self, snapshot_id: int) -> ExpireSnapshots:
911-
"""
912-
Expire a snapshot by its ID.
913-
914-
This will mark the snapshot for expiration.
915-
916-
Args:
917-
snapshot_id (int): The ID of the snapshot to expire.
918-
Returns:
919-
This for method chaining.
920-
"""
921-
if self._transaction.table_metadata.snapshot_by_id(snapshot_id) is None:
922-
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
923-
924-
if snapshot_id in self._get_protected_snapshot_ids():
925-
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")
926-
927-
self._snapshot_ids_to_expire.add(snapshot_id)
928-
929-
return self
930-
931-
def expire_snapshots_by_ids(self, snapshot_ids: List[int]) -> "ExpireSnapshots":
932-
"""
933-
Expire multiple snapshots by their IDs.
934-
935-
This will mark the snapshots for expiration.
936-
937-
Args:
938-
snapshot_ids (List[int]): List of snapshot IDs to expire.
939-
Returns:
940-
This for method chaining.
941-
"""
942-
for snapshot_id in snapshot_ids:
943-
self.expire_snapshot_by_id(snapshot_id)
944-
return self
945-
946-
def expire_snapshots_older_than(self, timestamp_ms: int) -> "ExpireSnapshots":
947-
"""
948-
Expire all unprotected snapshots with a timestamp older than a given value.
949-
950-
Args:
951-
timestamp_ms (int): Only snapshots with timestamp_ms < this value will be expired.
952-
953-
Returns:
954-
This for method chaining.
955-
"""
956-
protected_ids = self._get_protected_snapshot_ids()
957-
for snapshot in self._transaction.table_metadata.snapshots:
958-
if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids:
959-
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
960-
return self

tests/table/test_expire_snapshots.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None:
4343

4444
# Attempt to expire the HEAD snapshot and expect a ValueError
4545
with pytest.raises(ValueError, match=f"Snapshot with ID {HEAD_SNAPSHOT} is protected and cannot be expired."):
46-
table_v2.expire_snapshots().expire_snapshot_by_id(HEAD_SNAPSHOT).commit()
46+
table_v2.maintenance.expire_snapshot_by_id(HEAD_SNAPSHOT)
4747

4848
table_v2.catalog.commit_table.assert_not_called()
4949

@@ -66,7 +66,7 @@ def test_cannot_expire_tagged_snapshot(table_v2: Table) -> None:
6666
assert any(ref.snapshot_id == TAGGED_SNAPSHOT for ref in table_v2.metadata.refs.values())
6767

6868
with pytest.raises(ValueError, match=f"Snapshot with ID {TAGGED_SNAPSHOT} is protected and cannot be expired."):
69-
table_v2.expire_snapshots().expire_snapshot_by_id(TAGGED_SNAPSHOT).commit()
69+
table_v2.maintenance.expire_snapshot_by_id(TAGGED_SNAPSHOT)
7070

7171
table_v2.catalog.commit_table.assert_not_called()
7272

@@ -98,9 +98,11 @@ def test_expire_unprotected_snapshot(table_v2: Table) -> None:
9898
assert all(ref.snapshot_id != EXPIRE_SNAPSHOT for ref in table_v2.metadata.refs.values())
9999

100100
# Expire the snapshot
101-
table_v2.expire_snapshots().expire_snapshot_by_id(EXPIRE_SNAPSHOT).commit()
101+
table_v2.maintenance.expire_snapshot_by_id(EXPIRE_SNAPSHOT)
102102

103103
table_v2.catalog.commit_table.assert_called_once()
104+
# Update metadata to reflect the commit
105+
table_v2.metadata = mock_response.metadata
104106
remaining_snapshots = table_v2.metadata.snapshots
105107
assert EXPIRE_SNAPSHOT not in remaining_snapshots
106108
assert len(table_v2.metadata.snapshots) == 1
@@ -114,7 +116,7 @@ def test_expire_nonexistent_snapshot_raises(table_v2: Table) -> None:
114116
table_v2.metadata = table_v2.metadata.model_copy(update={"refs": {}})
115117

116118
with pytest.raises(ValueError, match=f"Snapshot with ID {NONEXISTENT_SNAPSHOT} does not exist."):
117-
table_v2.expire_snapshots().expire_snapshot_by_id(NONEXISTENT_SNAPSHOT).commit()
119+
table_v2.maintenance.expire_snapshot_by_id(NONEXISTENT_SNAPSHOT)
118120

119121
table_v2.catalog.commit_table.assert_not_called()
120122

@@ -152,7 +154,7 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
152154
)
153155
table_v2.catalog.commit_table.return_value = mock_response
154156

155-
table_v2.expire_snapshots().expire_snapshots_older_than(future_timestamp).commit()
157+
table_v2.maintenance.expire_snapshots_older_than(future_timestamp)
156158
# Update metadata to reflect the commit (as in other tests)
157159
table_v2.metadata = mock_response.metadata
158160

@@ -161,13 +163,8 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
161163
assert HEAD_SNAPSHOT in remaining_ids
162164
assert TAGGED_SNAPSHOT in remaining_ids
163165

164-
# No snapshots should have been expired (commit_table called, but with empty snapshot_ids)
165-
args, kwargs = table_v2.catalog.commit_table.call_args
166-
updates = args[2] if len(args) > 2 else ()
167-
# Find RemoveSnapshotsUpdate in updates
168-
remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None)
169-
assert remove_update is not None
170-
assert remove_update.snapshot_ids == []
166+
# Since all snapshots were protected, commit_table should not be called
167+
table_v2.catalog.commit_table.assert_not_called()
171168

172169

173170
def test_expire_snapshots_by_ids(table_v2: Table) -> None:
@@ -215,9 +212,11 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None:
215212
assert all(ref.snapshot_id not in (EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2) for ref in table_v2.metadata.refs.values())
216213

217214
# Expire the snapshots
218-
table_v2.expire_snapshots().expire_snapshots_by_ids([EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2]).commit()
215+
table_v2.maintenance.expire_snapshots_by_ids([EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2])
219216

220217
table_v2.catalog.commit_table.assert_called_once()
218+
# Update metadata to reflect the commit
219+
table_v2.metadata = mock_response.metadata
221220
remaining_snapshots = table_v2.metadata.snapshots
222221
assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots
223222
assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots

0 commit comments

Comments
 (0)