Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,12 @@ def __exit__(self, exctype: type[BaseException] | None, excinst: BaseException |
if exctype is None and excinst is None and exctb is None:
self.commit_transaction()

def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequirement, ...] = ()) -> Transaction:
def _apply(
self,
updates: tuple[TableUpdate, ...],
requirements: tuple[TableRequirement, ...] = (),
commit_transaction_if_autocommit: bool = True,
) -> Transaction:
"""Check if the requirements are met, and applies the updates to the metadata."""
for requirement in requirements:
requirement.validate(self.table_metadata)
Expand All @@ -289,7 +294,7 @@ def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequ
if type(new_requirement) not in existing_requirements:
self._requirements = self._requirements + (new_requirement,)

if self._autocommit:
if self._autocommit and commit_transaction_if_autocommit:
self.commit_transaction()

return self
Expand Down
45 changes: 45 additions & 0 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,13 @@ def _commit(self) -> UpdatesAndRequirements:
"""Apply the pending changes and commit."""
return self._updates, self._requirements

def _commit_if_ref_updates_exist(self) -> None:
"""Commit any pending ref updates to the transaction."""
if self._updates:
self._transaction._apply(*self._commit(), commit_transaction_if_autocommit=False)
self._updates = ()
self._requirements = ()

def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
"""Remove a snapshot ref.

Expand Down Expand Up @@ -941,6 +948,44 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
"""
return self._remove_ref_snapshot(ref_name=branch_name)

def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | None = None) -> ManageSnapshots:
"""Set the current snapshot to a specific snapshot ID or ref.

Args:
snapshot_id: The ID of the snapshot to set as current.
ref_name: The snapshot reference (branch or tag) to set as current.

Returns:
This for method chaining.

Raises:
ValueError: If neither or both arguments are provided, or if the snapshot/ref does not exist.
"""
self._commit_if_ref_updates_exist()

if (snapshot_id is None) == (ref_name is None):
raise ValueError("Either snapshot_id or ref_name must be provided, not both")

target_snapshot_id: int
if snapshot_id is not None:
target_snapshot_id = snapshot_id
else:
if ref_name not in self._transaction.table_metadata.refs:
raise ValueError(f"Cannot find matching snapshot ID for ref: {ref_name}")
target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id

if self._transaction.table_metadata.snapshot_by_id(target_snapshot_id) is None:
raise ValueError(f"Cannot set current snapshot to unknown snapshot id: {target_snapshot_id}")

update, requirement = self._transaction._set_ref_snapshot(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can rip this helper out of transaction and leave in managesnapshots api.

snapshot_id=target_snapshot_id,
ref_name=MAIN_BRANCH,
type="branch",
)
self._updates += update
self._requirements += requirement
return self


class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
"""Expire snapshots by ID.
Expand Down
88 changes: 88 additions & 0 deletions tests/integration/test_snapshot_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,91 @@ def test_remove_branch(catalog: Catalog) -> None:
# now, remove the branch
tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
assert tbl.metadata.refs.get(branch_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_current_snapshot(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

# first get the current snapshot and an older one
current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id

# set the current snapshot to the older one
tbl.manage_snapshots().set_current_snapshot(snapshot_id=older_snapshot_id).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
restored_snapshot = tbl.current_snapshot()
assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_current_snapshot_by_ref(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

# first get the current snapshot and an older one
current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id
assert older_snapshot_id != current_snapshot_id

# create a tag pointing to the older snapshot
tag_name = "my-tag"
tbl.manage_snapshots().create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name).commit()

# set current snapshot using the tag name
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().set_current_snapshot(ref_name=tag_name).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
assert tbl.metadata.refs.get(tag_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id
assert older_snapshot_id != current_snapshot_id

# create a tag and use it to set current snapshot
tag_name = "my-tag"
(
tbl.manage_snapshots()
.create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name)
.set_current_snapshot(ref_name=tag_name)
.commit()
)

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot
assert updated_snapshot.snapshot_id == older_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
assert tbl.metadata.refs.get(tag_name, None) is None
179 changes: 179 additions & 0 deletions tests/table/test_manage_snapshots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import MagicMock
from uuid import uuid4

import pytest

from pyiceberg.table import CommitTableResponse, Table
from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate


def _mock_commit_response(table: Table) -> CommitTableResponse:
return CommitTableResponse(
metadata=table.metadata,
metadata_location="s3://bucket/tbl",
uuid=uuid4(),
)


def _get_updates(mock_catalog: MagicMock) -> tuple[TableUpdate, ...]:
args, _ = mock_catalog.commit_table.call_args
return args[2]


def test_set_current_snapshot_basic(table_v2: Table) -> None:
snapshot_one = 3051729675574597004

table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)

table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).commit()

table_v2.catalog.commit_table.assert_called_once()

updates = _get_updates(table_v2.catalog)
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]

assert len(set_ref_updates) == 1
update = set_ref_updates[0]
assert update.snapshot_id == snapshot_one
assert update.ref_name == "main"
assert update.type == "branch"


def test_set_current_snapshot_unknown_id(table_v2: Table) -> None:
invalid_snapshot_id = 1234567890000
table_v2.catalog = MagicMock()

with pytest.raises(ValueError, match="Cannot set current snapshot to unknown snapshot id"):
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=invalid_snapshot_id).commit()

table_v2.catalog.commit_table.assert_not_called()


def test_set_current_snapshot_to_current(table_v2: Table) -> None:
current_snapshot = table_v2.current_snapshot()
assert current_snapshot is not None

table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)

table_v2.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot.snapshot_id).commit()

table_v2.catalog.commit_table.assert_called_once()


def test_set_current_snapshot_chained_with_tag(table_v2: Table) -> None:
snapshot_one = 3051729675574597004
table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)

(table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).create_tag(snapshot_one, "my-tag").commit())

table_v2.catalog.commit_table.assert_called_once()

updates = _get_updates(table_v2.catalog)
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]

assert len(set_ref_updates) == 2
assert {u.ref_name for u in set_ref_updates} == {"main", "my-tag"}


def test_set_current_snapshot_with_extensive_snapshots(table_v2_with_extensive_snapshots: Table) -> None:
snapshots = table_v2_with_extensive_snapshots.metadata.snapshots
assert len(snapshots) > 100

target_snapshot = snapshots[50].snapshot_id

table_v2_with_extensive_snapshots.catalog = MagicMock()
table_v2_with_extensive_snapshots.catalog.commit_table.return_value = _mock_commit_response(table_v2_with_extensive_snapshots)

table_v2_with_extensive_snapshots.manage_snapshots().set_current_snapshot(snapshot_id=target_snapshot).commit()

table_v2_with_extensive_snapshots.catalog.commit_table.assert_called_once()

updates = _get_updates(table_v2_with_extensive_snapshots.catalog)
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]

assert len(set_ref_updates) == 1
assert set_ref_updates[0].snapshot_id == target_snapshot


def test_set_current_snapshot_by_ref_name(table_v2: Table) -> None:
current_snapshot = table_v2.current_snapshot()
assert current_snapshot is not None

table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)

table_v2.manage_snapshots().set_current_snapshot(ref_name="main").commit()

updates = _get_updates(table_v2.catalog)
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]

assert len(set_ref_updates) == 1
assert set_ref_updates[0].snapshot_id == current_snapshot.snapshot_id
assert set_ref_updates[0].ref_name == "main"


def test_set_current_snapshot_unknown_ref(table_v2: Table) -> None:
table_v2.catalog = MagicMock()

with pytest.raises(ValueError, match="Cannot find matching snapshot ID for ref: nonexistent"):
table_v2.manage_snapshots().set_current_snapshot(ref_name="nonexistent").commit()

table_v2.catalog.commit_table.assert_not_called()


def test_set_current_snapshot_requires_one_argument(table_v2: Table) -> None:
table_v2.catalog = MagicMock()

with pytest.raises(ValueError, match="Either snapshot_id or ref_name must be provided, not both"):
table_v2.manage_snapshots().set_current_snapshot().commit()

with pytest.raises(ValueError, match="Either snapshot_id or ref_name must be provided, not both"):
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=123, ref_name="main").commit()

table_v2.catalog.commit_table.assert_not_called()


def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None:
snapshot_one = 3051729675574597004
table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)

# create a tag and immediately use it to set current snapshot
(
table_v2.manage_snapshots()
.create_tag(snapshot_id=snapshot_one, tag_name="new-tag")
.set_current_snapshot(ref_name="new-tag")
.commit()
)

table_v2.catalog.commit_table.assert_called_once()

updates = _get_updates(table_v2.catalog)
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]

# should have the tag and the main branch update
assert len(set_ref_updates) == 2
assert {u.ref_name for u in set_ref_updates} == {"new-tag", "main"}

# The main branch should point to the same snapshot as the tag
main_update = next(u for u in set_ref_updates if u.ref_name == "main")
assert main_update.snapshot_id == snapshot_one