From f22d46dd2745dcbe4114ea92cda936db544dcbc2 Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Tue, 21 May 2024 11:58:01 +0530 Subject: [PATCH 01/15] support rollback and set current snapshot operations --- pyiceberg/table/__init__.py | 53 +++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8c1493974b..b6636718d8 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1975,6 +1975,10 @@ def _commit(self) -> UpdatesAndRequirements: """Apply the pending changes and commit.""" return self._updates, self._requirements + def _commit_if_ref_updates_exist(self) -> None: + self.commit() + self._updates, self._requirements = (), () + def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots: """ Create a new tag pointing to the given snapshot id. @@ -2029,6 +2033,55 @@ def create_branch( self._requirements += requirement return self + def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: + """Rollback the table to the given snapshot id, whose snapshot needs to be an ancestor of the current table state.""" + self._commit_if_ref_updates_exist() + if self._transaction._table.snapshot_by_id(snapshot_id) is None: + raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}") + if snapshot_id not in { + ancestor.snapshot_id + for ancestor in ancestors_of(self._transaction._table.current_snapshot(), self._transaction.table_metadata) + }: + raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") + + update, requirement = self._transaction._set_ref_snapshot(snapshot_id=snapshot_id, ref_name="main", type="branch") + self._updates += update + self._requirements += requirement + return self + + def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: + """Rollback the table to the snapshot right before the given timestamp.""" + self._commit_if_ref_updates_exist() + if (snapshot := self._transaction._table.snapshot_as_of_timestamp(timestamp, inclusive=False)) is None: + raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}") + + update, requirement = self._transaction._set_ref_snapshot( + snapshot_id=snapshot.snapshot_id, ref_name="main", type="branch" + ) + self._updates += update + self._requirements += requirement + return self + + def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Optional[str] = None) -> ManageSnapshots: + """Set the table to a specific snapshot identified either by its id or the branch/tag its on, not both.""" + self._commit_if_ref_updates_exist() + if (not snapshot_id or ref_name) and (snapshot_id or not ref_name): + raise ValidationError("Either snapshot_id or ref must be provided") + else: + if snapshot_id is None: + target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id # type:ignore + else: + target_snapshot_id = snapshot_id + if (snapshot := self._transaction._table.snapshot_by_id(target_snapshot_id)) is None: + raise ValidationError(f"Cannot set snapshot current with snapshot id: {snapshot_id} or ref_name: {ref_name}") + + update, requirement = self._transaction._set_ref_snapshot( + snapshot_id=snapshot.snapshot_id, ref_name="main", type="branch" + ) + self._updates += update + self._requirements += requirement + return self + class UpdateSchema(UpdateTableMetadata["UpdateSchema"]): _schema: Schema From 45c25dbb278953587be5825e464d07d2ac6f5d6f Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Sun, 16 Jun 2024 09:38:28 +0530 Subject: [PATCH 02/15] add tests --- tests/integration/test_snapshot_operations.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 639193383e..67bafd6de8 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -40,3 +40,74 @@ def test_create_branch(catalog: Catalog) -> None: branch_snapshot_id = tbl.history()[-2].snapshot_id tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit() assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch") + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_manage_snapshots_context_manager(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore + rollback_snapshot_id = tbl.history()[-4].snapshot_id + with tbl.manage_snapshots() as ms: + ms.create_tag(snapshot_id=current_snapshot_id, tag_name="testing") + ms.rollback_to_snapshot(snapshot_id=rollback_snapshot_id) + assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore + assert tbl.metadata.refs["testing"].snapshot_id == current_snapshot_id + assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=rollback_snapshot_id, snapshot_ref_type="branch") + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_rollback_to_snapshot(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + rollback_snapshot_id = tbl.history()[-3].snapshot_id + current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore + tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=rollback_snapshot_id).commit() + assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore + assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=rollback_snapshot_id, snapshot_ref_type="branch") + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_rollback_to_timestamp(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore + timestamp = tbl.history()[-2].timestamp_ms + expected_snapshot_id = tbl.history()[-3].snapshot_id + # not inclusive of rollback_timestamp + tbl.manage_snapshots().rollback_to_timestamp(timestamp=timestamp).commit() + assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore + assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot_with_snapshot_id(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore + expected_snapshot_id = tbl.history()[-3].snapshot_id + tbl.manage_snapshots().set_current_snapshot(snapshot_id=expected_snapshot_id).commit() + assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore + assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot_with_ref_name(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore + expected_snapshot_id = tbl.history()[-3].snapshot_id + tbl.manage_snapshots().create_tag(snapshot_id=expected_snapshot_id, tag_name="test-tag19").commit() + tbl.manage_snapshots().set_current_snapshot(ref_name="test-tag19").commit() + assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore + assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") From ca63831730f7ea4b827209c64ba0a6359f5cc2ad Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Sun, 16 Jun 2024 10:23:59 +0530 Subject: [PATCH 03/15] use tbl.history() instead of ancestors_of() We don't need to find all the ancestors, we only need to validate that the snapshot is an ancestor, i.e if it was ever current. --- pyiceberg/table/__init__.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b6636718d8..cc4fb54064 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -113,7 +113,6 @@ SnapshotLogEntry, SnapshotSummaryCollector, Summary, - ancestors_of, update_snapshot_summaries, ) from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder @@ -2038,10 +2037,7 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: self._commit_if_ref_updates_exist() if self._transaction._table.snapshot_by_id(snapshot_id) is None: raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}") - if snapshot_id not in { - ancestor.snapshot_id - for ancestor in ancestors_of(self._transaction._table.current_snapshot(), self._transaction.table_metadata) - }: + if snapshot_id not in {ancestor.snapshot_id for ancestor in self._transaction._table.history()}: raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") update, requirement = self._transaction._set_ref_snapshot(snapshot_id=snapshot_id, ref_name="main", type="branch") From f7e192a0034e9dfa44a0034edb43b64753bda81e Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Sun, 16 Jun 2024 10:41:21 +0530 Subject: [PATCH 04/15] improve docstrings --- pyiceberg/table/__init__.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index cc4fb54064..04e8d7e20d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2033,7 +2033,15 @@ def create_branch( return self def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: - """Rollback the table to the given snapshot id, whose snapshot needs to be an ancestor of the current table state.""" + """Rollback the table to the given snapshot id. + + The snapshot needs to be an ancestor of the current table state. + + Args: + snapshot_id (int): rollback to this snapshot_id that used to be current. + Returns: + This for method chaining + """ self._commit_if_ref_updates_exist() if self._transaction._table.snapshot_by_id(snapshot_id) is None: raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}") @@ -2046,7 +2054,15 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: return self def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: - """Rollback the table to the snapshot right before the given timestamp.""" + """Rollback the table to the snapshot right before the given timestamp. + + The snapshot needs to be an ancestor of the current table state. + + Args: + timestamp (int): rollback to the snapshot that used to be current right before this timestamp. + Returns: + This for method chaining + """ self._commit_if_ref_updates_exist() if (snapshot := self._transaction._table.snapshot_as_of_timestamp(timestamp, inclusive=False)) is None: raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}") @@ -2059,7 +2075,16 @@ def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: return self def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Optional[str] = None) -> ManageSnapshots: - """Set the table to a specific snapshot identified either by its id or the branch/tag its on, not both.""" + """Set the table to a specific snapshot identified either by its id or the branch/tag its on, not both. + + The snapshot is not required to be an ancestor of the current table state. + + Args: + snapshot_id (Optional[int]): id of the snapshot to be set as current + ref_name (Optional[str]): branch/tag where the snapshot to be set as current exists. + Returns: + This for method chaining + """ self._commit_if_ref_updates_exist() if (not snapshot_id or ref_name) and (snapshot_id or not ref_name): raise ValidationError("Either snapshot_id or ref must be provided") From 6859fa47c48c2c3bc890c861512d0cb20b9b9f5b Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Sun, 16 Jun 2024 10:54:46 +0530 Subject: [PATCH 05/15] Revert "use tbl.history() instead of ancestors_of()" This reverts commit f5d489cbfe7eef7bdd89d83fbd7f94e1f9b420c8. --- pyiceberg/table/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 04e8d7e20d..63bbe0dfe0 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -113,6 +113,7 @@ SnapshotLogEntry, SnapshotSummaryCollector, Summary, + ancestors_of, update_snapshot_summaries, ) from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder @@ -2045,7 +2046,10 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: self._commit_if_ref_updates_exist() if self._transaction._table.snapshot_by_id(snapshot_id) is None: raise ValidationError(f"Cannot roll back to unknown snapshot id: {snapshot_id}") - if snapshot_id not in {ancestor.snapshot_id for ancestor in self._transaction._table.history()}: + if snapshot_id not in { + ancestor.snapshot_id + for ancestor in ancestors_of(self._transaction._table.current_snapshot(), self._transaction.table_metadata) + }: raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") update, requirement = self._transaction._set_ref_snapshot(snapshot_id=snapshot_id, ref_name="main", type="branch") From ea0e645f8c4c74985e4f92642d7ceb036ad490db Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Sun, 16 Jun 2024 12:35:30 +0530 Subject: [PATCH 06/15] find ancestor before timestamp we cannot use snapshot_as_of_timestamp() as it finds previously current snapshots but not necessarily an ancestor. An example is here: https://iceberg.apache.org/docs/nightly/spark-queries/?h=ancestor#history --- pyiceberg/table/__init__.py | 7 ++++++- pyiceberg/table/snapshots.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 63bbe0dfe0..c47f200ce2 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -113,6 +113,7 @@ SnapshotLogEntry, SnapshotSummaryCollector, Summary, + ancestor_right_before_timestamp, ancestors_of, update_snapshot_summaries, ) @@ -2068,7 +2069,11 @@ def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: This for method chaining """ self._commit_if_ref_updates_exist() - if (snapshot := self._transaction._table.snapshot_as_of_timestamp(timestamp, inclusive=False)) is None: + if ( + snapshot := ancestor_right_before_timestamp( + self._transaction._table.current_snapshot(), self._transaction.table_metadata, timestamp + ) + ) is None: raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}") update, requirement = self._transaction._set_ref_snapshot( diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 842d42522a..eaf420f26e 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -421,6 +421,17 @@ def set_when_positive(properties: Dict[str, str], num: int, property_name: str) properties[property_name] = str(num) +def ancestor_right_before_timestamp( + current_snapshot: Optional[Snapshot], table_metadata: TableMetadata, timestamp_ms: int +) -> Optional[Snapshot]: + """Get the ancestor right before the given timestamp.""" + if current_snapshot is not None: + for ancestor in ancestors_of(current_snapshot, table_metadata): + if ancestor.timestamp_ms < timestamp_ms: + return ancestor + return None + + def ancestors_of(current_snapshot: Optional[Snapshot], table_metadata: TableMetadata) -> Iterable[Snapshot]: """Get the ancestors of and including the given snapshot.""" snapshot = current_snapshot From dc4028bde5df93c7e6f9cd76c16dd7e202a2417a Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Sun, 16 Jun 2024 12:38:07 +0530 Subject: [PATCH 07/15] update tests --- dev/provision.py | 48 +++++++++++++++++++ tests/integration/test_snapshot_operations.py | 27 ++++++----- 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/dev/provision.py b/dev/provision.py index 6c8fe366d7..d831aa1560 100644 --- a/dev/provision.py +++ b/dev/provision.py @@ -389,3 +389,51 @@ VALUES (4) """ ) + + spark.sql( + f""" + CREATE OR REPLACE TABLE {catalog_name}.default.test_table_rollback_to_snapshot_id ( + timestamp int, + number integer + ) + USING iceberg + TBLPROPERTIES ( + 'format-version'='2' + ); + """ + ) + + spark.sql( + f""" + INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id + VALUES (200, 1) + """ + ) + + spark.sql( + f""" + INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id + VALUES (202, 2) + """ + ) + + spark.sql( + f""" + DELETE FROM {catalog_name}.default.test_table_rollback_to_snapshot_id + WHERE number = 2 + """ + ) + + spark.sql( + f""" + INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id + VALUES (204, 3) + """ + ) + + spark.sql( + f""" + INSERT INTO {catalog_name}.default.test_table_rollback_to_snapshot_id + VALUES (206, 4) + """ + ) diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 67bafd6de8..d333b4be4d 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -49,24 +49,26 @@ def test_manage_snapshots_context_manager(catalog: Catalog) -> None: tbl = catalog.load_table(identifier) assert len(tbl.history()) > 3 current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore - rollback_snapshot_id = tbl.history()[-4].snapshot_id + expected_snapshot_id = tbl.history()[-4].snapshot_id with tbl.manage_snapshots() as ms: ms.create_tag(snapshot_id=current_snapshot_id, tag_name="testing") - ms.rollback_to_snapshot(snapshot_id=rollback_snapshot_id) + ms.set_current_snapshot(snapshot_id=expected_snapshot_id) + ms.create_tag(snapshot_id=expected_snapshot_id, tag_name="testing2") assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore assert tbl.metadata.refs["testing"].snapshot_id == current_snapshot_id - assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=rollback_snapshot_id, snapshot_ref_type="branch") + assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") + assert tbl.metadata.refs["testing2"].snapshot_id == expected_snapshot_id @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_rollback_to_snapshot(catalog: Catalog) -> None: - identifier = "default.test_table_snapshot_operations" + identifier = "default.test_table_rollback_to_snapshot_id" tbl = catalog.load_table(identifier) assert len(tbl.history()) > 3 - rollback_snapshot_id = tbl.history()[-3].snapshot_id + rollback_snapshot_id = tbl.current_snapshot().parent_snapshot_id # type: ignore current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore - tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=rollback_snapshot_id).commit() + tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=rollback_snapshot_id).commit() # type: ignore assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=rollback_snapshot_id, snapshot_ref_type="branch") @@ -74,12 +76,11 @@ def test_rollback_to_snapshot(catalog: Catalog) -> None: @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_rollback_to_timestamp(catalog: Catalog) -> None: - identifier = "default.test_table_snapshot_operations" + identifier = "default.test_table_rollback_to_snapshot_id" tbl = catalog.load_table(identifier) - assert len(tbl.history()) > 3 - current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore - timestamp = tbl.history()[-2].timestamp_ms - expected_snapshot_id = tbl.history()[-3].snapshot_id + assert len(tbl.history()) > 4 + current_snapshot_id, timestamp = tbl.history()[-1].snapshot_id, tbl.history()[-1].timestamp_ms + expected_snapshot_id = tbl.snapshot_by_id(current_snapshot_id).parent_snapshot_id # type: ignore # not inclusive of rollback_timestamp tbl.manage_snapshots().rollback_to_timestamp(timestamp=timestamp).commit() assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore @@ -107,7 +108,7 @@ def test_set_current_snapshot_with_ref_name(catalog: Catalog) -> None: assert len(tbl.history()) > 3 current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore expected_snapshot_id = tbl.history()[-3].snapshot_id - tbl.manage_snapshots().create_tag(snapshot_id=expected_snapshot_id, tag_name="test-tag19").commit() - tbl.manage_snapshots().set_current_snapshot(ref_name="test-tag19").commit() + tbl.manage_snapshots().create_tag(snapshot_id=expected_snapshot_id, tag_name="test-tag").commit() + tbl.manage_snapshots().set_current_snapshot(ref_name="test-tag").commit() assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") From 7fba98b7cbdc9e3ce22a933553396dd8f7aeb515 Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Sun, 16 Jun 2024 12:52:49 +0530 Subject: [PATCH 08/15] small fix --- pyiceberg/table/snapshots.py | 2 +- tests/integration/test_snapshot_operations.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index eaf420f26e..70923ff407 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -425,7 +425,7 @@ def ancestor_right_before_timestamp( current_snapshot: Optional[Snapshot], table_metadata: TableMetadata, timestamp_ms: int ) -> Optional[Snapshot]: """Get the ancestor right before the given timestamp.""" - if current_snapshot is not None: + if current_snapshot: for ancestor in ancestors_of(current_snapshot, table_metadata): if ancestor.timestamp_ms < timestamp_ms: return ancestor diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index d333b4be4d..41c87c1b87 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -53,7 +53,7 @@ def test_manage_snapshots_context_manager(catalog: Catalog) -> None: with tbl.manage_snapshots() as ms: ms.create_tag(snapshot_id=current_snapshot_id, tag_name="testing") ms.set_current_snapshot(snapshot_id=expected_snapshot_id) - ms.create_tag(snapshot_id=expected_snapshot_id, tag_name="testing2") + ms.create_branch(snapshot_id=expected_snapshot_id, branch_name="testing2") assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore assert tbl.metadata.refs["testing"].snapshot_id == current_snapshot_id assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") From 1f4a40407fd5ce44f42e8698a14e1213afea6ce2 Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:00:25 +0530 Subject: [PATCH 09/15] fix test error --- tests/integration/test_snapshot_operations.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 41c87c1b87..5d1a0e0dcd 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -18,6 +18,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.snapshots import ancestors_of @pytest.mark.integration @@ -54,7 +55,7 @@ def test_manage_snapshots_context_manager(catalog: Catalog) -> None: ms.create_tag(snapshot_id=current_snapshot_id, tag_name="testing") ms.set_current_snapshot(snapshot_id=expected_snapshot_id) ms.create_branch(snapshot_id=expected_snapshot_id, branch_name="testing2") - assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore + assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore assert tbl.metadata.refs["testing"].snapshot_id == current_snapshot_id assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") assert tbl.metadata.refs["testing2"].snapshot_id == expected_snapshot_id @@ -69,7 +70,7 @@ def test_rollback_to_snapshot(catalog: Catalog) -> None: rollback_snapshot_id = tbl.current_snapshot().parent_snapshot_id # type: ignore current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=rollback_snapshot_id).commit() # type: ignore - assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore + assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=rollback_snapshot_id, snapshot_ref_type="branch") @@ -79,11 +80,11 @@ def test_rollback_to_timestamp(catalog: Catalog) -> None: identifier = "default.test_table_rollback_to_snapshot_id" tbl = catalog.load_table(identifier) assert len(tbl.history()) > 4 - current_snapshot_id, timestamp = tbl.history()[-1].snapshot_id, tbl.history()[-1].timestamp_ms - expected_snapshot_id = tbl.snapshot_by_id(current_snapshot_id).parent_snapshot_id # type: ignore + ancestors = list(ancestor for ancestor in ancestors_of(tbl.current_snapshot(), tbl.metadata)) # noqa + ancestor_to_rollback_to = ancestors[-1] + expected_snapshot_id, timestamp = ancestor_to_rollback_to.snapshot_id, ancestor_to_rollback_to.timestamp_ms + 1 # not inclusive of rollback_timestamp tbl.manage_snapshots().rollback_to_timestamp(timestamp=timestamp).commit() - assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") @@ -96,7 +97,7 @@ def test_set_current_snapshot_with_snapshot_id(catalog: Catalog) -> None: current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore expected_snapshot_id = tbl.history()[-3].snapshot_id tbl.manage_snapshots().set_current_snapshot(snapshot_id=expected_snapshot_id).commit() - assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore + assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") @@ -110,5 +111,5 @@ def test_set_current_snapshot_with_ref_name(catalog: Catalog) -> None: expected_snapshot_id = tbl.history()[-3].snapshot_id tbl.manage_snapshots().create_tag(snapshot_id=expected_snapshot_id, tag_name="test-tag").commit() tbl.manage_snapshots().set_current_snapshot(ref_name="test-tag").commit() - assert tbl.current_snapshot().snapshot_id is not current_snapshot_id # type: ignore + assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") From 59f1626a66484ab8fa755cbd07f1e139cfca193c Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Tue, 2 Jul 2024 23:35:06 +0530 Subject: [PATCH 10/15] fixes based on review --- pyiceberg/table/__init__.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index c47f200ce2..18f838a892 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -106,7 +106,7 @@ NameMapping, update_mapping, ) -from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -1980,6 +1980,13 @@ def _commit_if_ref_updates_exist(self) -> None: self.commit() self._updates, self._requirements = (), () + def _stage_main_branch_snapshot_ref(self, snapshot_id: int) -> None: + update, requirement = self._transaction._set_ref_snapshot( + snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=SnapshotRefType.BRANCH + ) + self._updates += update + self._requirements += requirement + def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots: """ Create a new tag pointing to the given snapshot id. @@ -2052,10 +2059,7 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: for ancestor in ancestors_of(self._transaction._table.current_snapshot(), self._transaction.table_metadata) }: raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") - - update, requirement = self._transaction._set_ref_snapshot(snapshot_id=snapshot_id, ref_name="main", type="branch") - self._updates += update - self._requirements += requirement + self._stage_main_branch_snapshot_ref(snapshot_id=snapshot_id) return self def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: @@ -2075,12 +2079,7 @@ def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: ) ) is None: raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}") - - update, requirement = self._transaction._set_ref_snapshot( - snapshot_id=snapshot.snapshot_id, ref_name="main", type="branch" - ) - self._updates += update - self._requirements += requirement + self._stage_main_branch_snapshot_ref(snapshot_id=snapshot.snapshot_id) return self def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Optional[str] = None) -> ManageSnapshots: @@ -2099,17 +2098,15 @@ def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Opti raise ValidationError("Either snapshot_id or ref must be provided") else: if snapshot_id is None: - target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id # type:ignore + if ref_name not in self._transaction.table_metadata.refs: + raise ValidationError(f"Cannot set snapshot current to unknown ref {ref_name}") + target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id else: target_snapshot_id = snapshot_id if (snapshot := self._transaction._table.snapshot_by_id(target_snapshot_id)) is None: raise ValidationError(f"Cannot set snapshot current with snapshot id: {snapshot_id} or ref_name: {ref_name}") - update, requirement = self._transaction._set_ref_snapshot( - snapshot_id=snapshot.snapshot_id, ref_name="main", type="branch" - ) - self._updates += update - self._requirements += requirement + self._stage_main_branch_snapshot_ref(snapshot_id=snapshot.snapshot_id) return self From 7c7907b1754bf8e3fb62333cb2d83887bdff8dea Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Wed, 3 Jul 2024 00:31:54 +0530 Subject: [PATCH 11/15] add parameter to control when transaction is committed --- pyiceberg/table/__init__.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 18f838a892..e4c4e3471b 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -300,7 +300,12 @@ def __exit__(self, _: Any, value: Any, traceback: Any) -> None: """Close and commit the transaction.""" self.commit_transaction() - def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...] = ()) -> Transaction: + def _apply( + self, + updates: Tuple[TableUpdate, ...], + requirements: Tuple[TableRequirement, ...] = (), + commit_transaction_now: bool = True, + ) -> Transaction: """Check if the requirements are met, and applies the updates to the metadata.""" for requirement in requirements: requirement.validate(self.table_metadata) @@ -310,7 +315,7 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ self.table_metadata = update_table_metadata(self.table_metadata, updates) - if self._autocommit: + if self._autocommit and commit_transaction_now: self.commit_transaction() self._updates = () self._requirements = () @@ -1977,12 +1982,12 @@ def _commit(self) -> UpdatesAndRequirements: return self._updates, self._requirements def _commit_if_ref_updates_exist(self) -> None: - self.commit() + self._transaction._apply(*self._commit(), commit_transaction_now=False) self._updates, self._requirements = (), () def _stage_main_branch_snapshot_ref(self, snapshot_id: int) -> None: update, requirement = self._transaction._set_ref_snapshot( - snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=SnapshotRefType.BRANCH + snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH) ) self._updates += update self._requirements += requirement From e563b7eb38012693149bf7d2e621ece8a7ebd198 Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Wed, 3 Jul 2024 01:12:45 +0530 Subject: [PATCH 12/15] move _set_ref_snapshot --- pyiceberg/table/__init__.py | 74 +++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index e4c4e3471b..286743b5fa 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -408,39 +408,6 @@ def set_ref_snapshot( requirements = (AssertRefSnapshotId(snapshot_id=parent_snapshot_id, ref="main"),) return self._apply(updates, requirements) - def _set_ref_snapshot( - self, - snapshot_id: int, - ref_name: str, - type: str, - max_ref_age_ms: Optional[int] = None, - max_snapshot_age_ms: Optional[int] = None, - min_snapshots_to_keep: Optional[int] = None, - ) -> UpdatesAndRequirements: - """Update a ref to a snapshot. - - Returns: - The updates and requirements for the set-snapshot-ref staged - """ - updates = ( - SetSnapshotRefUpdate( - snapshot_id=snapshot_id, - ref_name=ref_name, - type=type, - max_ref_age_ms=max_ref_age_ms, - max_snapshot_age_ms=max_snapshot_age_ms, - min_snapshots_to_keep=min_snapshots_to_keep, - ), - ) - requirements = ( - AssertRefSnapshotId( - snapshot_id=self.table_metadata.refs[ref_name].snapshot_id if ref_name in self.table_metadata.refs else None, - ref=ref_name, - ), - ) - - return updates, requirements - def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: """Create a new UpdateSchema to alter the columns of this table. @@ -1986,12 +1953,47 @@ def _commit_if_ref_updates_exist(self) -> None: self._updates, self._requirements = (), () def _stage_main_branch_snapshot_ref(self, snapshot_id: int) -> None: - update, requirement = self._transaction._set_ref_snapshot( + update, requirement = self._set_ref_snapshot( snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH) ) self._updates += update self._requirements += requirement + def _set_ref_snapshot( + self, + snapshot_id: int, + ref_name: str, + type: str, + max_ref_age_ms: Optional[int] = None, + max_snapshot_age_ms: Optional[int] = None, + min_snapshots_to_keep: Optional[int] = None, + ) -> UpdatesAndRequirements: + """Update a ref to a snapshot. + + Returns: + The updates and requirements for the set-snapshot-ref staged + """ + updates = ( + SetSnapshotRefUpdate( + snapshot_id=snapshot_id, + ref_name=ref_name, + type=type, + max_ref_age_ms=max_ref_age_ms, + max_snapshot_age_ms=max_snapshot_age_ms, + min_snapshots_to_keep=min_snapshots_to_keep, + ), + ) + requirements = ( + AssertRefSnapshotId( + snapshot_id=self._transaction.table_metadata.refs[ref_name].snapshot_id + if ref_name in self._transaction.table_metadata.refs + else None, + ref=ref_name, + ), + ) + + return updates, requirements + def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots: """ Create a new tag pointing to the given snapshot id. @@ -2004,7 +2006,7 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i Returns: This for method chaining """ - update, requirement = self._transaction._set_ref_snapshot( + update, requirement = self._set_ref_snapshot( snapshot_id=snapshot_id, ref_name=tag_name, type="tag", @@ -2034,7 +2036,7 @@ def create_branch( Returns: This for method chaining """ - update, requirement = self._transaction._set_ref_snapshot( + update, requirement = self._set_ref_snapshot( snapshot_id=snapshot_id, ref_name=branch_name, type="branch", From 386496f8ec8722d5261623c6063954368e1b4b4e Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Thu, 4 Jul 2024 11:31:56 +0530 Subject: [PATCH 13/15] changes after review --- pyiceberg/table/__init__.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 286743b5fa..8a0f3157cd 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -304,7 +304,7 @@ def _apply( self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...] = (), - commit_transaction_now: bool = True, + commit_transaction_if_autocommit: bool = True, ) -> Transaction: """Check if the requirements are met, and applies the updates to the metadata.""" for requirement in requirements: @@ -315,7 +315,7 @@ def _apply( self.table_metadata = update_table_metadata(self.table_metadata, updates) - if self._autocommit and commit_transaction_now: + if self._autocommit and commit_transaction_if_autocommit: self.commit_transaction() self._updates = () self._requirements = () @@ -1949,16 +1949,9 @@ def _commit(self) -> UpdatesAndRequirements: return self._updates, self._requirements def _commit_if_ref_updates_exist(self) -> None: - self._transaction._apply(*self._commit(), commit_transaction_now=False) + self._transaction._apply(*self._commit(), commit_transaction_if_autocommit=False) self._updates, self._requirements = (), () - def _stage_main_branch_snapshot_ref(self, snapshot_id: int) -> None: - update, requirement = self._set_ref_snapshot( - snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH) - ) - self._updates += update - self._requirements += requirement - def _set_ref_snapshot( self, snapshot_id: int, @@ -1967,7 +1960,7 @@ def _set_ref_snapshot( max_ref_age_ms: Optional[int] = None, max_snapshot_age_ms: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None, - ) -> UpdatesAndRequirements: + ) -> ManageSnapshots: """Update a ref to a snapshot. Returns: @@ -1991,8 +1984,9 @@ def _set_ref_snapshot( ref=ref_name, ), ) - - return updates, requirements + self._updates += updates + self._requirements += requirements + return self def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots: """ @@ -2006,14 +2000,12 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i Returns: This for method chaining """ - update, requirement = self._set_ref_snapshot( + self._set_ref_snapshot( snapshot_id=snapshot_id, ref_name=tag_name, type="tag", max_ref_age_ms=max_ref_age_ms, ) - self._updates += update - self._requirements += requirement return self def create_branch( @@ -2036,7 +2028,7 @@ def create_branch( Returns: This for method chaining """ - update, requirement = self._set_ref_snapshot( + self._set_ref_snapshot( snapshot_id=snapshot_id, ref_name=branch_name, type="branch", @@ -2044,8 +2036,6 @@ def create_branch( max_snapshot_age_ms=max_snapshot_age_ms, min_snapshots_to_keep=min_snapshots_to_keep, ) - self._updates += update - self._requirements += requirement return self def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: @@ -2066,7 +2056,7 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: for ancestor in ancestors_of(self._transaction._table.current_snapshot(), self._transaction.table_metadata) }: raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") - self._stage_main_branch_snapshot_ref(snapshot_id=snapshot_id) + self._set_ref_snapshot(snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) return self def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: @@ -2086,7 +2076,7 @@ def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: ) ) is None: raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}") - self._stage_main_branch_snapshot_ref(snapshot_id=snapshot.snapshot_id) + self._set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) return self def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Optional[str] = None) -> ManageSnapshots: @@ -2113,7 +2103,7 @@ def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Opti if (snapshot := self._transaction._table.snapshot_by_id(target_snapshot_id)) is None: raise ValidationError(f"Cannot set snapshot current with snapshot id: {snapshot_id} or ref_name: {ref_name}") - self._stage_main_branch_snapshot_ref(snapshot_id=snapshot.snapshot_id) + self._set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) return self From 5adccb96868b2b8cdfb12a9ab70b4f6d159f5c58 Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Thu, 4 Jul 2024 12:16:50 +0530 Subject: [PATCH 14/15] move test and use constants --- tests/integration/test_snapshot_operations.py | 45 +++++++++++++++---- tests/table/test_init.py | 24 ---------- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 5d1a0e0dcd..d08b483b95 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -17,7 +17,7 @@ import pytest from pyiceberg.catalog import Catalog -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ancestors_of @@ -29,7 +29,7 @@ def test_create_tag(catalog: Catalog) -> None: assert len(tbl.history()) > 3 tag_snapshot_id = tbl.history()[-3].snapshot_id tbl.manage_snapshots().create_tag(snapshot_id=tag_snapshot_id, tag_name="tag123").commit() - assert tbl.metadata.refs["tag123"] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type="tag") + assert tbl.metadata.refs["tag123"] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type=str(SnapshotRefType.TAG)) @pytest.mark.integration @@ -40,7 +40,9 @@ def test_create_branch(catalog: Catalog) -> None: assert len(tbl.history()) > 2 branch_snapshot_id = tbl.history()[-2].snapshot_id tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit() - assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch") + assert tbl.metadata.refs["branch123"] == SnapshotRef( + snapshot_id=branch_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) @pytest.mark.integration @@ -57,10 +59,13 @@ def test_manage_snapshots_context_manager(catalog: Catalog) -> None: ms.create_branch(snapshot_id=expected_snapshot_id, branch_name="testing2") assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore assert tbl.metadata.refs["testing"].snapshot_id == current_snapshot_id - assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=expected_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) assert tbl.metadata.refs["testing2"].snapshot_id == expected_snapshot_id +# Maintain relative order of tests for following apis like rollback, set_current_snapshot, etc. @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_rollback_to_snapshot(catalog: Catalog) -> None: @@ -71,7 +76,9 @@ def test_rollback_to_snapshot(catalog: Catalog) -> None: current_snapshot_id = tbl.current_snapshot().snapshot_id # type: ignore tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=rollback_snapshot_id).commit() # type: ignore assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore - assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=rollback_snapshot_id, snapshot_ref_type="branch") + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=rollback_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) @pytest.mark.integration @@ -85,7 +92,9 @@ def test_rollback_to_timestamp(catalog: Catalog) -> None: expected_snapshot_id, timestamp = ancestor_to_rollback_to.snapshot_id, ancestor_to_rollback_to.timestamp_ms + 1 # not inclusive of rollback_timestamp tbl.manage_snapshots().rollback_to_timestamp(timestamp=timestamp).commit() - assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=expected_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) @pytest.mark.integration @@ -98,7 +107,9 @@ def test_set_current_snapshot_with_snapshot_id(catalog: Catalog) -> None: expected_snapshot_id = tbl.history()[-3].snapshot_id tbl.manage_snapshots().set_current_snapshot(snapshot_id=expected_snapshot_id).commit() assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore - assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=expected_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) @pytest.mark.integration @@ -112,4 +123,22 @@ def test_set_current_snapshot_with_ref_name(catalog: Catalog) -> None: tbl.manage_snapshots().create_tag(snapshot_id=expected_snapshot_id, tag_name="test-tag").commit() tbl.manage_snapshots().set_current_snapshot(ref_name="test-tag").commit() assert tbl.current_snapshot().snapshot_id != current_snapshot_id # type: ignore - assert tbl.metadata.refs["main"] == SnapshotRef(snapshot_id=expected_snapshot_id, snapshot_ref_type="branch") + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=expected_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) + + +# Always test set_ref_snapshot last. +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_ref_snapshot(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + target_snapshot_id = tbl.history()[-2].snapshot_id + tbl.manage_snapshots()._set_ref_snapshot( + snapshot_id=target_snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH) + ).commit() + assert tbl.metadata.refs[MAIN_BRANCH] == SnapshotRef( + snapshot_id=target_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH) + ) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index d7c4ffeeaf..92329df7bd 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -706,30 +706,6 @@ def test_update_metadata_add_snapshot(table_v2: Table) -> None: assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms -def test_update_metadata_set_ref_snapshot(table_v2: Table) -> None: - update, _ = table_v2.transaction()._set_ref_snapshot( - snapshot_id=3051729675574597004, - ref_name="main", - type="branch", - max_ref_age_ms=123123123, - max_snapshot_age_ms=12312312312, - min_snapshots_to_keep=1, - ) - - new_metadata = update_table_metadata(table_v2.metadata, update) - assert len(new_metadata.snapshot_log) == 3 - assert new_metadata.snapshot_log[2].snapshot_id == 3051729675574597004 - assert new_metadata.current_snapshot_id == 3051729675574597004 - assert new_metadata.last_updated_ms > table_v2.metadata.last_updated_ms - assert new_metadata.refs["main"] == SnapshotRef( - snapshot_id=3051729675574597004, - snapshot_ref_type="branch", - min_snapshots_to_keep=1, - max_snapshot_age_ms=12312312312, - max_ref_age_ms=123123123, - ) - - def test_update_metadata_set_snapshot_ref(table_v2: Table) -> None: update = SetSnapshotRefUpdate( ref_name="main", From 8885f7808238def0236c09df7cfa505b1e77dcb9 Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Thu, 4 Jul 2024 12:42:36 +0530 Subject: [PATCH 15/15] fix docstring and returns --- pyiceberg/table/__init__.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8a0f3157cd..4db79f90fa 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1963,8 +1963,10 @@ def _set_ref_snapshot( ) -> ManageSnapshots: """Update a ref to a snapshot. + Stages the updates and requirements for the set-snapshot-ref + Returns: - The updates and requirements for the set-snapshot-ref staged + This for method chaining """ updates = ( SetSnapshotRefUpdate( @@ -2000,13 +2002,12 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i Returns: This for method chaining """ - self._set_ref_snapshot( + return self._set_ref_snapshot( snapshot_id=snapshot_id, ref_name=tag_name, type="tag", max_ref_age_ms=max_ref_age_ms, ) - return self def create_branch( self, @@ -2028,7 +2029,7 @@ def create_branch( Returns: This for method chaining """ - self._set_ref_snapshot( + return self._set_ref_snapshot( snapshot_id=snapshot_id, ref_name=branch_name, type="branch", @@ -2036,7 +2037,6 @@ def create_branch( max_snapshot_age_ms=max_snapshot_age_ms, min_snapshots_to_keep=min_snapshots_to_keep, ) - return self def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: """Rollback the table to the given snapshot id. @@ -2056,8 +2056,7 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: for ancestor in ancestors_of(self._transaction._table.current_snapshot(), self._transaction.table_metadata) }: raise ValidationError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") - self._set_ref_snapshot(snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) - return self + return self._set_ref_snapshot(snapshot_id=snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: """Rollback the table to the snapshot right before the given timestamp. @@ -2076,8 +2075,7 @@ def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots: ) ) is None: raise ValidationError(f"Cannot roll back, no valid snapshot older than: {timestamp}") - self._set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) - return self + return self._set_ref_snapshot(snapshot_id=snapshot.snapshot_id, ref_name=MAIN_BRANCH, type=str(SnapshotRefType.BRANCH)) def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Optional[str] = None) -> ManageSnapshots: """Set the table to a specific snapshot identified either by its id or the branch/tag its on, not both.