From b7bdb6c3f76866a91580ec377c70b442e4ee5edb Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Sun, 16 Jun 2024 18:48:15 +0530 Subject: [PATCH 1/3] add public and private APIs, register RemoveSnapshotRefUpdate with apply metadata fn --- pyiceberg/table/__init__.py | 64 +++++++++++++++++++ tests/integration/test_snapshot_operations.py | 32 ++++++++++ 2 files changed, 96 insertions(+) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2eec4d3036..372187eb9d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -435,6 +435,24 @@ def _set_ref_snapshot( return updates, requirements + def _remove_ref_snapshot(self, ref_name: str) -> UpdatesAndRequirements: + """Remove a snapshot ref. + + Args: + ref_name: branch / tag name to remove + + Returns + The updates and requirements for the remove-snapshot-ref. + """ + updates = (RemoveSnapshotRefUpdate(ref_name=ref_name),) + requirements = ( + AssertRefSnapshotId( + snapshot_id=self.table_metadata.refs[ref_name].snapshot_id if ref_name in self.table_metadata.refs else None, + ref=ref_name, + ), + ) + return updates, requirements + def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: """Create a new UpdateSchema to alter the columns of this table. @@ -1023,6 +1041,23 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl return base_metadata.model_copy(update=metadata_updates) +@_apply_table_update.register(RemoveSnapshotRefUpdate) +def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + if (existing_ref := base_metadata.refs.get(update.ref_name)) is None: + return base_metadata + + if base_metadata.snapshot_by_id(existing_ref.snapshot_id) is None: + raise ValueError(f"Cannot remove {update.ref_name} ref with unknown snapshot {existing_ref.snapshot_id}") + + if update.ref_name == MAIN_BRANCH: + raise ValueError("Cannot remove main branch") + + metadata_refs = {**base_metadata.refs} + metadata_refs.pop(update.ref_name, None) + context.add_update(update) + return base_metadata.model_copy(update={"refs": metadata_refs}) + + @_apply_table_update.register(AddSortOrderUpdate) def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: context.add_update(update) @@ -1997,6 +2032,21 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[i self._requirements += requirement return self + def remove_tag(self, tag_name: str) -> ManageSnapshots: + """ + Remove a tag. + + Args: + tag_name (str): name of tag to remove + + Returns: + This for method chaining + """ + update, requirement = self._transaction._remove_ref_snapshot(ref_name=tag_name) + self._updates += update + self._requirements += requirement + return self + def create_branch( self, snapshot_id: int, @@ -2029,6 +2079,20 @@ def create_branch( self._requirements += requirement return self + def remove_branch(self, branch_name: str) -> ManageSnapshots: + """ + Remove a branch. + + Args: + branch_name (str): name of branch to remove + Returns: + This for method chaining + """ + update, requirement = self._transaction._remove_ref_snapshot(ref_name=branch_name) + self._updates += update + self._requirements += requirement + return self + class UpdateSchema(UpdateTableMetadata["UpdateSchema"]): _schema: Schema diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 639193383e..1b7f2d3a29 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -40,3 +40,35 @@ def test_create_branch(catalog: Catalog) -> None: branch_snapshot_id = tbl.history()[-2].snapshot_id tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit() assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch") + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_remove_tag(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 3 + # first, create the tag to remove + tag_name = "tag_to_remove" + tag_snapshot_id = tbl.history()[-3].snapshot_id + tbl.manage_snapshots().create_tag(snapshot_id=tag_snapshot_id, tag_name=tag_name).commit() + assert tbl.metadata.refs[tag_name] == SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type="tag") + # now, remove the tag + tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() + assert tbl.metadata.refs.get(tag_name, None) is None + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_remove_branch(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + # first, create the branch to remove + branch_name = "branch_to_remove" + branch_snapshot_id = tbl.history()[-2].snapshot_id + tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name=branch_name).commit() + assert tbl.metadata.refs[branch_name] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch") + # now, remove the branch + tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit() + assert tbl.metadata.refs.get(branch_name, None) is None From 43b1378c2f8e49d2ee862bcc0e5d86cb0670ec2d Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Thu, 4 Jul 2024 13:46:57 +0530 Subject: [PATCH 2/3] updates --- pyiceberg/table/__init__.py | 52 ++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 372187eb9d..49381c2450 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -435,24 +435,6 @@ def _set_ref_snapshot( return updates, requirements - def _remove_ref_snapshot(self, ref_name: str) -> UpdatesAndRequirements: - """Remove a snapshot ref. - - Args: - ref_name: branch / tag name to remove - - Returns - The updates and requirements for the remove-snapshot-ref. - """ - updates = (RemoveSnapshotRefUpdate(ref_name=ref_name),) - requirements = ( - AssertRefSnapshotId( - snapshot_id=self.table_metadata.refs[ref_name].snapshot_id if ref_name in self.table_metadata.refs else None, - ref=ref_name, - ), - ) - return updates, requirements - def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: """Create a new UpdateSchema to alter the columns of this table. @@ -2010,6 +1992,30 @@ def _commit(self) -> UpdatesAndRequirements: """Apply the pending changes and commit.""" return self._updates, self._requirements + def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots: + """Remove a snapshot ref. + + Args: + ref_name: branch / tag name to remove + + Stages the updates and requirements for the remove-snapshot-ref. + + Returns + This method for chaining + """ + updates = (RemoveSnapshotRefUpdate(ref_name=ref_name),) + requirements = ( + AssertRefSnapshotId( + snapshot_id=self._transaction.table_metadata.refs[ref_name].snapshot_id + if ref_name in self._transaction.table_metadata.refs + else None, + ref=ref_name, + ), + ) + self._updates += updates + self._requirements += requirements + return self + def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots: """ Create a new tag pointing to the given snapshot id. @@ -2042,10 +2048,7 @@ def remove_tag(self, tag_name: str) -> ManageSnapshots: Returns: This for method chaining """ - update, requirement = self._transaction._remove_ref_snapshot(ref_name=tag_name) - self._updates += update - self._requirements += requirement - return self + return self._remove_ref_snapshot(ref_name=tag_name) def create_branch( self, @@ -2088,10 +2091,7 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots: Returns: This for method chaining """ - update, requirement = self._transaction._remove_ref_snapshot(ref_name=branch_name) - self._updates += update - self._requirements += requirement - return self + return self._remove_ref_snapshot(ref_name=branch_name) class UpdateSchema(UpdateTableMetadata["UpdateSchema"]): From cca7d2b7f4d313c277dbe91ef1fabc17beebe76c Mon Sep 17 00:00:00 2001 From: chinmay-bhat <12948588+chinmay-bhat@users.noreply.github.com> Date: Thu, 4 Jul 2024 13:56:48 +0530 Subject: [PATCH 3/3] small fix --- pyiceberg/table/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 49381c2450..d0d646f037 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1025,7 +1025,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl @_apply_table_update.register(RemoveSnapshotRefUpdate) def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: - if (existing_ref := base_metadata.refs.get(update.ref_name)) is None: + if (existing_ref := base_metadata.refs.get(update.ref_name, None)) is None: return base_metadata if base_metadata.snapshot_by_id(existing_ref.snapshot_id) is None: