Skip to content

Commit 949e140

Browse files
committed
feat: validate snapshot write compatibility
1 parent 06404a5 commit 949e140

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

pyiceberg/table/update/snapshot.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from pyiceberg.utils.bin_packing import ListPacker
8181
from pyiceberg.utils.concurrent import ExecutorFactory
8282
from pyiceberg.utils.properties import property_as_bool, property_as_int
83+
from pyiceberg.utils.snapshot import ancestors_between
8384

8485
if TYPE_CHECKING:
8586
from pyiceberg.table import Transaction
@@ -251,6 +252,12 @@ def _commit(self) -> UpdatesAndRequirements:
251252
)
252253
location_provider = self._transaction._table.location_provider()
253254
manifest_list_file_path = location_provider.new_metadata_location(file_name)
255+
256+
# get current snapshot id and starting snapshot id, and validate that there are no conflicts
257+
starting_snapshot_id = self._parent_snapshot_id
258+
current_snapshot_id = self._transaction._table.refresh().metadata.current_snapshot_id
259+
self._validate(starting_snapshot_id, current_snapshot_id)
260+
254261
with write_manifest_list(
255262
format_version=self._transaction.table_metadata.format_version,
256263
output_file=self._io.new_output(manifest_list_file_path),
@@ -279,6 +286,27 @@ def _commit(self) -> UpdatesAndRequirements:
279286
(AssertRefSnapshotId(snapshot_id=self._transaction.table_metadata.current_snapshot_id, ref="main"),),
280287
)
281288

289+
def _validate(self, starting_snapshot_id: Optional[int], current_snapshot_id: Optional[int]) -> None:
290+
# get all the snapshots between the current snapshot id and the parent id
291+
snapshots = ancestors_between(starting_snapshot_id, current_snapshot_id, self._transaction._table.metadata.snapshot_by_id)
292+
293+
# Define allowed operations for each type of operation
294+
allowed_operations = {
295+
Operation.APPEND: {Operation.APPEND, Operation.REPLACE, Operation.OVERWRITE, Operation.DELETE},
296+
Operation.REPLACE: {Operation.APPEND},
297+
Operation.OVERWRITE: set(),
298+
Operation.DELETE: set(),
299+
}
300+
301+
for snapshot in snapshots:
302+
snapshot_operation = snapshot.summary.operation
303+
304+
if snapshot_operation not in allowed_operations[self._operation]:
305+
raise ValueError(
306+
f"Operation {snapshot_operation} is not allowed when performing {self._operation}. "
307+
"Check for overlaps or conflicts."
308+
)
309+
282310
@property
283311
def snapshot_id(self) -> int:
284312
return self._snapshot_id

pyiceberg/utils/snapshot.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from typing import Callable, Iterable, Iterator, Optional
19+
from pyiceberg.table.snapshots import Snapshot
20+
21+
22+
def ancestors_of(snapshot_id: Optional[int], lookup_fn: Callable[[int], Optional[Snapshot]]) -> Iterable[Snapshot]:
23+
def _snapshot_iterator(snapshot: Snapshot) -> Iterator[Snapshot]:
24+
next_snapshot: Optional[Snapshot] = snapshot
25+
consumed = False
26+
27+
while next_snapshot is not None:
28+
if not consumed:
29+
yield next_snapshot
30+
consumed = True
31+
32+
parent_id = next_snapshot.parent_snapshot_id
33+
if parent_id is None:
34+
break
35+
36+
next_snapshot = lookup_fn(parent_id)
37+
consumed = False
38+
39+
snapshot: Optional[Snapshot] = lookup_fn(snapshot_id)
40+
if snapshot is not None:
41+
return _snapshot_iterator(snapshot)
42+
else:
43+
return iter([])
44+
45+
def ancestors_between(starting_snapshot_id: Optional[int], current_snapshot_id: Optional[int], lookup_fn: Callable[[int], Optional[Snapshot]]) -> Iterable[Snapshot]:
46+
if starting_snapshot_id == current_snapshot_id:
47+
return iter([])
48+
49+
return ancestors_of(
50+
current_snapshot_id,
51+
lambda snapshot_id: lookup_fn(snapshot_id) if snapshot_id != starting_snapshot_id else None
52+
)
53+

0 commit comments

Comments
 (0)