Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,9 +924,15 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
Pending changes are applied on commit.
"""

_snapshot_ids_to_expire: Set[int] = set()
_updates: Tuple[TableUpdate, ...] = ()
_requirements: Tuple[TableRequirement, ...] = ()
_updates: Tuple[TableUpdate, ...]
_requirements: Tuple[TableRequirement, ...]
_snapshot_ids_to_expire: Set[int]

def __init__(self, transaction: Transaction) -> None:
super().__init__(transaction)
self._updates = ()
self._requirements = ()
self._snapshot_ids_to_expire = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the smoking gun, since the set() is mutable, and the tuple() isn't 👍


def _commit(self) -> UpdatesAndRequirements:
"""
Expand Down
63 changes: 60 additions & 3 deletions tests/table/test_expire_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
from unittest.mock import MagicMock
import threading
from datetime import datetime, timedelta
from typing import Dict
from unittest.mock import MagicMock, Mock
from uuid import uuid4

import pytest

from pyiceberg.table import CommitTableResponse, Table
from pyiceberg.table.update.snapshot import ExpireSnapshots


def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None:
Expand Down Expand Up @@ -143,7 +146,7 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
table_v2.catalog = MagicMock()

# Attempt to expire all snapshots before a future timestamp (so both are candidates)
future_datetime = datetime.datetime.now() + datetime.timedelta(days=1)
future_datetime = datetime.now() + timedelta(days=1)

# Mock the catalog's commit_table to return the current metadata (simulate no change)
mock_response = CommitTableResponse(
Expand Down Expand Up @@ -223,3 +226,57 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None:
assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots
assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots
assert len(table_v2.metadata.snapshots) == 1


def test_thread_safety_fix() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails on the old code 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your saying this is the one, good test, right? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed :)

"""Test that ExpireSnapshots instances have isolated state."""
# Create two ExpireSnapshots instances
expire1 = ExpireSnapshots(Mock())
expire2 = ExpireSnapshots(Mock())

# Verify they have separate snapshot sets (this was the bug!)
# Before fix: both would have the same id (shared class attribute)
# After fix: they should have different ids (separate instance attributes)
assert id(expire1._snapshot_ids_to_expire) != id(expire2._snapshot_ids_to_expire), (
"ExpireSnapshots instances are sharing the same snapshot set - thread safety bug still exists"
)

# Test that modifications to one don't affect the other
expire1._snapshot_ids_to_expire.add(1001)
expire2._snapshot_ids_to_expire.add(2001)

# Verify no cross-contamination of snapshot IDs
assert 2001 not in expire1._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances"
assert 1001 not in expire2._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances"


def test_concurrent_operations() -> None:
"""Test concurrent operations with separate ExpireSnapshots instances."""
results: Dict[str, set[int]] = {"expire1_snapshots": set(), "expire2_snapshots": set()}

def worker1() -> None:
expire1 = ExpireSnapshots(Mock())
expire1._snapshot_ids_to_expire.update([1001, 1002, 1003])
results["expire1_snapshots"] = expire1._snapshot_ids_to_expire.copy()

def worker2() -> None:
expire2 = ExpireSnapshots(Mock())
expire2._snapshot_ids_to_expire.update([2001, 2002, 2003])
results["expire2_snapshots"] = expire2._snapshot_ids_to_expire.copy()

# Run both workers concurrently
thread1 = threading.Thread(target=worker1)
thread2 = threading.Thread(target=worker2)

thread1.start()
thread2.start()

thread1.join()
thread2.join()

# Check for cross-contamination
expected_1 = {1001, 1002, 1003}
expected_2 = {2001, 2002, 2003}

assert results["expire1_snapshots"] == expected_1, "Worker 1 snapshots contaminated"
assert results["expire2_snapshots"] == expected_2, "Worker 2 snapshots contaminated"