Skip to content

Commit cfb88b3

Browse files
committed
feat: add close method to RecordBatchStream for resource management
1 parent 865bc39 commit cfb88b3

File tree

3 files changed

+82
-9
lines changed

3 files changed

+82
-9
lines changed

python/datafusion/record_batch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class RecordBatchStream:
5454
5555
These are typically the result of a
5656
:py:func:`~datafusion.dataframe.DataFrame.execute_stream` operation.
57+
58+
Call :py:meth:`close` when finished consuming the stream to avoid
59+
lingering background tasks.
5760
"""
5861

5962
def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None:
@@ -82,6 +85,16 @@ def __iter__(self) -> typing_extensions.Self:
8285
"""Iterator function."""
8386
return self
8487

88+
def close(self) -> None:
89+
"""Close the stream and release associated resources.
90+
91+
This drains any remaining batches and allows the underlying
92+
:class:`SessionContext` to be released. Call this when you are
93+
done consuming the stream to avoid leaving tasks running in the
94+
background.
95+
"""
96+
self.rbs.close()
97+
8598

8699
def to_record_batch_stream(df: DataFrame) -> RecordBatchStream:
87100
"""Convert a DataFrame into a RecordBatchStream.

python/tests/test_dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,18 @@ def test_stream_keeps_context_alive():
15981598
assert table.equals(expected)
15991599

16001600

1601+
def test_record_batch_stream_close(df):
1602+
stream = df.execute_stream()
1603+
1604+
batch = next(stream)
1605+
assert batch is not None
1606+
1607+
stream.close()
1608+
1609+
with pytest.raises(RuntimeError):
1610+
next(stream)
1611+
1612+
16011613
def test_empty_to_arrow_table(df):
16021614
# Convert empty datafusion dataframe to pyarrow Table
16031615
pyarrow_table = df.limit(0).to_arrow_table()

src/record_batch.rs

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use datafusion::arrow::record_batch::RecordBatch;
2424
use datafusion::execution::context::SessionContext;
2525
use datafusion::physical_plan::SendableRecordBatchStream;
2626
use futures::StreamExt;
27-
use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
27+
use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration, PyStopIteration};
2828
use pyo3::prelude::*;
2929
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
3030
use tokio::sync::Mutex;
@@ -66,25 +66,33 @@ pub(crate) fn record_batches_to_pyarrow(
6666

6767
#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
6868
pub struct PyRecordBatchStream {
69-
stream: Arc<Mutex<SendableRecordBatchStream>>,
70-
// Hold on to the session context to ensure the underlying execution
71-
// environment remains alive for the duration of the stream
72-
_ctx: Arc<SessionContext>,
69+
/// The underlying stream. Wrapped in an [`Option`] so it can be
70+
/// taken during [`close`](Self::close) and [`Drop`].
71+
stream: Option<Arc<Mutex<SendableRecordBatchStream>>>,
72+
/// Hold on to the session context to ensure the underlying execution
73+
/// environment remains alive for the duration of the stream. This is
74+
/// also wrapped in an [`Option`] so it can be released once the stream
75+
/// is closed.
76+
_ctx: Option<Arc<SessionContext>>,
7377
}
7478

7579
impl PyRecordBatchStream {
7680
pub fn new(stream: SendableRecordBatchStream, ctx: Arc<SessionContext>) -> Self {
7781
Self {
78-
stream: Arc::new(Mutex::new(stream)),
79-
_ctx: ctx,
82+
stream: Some(Arc::new(Mutex::new(stream))),
83+
_ctx: Some(ctx),
8084
}
8185
}
8286
}
8387

8488
#[pymethods]
8589
impl PyRecordBatchStream {
8690
fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
87-
let stream = self.stream.clone();
91+
let stream = self
92+
.stream
93+
.as_ref()
94+
.ok_or_else(|| PyRuntimeError::new_err("stream is closed"))?
95+
.clone();
8896
wait_for_future(py, next_stream(stream, true))?
8997
}
9098

@@ -93,7 +101,11 @@ impl PyRecordBatchStream {
93101
}
94102

95103
fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
96-
let stream = self.stream.clone();
104+
let stream = self
105+
.stream
106+
.as_ref()
107+
.ok_or_else(|| PyRuntimeError::new_err("stream is closed"))?
108+
.clone();
97109
pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false))
98110
}
99111

@@ -104,6 +116,21 @@ impl PyRecordBatchStream {
104116
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
105117
slf
106118
}
119+
120+
/// Close the underlying stream.
121+
///
122+
/// This drains any remaining record batches and releases the
123+
/// [`SessionContext`] associated with the stream. Users should call
124+
/// this method when they are finished consuming the stream to avoid
125+
/// leaving background tasks running.
126+
fn close(&mut self, py: Python) -> PyResult<()> {
127+
if let Some(stream) = self.stream.take() {
128+
wait_for_future(py, close_stream(stream))?;
129+
}
130+
// Drop the context once the stream is drained
131+
self._ctx.take();
132+
Ok(())
133+
}
107134
}
108135

109136
/// Polls the next batch from a `SendableRecordBatchStream`, converting the `Option<Result<_>>` form.
@@ -132,3 +159,24 @@ async fn next_stream(
132159
Err(e) => Err(PyDataFusionError::from(e))?,
133160
}
134161
}
162+
163+
/// Drain the remaining batches from a [`SendableRecordBatchStream`].
164+
///
165+
/// Errors are ignored as the stream is being closed.
166+
async fn close_stream(stream: Arc<Mutex<SendableRecordBatchStream>>) {
167+
let mut stream = stream.lock().await;
168+
while let Ok(Some(_)) = poll_next_batch(&mut stream).await {}
169+
}
170+
171+
impl Drop for PyRecordBatchStream {
172+
fn drop(&mut self) {
173+
if let Some(stream) = self.stream.take() {
174+
// Drain the stream while the context is still alive
175+
Python::with_gil(|py| {
176+
let _ = wait_for_future(py, close_stream(stream));
177+
});
178+
}
179+
// Drop the context after the stream has been fully drained
180+
self._ctx.take();
181+
}
182+
}

0 commit comments

Comments
 (0)