Skip to content

Commit 4cb5fe5

Browse files
committed
Refactor ArrowStreamReader to use async/await for improved signal handling and responsiveness
1 parent 8c3ebaf commit 4cb5fe5

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

src/dataframe.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ use pyo3::exceptions::PyValueError;
4242
use pyo3::prelude::*;
4343
use pyo3::pybacked::PyBackedStr;
4444
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45+
use tokio::task::JoinHandle;
46+
use tokio::time::{sleep, Duration};
4547

4648
use crate::catalog::PyTable;
4749
use crate::errors::{py_datafusion_err, PyDataFusionError};
@@ -1011,8 +1013,29 @@ impl Iterator for ArrowStreamReader {
10111013
type Item = Result<RecordBatch, ArrowError>;
10121014

10131015
fn next(&mut self) -> Option<Self::Item> {
1016+
const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
10141017
let rt = &get_tokio_runtime().0;
1015-
match rt.block_on(crate::record_batch::pull_next_batch(&mut self.stream)) {
1018+
let fut = self.stream.next();
1019+
1020+
let result = Python::with_gil(|py| {
1021+
py.allow_threads(|| {
1022+
rt.block_on(async {
1023+
tokio::pin!(fut);
1024+
loop {
1025+
tokio::select! {
1026+
res = &mut fut => break res,
1027+
_ = sleep(INTERVAL_CHECK_SIGNALS) => {
1028+
if let Err(err) = Python::with_gil(|py| py.check_signals()) {
1029+
break Some(Err(to_datafusion_err(err)));
1030+
}
1031+
}
1032+
}
1033+
}
1034+
})
1035+
})
1036+
});
1037+
1038+
match result {
10161039
Some(Ok(batch)) => {
10171040
let batch = if self.project {
10181041
match record_batch_into_schema(batch, self.schema.as_ref()) {

0 commit comments

Comments
 (0)