diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index c1e0f61137..7c63aa79a1 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -398,7 +398,9 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE expr = Or(expr, match_partition_expression) return expr - def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: Optional[str]) -> _FastAppendFiles: + def _append_snapshot_producer( + self, snapshot_properties: Dict[str, str], branch: Optional[str] = MAIN_BRANCH + ) -> _FastAppendFiles: """Determine the append type based on table properties. Args: @@ -439,9 +441,6 @@ def update_snapshot( Returns: A new UpdateSnapshot """ - if branch is None: - branch = MAIN_BRANCH - return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties) def update_statistics(self) -> UpdateStatistics: @@ -453,7 +452,7 @@ def update_statistics(self) -> UpdateStatistics: """ return UpdateStatistics(transaction=self) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to a table transaction. @@ -492,7 +491,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, append_files.append_data_file(data_file) def dynamic_partition_overwrite( - self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH ) -> None: """ Shorthand for overwriting existing partitions with a PyArrow table. @@ -559,7 +558,7 @@ def overwrite( overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for adding a table overwrite with a PyArrow table to the transaction. @@ -619,7 +618,7 @@ def delete( delete_filter: Union[str, BooleanExpression], snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for deleting record from a table. @@ -722,7 +721,7 @@ def upsert( when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -807,7 +806,7 @@ def upsert( case_sensitive=case_sensitive, ) - if branch is not None: + if branch in self.table_metadata.refs: matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch) matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader() @@ -1303,7 +1302,7 @@ def upsert( when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -1350,7 +1349,7 @@ def upsert( branch=branch, ) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to the table. @@ -1363,7 +1362,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch) def dynamic_partition_overwrite( - self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH ) -> None: """Shorthand for dynamic overwriting the table with a PyArrow table. @@ -1382,7 +1381,7 @@ def overwrite( overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for overwriting the table with a PyArrow table. @@ -1415,7 +1414,7 @@ def delete( delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = None, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand for deleting rows from the table. diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 9c2ae29cdd..9ab29815e9 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -295,8 +295,10 @@ def new_snapshot_id(self) -> int: return snapshot_id - def snapshot_by_name(self, name: str) -> Optional[Snapshot]: + def snapshot_by_name(self, name: Optional[str]) -> Optional[Snapshot]: """Return the snapshot referenced by the given name or null if no such reference exists.""" + if name is None: + name = MAIN_BRANCH if ref := self.refs.get(name): return self.snapshot_by_id(ref.snapshot_id) return None diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 2ed03ced73..42d7a9c2b7 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -110,7 +110,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): _manifest_num_counter: itertools.count[int] _deleted_data_files: Set[DataFile] _compression: AvroCompressionCodec - _target_branch = MAIN_BRANCH + _target_branch: Optional[str] def __init__( self, @@ -119,7 +119,7 @@ def __init__( io: FileIO, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, - branch: str = MAIN_BRANCH, + branch: Optional[str] = MAIN_BRANCH, ) -> None: super().__init__(transaction) self.commit_uuid = commit_uuid or uuid.uuid4() @@ -140,14 +140,13 @@ def __init__( snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None ) - def _validate_target_branch(self, branch: str) -> str: - # Default is already set to MAIN_BRANCH. So branch name can't be None. - if branch is None: - raise ValueError("Invalid branch name: null") - if branch in self._transaction.table_metadata.refs: - ref = self._transaction.table_metadata.refs[branch] - if ref.snapshot_ref_type != SnapshotRefType.BRANCH: - raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots") + def _validate_target_branch(self, branch: Optional[str]) -> Optional[str]: + # if branch is none, write will be written into a staging snapshot + if branch is not None: + if branch in self._transaction.table_metadata.refs: + ref = self._transaction.table_metadata.refs[branch] + if ref.snapshot_ref_type != SnapshotRefType.BRANCH: + raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots") return branch def append_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]: @@ -294,25 +293,33 @@ def _commit(self) -> UpdatesAndRequirements: schema_id=self._transaction.table_metadata.current_schema_id, ) - return ( - ( - AddSnapshotUpdate(snapshot=snapshot), - SetSnapshotRefUpdate( - snapshot_id=self._snapshot_id, - parent_snapshot_id=self._parent_snapshot_id, - ref_name=self._target_branch, - type=SnapshotRefType.BRANCH, + add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot) + + if self._target_branch is None: + return ( + (add_snapshot_update,), + (), + ) + else: + return ( + ( + add_snapshot_update, + SetSnapshotRefUpdate( + snapshot_id=self._snapshot_id, + parent_snapshot_id=self._parent_snapshot_id, + ref_name=self._target_branch, + type=SnapshotRefType.BRANCH, + ), ), - ), - ( - AssertRefSnapshotId( - snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id - if self._target_branch in self._transaction.table_metadata.refs - else None, - ref=self._target_branch, + ( + AssertRefSnapshotId( + snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id + if self._target_branch in self._transaction.table_metadata.refs + else None, + ref=self._target_branch, + ), ), - ), - ) + ) @property def snapshot_id(self) -> int: @@ -359,7 +366,7 @@ def __init__( operation: Operation, transaction: Transaction, io: FileIO, - branch: str, + branch: Optional[str] = MAIN_BRANCH, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, ): @@ -530,7 +537,7 @@ def __init__( operation: Operation, transaction: Transaction, io: FileIO, - branch: str, + branch: Optional[str] = MAIN_BRANCH, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: @@ -651,14 +658,14 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: class UpdateSnapshot: _transaction: Transaction _io: FileIO - _branch: str + _branch: Optional[str] _snapshot_properties: Dict[str, str] def __init__( self, transaction: Transaction, io: FileIO, - branch: str, + branch: Optional[str] = MAIN_BRANCH, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: self._transaction = transaction diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 4b6c6a4d7b..1913f7beb7 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -1151,3 +1151,61 @@ def test_append_multiple_partitions( """ ) assert files_df.count() == 6 + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_dynamic_partition_overwrite_files( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_dynamic_partition_overwrite_files_v{format_version}" + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="bool"), + PartitionField(source_id=4, field_id=1002, transform=IdentityTransform(), name="int"), + ), + properties={"format-version": str(format_version)}, + ) + + tbl.append(arrow_table_with_null) + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + # write to staging snapshot + tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 1), branch=None) + + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + snapshots = tbl.snapshots() + # dynamic partition overwrite will create 2 snapshots, one delete and another append + assert len(snapshots) == 3 + + # Write to main branch + tbl.append(arrow_table_with_null) + + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 6 + snapshots = tbl.snapshots() + assert len(snapshots) == 4 + + rows = spark.sql( + f""" + SELECT operation, parent_id, snapshot_id + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + parent_snapshot_id = [row.parent_id for row in rows] + assert operations == ["append", "delete", "append", "append"] + assert parent_snapshot_id == [None, current_snapshot, current_snapshot, current_snapshot] diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index bda50bd13e..50c7007337 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -2271,3 +2271,149 @@ def test_nanosecond_support_on_catalog( _create_table( session_catalog, identifier, {"format-version": "2"}, schema=arrow_table_schema_with_all_timestamp_precisions ) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_delete( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_delete_files_v{format_version}" + iceberg_spec = PartitionSpec( + *[PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="integer_partition")] + ) + tbl = _create_table( + session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null], iceberg_spec + ) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + tbl.delete("int = 9", branch=None) + + # a new delete snapshot is added + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + # snapshot main ref has not changed + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + + # Write to main branch + tbl.append(arrow_table_with_null) + + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 6 + snapshots = tbl.snapshots() + assert len(snapshots) == 3 + + rows = spark.sql( + f""" + SELECT operation, parent_id + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + parent_snapshot_id = [row.parent_id for row in rows] + assert operations == ["append", "delete", "append"] + # both subsequent parent id should be the first snapshot id + assert parent_snapshot_id == [None, current_snapshot, current_snapshot] + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_fast_append_files_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 3 + + # Write to staging branch + tbl.append(arrow_table_with_null, branch=None) + + # Main ref has not changed and data is not yet appended + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + # There should be a new staged snapshot + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + # Write to main branch + tbl.append(arrow_table_with_null) + + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 6 + snapshots = tbl.snapshots() + assert len(snapshots) == 3 + + rows = spark.sql( + f""" + SELECT operation, parent_id + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + parent_snapshot_id = [row.parent_id for row in rows] + assert operations == ["append", "append", "append"] + # both subsequent parent id should be the first snapshot id + assert parent_snapshot_id == [None, current_snapshot, current_snapshot] + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_stage_only_overwrite_files( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.test_stage_only_overwrite_files_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + first_snapshot = tbl.metadata.current_snapshot_id + + # duplicate data with a new insert + tbl.append(arrow_table_with_null) + + second_snapshot = tbl.metadata.current_snapshot_id + assert second_snapshot is not None + original_count = len(tbl.scan().to_arrow()) + assert original_count == 6 + + # write to non-main branch + tbl.overwrite(arrow_table_with_null, branch=None) + assert second_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + snapshots = tbl.snapshots() + # overwrite will create 2 snapshots + assert len(snapshots) == 4 + + # Write to main branch again + tbl.append(arrow_table_with_null) + + # Main ref has changed + assert second_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 9 + snapshots = tbl.snapshots() + assert len(snapshots) == 5 + + rows = spark.sql( + f""" + SELECT operation, parent_id, snapshot_id + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + operations = [row.operation for row in rows] + parent_snapshot_id = [row.parent_id for row in rows] + assert operations == ["append", "append", "delete", "append", "append"] + + assert parent_snapshot_id == [None, first_snapshot, second_snapshot, second_snapshot, second_snapshot] diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index cc6e008b1e..891d4bbac7 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -770,3 +770,67 @@ def test_transaction_multiple_upserts(catalog: Catalog) -> None: {"id": 1, "name": "Alicia"}, {"id": 2, "name": "Bob"}, ] + + +def test_stage_only_upsert(catalog: Catalog) -> None: + identifier = "default.test_stage_only_dynamic_partition_overwrite_files" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "city", StringType(), required=True), + NestedField(2, "inhabitants", IntegerType(), required=True), + # Mark City as the identifier field, also known as the primary-key + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("city", pa.string(), nullable=False), + pa.field("inhabitants", pa.int32(), nullable=False), + ] + ) + + # Write some data + df = pa.Table.from_pylist( + [ + {"city": "Amsterdam", "inhabitants": 921402}, + {"city": "San Francisco", "inhabitants": 808988}, + {"city": "Drachten", "inhabitants": 45019}, + {"city": "Paris", "inhabitants": 2103000}, + ], + schema=arrow_schema, + ) + + tbl.append(df.slice(0, 1)) + current_snapshot = tbl.metadata.current_snapshot_id + assert current_snapshot is not None + + original_count = len(tbl.scan().to_arrow()) + assert original_count == 1 + + # write to staging snapshot + upd = tbl.upsert(df, branch=None) + assert upd.rows_updated == 0 + assert upd.rows_inserted == 3 + + assert current_snapshot == tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == original_count + snapshots = tbl.snapshots() + assert len(snapshots) == 2 + + # Write to main ref + tbl.append(df.slice(1, 1)) + # Main ref has changed + assert current_snapshot != tbl.metadata.current_snapshot_id + assert len(tbl.scan().to_arrow()) == 2 + snapshots = tbl.snapshots() + assert len(snapshots) == 3 + + sorted_snapshots = sorted(tbl.snapshots(), key=lambda s: s.timestamp_ms) + operations = [snapshot.summary.operation.value if snapshot.summary else None for snapshot in sorted_snapshots] + parent_snapshot_id = [snapshot.parent_snapshot_id for snapshot in sorted_snapshots] + assert operations == ["append", "append", "append"] + # both subsequent parent id should be the first snapshot id + assert parent_snapshot_id == [None, current_snapshot, current_snapshot]