106106 NameMapping ,
107107 update_mapping ,
108108)
109- from pyiceberg .table .refs import MAIN_BRANCH , SnapshotRef
109+ from pyiceberg .table .refs import MAIN_BRANCH , SnapshotRef , SnapshotRefType
110110from pyiceberg .table .snapshots import (
111111 Operation ,
112112 Snapshot ,
@@ -1980,6 +1980,13 @@ def _commit_if_ref_updates_exist(self) -> None:
19801980 self .commit ()
19811981 self ._updates , self ._requirements = (), ()
19821982
1983+ def _stage_main_branch_snapshot_ref (self , snapshot_id : int ) -> None :
1984+ update , requirement = self ._transaction ._set_ref_snapshot (
1985+ snapshot_id = snapshot_id , ref_name = MAIN_BRANCH , type = SnapshotRefType .BRANCH
1986+ )
1987+ self ._updates += update
1988+ self ._requirements += requirement
1989+
19831990 def create_tag (self , snapshot_id : int , tag_name : str , max_ref_age_ms : Optional [int ] = None ) -> ManageSnapshots :
19841991 """
19851992 Create a new tag pointing to the given snapshot id.
@@ -2052,10 +2059,7 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
20522059 for ancestor in ancestors_of (self ._transaction ._table .current_snapshot (), self ._transaction .table_metadata )
20532060 }:
20542061 raise ValidationError (f"Cannot roll back to snapshot, not an ancestor of the current state: { snapshot_id } " )
2055-
2056- update , requirement = self ._transaction ._set_ref_snapshot (snapshot_id = snapshot_id , ref_name = "main" , type = "branch" )
2057- self ._updates += update
2058- self ._requirements += requirement
2062+ self ._stage_main_branch_snapshot_ref (snapshot_id = snapshot_id )
20592063 return self
20602064
20612065 def rollback_to_timestamp (self , timestamp : int ) -> ManageSnapshots :
@@ -2075,12 +2079,7 @@ def rollback_to_timestamp(self, timestamp: int) -> ManageSnapshots:
20752079 )
20762080 ) is None :
20772081 raise ValidationError (f"Cannot roll back, no valid snapshot older than: { timestamp } " )
2078-
2079- update , requirement = self ._transaction ._set_ref_snapshot (
2080- snapshot_id = snapshot .snapshot_id , ref_name = "main" , type = "branch"
2081- )
2082- self ._updates += update
2083- self ._requirements += requirement
2082+ self ._stage_main_branch_snapshot_ref (snapshot_id = snapshot .snapshot_id )
20842083 return self
20852084
20862085 def set_current_snapshot (self , snapshot_id : Optional [int ] = None , ref_name : Optional [str ] = None ) -> ManageSnapshots :
@@ -2099,17 +2098,15 @@ def set_current_snapshot(self, snapshot_id: Optional[int] = None, ref_name: Opti
20992098 raise ValidationError ("Either snapshot_id or ref must be provided" )
21002099 else :
21012100 if snapshot_id is None :
2102- target_snapshot_id = self ._transaction .table_metadata .refs [ref_name ].snapshot_id # type:ignore
2101+ if ref_name not in self ._transaction .table_metadata .refs :
2102+ raise ValidationError (f"Cannot set snapshot current to unknown ref { ref_name } " )
2103+ target_snapshot_id = self ._transaction .table_metadata .refs [ref_name ].snapshot_id
21032104 else :
21042105 target_snapshot_id = snapshot_id
21052106 if (snapshot := self ._transaction ._table .snapshot_by_id (target_snapshot_id )) is None :
21062107 raise ValidationError (f"Cannot set snapshot current with snapshot id: { snapshot_id } or ref_name: { ref_name } " )
21072108
2108- update , requirement = self ._transaction ._set_ref_snapshot (
2109- snapshot_id = snapshot .snapshot_id , ref_name = "main" , type = "branch"
2110- )
2111- self ._updates += update
2112- self ._requirements += requirement
2109+ self ._stage_main_branch_snapshot_ref (snapshot_id = snapshot .snapshot_id )
21132110 return self
21142111
21152112
0 commit comments