Skip to content

Commit a24e280

Browse files
committed
fix: add utility to get Tokio Runtime with time enabled and update wait_for_future to use it
1 parent fb814cc commit a24e280

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

src/utils.rs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,44 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
4242
RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
4343
}
4444

45+
/// Utility to get a Tokio Runtime with time explicitly enabled
46+
#[inline]
47+
pub(crate) fn get_tokio_runtime_with_time() -> &'static TokioRuntime {
48+
static RUNTIME_WITH_TIME: OnceLock<TokioRuntime> = OnceLock::new();
49+
RUNTIME_WITH_TIME.get_or_init(|| {
50+
let runtime = tokio::runtime::Builder::new_multi_thread()
51+
.enable_time()
52+
.build()
53+
.unwrap();
54+
55+
TokioRuntime(runtime)
56+
})
57+
}
58+
4559
/// Utility to get the Global Datafussion CTX
4660
#[inline]
4761
pub(crate) fn get_global_ctx() -> &'static SessionContext {
4862
static CTX: OnceLock<SessionContext> = OnceLock::new();
4963
CTX.get_or_init(SessionContext::new)
5064
}
5165

66+
/// Gets the Tokio runtime with time enabled and enters it, returning both the runtime and enter guard
67+
/// This helps ensure that we don't forget to call enter() after getting the runtime
68+
#[inline]
69+
pub(crate) fn get_and_enter_tokio_runtime(
70+
) -> (&'static Runtime, tokio::runtime::EnterGuard<'static>) {
71+
let runtime = &get_tokio_runtime_with_time().0;
72+
let enter_guard = runtime.enter();
73+
(runtime, enter_guard)
74+
}
75+
5276
/// Utility to collect rust futures with GIL released and interrupt support
5377
pub fn wait_for_future<F>(py: Python, f: F) -> PyResult<F::Output>
5478
where
5579
F: Future + Send + 'static,
5680
F::Output: Send + 'static,
5781
{
58-
let runtime: &Runtime = &get_tokio_runtime().0;
82+
let (runtime, _enter_guard) = get_and_enter_tokio_runtime();
5983

6084
// Spawn the task so we can poll it with timeouts
6185
let mut handle = runtime.spawn(f);
@@ -65,21 +89,21 @@ where
6589
loop {
6690
// Poll the future with a timeout to allow periodic signal checking
6791
match runtime.block_on(timeout(Duration::from_millis(100), &mut handle)) {
68-
Ok(result) => {
69-
return result.map_err(|e| {
92+
Ok(join_result) => {
93+
// The inner task has completed before timeout
94+
return join_result.map_err(|e| {
7095
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
7196
"Task failed: {}",
7297
e
7398
))
7499
});
75100
}
76-
Err(_) => {
77-
// Timeout occurred, check for Python signals
78-
// We need to re-acquire the GIL temporarily to check signals
79-
if let Err(e) = Python::with_gil(|py| py.check_signals()) {
80-
return Err(e);
101+
Err(_elapsed) => {
102+
// 100 ms elapsed without task completion → check Python signals
103+
if let Err(py_exc) = Python::with_gil(|py| py.check_signals()) {
104+
return Err(py_exc);
81105
}
82-
// Continue polling if no signal was received
106+
// Loop again, reintroducing another 100 ms timeout slice
83107
}
84108
}
85109
}

0 commit comments

Comments
 (0)