Skip to content

Commit 1958e5c

Browse files
author
Yingjian Wu
committed
implement stageOnly Commit
1 parent 904c0b7 commit 1958e5c

File tree

3 files changed

+216
-21
lines changed

3 files changed

+216
-21
lines changed

pyiceberg/table/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,9 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
432432
name_mapping=self.table_metadata.name_mapping(),
433433
)
434434

435-
def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot:
435+
def update_snapshot(
436+
self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None, stage_only: bool = False
437+
) -> UpdateSnapshot:
436438
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
437439
438440
Returns:
@@ -441,7 +443,9 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, bran
441443
if branch is None:
442444
branch = MAIN_BRANCH
443445

444-
return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties)
446+
return UpdateSnapshot(
447+
self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties, stage_only=stage_only
448+
)
445449

446450
def update_statistics(self) -> UpdateStatistics:
447451
"""

pyiceberg/table/update/snapshot.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]):
109109
_deleted_data_files: Set[DataFile]
110110
_compression: AvroCompressionCodec
111111
_target_branch = MAIN_BRANCH
112+
_stage_only = False
112113

113114
def __init__(
114115
self,
@@ -118,6 +119,7 @@ def __init__(
118119
commit_uuid: Optional[uuid.UUID] = None,
119120
snapshot_properties: Dict[str, str] = EMPTY_DICT,
120121
branch: str = MAIN_BRANCH,
122+
stage_only: bool = False,
121123
) -> None:
122124
super().__init__(transaction)
123125
self.commit_uuid = commit_uuid or uuid.uuid4()
@@ -137,6 +139,7 @@ def __init__(
137139
self._parent_snapshot_id = (
138140
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None
139141
)
142+
self._stage_only = stage_only
140143

141144
def _validate_target_branch(self, branch: str) -> str:
142145
# Default is already set to MAIN_BRANCH. So branch name can't be None.
@@ -292,25 +295,33 @@ def _commit(self) -> UpdatesAndRequirements:
292295
schema_id=self._transaction.table_metadata.current_schema_id,
293296
)
294297

295-
return (
296-
(
297-
AddSnapshotUpdate(snapshot=snapshot),
298-
SetSnapshotRefUpdate(
299-
snapshot_id=self._snapshot_id,
300-
parent_snapshot_id=self._parent_snapshot_id,
301-
ref_name=self._target_branch,
302-
type=SnapshotRefType.BRANCH,
298+
add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot)
299+
300+
if self._stage_only:
301+
return (
302+
(add_snapshot_update,),
303+
(),
304+
)
305+
else:
306+
return (
307+
(
308+
add_snapshot_update,
309+
SetSnapshotRefUpdate(
310+
snapshot_id=self._snapshot_id,
311+
parent_snapshot_id=self._parent_snapshot_id,
312+
ref_name=self._target_branch,
313+
type=SnapshotRefType.BRANCH,
314+
),
303315
),
304-
),
305-
(
306-
AssertRefSnapshotId(
307-
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
308-
if self._target_branch in self._transaction.table_metadata.refs
309-
else None,
310-
ref=self._target_branch,
316+
(
317+
AssertRefSnapshotId(
318+
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
319+
if self._target_branch in self._transaction.table_metadata.refs
320+
else None,
321+
ref=self._target_branch,
322+
),
311323
),
312-
),
313-
)
324+
)
314325

315326
@property
316327
def snapshot_id(self) -> int:
@@ -360,8 +371,9 @@ def __init__(
360371
branch: str,
361372
commit_uuid: Optional[uuid.UUID] = None,
362373
snapshot_properties: Dict[str, str] = EMPTY_DICT,
374+
stage_only: bool = False,
363375
):
364-
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
376+
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only)
365377
self._predicate = AlwaysFalse()
366378
self._case_sensitive = True
367379

@@ -530,10 +542,11 @@ def __init__(
530542
branch: str,
531543
commit_uuid: Optional[uuid.UUID] = None,
532544
snapshot_properties: Dict[str, str] = EMPTY_DICT,
545+
stage_only: bool = False,
533546
) -> None:
534547
from pyiceberg.table import TableProperties
535548

536-
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
549+
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only)
537550
self._target_size_bytes = property_as_int(
538551
self._transaction.table_metadata.properties,
539552
TableProperties.MANIFEST_TARGET_SIZE_BYTES,
@@ -649,19 +662,22 @@ class UpdateSnapshot:
649662
_transaction: Transaction
650663
_io: FileIO
651664
_branch: str
665+
_stage_only: bool
652666
_snapshot_properties: Dict[str, str]
653667

654668
def __init__(
655669
self,
656670
transaction: Transaction,
657671
io: FileIO,
658672
branch: str,
673+
stage_only: bool = False,
659674
snapshot_properties: Dict[str, str] = EMPTY_DICT,
660675
) -> None:
661676
self._transaction = transaction
662677
self._io = io
663678
self._snapshot_properties = snapshot_properties
664679
self._branch = branch
680+
self._stage_only = stage_only
665681

666682
def fast_append(self) -> _FastAppendFiles:
667683
return _FastAppendFiles(
@@ -670,6 +686,7 @@ def fast_append(self) -> _FastAppendFiles:
670686
io=self._io,
671687
branch=self._branch,
672688
snapshot_properties=self._snapshot_properties,
689+
stage_only=self._stage_only,
673690
)
674691

675692
def merge_append(self) -> _MergeAppendFiles:
@@ -679,6 +696,7 @@ def merge_append(self) -> _MergeAppendFiles:
679696
io=self._io,
680697
branch=self._branch,
681698
snapshot_properties=self._snapshot_properties,
699+
stage_only=self._stage_only,
682700
)
683701

684702
def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles:
@@ -691,6 +709,7 @@ def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles:
691709
io=self._io,
692710
branch=self._branch,
693711
snapshot_properties=self._snapshot_properties,
712+
stage_only=self._stage_only,
694713
)
695714

696715
def delete(self) -> _DeleteFiles:
@@ -700,6 +719,7 @@ def delete(self) -> _DeleteFiles:
700719
io=self._io,
701720
branch=self._branch,
702721
snapshot_properties=self._snapshot_properties,
722+
stage_only=self._stage_only,
703723
)
704724

705725

tests/integration/test_writes/test_writes.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,3 +2098,174 @@ def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSessio
20982098
)
20992099
assert main_df.count() == 3
21002100
assert branch_df.count() == 2
2101+
2102+
2103+
@pytest.mark.integration
2104+
@pytest.mark.parametrize("format_version", [1, 2])
2105+
def test_stage_only_delete(
2106+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
2107+
) -> None:
2108+
identifier = f"default.test_stage_only_delete_files_v{format_version}"
2109+
iceberg_spec = PartitionSpec(
2110+
*[PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="integer_partition")]
2111+
)
2112+
tbl = _create_table(
2113+
session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null], iceberg_spec
2114+
)
2115+
2116+
current_snapshot = tbl.metadata.current_snapshot_id
2117+
assert current_snapshot is not None
2118+
2119+
original_count = len(tbl.scan().to_arrow())
2120+
assert original_count == 3
2121+
2122+
files_to_delete = []
2123+
for file_task in tbl.scan().plan_files():
2124+
files_to_delete.append(file_task.file)
2125+
assert len(files_to_delete) > 0
2126+
2127+
with tbl.transaction() as txn:
2128+
with txn.update_snapshot(stage_only=True).delete() as delete:
2129+
delete.delete_by_predicate(EqualTo("int", 9))
2130+
2131+
# a new delete snapshot is added
2132+
snapshots = tbl.snapshots()
2133+
assert len(snapshots) == 2
2134+
2135+
rows = spark.sql(
2136+
f"""
2137+
SELECT operation, summary
2138+
FROM {identifier}.snapshots
2139+
ORDER BY committed_at ASC
2140+
"""
2141+
).collect()
2142+
operations = [row.operation for row in rows]
2143+
assert operations == ["append", "delete"]
2144+
2145+
# snapshot main ref has not changed
2146+
assert current_snapshot == tbl.metadata.current_snapshot_id
2147+
assert len(tbl.scan().to_arrow()) == original_count
2148+
2149+
2150+
@pytest.mark.integration
2151+
@pytest.mark.parametrize("format_version", [1, 2])
2152+
def test_stage_only_fast_append(
2153+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
2154+
) -> None:
2155+
identifier = f"default.test_stage_only_fast_append_files_v{format_version}"
2156+
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])
2157+
2158+
current_snapshot = tbl.metadata.current_snapshot_id
2159+
assert current_snapshot is not None
2160+
2161+
original_count = len(tbl.scan().to_arrow())
2162+
assert original_count == 3
2163+
2164+
with tbl.transaction() as txn:
2165+
with txn.update_snapshot(stage_only=True).fast_append() as fast_append:
2166+
for data_file in _dataframe_to_data_files(
2167+
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2168+
):
2169+
fast_append.append_data_file(data_file=data_file)
2170+
2171+
# Main ref has not changed and data is not yet appended
2172+
assert current_snapshot == tbl.metadata.current_snapshot_id
2173+
assert len(tbl.scan().to_arrow()) == original_count
2174+
2175+
# There should be a new staged snapshot
2176+
snapshots = tbl.snapshots()
2177+
assert len(snapshots) == 2
2178+
2179+
rows = spark.sql(
2180+
f"""
2181+
SELECT operation, summary
2182+
FROM {identifier}.snapshots
2183+
ORDER BY committed_at ASC
2184+
"""
2185+
).collect()
2186+
operations = [row.operation for row in rows]
2187+
assert operations == ["append", "append"]
2188+
2189+
2190+
@pytest.mark.integration
2191+
@pytest.mark.parametrize("format_version", [1, 2])
2192+
def test_stage_only_merge_append(
2193+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
2194+
) -> None:
2195+
identifier = f"default.test_stage_only_merge_append_files_v{format_version}"
2196+
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])
2197+
2198+
current_snapshot = tbl.metadata.current_snapshot_id
2199+
assert current_snapshot is not None
2200+
2201+
original_count = len(tbl.scan().to_arrow())
2202+
assert original_count == 3
2203+
2204+
with tbl.transaction() as txn:
2205+
with txn.update_snapshot(stage_only=True).merge_append() as merge_append:
2206+
for data_file in _dataframe_to_data_files(
2207+
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2208+
):
2209+
merge_append.append_data_file(data_file=data_file)
2210+
2211+
# Main ref has not changed and data is not yet appended
2212+
assert current_snapshot == tbl.metadata.current_snapshot_id
2213+
assert len(tbl.scan().to_arrow()) == original_count
2214+
2215+
# There should be a new staged snapshot
2216+
snapshots = tbl.snapshots()
2217+
assert len(snapshots) == 2
2218+
2219+
rows = spark.sql(
2220+
f"""
2221+
SELECT operation, summary
2222+
FROM {identifier}.snapshots
2223+
ORDER BY committed_at ASC
2224+
"""
2225+
).collect()
2226+
operations = [row.operation for row in rows]
2227+
assert operations == ["append", "append"]
2228+
2229+
2230+
@pytest.mark.integration
2231+
@pytest.mark.parametrize("format_version", [1, 2])
2232+
def test_stage_only_overwrite_files(
2233+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
2234+
) -> None:
2235+
identifier = f"default.test_stage_only_overwrite_files_v{format_version}"
2236+
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])
2237+
2238+
current_snapshot = tbl.metadata.current_snapshot_id
2239+
assert current_snapshot is not None
2240+
2241+
original_count = len(tbl.scan().to_arrow())
2242+
assert original_count == 3
2243+
2244+
files_to_delete = []
2245+
for file_task in tbl.scan().plan_files():
2246+
files_to_delete.append(file_task.file)
2247+
assert len(files_to_delete) > 0
2248+
2249+
with tbl.transaction() as txn:
2250+
with txn.update_snapshot(stage_only=True).overwrite() as overwrite:
2251+
for data_file in _dataframe_to_data_files(
2252+
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
2253+
):
2254+
overwrite.append_data_file(data_file=data_file)
2255+
overwrite.delete_data_file(files_to_delete[0])
2256+
2257+
assert current_snapshot == tbl.metadata.current_snapshot_id
2258+
assert len(tbl.scan().to_arrow()) == original_count
2259+
2260+
snapshots = tbl.snapshots()
2261+
assert len(snapshots) == 2
2262+
2263+
rows = spark.sql(
2264+
f"""
2265+
SELECT operation, summary
2266+
FROM {identifier}.snapshots
2267+
ORDER BY committed_at ASC
2268+
"""
2269+
).collect()
2270+
operations = [row.operation for row in rows]
2271+
assert operations == ["append", "overwrite"]

0 commit comments

Comments
 (0)