Skip to content

Commit 2a5b08e

Browse files
committed
feat: check whether table ops conflict when committing
1 parent 8adf246 commit 2a5b08e

File tree

3 files changed

+62
-0
lines changed

3 files changed

+62
-0
lines changed

pyiceberg/table/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,15 @@ def refresh(self) -> Table:
918918
self.metadata_location = fresh.metadata_location
919919
return self
920920

921+
def check_and_refresh_table(self) -> Optional[Table]:
922+
fresh = self.catalog.load_table(self._identifier)
923+
if self.metadata.current_snapshot_id != fresh.metadata.current_snapshot_id:
924+
self.metadata = fresh.metadata
925+
self.io = fresh.io
926+
self.metadata_location = fresh.metadata_location
927+
return fresh
928+
return None
929+
921930
def name(self) -> Identifier:
922931
"""Return the identifier of this table.
923932

pyiceberg/table/update/snapshot.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,16 @@ def _summary(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> Summary:
239239
truncate_full_table=self._operation == Operation.OVERWRITE,
240240
)
241241

242+
@abstractmethod
243+
def _validate(self) -> None:
244+
pass
245+
242246
def _commit(self) -> UpdatesAndRequirements:
247+
from pyiceberg.table import StagedTable
248+
249+
if not isinstance(self._transaction._table, StagedTable):
250+
self._validate()
251+
243252
new_manifests = self._manifests()
244253
next_sequence_number = self._transaction.table_metadata.next_sequence_number()
245254

@@ -435,6 +444,9 @@ def _existing_manifests(self) -> List[ManifestFile]:
435444
def _deleted_entries(self) -> List[ManifestEntry]:
436445
return self._compute_deletes[1]
437446

447+
def _validate(self) -> None:
448+
return
449+
438450
@property
439451
def rewrites_needed(self) -> bool:
440452
"""Indicate if data files need to be rewritten."""
@@ -474,6 +486,15 @@ def _deleted_entries(self) -> List[ManifestEntry]:
474486
"""
475487
return []
476488

489+
def _validate(self) -> None:
490+
refresh_table = self._transaction._table.check_and_refresh_table()
491+
if refresh_table is None:
492+
return
493+
current_snapshot = refresh_table.metadata.current_snapshot()
494+
if current_snapshot is not None and current_snapshot.snapshot_id != self._parent_snapshot_id:
495+
self._parent_snapshot_id = current_snapshot.snapshot_id
496+
self._transaction.table_metadata = refresh_table.metadata
497+
477498

478499
class _MergeAppendFiles(_FastAppendFiles):
479500
_target_size_bytes: int
@@ -602,6 +623,9 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
602623
else:
603624
return []
604625

626+
def _validate(self) -> None:
627+
return
628+
605629

606630
class UpdateSnapshot:
607631
_transaction: Transaction

tests/integration/test_add_files.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,3 +898,32 @@ def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_file
898898
with pytest.raises(ValueError) as exc_info:
899899
tbl.add_files(file_paths=[existing_files_in_table], check_duplicate_files=True)
900900
assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value)
901+
902+
903+
@pytest.mark.integration
904+
@pytest.mark.parametrize("format_version", [1, 2])
905+
def test_conflict_delete_append(
906+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
907+
) -> None:
908+
identifier = "default.test_conflict"
909+
tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema)
910+
tbl1.append(arrow_table_with_null)
911+
tbl2 = session_catalog.load_table(identifier)
912+
913+
# This is allowed
914+
tbl1.delete("string == 'z'")
915+
tbl2.append(arrow_table_with_null)
916+
917+
918+
@pytest.mark.integration
919+
@pytest.mark.parametrize("format_version", [1, 2])
920+
def test_conflict_append_append(
921+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
922+
) -> None:
923+
identifier = "default.test_conflict"
924+
tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema)
925+
tbl1.append(arrow_table_with_null)
926+
tbl2 = session_catalog.load_table(identifier)
927+
928+
tbl1.append(arrow_table_with_null)
929+
tbl2.append(arrow_table_with_null)

0 commit comments

Comments
 (0)