Skip to content

Commit 60ff7f2

Browse files
committed
fix: enhance error handling in async wait_for_future function
1 parent 0cc9b0a commit 60ff7f2

File tree

4 files changed

+62
-24
lines changed

4 files changed

+62
-24
lines changed

src/catalog.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::sync::Arc;
2121
use pyo3::exceptions::PyKeyError;
2222
use pyo3::prelude::*;
2323

24-
use crate::errors::{PyDataFusionError, PyDataFusionResult};
24+
use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
2525
use crate::utils::wait_for_future;
2626
use datafusion::{
2727
arrow::pyarrow::ToPyArrow,
@@ -97,7 +97,10 @@ impl PyDatabase {
9797
}
9898

9999
fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
100-
if let Some(table) = wait_for_future(py, self.database.table(name))? {
100+
let table_option = wait_for_future(py, self.database.table(name))
101+
.map_err(py_datafusion_err)?
102+
.map_err(PyDataFusionError::from)?;
103+
if let Some(table) = table_option {
101104
Ok(PyTable::new(table))
102105
} else {
103106
Err(PyDataFusionError::Common(format!(

src/context.rs

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use pyo3::prelude::*;
3434
use crate::catalog::{PyCatalog, PyTable};
3535
use crate::dataframe::PyDataFrame;
3636
use crate::dataset::Dataset;
37-
use crate::errors::{py_datafusion_err, PyDataFusionResult};
37+
use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
3838
use crate::expr::sort_expr::PySortExpr;
3939
use crate::physical_plan::PyExecutionPlan;
4040
use crate::record_batch::PyRecordBatchStream;
@@ -375,7 +375,7 @@ impl PySessionContext {
375375
None => {
376376
let state = self.ctx.state();
377377
let schema = options.infer_schema(&state, &table_path);
378-
wait_for_future(py, schema)?
378+
wait_for_future(py, schema)?.map_err(PyDataFusionError::from)?
379379
}
380380
};
381381
let config = ListingTableConfig::new(table_path)
@@ -400,7 +400,7 @@ impl PySessionContext {
400400
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
401401
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
402402
let result = self.ctx.sql(query);
403-
let df = wait_for_future(py, result)?;
403+
let df = wait_for_future(py, result)?.map_err(PyDataFusionError::from)?;
404404
Ok(PyDataFrame::new(df))
405405
}
406406

@@ -417,7 +417,7 @@ impl PySessionContext {
417417
SQLOptions::new()
418418
};
419419
let result = self.ctx.sql_with_options(query, options);
420-
let df = wait_for_future(py, result)?;
420+
let df = wait_for_future(py, result)?.map_err(PyDataFusionError::from)?;
421421
Ok(PyDataFrame::new(df))
422422
}
423423

@@ -451,7 +451,8 @@ impl PySessionContext {
451451

452452
self.ctx.register_table(&*table_name, Arc::new(table))?;
453453

454-
let table = wait_for_future(py, self._table(&table_name))?;
454+
let table =
455+
wait_for_future(py, self._table(&table_name))?.map_err(PyDataFusionError::from)?;
455456

456457
let df = PyDataFrame::new(table);
457458
Ok(df)
@@ -826,6 +827,7 @@ impl PySessionContext {
826827

827828
pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
828829
let x = wait_for_future(py, self.ctx.table(name))
830+
.map_err(|e| PyKeyError::new_err(e.to_string()))?
829831
.map_err(|e| PyKeyError::new_err(e.to_string()))?;
830832
Ok(PyDataFrame::new(x))
831833
}
@@ -865,10 +867,10 @@ impl PySessionContext {
865867
let df = if let Some(schema) = schema {
866868
options.schema = Some(&schema.0);
867869
let result = self.ctx.read_json(path, options);
868-
wait_for_future(py, result)?
870+
wait_for_future(py, result)?.map_err(PyDataFusionError::from)?
869871
} else {
870872
let result = self.ctx.read_json(path, options);
871-
wait_for_future(py, result)?
873+
wait_for_future(py, result)?.map_err(PyDataFusionError::from)?
872874
};
873875
Ok(PyDataFrame::new(df))
874876
}
@@ -915,13 +917,13 @@ impl PySessionContext {
915917
let paths = path.extract::<Vec<String>>()?;
916918
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
917919
let result = self.ctx.read_csv(paths, options);
918-
let df = PyDataFrame::new(wait_for_future(py, result)?);
919-
Ok(df)
920+
let df = wait_for_future(py, result)?.map_err(PyDataFusionError::from)?;
921+
Ok(PyDataFrame::new(df))
920922
} else {
921923
let path = path.extract::<String>()?;
922924
let result = self.ctx.read_csv(path, options);
923-
let df = PyDataFrame::new(wait_for_future(py, result)?);
924-
Ok(df)
925+
let df = wait_for_future(py, result)?.map_err(PyDataFusionError::from)?;
926+
Ok(PyDataFrame::new(df))
925927
}
926928
}
927929

@@ -958,7 +960,7 @@ impl PySessionContext {
958960
.collect();
959961

960962
let result = self.ctx.read_parquet(path, options);
961-
let df = PyDataFrame::new(wait_for_future(py, result)?);
963+
let df = PyDataFrame::new(wait_for_future(py, result)?.map_err(PyDataFusionError::from)?);
962964
Ok(df)
963965
}
964966

@@ -978,10 +980,10 @@ impl PySessionContext {
978980
let df = if let Some(schema) = schema {
979981
options.schema = Some(&schema.0);
980982
let read_future = self.ctx.read_avro(path, options);
981-
wait_for_future(py, read_future)?
983+
wait_for_future(py, read_future)?.map_err(PyDataFusionError::from)?
982984
} else {
983985
let read_future = self.ctx.read_avro(path, options);
984-
wait_for_future(py, read_future)?
986+
wait_for_future(py, read_future)?.map_err(PyDataFusionError::from)?
985987
};
986988
Ok(PyDataFrame::new(df))
987989
}
@@ -1021,8 +1023,10 @@ impl PySessionContext {
10211023
let plan = plan.plan.clone();
10221024
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
10231025
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1024-
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
1025-
Ok(PyRecordBatchStream::new(stream?))
1026+
let stream = wait_for_future(py, fut)
1027+
.map_err(py_datafusion_err)?
1028+
.map_err(PyDataFusionError::from)?;
1029+
Ok(PyRecordBatchStream::new(stream))
10261030
}
10271031
}
10281032

src/dataframe.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ impl PyDataFrame {
233233
let (batches, has_more) = wait_for_future(
234234
py,
235235
collect_record_batches_to_display(self.df.as_ref().clone(), config),
236-
)?;
236+
)?
237+
.map_err(PyDataFusionError::from)?;
237238
if batches.is_empty() {
238239
// This should not be reached, but do it for safety since we index into the vector below
239240
return Ok("No data to display".to_string());
@@ -256,7 +257,8 @@ impl PyDataFrame {
256257
let (batches, has_more) = wait_for_future(
257258
py,
258259
collect_record_batches_to_display(self.df.as_ref().clone(), config),
259-
)?;
260+
)?
261+
.map_err(PyDataFusionError::from)?;
260262
if batches.is_empty() {
261263
// This should not be reached, but do it for safety since we index into the vector below
262264
return Ok("No data to display".to_string());
@@ -288,7 +290,7 @@ impl PyDataFrame {
288290
/// Calculate summary statistics for a DataFrame
289291
fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
290292
let df = self.df.as_ref().clone();
291-
let stat_df = wait_for_future(py, df.describe())?;
293+
let stat_df = wait_for_future(py, df.describe())?.map_err(PyDataFusionError::from)?;
292294
Ok(Self::new(stat_df))
293295
}
294296

src/utils.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ use pyo3::prelude::*;
2626
use pyo3::types::PyCapsule;
2727
use std::future::Future;
2828
use std::sync::OnceLock;
29+
use std::time::Duration;
2930
use tokio::runtime::Runtime;
31+
use tokio::time::timeout;
3032

3133
/// Utility to get the Tokio Runtime from Python
3234
#[inline]
@@ -47,14 +49,41 @@ pub(crate) fn get_global_ctx() -> &'static SessionContext {
4749
CTX.get_or_init(SessionContext::new)
4850
}
4951

50-
/// Utility to collect rust futures with GIL released
51-
pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
52+
/// Utility to collect rust futures with GIL released and interrupt support
53+
pub fn wait_for_future<F>(py: Python, f: F) -> PyResult<F::Output>
5254
where
5355
F: Future + Send,
5456
F::Output: Send,
5557
{
5658
let runtime: &Runtime = &get_tokio_runtime().0;
57-
py.allow_threads(|| runtime.block_on(f))
59+
60+
// Spawn the task so we can poll it with timeouts
61+
let mut handle = runtime.spawn(f);
62+
63+
// Release the GIL and poll the future with periodic signal checks
64+
py.allow_threads(|| {
65+
loop {
66+
// Poll the future with a timeout to allow periodic signal checking
67+
match runtime.block_on(timeout(Duration::from_millis(100), &mut handle)) {
68+
Ok(result) => {
69+
return result.map_err(|e| {
70+
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
71+
"Task failed: {}",
72+
e
73+
))
74+
});
75+
}
76+
Err(_) => {
77+
// Timeout occurred, check for Python signals
78+
// We need to re-acquire the GIL temporarily to check signals
79+
if let Err(e) = Python::with_gil(|py| py.check_signals()) {
80+
return Err(e);
81+
}
82+
// Continue polling if no signal was received
83+
}
84+
}
85+
}
86+
})
5887
}
5988

6089
pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {

0 commit comments

Comments
 (0)