Skip to content

Commit c88674d

Browse files
feat: add Python bindings for accessing ExecutionMetrics
1 parent 4cd5674 commit c88674d

File tree

7 files changed

+453
-6
lines changed

7 files changed

+453
-6
lines changed

python/datafusion/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from .expr import Expr, WindowFrame
5656
from .io import read_avro, read_csv, read_json, read_parquet
5757
from .options import CsvReadOptions
58-
from .plan import ExecutionPlan, LogicalPlan
58+
from .plan import ExecutionPlan, LogicalPlan, Metric, MetricsSet
5959
from .record_batch import RecordBatch, RecordBatchStream
6060
from .user_defined import (
6161
Accumulator,
@@ -85,6 +85,8 @@
8585
"Expr",
8686
"InsertOp",
8787
"LogicalPlan",
88+
"Metric",
89+
"MetricsSet",
8890
"ParquetColumnOptions",
8991
"ParquetWriterOptions",
9092
"RecordBatch",

python/datafusion/plan.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
__all__ = [
3030
"ExecutionPlan",
3131
"LogicalPlan",
32+
"Metric",
33+
"MetricsSet",
3234
]
3335

3436

@@ -151,3 +153,107 @@ def to_proto(self) -> bytes:
151153
Tables created in memory from record batches are currently not supported.
152154
"""
153155
return self._raw_plan.to_proto()
156+
157+
def metrics(self) -> MetricsSet | None:
158+
"""Return metrics for this plan node after execution, or None if unavailable."""
159+
raw = self._raw_plan.metrics()
160+
if raw is None:
161+
return None
162+
return MetricsSet(raw)
163+
164+
def collect_metrics(self) -> list[tuple[str, MetricsSet]]:
165+
"""Walk the plan tree and collect metrics from all operators.
166+
167+
Returns a list of (operator_name, MetricsSet) tuples.
168+
"""
169+
result: list[tuple[str, MetricsSet]] = []
170+
171+
def _walk(node: ExecutionPlan) -> None:
172+
ms = node.metrics()
173+
if ms is not None:
174+
result.append((node.display(), ms))
175+
for child in node.children():
176+
_walk(child)
177+
178+
_walk(self)
179+
return result
180+
181+
182+
class MetricsSet:
183+
"""A set of metrics for a single execution plan operator.
184+
185+
Provides both individual metric access and convenience aggregations
186+
across partitions.
187+
"""
188+
189+
def __init__(self, raw: df_internal.MetricsSet) -> None:
190+
"""This constructor should not be called by the end user."""
191+
self._raw = raw
192+
193+
def metrics(self) -> list[Metric]:
194+
"""Return all individual metrics in this set."""
195+
return [Metric(m) for m in self._raw.metrics()]
196+
197+
@property
198+
def output_rows(self) -> int | None:
199+
"""Sum of output_rows across all partitions."""
200+
return self._raw.output_rows()
201+
202+
@property
203+
def elapsed_compute(self) -> int | None:
204+
"""Sum of elapsed_compute across all partitions, in nanoseconds."""
205+
return self._raw.elapsed_compute()
206+
207+
@property
208+
def spill_count(self) -> int | None:
209+
"""Sum of spill_count across all partitions."""
210+
return self._raw.spill_count()
211+
212+
@property
213+
def spilled_bytes(self) -> int | None:
214+
"""Sum of spilled_bytes across all partitions."""
215+
return self._raw.spilled_bytes()
216+
217+
@property
218+
def spilled_rows(self) -> int | None:
219+
"""Sum of spilled_rows across all partitions."""
220+
return self._raw.spilled_rows()
221+
222+
def sum_by_name(self, name: str) -> int | None:
223+
"""Return the sum of metrics matching the given name."""
224+
return self._raw.sum_by_name(name)
225+
226+
def __repr__(self) -> str:
227+
"""Return a string representation of the metrics set."""
228+
return repr(self._raw)
229+
230+
231+
class Metric:
232+
"""A single execution metric with name, value, partition, and labels."""
233+
234+
def __init__(self, raw: df_internal.Metric) -> None:
235+
"""This constructor should not be called by the end user."""
236+
self._raw = raw
237+
238+
@property
239+
def name(self) -> str:
240+
"""The name of this metric (e.g. ``output_rows``)."""
241+
return self._raw.name
242+
243+
@property
244+
def value(self) -> int | None:
245+
"""The numeric value of this metric, or None for non-numeric types."""
246+
return self._raw.value
247+
248+
@property
249+
def partition(self) -> int | None:
250+
"""The partition this metric applies to, or None if global."""
251+
return self._raw.partition
252+
253+
def labels(self) -> dict[str, str]:
254+
"""Return the labels associated with this metric."""
255+
return self._raw.labels()
256+
257+
def __repr__(self) -> str:
258+
"""Return a string representation of the metric."""
259+
return repr(self._raw)

python/tests/test_plans.py

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
# under the License.
1717

1818
import pytest
19-
from datafusion import ExecutionPlan, LogicalPlan, SessionContext
19+
from datafusion import (
20+
ExecutionPlan,
21+
LogicalPlan,
22+
Metric,
23+
MetricsSet,
24+
SessionContext,
25+
)
2026

2127

2228
# Note: We must use CSV because memory tables are currently not supported for
@@ -40,3 +46,147 @@ def test_logical_plan_to_proto(ctx, df) -> None:
4046
execution_plan = ExecutionPlan.from_proto(ctx, execution_plan_bytes)
4147

4248
assert str(original_execution_plan) == str(execution_plan)
49+
50+
51+
def test_execution_plan_metrics() -> None:
52+
ctx = SessionContext()
53+
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
54+
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
55+
56+
df.collect()
57+
plan = df.execution_plan()
58+
59+
found_metrics = False
60+
61+
def _check(node):
62+
nonlocal found_metrics
63+
ms = node.metrics()
64+
if ms is not None and ms.output_rows is not None and ms.output_rows > 0:
65+
found_metrics = True
66+
for child in node.children():
67+
_check(child)
68+
69+
_check(plan)
70+
assert found_metrics
71+
72+
73+
def test_metric_properties() -> None:
74+
ctx = SessionContext()
75+
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
76+
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
77+
78+
df.collect()
79+
plan = df.execution_plan()
80+
81+
for _, ms in plan.collect_metrics():
82+
for metric in ms.metrics():
83+
assert isinstance(metric, Metric)
84+
assert isinstance(metric.name, str)
85+
assert len(metric.name) > 0
86+
assert metric.partition is None or isinstance(metric.partition, int)
87+
assert isinstance(metric.labels(), dict)
88+
return
89+
pytest.skip("No metrics found")
90+
91+
92+
def test_metrics_tree_walk() -> None:
93+
ctx = SessionContext()
94+
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'a'), (4, 'b')")
95+
df = ctx.sql("SELECT column2, COUNT(*) FROM t GROUP BY column2")
96+
97+
df.collect()
98+
plan = df.execution_plan()
99+
100+
results = plan.collect_metrics()
101+
assert len(results) >= 2
102+
for name, ms in results:
103+
assert isinstance(name, str)
104+
assert isinstance(ms, MetricsSet)
105+
106+
107+
def test_no_metrics_before_execution() -> None:
108+
ctx = SessionContext()
109+
ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)")
110+
df = ctx.sql("SELECT * FROM t")
111+
plan = df.execution_plan()
112+
ms = plan.metrics()
113+
assert ms is None or ms.output_rows is None or ms.output_rows == 0
114+
115+
116+
def test_metrics_repr() -> None:
117+
ctx = SessionContext()
118+
ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)")
119+
df = ctx.sql("SELECT * FROM t")
120+
121+
df.collect()
122+
plan = df.execution_plan()
123+
124+
for _, ms in plan.collect_metrics():
125+
r = repr(ms)
126+
assert isinstance(r, str)
127+
for metric in ms.metrics():
128+
mr = repr(metric)
129+
assert isinstance(mr, str)
130+
assert len(mr) > 0
131+
return
132+
pytest.skip("No metrics found")
133+
134+
135+
def test_collect_partitioned_metrics() -> None:
136+
ctx = SessionContext()
137+
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
138+
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
139+
140+
partitions = df.collect_partitioned()
141+
plan = df.execution_plan()
142+
assert len(partitions) == plan.partition_count
143+
144+
# Metrics should be populated after collecting
145+
found_metrics = False
146+
for _, ms in plan.collect_metrics():
147+
if ms.output_rows is not None and ms.output_rows > 0:
148+
found_metrics = True
149+
assert found_metrics
150+
151+
152+
def test_execute_stream_metrics() -> None:
153+
ctx = SessionContext()
154+
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
155+
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
156+
157+
stream = df.execute_stream()
158+
159+
# Consume the stream (iterates over RecordBatches)
160+
batches = list(stream)
161+
assert len(batches) >= 1
162+
163+
# Metrics should be populated after consuming the stream
164+
plan = df.execution_plan()
165+
found_metrics = False
166+
for name, ms in plan.collect_metrics():
167+
assert isinstance(name, str)
168+
assert isinstance(ms, MetricsSet)
169+
if ms.output_rows is not None and ms.output_rows > 0:
170+
found_metrics = True
171+
assert found_metrics
172+
173+
174+
def test_execute_stream_partitioned_metrics() -> None:
175+
ctx = SessionContext()
176+
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
177+
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
178+
179+
streams = df.execute_stream_partitioned()
180+
181+
# Consume all partition streams
182+
for stream in streams:
183+
for _ in stream:
184+
pass
185+
186+
# Metrics should be populated (FilterExec reports output_rows)
187+
plan = df.execution_plan()
188+
found_metrics = False
189+
for _, ms in plan.collect_metrics():
190+
if ms.output_rows is not None and ms.output_rows > 0:
191+
found_metrics = True
192+
assert found_metrics

src/dataframe.rs

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ use pyo3::prelude::*;
4848
use pyo3::pybacked::PyBackedStr;
4949
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
5050

51+
use datafusion::physical_plan::{
52+
ExecutionPlan as DFExecutionPlan,
53+
collect as df_collect,
54+
collect_partitioned as df_collect_partitioned,
55+
execute_stream as df_execute_stream,
56+
execute_stream_partitioned as df_execute_stream_partitioned,
57+
};
58+
5159
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err};
5260
use crate::expr::PyExpr;
5361
use crate::expr::sort_expr::{PySortExpr, to_sort_expressions};
@@ -290,6 +298,9 @@ pub struct PyDataFrame {
290298

291299
// In IPython environment cache batches between __repr__ and _repr_html_ calls.
292300
batches: SharedCachedBatches,
301+
302+
// Cache the last physical plan so that metrics are available after execution.
303+
last_plan: Arc<Mutex<Option<Arc<dyn DFExecutionPlan>>>>,
293304
}
294305

295306
impl PyDataFrame {
@@ -298,6 +309,7 @@ impl PyDataFrame {
298309
Self {
299310
df: Arc::new(df),
300311
batches: Arc::new(Mutex::new(None)),
312+
last_plan: Arc::new(Mutex::new(None)),
301313
}
302314
}
303315

@@ -627,7 +639,12 @@ impl PyDataFrame {
627639
/// Unless some order is specified in the plan, there is no
628640
/// guarantee of the order of the result.
629641
fn collect<'py>(&self, py: Python<'py>) -> PyResult<Vec<Bound<'py, PyAny>>> {
630-
let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
642+
let df = self.df.as_ref().clone();
643+
let plan = wait_for_future(py, df.create_physical_plan())?
644+
.map_err(PyDataFusionError::from)?;
645+
*self.last_plan.lock() = Some(Arc::clone(&plan));
646+
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
647+
let batches = wait_for_future(py, df_collect(plan, task_ctx))?
631648
.map_err(PyDataFusionError::from)?;
632649
// cannot use PyResult<Vec<RecordBatch>> return type due to
633650
// https://github.com/PyO3/pyo3/issues/1813
@@ -643,7 +660,12 @@ impl PyDataFrame {
643660
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
644661
/// maintaining the input partitioning.
645662
fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult<Vec<Vec<Bound<'py, PyAny>>>> {
646-
let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
663+
let df = self.df.as_ref().clone();
664+
let plan = wait_for_future(py, df.create_physical_plan())?
665+
.map_err(PyDataFusionError::from)?;
666+
*self.last_plan.lock() = Some(Arc::clone(&plan));
667+
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
668+
let batches = wait_for_future(py, df_collect_partitioned(plan, task_ctx))?
647669
.map_err(PyDataFusionError::from)?;
648670

649671
batches
@@ -803,7 +825,13 @@ impl PyDataFrame {
803825
}
804826

805827
/// Get the execution plan for this `DataFrame`
828+
///
829+
/// If the DataFrame has already been executed (e.g. via `collect()`),
830+
/// returns the cached plan which includes populated metrics.
806831
fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
832+
if let Some(plan) = self.last_plan.lock().as_ref() {
833+
return Ok(PyExecutionPlan::new(Arc::clone(plan)));
834+
}
807835
let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
808836
Ok(plan.into())
809837
}
@@ -1128,13 +1156,22 @@ impl PyDataFrame {
11281156

11291157
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
11301158
let df = self.df.as_ref().clone();
1131-
let stream = spawn_future(py, async move { df.execute_stream().await })?;
1159+
let plan = wait_for_future(py, df.create_physical_plan())??;
1160+
*self.last_plan.lock() = Some(Arc::clone(&plan));
1161+
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
1162+
let stream = spawn_future(py, async move { df_execute_stream(plan, task_ctx) })?;
11321163
Ok(PyRecordBatchStream::new(stream))
11331164
}
11341165

11351166
fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
11361167
let df = self.df.as_ref().clone();
1137-
let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
1168+
let plan = wait_for_future(py, df.create_physical_plan())?
1169+
.map_err(PyDataFusionError::from)?;
1170+
*self.last_plan.lock() = Some(Arc::clone(&plan));
1171+
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
1172+
let streams = spawn_future(py, async move {
1173+
df_execute_stream_partitioned(plan, task_ctx)
1174+
})?;
11381175
Ok(streams.into_iter().map(PyRecordBatchStream::new).collect())
11391176
}
11401177

0 commit comments

Comments
 (0)