Skip to content

Commit ab8c31b

Browse files
committed
feat: add read throughput micro-benchmark for ArrowScan configurations
1 parent 13feb8d commit ab8c31b

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Read throughput micro-benchmark for ArrowScan configurations.
18+
19+
Measures records/sec and peak Arrow memory across streaming, concurrent_files,
20+
and batch_size configurations introduced for issue #3036.
21+
22+
Memory is measured using pa.total_allocated_bytes() which tracks PyArrow's C++
23+
memory pool (Arrow buffers, Parquet decompression), not Python heap allocations.
24+
25+
Run with: uv run pytest tests/benchmark/test_read_benchmark.py -v -s -m benchmark
26+
"""
27+
28+
import gc
29+
import statistics
30+
import timeit
31+
from datetime import datetime, timezone
32+
33+
import pyarrow as pa
34+
import pyarrow.parquet as pq
35+
import pytest
36+
37+
from pyiceberg.catalog.sql import SqlCatalog
38+
from pyiceberg.table import Table
39+
40+
NUM_FILES = 32
41+
ROWS_PER_FILE = 500_000
42+
TOTAL_ROWS = NUM_FILES * ROWS_PER_FILE
43+
NUM_RUNS = 3
44+
45+
46+
def _generate_parquet_file(path: str, num_rows: int, seed: int) -> pa.Schema:
47+
"""Write a synthetic Parquet file and return its schema."""
48+
table = pa.table(
49+
{
50+
"id": pa.array(range(seed, seed + num_rows), type=pa.int64()),
51+
"value": pa.array([float(i) * 0.1 for i in range(num_rows)], type=pa.float64()),
52+
"label": pa.array([f"row_{i}" for i in range(num_rows)], type=pa.string()),
53+
"flag": pa.array([i % 2 == 0 for i in range(num_rows)], type=pa.bool_()),
54+
"ts": pa.array([datetime.now(timezone.utc)] * num_rows, type=pa.timestamp("us", tz="UTC")),
55+
}
56+
)
57+
pq.write_table(table, path)
58+
return table.schema
59+
60+
61+
@pytest.fixture(scope="session")
62+
def benchmark_table(tmp_path_factory: pytest.TempPathFactory) -> Table:
63+
"""Create a catalog and table with synthetic Parquet files for benchmarking."""
64+
warehouse_path = str(tmp_path_factory.mktemp("benchmark_warehouse"))
65+
catalog = SqlCatalog(
66+
"benchmark",
67+
uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db",
68+
warehouse=f"file://{warehouse_path}",
69+
)
70+
catalog.create_namespace("default")
71+
72+
# Generate files and append to table
73+
table = None
74+
for i in range(NUM_FILES):
75+
file_path = f"{warehouse_path}/data_{i}.parquet"
76+
_generate_parquet_file(file_path, ROWS_PER_FILE, seed=i * ROWS_PER_FILE)
77+
78+
file_table = pq.read_table(file_path)
79+
if table is None:
80+
table = catalog.create_table("default.benchmark_read", schema=file_table.schema)
81+
table.append(file_table)
82+
83+
assert table is not None
84+
return table
85+
86+
87+
@pytest.mark.benchmark
88+
@pytest.mark.parametrize(
89+
"streaming,concurrent_files,batch_size",
90+
[
91+
pytest.param(False, 1, None, id="default"),
92+
pytest.param(True, 1, None, id="streaming-cf1"),
93+
pytest.param(True, 2, None, id="streaming-cf2"),
94+
pytest.param(True, 4, None, id="streaming-cf4"),
95+
pytest.param(True, 8, None, id="streaming-cf8"),
96+
pytest.param(True, 16, None, id="streaming-cf16"),
97+
],
98+
)
99+
def test_read_throughput(
100+
benchmark_table: Table,
101+
streaming: bool,
102+
concurrent_files: int,
103+
batch_size: int | None,
104+
) -> None:
105+
"""Measure records/sec and peak Arrow memory for a scan configuration."""
106+
effective_batch_size = batch_size or 131_072 # PyArrow default
107+
if streaming:
108+
config_str = f"streaming=True, concurrent_files={concurrent_files}, batch_size={effective_batch_size}"
109+
else:
110+
config_str = f"streaming=False (executor.map, all files parallel), batch_size={effective_batch_size}"
111+
print(f"\n--- ArrowScan Read Throughput Benchmark ---")
112+
print(f"Config: {config_str}")
113+
print(f" Files: {NUM_FILES}, Rows per file: {ROWS_PER_FILE}, Total rows: {TOTAL_ROWS}")
114+
115+
elapsed_times: list[float] = []
116+
throughputs: list[float] = []
117+
peak_memories: list[int] = []
118+
119+
for run in range(NUM_RUNS):
120+
# Measure throughput
121+
gc.collect()
122+
pa.default_memory_pool().release_unused()
123+
baseline_mem = pa.total_allocated_bytes()
124+
peak_mem = baseline_mem
125+
126+
start = timeit.default_timer()
127+
total_rows = 0
128+
for batch in benchmark_table.scan().to_arrow_batch_reader(
129+
batch_size=batch_size,
130+
streaming=streaming,
131+
concurrent_files=concurrent_files,
132+
):
133+
total_rows += len(batch)
134+
current_mem = pa.total_allocated_bytes()
135+
if current_mem > peak_mem:
136+
peak_mem = current_mem
137+
elapsed = timeit.default_timer() - start
138+
139+
peak_above_baseline = peak_mem - baseline_mem
140+
rows_per_sec = total_rows / elapsed if elapsed > 0 else 0
141+
elapsed_times.append(elapsed)
142+
throughputs.append(rows_per_sec)
143+
peak_memories.append(peak_above_baseline)
144+
145+
print(
146+
f" Run {run + 1}: {elapsed:.2f}s, {rows_per_sec:,.0f} rows/s, "
147+
f"peak arrow mem: {peak_above_baseline / (1024 * 1024):.1f} MB"
148+
)
149+
150+
assert total_rows == TOTAL_ROWS, f"Expected {TOTAL_ROWS} rows, got {total_rows}"
151+
152+
mean_elapsed = statistics.mean(elapsed_times)
153+
stdev_elapsed = statistics.stdev(elapsed_times) if len(elapsed_times) > 1 else 0.0
154+
mean_throughput = statistics.mean(throughputs)
155+
mean_peak_mem = statistics.mean(peak_memories)
156+
157+
print(
158+
f" Mean: {mean_elapsed:.2f}s ± {stdev_elapsed:.2f}s, {mean_throughput:,.0f} rows/s, "
159+
f"peak arrow mem: {mean_peak_mem / (1024 * 1024):.1f} MB"
160+
)

0 commit comments

Comments
 (0)