Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pyiceberg/table/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ class Snapshot(IcebergBaseModel):
manifest_list: str = Field(alias="manifest-list", description="Location of the snapshot's manifest list file")
summary: Optional[Summary] = Field(default=None)
schema_id: Optional[int] = Field(alias="schema-id", default=None)
first_row_id: Optional[int] = Field(
alias="first-row-id", default=None, description="assigned to the first row in the first data file in the first manifest"
)
added_rows: Optional[int] = Field(
alias="added-rows", default=None, description="The upper bound of the number of rows with assigned row IDs"
)

def __str__(self) -> str:
"""Return the string representation of the Snapshot class."""
Expand All @@ -253,6 +259,22 @@ def __str__(self) -> str:
result_str = f"{operation}id={self.snapshot_id}{parent_id}{schema_id}"
return result_str

def __repr__(self) -> str:
"""Return the string representation of the Snapshot class."""
fields = [
f"snapshot_id={self.snapshot_id}",
f"parent_snapshot_id={self.parent_snapshot_id}",
f"sequence_number={self.sequence_number}",
f"timestamp_ms={self.timestamp_ms}",
f"manifest_list='{self.manifest_list}'",
f"summary={repr(self.summary)}" if self.summary else None,
f"schema_id={self.schema_id}" if self.schema_id is not None else None,
f"first_row_id={self.first_row_id}" if self.first_row_id is not None else None,
f"added_rows={self.added_rows}" if self.added_rows is not None else None,
]
filtered_fields = [field for field in fields if field is not None]
return f"Snapshot({', '.join(filtered_fields)})"

def manifests(self, io: FileIO) -> List[ManifestFile]:
"""Return the manifests for the given snapshot."""
return list(_manifests(io, self.manifest_list))
Expand Down
16 changes: 16 additions & 0 deletions pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,29 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe
f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} "
f"older than last sequence number {base_metadata.last_sequence_number}"
)
elif base_metadata.format_version >= 3 and update.snapshot.first_row_id is None:
raise ValueError("Cannot add snapshot without first row id")
elif (
base_metadata.format_version >= 3
and update.snapshot.first_row_id is not None
and base_metadata.next_row_id is not None
and update.snapshot.first_row_id < base_metadata.next_row_id
):
raise ValueError(
f"Cannot add a snapshot with first row id smaller than the table's next-row-id {update.snapshot.first_row_id} < {base_metadata.next_row_id}"
)

context.add_update(update)
return base_metadata.model_copy(
update={
"last_updated_ms": update.snapshot.timestamp_ms,
"last_sequence_number": update.snapshot.sequence_number,
"snapshots": base_metadata.snapshots + [update.snapshot],
"next_row_id": base_metadata.next_row_id + update.snapshot.added_rows
if base_metadata.format_version >= 3
and base_metadata.next_row_id is not None
and update.snapshot.added_rows is not None
else None,
}
)

Expand Down
19 changes: 19 additions & 0 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,19 @@ def delete_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]:
self._deleted_data_files.add(data_file)
return self

def _calculate_added_rows(self, manifests: List[ManifestFile]) -> int:
"""Calculate the number of added rows from a list of manifest files."""
added_rows = 0
for manifest in manifests:
if manifest.added_snapshot_id is None or manifest.added_snapshot_id == self._snapshot_id:
if manifest.added_rows_count is None:
raise ValueError(
"Cannot determine number of added rows in snapshot because "
f"the entry for manifest {manifest.manifest_path} is missing the field `added-rows-count`"
)
added_rows += manifest.added_rows_count
return added_rows

@abstractmethod
def _deleted_entries(self) -> List[ManifestEntry]: ...

Expand Down Expand Up @@ -284,13 +297,19 @@ def _commit(self) -> UpdatesAndRequirements:
) as writer:
writer.add_manifests(new_manifests)

first_row_id: Optional[int] = None

if self._transaction.table_metadata.format_version >= 3:
first_row_id = self._transaction.table_metadata.next_row_id

snapshot = Snapshot(
snapshot_id=self._snapshot_id,
parent_snapshot_id=self._parent_snapshot_id,
manifest_list=manifest_list_file_path,
sequence_number=next_sequence_number,
summary=summary,
schema_id=self._transaction.table_metadata.current_schema_id,
first_row_id=first_row_id,
)

add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot)
Expand Down
15 changes: 14 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from pyiceberg.schema import Accessor, Schema
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2, TableMetadataV3
from pyiceberg.transforms import DayTransform, IdentityTransform
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -920,6 +920,7 @@ def generate_snapshot(
"table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1",
"location": "s3://bucket/test/location",
"last-sequence-number": 34,
"next-row-id": 1,
"last-updated-ms": 1602638573590,
"last-column-id": 3,
"current-schema-id": 1,
Expand Down Expand Up @@ -2489,6 +2490,18 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table:
)


@pytest.fixture
def table_v3(example_table_metadata_v3: Dict[str, Any]) -> Table:
table_metadata = TableMetadataV3(**example_table_metadata_v3)
return Table(
identifier=("database", "table"),
metadata=table_metadata,
metadata_location=f"{table_metadata.location}/uuid.metadata.json",
io=load_file_io(),
catalog=NoopCatalog("NoopCatalog"),
)


@pytest.fixture
def table_v2_orc(example_table_metadata_v2: Dict[str, Any]) -> Table:
import copy
Expand Down
40 changes: 39 additions & 1 deletion tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
StringType,
UUIDType,
)
from utils import _create_table
from utils import TABLE_SCHEMA, _create_table


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -2490,3 +2490,41 @@ def test_stage_only_overwrite_files(
assert operations == ["append", "append", "delete", "append", "append"]

assert parent_snapshot_id == [None, first_snapshot, second_snapshot, second_snapshot, second_snapshot]


@pytest.mark.skip("V3 writer support is not enabled.")
@pytest.mark.integration
def test_v3_write_and_read_row_lineage(spark: SparkSession, session_catalog: Catalog) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the test I would want to write for testing with Spark. Unfortunately, it fails because we don't have v3 writer support (which is needed to append new rows)

"""Test writing to a v3 table and reading with Spark."""
identifier = "default.test_v3_write_and_read"
tbl = _create_table(session_catalog, identifier, {"format-version": "3"})
assert tbl.format_version == 3, f"Expected v3, got: v{tbl.format_version}"
initial_next_row_id = tbl.metadata.next_row_id or 0

test_data = pa.Table.from_pydict(
{
"bool": [True, False, True],
"string": ["a", "b", "c"],
"string_long": ["a_long", "b_long", "c_long"],
"int": [1, 2, 3],
"long": [11, 22, 33],
"float": [1.1, 2.2, 3.3],
"double": [1.11, 2.22, 3.33],
"timestamp": [datetime(2023, 1, 1, 1, 1, 1), datetime(2023, 2, 2, 2, 2, 2), datetime(2023, 3, 3, 3, 3, 3)],
"timestamptz": [
datetime(2023, 1, 1, 1, 1, 1, tzinfo=pytz.utc),
datetime(2023, 2, 2, 2, 2, 2, tzinfo=pytz.utc),
datetime(2023, 3, 3, 3, 3, 3, tzinfo=pytz.utc),
],
"date": [date(2023, 1, 1), date(2023, 2, 2), date(2023, 3, 3)],
"binary": [b"\x01", b"\x02", b"\x03"],
"fixed": [b"1234567890123456", b"1234567890123456", b"1234567890123456"],
},
schema=TABLE_SCHEMA.as_arrow(),
)

tbl.append(test_data)

assert tbl.metadata.next_row_id == initial_next_row_id + len(test_data), (
"Expected next_row_id to be incremented by the number of added rows"
)
54 changes: 54 additions & 0 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,3 +1521,57 @@ def test_remove_partition_statistics_update_with_invalid_snapshot_id(table_v2_wi
table_v2_with_statistics.metadata,
(RemovePartitionStatisticsUpdate(snapshot_id=123456789),),
)


def test_add_snapshot_update_fails_without_first_row_id(table_v3: Table) -> None:
new_snapshot = Snapshot(
snapshot_id=25,
parent_snapshot_id=19,
sequence_number=200,
timestamp_ms=1602638593590,
manifest_list="s3:/a/b/c.avro",
summary=Summary(Operation.APPEND),
schema_id=3,
)

with pytest.raises(
ValueError,
match="Cannot add snapshot without first row id",
):
update_table_metadata(table_v3.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),))


def test_add_snapshot_update_fails_with_smaller_first_row_id(table_v3: Table) -> None:
new_snapshot = Snapshot(
snapshot_id=25,
parent_snapshot_id=19,
sequence_number=200,
timestamp_ms=1602638593590,
manifest_list="s3:/a/b/c.avro",
summary=Summary(Operation.APPEND),
schema_id=3,
first_row_id=0,
)

with pytest.raises(
ValueError,
match="Cannot add a snapshot with first row id smaller than the table's next-row-id",
):
update_table_metadata(table_v3.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),))


def test_add_snapshot_update_updates_next_row_id(table_v3: Table) -> None:
new_snapshot = Snapshot(
snapshot_id=25,
parent_snapshot_id=19,
sequence_number=200,
timestamp_ms=1602638593590,
manifest_list="s3:/a/b/c.avro",
summary=Summary(Operation.APPEND),
schema_id=3,
first_row_id=2,
added_rows=10,
)

new_metadata = update_table_metadata(table_v3.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),))
assert new_metadata.next_row_id == 11