diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 42d7a9c2b7..d19f54e2d8 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -924,9 +924,15 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): Pending changes are applied on commit. """ - _snapshot_ids_to_expire: Set[int] = set() - _updates: Tuple[TableUpdate, ...] = () - _requirements: Tuple[TableRequirement, ...] = () + _updates: Tuple[TableUpdate, ...] + _requirements: Tuple[TableRequirement, ...] + _snapshot_ids_to_expire: Set[int] + + def __init__(self, transaction: Transaction) -> None: + super().__init__(transaction) + self._updates = () + self._requirements = () + self._snapshot_ids_to_expire = set() def _commit(self) -> UpdatesAndRequirements: """ diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py index e2b2d47b67..51f5ba687a 100644 --- a/tests/table/test_expire_snapshots.py +++ b/tests/table/test_expire_snapshots.py @@ -14,13 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import datetime -from unittest.mock import MagicMock +import threading +from datetime import datetime, timedelta +from typing import Dict +from unittest.mock import MagicMock, Mock from uuid import uuid4 import pytest from pyiceberg.table import CommitTableResponse, Table +from pyiceberg.table.update.snapshot import ExpireSnapshots def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None: @@ -143,7 +146,7 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None: table_v2.catalog = MagicMock() # Attempt to expire all snapshots before a future timestamp (so both are candidates) - future_datetime = datetime.datetime.now() + datetime.timedelta(days=1) + future_datetime = datetime.now() + timedelta(days=1) # Mock the catalog's commit_table to return the current metadata (simulate no change) mock_response = CommitTableResponse( @@ -223,3 +226,57 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None: assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots assert len(table_v2.metadata.snapshots) == 1 + + +def test_thread_safety_fix() -> None: + """Test that ExpireSnapshots instances have isolated state.""" + # Create two ExpireSnapshots instances + expire1 = ExpireSnapshots(Mock()) + expire2 = ExpireSnapshots(Mock()) + + # Verify they have separate snapshot sets (this was the bug!) + # Before fix: both would have the same id (shared class attribute) + # After fix: they should have different ids (separate instance attributes) + assert id(expire1._snapshot_ids_to_expire) != id(expire2._snapshot_ids_to_expire), ( + "ExpireSnapshots instances are sharing the same snapshot set - thread safety bug still exists" + ) + + # Test that modifications to one don't affect the other + expire1._snapshot_ids_to_expire.add(1001) + expire2._snapshot_ids_to_expire.add(2001) + + # Verify no cross-contamination of snapshot IDs + assert 2001 not in expire1._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances" + assert 1001 not in expire2._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances" + + +def test_concurrent_operations() -> None: + """Test concurrent operations with separate ExpireSnapshots instances.""" + results: Dict[str, set[int]] = {"expire1_snapshots": set(), "expire2_snapshots": set()} + + def worker1() -> None: + expire1 = ExpireSnapshots(Mock()) + expire1._snapshot_ids_to_expire.update([1001, 1002, 1003]) + results["expire1_snapshots"] = expire1._snapshot_ids_to_expire.copy() + + def worker2() -> None: + expire2 = ExpireSnapshots(Mock()) + expire2._snapshot_ids_to_expire.update([2001, 2002, 2003]) + results["expire2_snapshots"] = expire2._snapshot_ids_to_expire.copy() + + # Run both workers concurrently + thread1 = threading.Thread(target=worker1) + thread2 = threading.Thread(target=worker2) + + thread1.start() + thread2.start() + + thread1.join() + thread2.join() + + # Check for cross-contamination + expected_1 = {1001, 1002, 1003} + expected_2 = {2001, 2002, 2003} + + assert results["expire1_snapshots"] == expected_1, "Worker 1 snapshots contaminated" + assert results["expire2_snapshots"] == expected_2, "Worker 2 snapshots contaminated"