From ba18a0200c441faaad72e867f83f66c759267608 Mon Sep 17 00:00:00 2001 From: geruh Date: Mon, 29 Dec 2025 01:22:35 -0800 Subject: [PATCH] feat: Add support set current snapshot Co-authored-by: Chinmay Bhat <12948588+chinmay-bhat@users.noreply.github.com> --- pyiceberg/table/__init__.py | 9 +- pyiceberg/table/update/snapshot.py | 45 +++++ tests/integration/test_snapshot_operations.py | 88 +++++++++ tests/table/test_manage_snapshots.py | 179 ++++++++++++++++++ 4 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 tests/table/test_manage_snapshots.py diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2e26a4ccc2..8c249a362f 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -275,7 +275,12 @@ def __exit__(self, exctype: type[BaseException] | None, excinst: BaseException | if exctype is None and excinst is None and exctb is None: self.commit_transaction() - def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequirement, ...] = ()) -> Transaction: + def _apply( + self, + updates: tuple[TableUpdate, ...], + requirements: tuple[TableRequirement, ...] = (), + commit_transaction_if_autocommit: bool = True, + ) -> Transaction: """Check if the requirements are met, and applies the updates to the metadata.""" for requirement in requirements: requirement.validate(self.table_metadata) @@ -289,7 +294,7 @@ def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequ if type(new_requirement) not in existing_requirements: self._requirements = self._requirements + (new_requirement,) - if self._autocommit: + if self._autocommit and commit_transaction_if_autocommit: self.commit_transaction() return self diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index e89cd45d34..6e468285d2 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -843,6 +843,13 @@ def _commit(self) -> UpdatesAndRequirements: """Apply the pending changes and commit.""" return self._updates, self._requirements + def _commit_if_ref_updates_exist(self) -> None: + """Commit any pending ref updates to the transaction.""" + if self._updates: + self._transaction._apply(*self._commit(), commit_transaction_if_autocommit=False) + self._updates = () + self._requirements = () + def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots: """Remove a snapshot ref. @@ -941,6 +948,44 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots: """ return self._remove_ref_snapshot(ref_name=branch_name) + def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | None = None) -> ManageSnapshots: + """Set the current snapshot to a specific snapshot ID or ref. + + Args: + snapshot_id: The ID of the snapshot to set as current. + ref_name: The snapshot reference (branch or tag) to set as current. + + Returns: + This for method chaining. + + Raises: + ValueError: If neither or both arguments are provided, or if the snapshot/ref does not exist. + """ + self._commit_if_ref_updates_exist() + + if (snapshot_id is None) == (ref_name is None): + raise ValueError("Either snapshot_id or ref_name must be provided, not both") + + target_snapshot_id: int + if snapshot_id is not None: + target_snapshot_id = snapshot_id + else: + if ref_name not in self._transaction.table_metadata.refs: + raise ValueError(f"Cannot find matching snapshot ID for ref: {ref_name}") + target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id + + if self._transaction.table_metadata.snapshot_by_id(target_snapshot_id) is None: + raise ValueError(f"Cannot set current snapshot to unknown snapshot id: {target_snapshot_id}") + + update, requirement = self._transaction._set_ref_snapshot( + snapshot_id=target_snapshot_id, + ref_name=MAIN_BRANCH, + type="branch", + ) + self._updates += update + self._requirements += requirement + return self + class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): """Expire snapshots by ID. diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 1b7f2d3a29..2f0447ec52 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -72,3 +72,91 @@ def test_remove_branch(catalog: Catalog) -> None: # now, remove the branch tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit() assert tbl.metadata.refs.get(branch_name, None) is None + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + # first get the current snapshot and an older one + current_snapshot_id = tbl.history()[-1].snapshot_id + older_snapshot_id = tbl.history()[-2].snapshot_id + + # set the current snapshot to the older one + tbl.manage_snapshots().set_current_snapshot(snapshot_id=older_snapshot_id).commit() + + tbl = catalog.load_table(identifier) + updated_snapshot = tbl.current_snapshot() + assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id + + # restore table + tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit() + tbl = catalog.load_table(identifier) + restored_snapshot = tbl.current_snapshot() + assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot_by_ref(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + # first get the current snapshot and an older one + current_snapshot_id = tbl.history()[-1].snapshot_id + older_snapshot_id = tbl.history()[-2].snapshot_id + assert older_snapshot_id != current_snapshot_id + + # create a tag pointing to the older snapshot + tag_name = "my-tag" + tbl.manage_snapshots().create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name).commit() + + # set current snapshot using the tag name + tbl = catalog.load_table(identifier) + tbl.manage_snapshots().set_current_snapshot(ref_name=tag_name).commit() + + tbl = catalog.load_table(identifier) + updated_snapshot = tbl.current_snapshot() + assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id + + # restore table + tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit() + tbl = catalog.load_table(identifier) + tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() + assert tbl.metadata.refs.get(tag_name, None) is None + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + current_snapshot_id = tbl.history()[-1].snapshot_id + older_snapshot_id = tbl.history()[-2].snapshot_id + assert older_snapshot_id != current_snapshot_id + + # create a tag and use it to set current snapshot + tag_name = "my-tag" + ( + tbl.manage_snapshots() + .create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name) + .set_current_snapshot(ref_name=tag_name) + .commit() + ) + + tbl = catalog.load_table(identifier) + updated_snapshot = tbl.current_snapshot() + assert updated_snapshot + assert updated_snapshot.snapshot_id == older_snapshot_id + + # restore table + tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit() + tbl = catalog.load_table(identifier) + tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() + assert tbl.metadata.refs.get(tag_name, None) is None diff --git a/tests/table/test_manage_snapshots.py b/tests/table/test_manage_snapshots.py new file mode 100644 index 0000000000..93301a01c7 --- /dev/null +++ b/tests/table/test_manage_snapshots.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from pyiceberg.table import CommitTableResponse, Table +from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate + + +def _mock_commit_response(table: Table) -> CommitTableResponse: + return CommitTableResponse( + metadata=table.metadata, + metadata_location="s3://bucket/tbl", + uuid=uuid4(), + ) + + +def _get_updates(mock_catalog: MagicMock) -> tuple[TableUpdate, ...]: + args, _ = mock_catalog.commit_table.call_args + return args[2] + + +def test_set_current_snapshot_basic(table_v2: Table) -> None: + snapshot_one = 3051729675574597004 + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).commit() + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 1 + update = set_ref_updates[0] + assert update.snapshot_id == snapshot_one + assert update.ref_name == "main" + assert update.type == "branch" + + +def test_set_current_snapshot_unknown_id(table_v2: Table) -> None: + invalid_snapshot_id = 1234567890000 + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Cannot set current snapshot to unknown snapshot id"): + table_v2.manage_snapshots().set_current_snapshot(snapshot_id=invalid_snapshot_id).commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_set_current_snapshot_to_current(table_v2: Table) -> None: + current_snapshot = table_v2.current_snapshot() + assert current_snapshot is not None + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot.snapshot_id).commit() + + table_v2.catalog.commit_table.assert_called_once() + + +def test_set_current_snapshot_chained_with_tag(table_v2: Table) -> None: + snapshot_one = 3051729675574597004 + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + (table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).create_tag(snapshot_one, "my-tag").commit()) + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 2 + assert {u.ref_name for u in set_ref_updates} == {"main", "my-tag"} + + +def test_set_current_snapshot_with_extensive_snapshots(table_v2_with_extensive_snapshots: Table) -> None: + snapshots = table_v2_with_extensive_snapshots.metadata.snapshots + assert len(snapshots) > 100 + + target_snapshot = snapshots[50].snapshot_id + + table_v2_with_extensive_snapshots.catalog = MagicMock() + table_v2_with_extensive_snapshots.catalog.commit_table.return_value = _mock_commit_response(table_v2_with_extensive_snapshots) + + table_v2_with_extensive_snapshots.manage_snapshots().set_current_snapshot(snapshot_id=target_snapshot).commit() + + table_v2_with_extensive_snapshots.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2_with_extensive_snapshots.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 1 + assert set_ref_updates[0].snapshot_id == target_snapshot + + +def test_set_current_snapshot_by_ref_name(table_v2: Table) -> None: + current_snapshot = table_v2.current_snapshot() + assert current_snapshot is not None + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().set_current_snapshot(ref_name="main").commit() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 1 + assert set_ref_updates[0].snapshot_id == current_snapshot.snapshot_id + assert set_ref_updates[0].ref_name == "main" + + +def test_set_current_snapshot_unknown_ref(table_v2: Table) -> None: + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Cannot find matching snapshot ID for ref: nonexistent"): + table_v2.manage_snapshots().set_current_snapshot(ref_name="nonexistent").commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_set_current_snapshot_requires_one_argument(table_v2: Table) -> None: + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Either snapshot_id or ref_name must be provided, not both"): + table_v2.manage_snapshots().set_current_snapshot().commit() + + with pytest.raises(ValueError, match="Either snapshot_id or ref_name must be provided, not both"): + table_v2.manage_snapshots().set_current_snapshot(snapshot_id=123, ref_name="main").commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None: + snapshot_one = 3051729675574597004 + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + # create a tag and immediately use it to set current snapshot + ( + table_v2.manage_snapshots() + .create_tag(snapshot_id=snapshot_one, tag_name="new-tag") + .set_current_snapshot(ref_name="new-tag") + .commit() + ) + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + # should have the tag and the main branch update + assert len(set_ref_updates) == 2 + assert {u.ref_name for u in set_ref_updates} == {"new-tag", "main"} + + # The main branch should point to the same snapshot as the tag + main_update = next(u for u in set_ref_updates if u.ref_name == "main") + assert main_update.snapshot_id == snapshot_one