@@ -24,7 +24,7 @@ use datafusion::arrow::record_batch::RecordBatch;
2424use datafusion:: execution:: context:: SessionContext ;
2525use datafusion:: physical_plan:: SendableRecordBatchStream ;
2626use futures:: StreamExt ;
27- use pyo3:: exceptions:: { PyStopAsyncIteration , PyStopIteration } ;
27+ use pyo3:: exceptions:: { PyRuntimeError , PyStopAsyncIteration , PyStopIteration } ;
2828use pyo3:: prelude:: * ;
2929use pyo3:: { pyclass, pymethods, PyObject , PyResult , Python } ;
3030use tokio:: sync:: Mutex ;
@@ -66,25 +66,33 @@ pub(crate) fn record_batches_to_pyarrow(
6666
6767#[ pyclass( name = "RecordBatchStream" , module = "datafusion" , subclass) ]
6868pub struct PyRecordBatchStream {
69- stream : Arc < Mutex < SendableRecordBatchStream > > ,
70- // Hold on to the session context to ensure the underlying execution
71- // environment remains alive for the duration of the stream
72- _ctx : Arc < SessionContext > ,
69+ /// The underlying stream. Wrapped in an [`Option`] so it can be
70+ /// taken during [`close`](Self::close) and [`Drop`].
71+ stream : Option < Arc < Mutex < SendableRecordBatchStream > > > ,
72+ /// Hold on to the session context to ensure the underlying execution
73+ /// environment remains alive for the duration of the stream. This is
74+ /// also wrapped in an [`Option`] so it can be released once the stream
75+ /// is closed.
76+ _ctx : Option < Arc < SessionContext > > ,
7377}
7478
7579impl PyRecordBatchStream {
7680 pub fn new ( stream : SendableRecordBatchStream , ctx : Arc < SessionContext > ) -> Self {
7781 Self {
78- stream : Arc :: new ( Mutex :: new ( stream) ) ,
79- _ctx : ctx,
82+ stream : Some ( Arc :: new ( Mutex :: new ( stream) ) ) ,
83+ _ctx : Some ( ctx) ,
8084 }
8185 }
8286}
8387
8488#[ pymethods]
8589impl PyRecordBatchStream {
8690 fn next ( & mut self , py : Python ) -> PyResult < PyRecordBatch > {
87- let stream = self . stream . clone ( ) ;
91+ let stream = self
92+ . stream
93+ . as_ref ( )
94+ . ok_or_else ( || PyRuntimeError :: new_err ( "stream is closed" ) ) ?
95+ . clone ( ) ;
8896 wait_for_future ( py, next_stream ( stream, true ) ) ?
8997 }
9098
@@ -93,7 +101,11 @@ impl PyRecordBatchStream {
93101 }
94102
95103 fn __anext__ < ' py > ( & ' py self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
96- let stream = self . stream . clone ( ) ;
104+ let stream = self
105+ . stream
106+ . as_ref ( )
107+ . ok_or_else ( || PyRuntimeError :: new_err ( "stream is closed" ) ) ?
108+ . clone ( ) ;
97109 pyo3_async_runtimes:: tokio:: future_into_py ( py, next_stream ( stream, false ) )
98110 }
99111
@@ -104,6 +116,21 @@ impl PyRecordBatchStream {
104116 fn __aiter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
105117 slf
106118 }
119+
120+ /// Close the underlying stream.
121+ ///
122+ /// This drains any remaining record batches and releases the
123+ /// [`SessionContext`] associated with the stream. Users should call
124+ /// this method when they are finished consuming the stream to avoid
125+ /// leaving background tasks running.
126+ fn close ( & mut self , py : Python ) -> PyResult < ( ) > {
127+ if let Some ( stream) = self . stream . take ( ) {
128+ wait_for_future ( py, close_stream ( stream) ) ?;
129+ }
130+ // Drop the context once the stream is drained
131+ self . _ctx . take ( ) ;
132+ Ok ( ( ) )
133+ }
107134}
108135
109136/// Polls the next batch from a `SendableRecordBatchStream`, converting the `Option<Result<_>>` form.
@@ -132,3 +159,24 @@ async fn next_stream(
132159 Err ( e) => Err ( PyDataFusionError :: from ( e) ) ?,
133160 }
134161}
162+
163+ /// Drain the remaining batches from a [`SendableRecordBatchStream`].
164+ ///
165+ /// Errors are ignored as the stream is being closed.
166+ async fn close_stream ( stream : Arc < Mutex < SendableRecordBatchStream > > ) {
167+ let mut stream = stream. lock ( ) . await ;
168+ while let Ok ( Some ( _) ) = poll_next_batch ( & mut stream) . await { }
169+ }
170+
171+ impl Drop for PyRecordBatchStream {
172+ fn drop ( & mut self ) {
173+ if let Some ( stream) = self . stream . take ( ) {
174+ // Drain the stream while the context is still alive
175+ Python :: with_gil ( |py| {
176+ let _ = wait_for_future ( py, close_stream ( stream) ) ;
177+ } ) ;
178+ }
179+ // Drop the context after the stream has been fully drained
180+ self . _ctx . take ( ) ;
181+ }
182+ }
0 commit comments