Skip to content

Commit 8c3ebaf

Browse files
committed
Refactor async execution in DataFusion by replacing JoinHandle with spawn_and_wait utility for improved readability and maintainability.
1 parent e621b64 commit 8c3ebaf

File tree

4 files changed

+30
-31
lines changed

4 files changed

+30
-31
lines changed

src/context.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use pyo3::prelude::*;
3434
use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider};
3535
use crate::dataframe::PyDataFrame;
3636
use crate::dataset::Dataset;
37-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
37+
use crate::errors::{py_datafusion_err, PyDataFusionResult};
3838
use crate::expr::sort_expr::PySortExpr;
3939
use crate::physical_plan::PyExecutionPlan;
4040
use crate::record_batch::PyRecordBatchStream;
@@ -45,7 +45,7 @@ use crate::udaf::PyAggregateUDF;
4545
use crate::udf::PyScalarUDF;
4646
use crate::udtf::PyTableFunction;
4747
use crate::udwf::PyWindowUDF;
48-
use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
48+
use crate::utils::{get_global_ctx, spawn_and_wait, validate_pycapsule, wait_for_future};
4949
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5050
use datafusion::arrow::pyarrow::PyArrowType;
5151
use datafusion::arrow::record_batch::RecordBatch;
@@ -66,15 +66,13 @@ use datafusion::execution::disk_manager::DiskManagerMode;
6666
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
6767
use datafusion::execution::options::ReadOptions;
6868
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
69-
use datafusion::physical_plan::SendableRecordBatchStream;
7069
use datafusion::prelude::{
7170
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7271
};
7372
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
7473
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
7574
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7675
use pyo3::IntoPyObjectExt;
77-
use tokio::task::JoinHandle;
7876

7977
/// Configuration options for a SessionContext
8078
#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
@@ -1132,12 +1130,8 @@ impl PySessionContext {
11321130
py: Python,
11331131
) -> PyDataFusionResult<PyRecordBatchStream> {
11341132
let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1135-
// create a Tokio runtime to run the async code
1136-
let rt = &get_tokio_runtime().0;
11371133
let plan = plan.plan.clone();
1138-
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1139-
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1140-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
1134+
let stream = spawn_and_wait(py, async move { plan.execute(part, Arc::new(ctx)) })?;
11411135
Ok(PyRecordBatchStream::new(stream))
11421136
}
11431137
}

src/dataframe.rs

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ use pyo3::exceptions::PyValueError;
4242
use pyo3::prelude::*;
4343
use pyo3::pybacked::PyBackedStr;
4444
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45-
use tokio::task::JoinHandle;
4645

4746
use crate::catalog::PyTable;
48-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
47+
use crate::errors::{py_datafusion_err, PyDataFusionError};
4948
use crate::expr::sort_expr::to_sort_expressions;
5049
use crate::physical_plan::PyExecutionPlan;
5150
use crate::record_batch::PyRecordBatchStream;
5251
use crate::sql::logical::PyLogicalPlan;
5352
use crate::utils::{
54-
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
53+
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_and_wait, validate_pycapsule,
54+
wait_for_future,
5555
};
5656
use crate::{
5757
errors::PyDataFusionResult,
@@ -880,11 +880,8 @@ impl PyDataFrame {
880880
requested_schema: Option<Bound<'py, PyCapsule>>,
881881
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
882882
// execute query lazily using a stream
883-
let rt = &get_tokio_runtime().0;
884883
let df = self.df.as_ref().clone();
885-
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
886-
rt.spawn(async move { df.execute_stream().await });
887-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
884+
let stream = spawn_and_wait(py, async move { df.execute_stream().await })?;
888885

889886
// Determine the schema and handle optional projection
890887
let stream_schema = stream.schema();
@@ -911,24 +908,14 @@ impl PyDataFrame {
911908
}
912909

913910
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
914-
// create a Tokio runtime to run the async code
915-
let rt = &get_tokio_runtime().0;
916911
let df = self.df.as_ref().clone();
917-
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
918-
rt.spawn(async move { df.execute_stream().await });
919-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
912+
let stream = spawn_and_wait(py, async move { df.execute_stream().await })?;
920913
Ok(PyRecordBatchStream::new(stream))
921914
}
922915

923916
fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
924-
// create a Tokio runtime to run the async code
925-
let rt = &get_tokio_runtime().0;
926917
let df = self.df.as_ref().clone();
927-
let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
928-
rt.spawn(async move { df.execute_stream_partitioned().await });
929-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })?
930-
.map_err(py_datafusion_err)?
931-
.map_err(py_datafusion_err)?;
918+
let stream = spawn_and_wait(py, async move { df.execute_stream_partitioned().await })?;
932919

933920
Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
934921
}
@@ -1025,7 +1012,7 @@ impl Iterator for ArrowStreamReader {
10251012

10261013
fn next(&mut self) -> Option<Self::Item> {
10271014
let rt = &get_tokio_runtime().0;
1028-
match rt.block_on(self.stream.next()) {
1015+
match rt.block_on(crate::record_batch::pull_next_batch(&mut self.stream)) {
10291016
Some(Ok(batch)) => {
10301017
let batch = if self.project {
10311018
match record_batch_into_schema(batch, self.schema.as_ref()) {

src/record_batch.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ impl PyRecordBatchStream {
5959
}
6060
}
6161

62+
pub(crate) async fn pull_next_batch(
63+
stream: &mut SendableRecordBatchStream,
64+
) -> Option<datafusion::common::Result<RecordBatch>> {
65+
stream.next().await
66+
}
67+
6268
#[pymethods]
6369
impl PyRecordBatchStream {
6470
fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
@@ -89,7 +95,7 @@ async fn next_stream(
8995
sync: bool,
9096
) -> PyResult<PyRecordBatch> {
9197
let mut stream = stream.lock().await;
92-
match stream.next().await {
98+
match pull_next_batch(&mut stream).await {
9399
Some(Ok(batch)) => Ok(batch.into()),
94100
Some(Err(e)) => Err(PyDataFusionError::from(e))?,
95101
None => {

src/utils.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::{
1919
common::data_type::PyScalarValue,
20-
errors::{PyDataFusionError, PyDataFusionResult},
20+
errors::{to_datafusion_err, PyDataFusionError, PyDataFusionResult},
2121
TokioRuntime,
2222
};
2323
use datafusion::{
@@ -84,6 +84,18 @@ where
8484
})
8585
}
8686

87+
pub fn spawn_and_wait<F, T>(py: Python, fut: F) -> PyDataFusionResult<T>
88+
where
89+
F: Future<Output = datafusion::common::Result<T>> + Send + 'static,
90+
T: Send + 'static,
91+
{
92+
let rt = &get_tokio_runtime().0;
93+
let handle = rt.spawn(fut);
94+
Ok(wait_for_future(py, async {
95+
handle.await.map_err(to_datafusion_err)
96+
})???)
97+
}
98+
8799
pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
88100
Ok(match value {
89101
"immutable" => Volatility::Immutable,

0 commit comments

Comments
 (0)