Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Expand Down Expand Up @@ -1942,11 +1943,11 @@ def _check_sequence_number(min_sequence_number: int, manifest: ManifestFile) ->
and (manifest.sequence_number or INITIAL_SEQUENCE_NUMBER) >= min_sequence_number
)

def plan_files(self) -> Iterable[FileScanTask]:
"""Plans the relevant files by filtering on the PartitionSpecs.
def scan_plan_helper(self) -> Iterator[List[ManifestEntry]]:
"""Filter and return manifest entries based on partition and metrics evaluators.

Returns:
List of FileScanTasks that contain both data and delete files.
Iterator of ManifestEntry objects that match the scan's partition filter.
"""
snapshot = self.snapshot()
if not snapshot:
Expand All @@ -1957,8 +1958,6 @@ def plan_files(self) -> Iterable[FileScanTask]:

manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator)

residual_evaluators: Dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator)

manifests = [
manifest_file
for manifest_file in snapshot.manifests(self.io)
Expand All @@ -1972,25 +1971,34 @@ def plan_files(self) -> Iterable[FileScanTask]:

min_sequence_number = _min_sequence_number(manifests)

executor = ExecutorFactory.get_or_create()

return executor.map(
lambda args: _open_manifest(*args),
[
(
self.io,
manifest,
partition_evaluators[manifest.partition_spec_id],
self._build_metrics_evaluator(),
)
for manifest in manifests
if self._check_sequence_number(min_sequence_number, manifest)
],
)

def plan_files(self) -> Iterable[FileScanTask]:
"""Plans the relevant files by filtering on the PartitionSpecs.

Returns:
List of FileScanTasks that contain both data and delete files.
"""
data_entries: List[ManifestEntry] = []
positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER)

executor = ExecutorFactory.get_or_create()
for manifest_entry in chain(
*executor.map(
lambda args: _open_manifest(*args),
[
(
self.io,
manifest,
partition_evaluators[manifest.partition_spec_id],
self._build_metrics_evaluator(),
)
for manifest in manifests
if self._check_sequence_number(min_sequence_number, manifest)
],
)
):
residual_evaluators: Dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator)

for manifest_entry in chain.from_iterable(self.scan_plan_helper()):
data_file = manifest_entry.data_file
if data_file.content == DataFileContent.DATA:
data_entries.append(manifest_entry)
Expand Down
141 changes: 70 additions & 71 deletions pyiceberg/table/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
# under the License.
from __future__ import annotations

import itertools
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple, Union

from pyiceberg.conversions import from_bytes
from pyiceberg.manifest import DataFileContent, ManifestContent, ManifestFile, PartitionFieldSummary
from pyiceberg.expressions import AlwaysTrue, BooleanExpression
from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, ManifestFile, PartitionFieldSummary
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.table.snapshots import Snapshot, ancestors_of
from pyiceberg.types import PrimitiveType
Expand All @@ -32,6 +34,8 @@

from pyiceberg.table import Table

ALWAYS_TRUE = AlwaysTrue()


class InspectTable:
tbl: Table
Expand Down Expand Up @@ -255,10 +259,16 @@ def refs(self) -> "pa.Table":

return pa.Table.from_pylist(ref_results, schema=ref_schema)

def partitions(self, snapshot_id: Optional[int] = None) -> "pa.Table":
def partitions(
self,
snapshot_id: Optional[int] = None,
row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE,
case_sensitive: bool = True,
) -> "pa.Table":
import pyarrow as pa

from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.table import DataScan

table_schema = pa.schema(
[
Expand Down Expand Up @@ -289,85 +299,74 @@ def partitions(self, snapshot_id: Optional[int] = None) -> "pa.Table":
table_schema = pa.unify_schemas([partitions_schema, table_schema])

snapshot = self._get_snapshot(snapshot_id)
executor = ExecutorFactory.get_or_create()
local_partitions_maps = executor.map(self._process_manifest, snapshot.manifests(self.tbl.io))

partitions_map: Dict[Tuple[str, Any], Any] = {}
for local_map in local_partitions_maps:
for partition_record_key, partition_row in local_map.items():
if partition_record_key not in partitions_map:
partitions_map[partition_record_key] = partition_row
else:
existing = partitions_map[partition_record_key]
existing["record_count"] += partition_row["record_count"]
existing["file_count"] += partition_row["file_count"]
existing["total_data_file_size_in_bytes"] += partition_row["total_data_file_size_in_bytes"]
existing["position_delete_record_count"] += partition_row["position_delete_record_count"]
existing["position_delete_file_count"] += partition_row["position_delete_file_count"]
existing["equality_delete_record_count"] += partition_row["equality_delete_record_count"]
existing["equality_delete_file_count"] += partition_row["equality_delete_file_count"]

if partition_row["last_updated_at"] and (
not existing["last_updated_at"] or partition_row["last_updated_at"] > existing["last_updated_at"]
):
existing["last_updated_at"] = partition_row["last_updated_at"]
existing["last_updated_snapshot_id"] = partition_row["last_updated_snapshot_id"]

return pa.Table.from_pylist(
partitions_map.values(),
schema=table_schema,
scan = DataScan(
table_metadata=self.tbl.metadata,
io=self.tbl.io,
row_filter=row_filter,
case_sensitive=case_sensitive,
snapshot_id=snapshot.snapshot_id,
)

def _process_manifest(self, manifest: ManifestFile) -> Dict[Tuple[str, Any], Any]:
partitions_map: Dict[Tuple[str, Any], Any] = {}
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):

for entry in itertools.chain.from_iterable(scan.scan_plan_helper()):
partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
field.name: partition[pos] for pos, field in enumerate(self.tbl.metadata.specs()[entry.data_file.spec_id].fields)
}
entry_snapshot = self.tbl.snapshot_by_id(entry.snapshot_id) if entry.snapshot_id is not None else None
self._update_partitions_map_from_manifest_entry(
partitions_map, entry.data_file, partition_record_dict, entry_snapshot
)

partition_record_key = _convert_to_hashable_type(partition_record_dict)
if partition_record_key not in partitions_map:
partitions_map[partition_record_key] = {
"partition": partition_record_dict,
"spec_id": entry.data_file.spec_id,
"record_count": 0,
"file_count": 0,
"total_data_file_size_in_bytes": 0,
"position_delete_record_count": 0,
"position_delete_file_count": 0,
"equality_delete_record_count": 0,
"equality_delete_file_count": 0,
"last_updated_at": entry_snapshot.timestamp_ms if entry_snapshot else None,
"last_updated_snapshot_id": entry_snapshot.snapshot_id if entry_snapshot else None,
}
return pa.Table.from_pylist(
partitions_map.values(),
schema=table_schema,
)

partition_row = partitions_map[partition_record_key]

if entry_snapshot is not None:
if (
partition_row["last_updated_at"] is None
or partition_row["last_updated_snapshot_id"] < entry_snapshot.timestamp_ms
):
partition_row["last_updated_at"] = entry_snapshot.timestamp_ms
partition_row["last_updated_snapshot_id"] = entry_snapshot.snapshot_id

if entry.data_file.content == DataFileContent.DATA:
partition_row["record_count"] += entry.data_file.record_count
partition_row["file_count"] += 1
partition_row["total_data_file_size_in_bytes"] += entry.data_file.file_size_in_bytes
elif entry.data_file.content == DataFileContent.POSITION_DELETES:
partition_row["position_delete_record_count"] += entry.data_file.record_count
partition_row["position_delete_file_count"] += 1
elif entry.data_file.content == DataFileContent.EQUALITY_DELETES:
partition_row["equality_delete_record_count"] += entry.data_file.record_count
partition_row["equality_delete_file_count"] += 1
else:
raise ValueError(f"Unknown DataFileContent ({entry.data_file.content})")
def _update_partitions_map_from_manifest_entry(
self,
partitions_map: Dict[Tuple[str, Any], Any],
file: DataFile,
partition_record_dict: Dict[str, Any],
snapshot: Optional[Snapshot],
) -> None:
partition_record_key = _convert_to_hashable_type(partition_record_dict)
if partition_record_key not in partitions_map:
partitions_map[partition_record_key] = {
"partition": partition_record_dict,
"spec_id": file.spec_id,
"record_count": 0,
"file_count": 0,
"total_data_file_size_in_bytes": 0,
"position_delete_record_count": 0,
"position_delete_file_count": 0,
"equality_delete_record_count": 0,
"equality_delete_file_count": 0,
"last_updated_at": snapshot.timestamp_ms if snapshot else None,
"last_updated_snapshot_id": snapshot.snapshot_id if snapshot else None,
}

return partitions_map
partition_row = partitions_map[partition_record_key]

if snapshot is not None:
if partition_row["last_updated_at"] is None or partition_row["last_updated_snapshot_id"] < snapshot.timestamp_ms:
partition_row["last_updated_at"] = snapshot.timestamp_ms
partition_row["last_updated_snapshot_id"] = snapshot.snapshot_id

if file.content == DataFileContent.DATA:
partition_row["record_count"] += file.record_count
partition_row["file_count"] += 1
partition_row["total_data_file_size_in_bytes"] += file.file_size_in_bytes
elif file.content == DataFileContent.POSITION_DELETES:
partition_row["position_delete_record_count"] += file.record_count
partition_row["position_delete_file_count"] += 1
elif file.content == DataFileContent.EQUALITY_DELETES:
partition_row["equality_delete_record_count"] += file.record_count
partition_row["equality_delete_file_count"] += 1
else:
raise ValueError(f"Unknown DataFileContent ({file.content})")

def _get_manifests_schema(self) -> "pa.Schema":
import pyarrow as pa
Expand Down
98 changes: 90 additions & 8 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import math
from datetime import date, datetime
from typing import Union

import pyarrow as pa
import pytest
Expand All @@ -26,6 +27,13 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import (
And,
BooleanExpression,
EqualTo,
GreaterThanOrEqual,
LessThan,
)
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.typedef import Properties
Expand Down Expand Up @@ -198,6 +206,14 @@ def _inspect_files_asserts(df: pa.Table, spark_df: DataFrame) -> None:
assert left == right, f"Difference in column {column}: {left} != {right}"


def _check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
lhs = df.to_pandas().sort_values("last_updated_at")
rhs = spark_df.toPandas().sort_values("last_updated_at")
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
assert left == right, f"Difference in column {column}: {left} != {right}"


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_inspect_snapshots(
Expand Down Expand Up @@ -581,18 +597,84 @@ def test_inspect_partitions_partitioned(spark: SparkSession, session_catalog: Ca
"""
)

def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
lhs = df.to_pandas().sort_values("spec_id")
rhs = spark_df.toPandas().sort_values("spec_id")
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
assert left == right, f"Difference in column {column}: {left} != {right}"

tbl = session_catalog.load_table(identifier)
for snapshot in tbl.metadata.snapshots:
df = tbl.inspect.partitions(snapshot_id=snapshot.snapshot_id)
spark_df = spark.sql(f"SELECT * FROM {identifier}.partitions VERSION AS OF {snapshot.snapshot_id}")
check_pyiceberg_df_equals_spark_df(df, spark_df)
_check_pyiceberg_df_equals_spark_df(df, spark_df)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_inspect_partitions_partitioned_with_filter(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = "default.table_metadata_partitions_with_filter"
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

spark.sql(
f"""
CREATE TABLE {identifier} (
name string,
dt date
)
PARTITIONED BY (dt)
"""
)

spark.sql(
f"""
INSERT INTO {identifier} VALUES ('John', CAST('2021-01-01' AS date))
"""
)

spark.sql(
f"""
INSERT INTO {identifier} VALUES ('Doe', CAST('2021-01-05' AS date))
"""
)

spark.sql(
f"""
INSERT INTO {identifier} VALUES ('Jenny', CAST('2021-02-01' AS date))
"""
)

tbl = session_catalog.load_table(identifier)
for snapshot in tbl.metadata.snapshots:
test_cases: list[tuple[Union[str, BooleanExpression], str]] = [
("dt >= '2021-01-01'", "partition.dt >= '2021-01-01'"),
(GreaterThanOrEqual("dt", "2021-01-01"), "partition.dt >= '2021-01-01'"),
("dt >= '2021-01-01' and dt < '2021-03-01'", "partition.dt >= '2021-01-01' AND partition.dt < '2021-03-01'"),
(
And(GreaterThanOrEqual("dt", "2021-01-01"), LessThan("dt", "2021-03-01")),
"partition.dt >= '2021-01-01' AND partition.dt < '2021-03-01'",
),
("dt == '2021-02-01'", "partition.dt = '2021-02-01'"),
(EqualTo("dt", "2021-02-01"), "partition.dt = '2021-02-01'"),
]
for filter_predicate_lt, filter_predicate_rt in test_cases:
df = tbl.inspect.partitions(snapshot_id=snapshot.snapshot_id, row_filter=filter_predicate_lt)
spark_df = spark.sql(
f"SELECT * FROM {identifier}.partitions VERSION AS OF {snapshot.snapshot_id} WHERE {filter_predicate_rt}"
)
_check_pyiceberg_df_equals_spark_df(df, spark_df)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog")])
def test_inspect_partitions_partitioned_transform_with_filter(spark: SparkSession, catalog: Catalog) -> None:
for table_name, predicate, partition_predicate in [
("test_partitioned_by_identity", "ts >= '2023-03-05T00:00:00+00:00'", "ts >= '2023-03-05T00:00:00+00:00'"),
("test_partitioned_by_years", "dt >= '2023-03-05'", "dt_year >= 53"),
("test_partitioned_by_months", "dt >= '2023-03-05'", "dt_month >= 638"),
("test_partitioned_by_days", "ts >= '2023-03-05T00:00:00+00:00'", "ts_day >= '2023-03-05'"),
]:
table = catalog.load_table(f"default.{table_name}")
df = table.inspect.partitions(row_filter=predicate)
expected_df = spark.sql(f"select * from default.{table_name}.partitions where partition.{partition_predicate}")
assert len(df.to_pandas()) == len(expected_df.toPandas())


@pytest.mark.integration
Expand Down