diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 60ad7219e1..13ce52b7eb 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -244,6 +244,12 @@ class Snapshot(IcebergBaseModel): manifest_list: str = Field(alias="manifest-list", description="Location of the snapshot's manifest list file") summary: Optional[Summary] = Field(default=None) schema_id: Optional[int] = Field(alias="schema-id", default=None) + first_row_id: Optional[int] = Field( + alias="first-row-id", default=None, description="assigned to the first row in the first data file in the first manifest" + ) + added_rows: Optional[int] = Field( + alias="added-rows", default=None, description="The upper bound of the number of rows with assigned row IDs" + ) def __str__(self) -> str: """Return the string representation of the Snapshot class.""" @@ -253,6 +259,22 @@ def __str__(self) -> str: result_str = f"{operation}id={self.snapshot_id}{parent_id}{schema_id}" return result_str + def __repr__(self) -> str: + """Return the string representation of the Snapshot class.""" + fields = [ + f"snapshot_id={self.snapshot_id}", + f"parent_snapshot_id={self.parent_snapshot_id}", + f"sequence_number={self.sequence_number}", + f"timestamp_ms={self.timestamp_ms}", + f"manifest_list='{self.manifest_list}'", + f"summary={repr(self.summary)}" if self.summary else None, + f"schema_id={self.schema_id}" if self.schema_id is not None else None, + f"first_row_id={self.first_row_id}" if self.first_row_id is not None else None, + f"added_rows={self.added_rows}" if self.added_rows is not None else None, + ] + filtered_fields = [field for field in fields if field is not None] + return f"Snapshot({', '.join(filtered_fields)})" + def manifests(self, io: FileIO) -> List[ManifestFile]: """Return the manifests for the given snapshot.""" return list(_manifests(io, self.manifest_list)) diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 038b952bb3..bcbe429688 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -437,6 +437,17 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} " f"older than last sequence number {base_metadata.last_sequence_number}" ) + elif base_metadata.format_version >= 3 and update.snapshot.first_row_id is None: + raise ValueError("Cannot add snapshot without first row id") + elif ( + base_metadata.format_version >= 3 + and update.snapshot.first_row_id is not None + and base_metadata.next_row_id is not None + and update.snapshot.first_row_id < base_metadata.next_row_id + ): + raise ValueError( + f"Cannot add a snapshot with first row id smaller than the table's next-row-id {update.snapshot.first_row_id} < {base_metadata.next_row_id}" + ) context.add_update(update) return base_metadata.model_copy( @@ -444,6 +455,11 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe "last_updated_ms": update.snapshot.timestamp_ms, "last_sequence_number": update.snapshot.sequence_number, "snapshots": base_metadata.snapshots + [update.snapshot], + "next_row_id": base_metadata.next_row_id + update.snapshot.added_rows + if base_metadata.format_version >= 3 + and base_metadata.next_row_id is not None + and update.snapshot.added_rows is not None + else None, } ) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 148dacd22f..aed7ec0449 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -157,6 +157,19 @@ def delete_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]: self._deleted_data_files.add(data_file) return self + def _calculate_added_rows(self, manifests: List[ManifestFile]) -> int: + """Calculate the number of added rows from a list of manifest files.""" + added_rows = 0 + for manifest in manifests: + if manifest.added_snapshot_id is None or manifest.added_snapshot_id == self._snapshot_id: + if manifest.added_rows_count is None: + raise ValueError( + "Cannot determine number of added rows in snapshot because " + f"the entry for manifest {manifest.manifest_path} is missing the field `added-rows-count`" + ) + added_rows += manifest.added_rows_count + return added_rows + @abstractmethod def _deleted_entries(self) -> List[ManifestEntry]: ... @@ -284,6 +297,11 @@ def _commit(self) -> UpdatesAndRequirements: ) as writer: writer.add_manifests(new_manifests) + first_row_id: Optional[int] = None + + if self._transaction.table_metadata.format_version >= 3: + first_row_id = self._transaction.table_metadata.next_row_id + snapshot = Snapshot( snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, @@ -291,6 +309,7 @@ def _commit(self) -> UpdatesAndRequirements: sequence_number=next_sequence_number, summary=summary, schema_id=self._transaction.table_metadata.current_schema_id, + first_row_id=first_row_id, ) add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot) diff --git a/tests/conftest.py b/tests/conftest.py index 21f33858d5..6734932993 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,7 +72,7 @@ from pyiceberg.schema import Accessor, Schema from pyiceberg.serializers import ToOutputFile from pyiceberg.table import FileScanTask, Table -from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2 +from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2, TableMetadataV3 from pyiceberg.transforms import DayTransform, IdentityTransform from pyiceberg.types import ( BinaryType, @@ -920,6 +920,7 @@ def generate_snapshot( "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", "location": "s3://bucket/test/location", "last-sequence-number": 34, + "next-row-id": 1, "last-updated-ms": 1602638573590, "last-column-id": 3, "current-schema-id": 1, @@ -2489,6 +2490,18 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table: ) +@pytest.fixture +def table_v3(example_table_metadata_v3: Dict[str, Any]) -> Table: + table_metadata = TableMetadataV3(**example_table_metadata_v3) + return Table( + identifier=("database", "table"), + metadata=table_metadata, + metadata_location=f"{table_metadata.location}/uuid.metadata.json", + io=load_file_io(), + catalog=NoopCatalog("NoopCatalog"), + ) + + @pytest.fixture def table_v2_orc(example_table_metadata_v2: Dict[str, Any]) -> Table: import copy diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index c7d79f2c37..dcd465a7ca 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -64,7 +64,7 @@ StringType, UUIDType, ) -from utils import _create_table +from utils import TABLE_SCHEMA, _create_table @pytest.fixture(scope="session", autouse=True) @@ -2490,3 +2490,41 @@ def test_stage_only_overwrite_files( assert operations == ["append", "append", "delete", "append", "append"] assert parent_snapshot_id == [None, first_snapshot, second_snapshot, second_snapshot, second_snapshot] + + +@pytest.mark.skip("V3 writer support is not enabled.") +@pytest.mark.integration +def test_v3_write_and_read_row_lineage(spark: SparkSession, session_catalog: Catalog) -> None: + """Test writing to a v3 table and reading with Spark.""" + identifier = "default.test_v3_write_and_read" + tbl = _create_table(session_catalog, identifier, {"format-version": "3"}) + assert tbl.format_version == 3, f"Expected v3, got: v{tbl.format_version}" + initial_next_row_id = tbl.metadata.next_row_id or 0 + + test_data = pa.Table.from_pydict( + { + "bool": [True, False, True], + "string": ["a", "b", "c"], + "string_long": ["a_long", "b_long", "c_long"], + "int": [1, 2, 3], + "long": [11, 22, 33], + "float": [1.1, 2.2, 3.3], + "double": [1.11, 2.22, 3.33], + "timestamp": [datetime(2023, 1, 1, 1, 1, 1), datetime(2023, 2, 2, 2, 2, 2), datetime(2023, 3, 3, 3, 3, 3)], + "timestamptz": [ + datetime(2023, 1, 1, 1, 1, 1, tzinfo=pytz.utc), + datetime(2023, 2, 2, 2, 2, 2, tzinfo=pytz.utc), + datetime(2023, 3, 3, 3, 3, 3, tzinfo=pytz.utc), + ], + "date": [date(2023, 1, 1), date(2023, 2, 2), date(2023, 3, 3)], + "binary": [b"\x01", b"\x02", b"\x03"], + "fixed": [b"1234567890123456", b"1234567890123456", b"1234567890123456"], + }, + schema=TABLE_SCHEMA.as_arrow(), + ) + + tbl.append(test_data) + + assert tbl.metadata.next_row_id == initial_next_row_id + len(test_data), ( + "Expected next_row_id to be incremented by the number of added rows" + ) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 95c5d822aa..5cc68b62a4 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -1521,3 +1521,57 @@ def test_remove_partition_statistics_update_with_invalid_snapshot_id(table_v2_wi table_v2_with_statistics.metadata, (RemovePartitionStatisticsUpdate(snapshot_id=123456789),), ) + + +def test_add_snapshot_update_fails_without_first_row_id(table_v3: Table) -> None: + new_snapshot = Snapshot( + snapshot_id=25, + parent_snapshot_id=19, + sequence_number=200, + timestamp_ms=1602638593590, + manifest_list="s3:/a/b/c.avro", + summary=Summary(Operation.APPEND), + schema_id=3, + ) + + with pytest.raises( + ValueError, + match="Cannot add snapshot without first row id", + ): + update_table_metadata(table_v3.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),)) + + +def test_add_snapshot_update_fails_with_smaller_first_row_id(table_v3: Table) -> None: + new_snapshot = Snapshot( + snapshot_id=25, + parent_snapshot_id=19, + sequence_number=200, + timestamp_ms=1602638593590, + manifest_list="s3:/a/b/c.avro", + summary=Summary(Operation.APPEND), + schema_id=3, + first_row_id=0, + ) + + with pytest.raises( + ValueError, + match="Cannot add a snapshot with first row id smaller than the table's next-row-id", + ): + update_table_metadata(table_v3.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),)) + + +def test_add_snapshot_update_updates_next_row_id(table_v3: Table) -> None: + new_snapshot = Snapshot( + snapshot_id=25, + parent_snapshot_id=19, + sequence_number=200, + timestamp_ms=1602638593590, + manifest_list="s3:/a/b/c.avro", + summary=Summary(Operation.APPEND), + schema_id=3, + first_row_id=2, + added_rows=10, + ) + + new_metadata = update_table_metadata(table_v3.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),)) + assert new_metadata.next_row_id == 11