Skip to content

Commit a1ba264

Browse files
committed
Enhance DataFrame streaming to preserve partition order and update related tests
1 parent f2e41cd commit a1ba264

File tree

3 files changed

+52
-38
lines changed

3 files changed

+52
-38
lines changed

python/datafusion/dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,8 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
11171117
Arrow PyCapsule object representing an ``ArrowArrayStream``.
11181118
"""
11191119
# ``DataFrame.__arrow_c_stream__`` in the Rust extension leverages
1120-
# ``execute_stream`` under the hood to stream batches one at a time.
1120+
# ``execute_stream_partitioned`` under the hood to stream batches while
1121+
# preserving the original partition order.
11211122
return self.df.__arrow_c_stream__(requested_schema)
11221123

11231124
def __iter__(self) -> Iterator[pa.RecordBatch]:
@@ -1126,7 +1127,8 @@ def __iter__(self) -> Iterator[pa.RecordBatch]:
11261127
This implementation streams record batches via the Arrow C Stream
11271128
interface, allowing callers such as :func:`pyarrow.Table.from_batches` to
11281129
consume results lazily. The DataFrame is executed using DataFusion's
1129-
streaming APIs so ``collect`` is never invoked.
1130+
partitioned streaming APIs so ``collect`` is never invoked and batch
1131+
order across partitions is preserved.
11301132
"""
11311133
import pyarrow as pa
11321134

python/tests/test_dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,9 +1603,11 @@ def test_arrow_c_stream_to_table(fail_collect):
16031603
df = ctx.create_dataframe([[batch1], [batch2]])
16041604

16051605
table = pa.Table.from_batches(df)
1606-
expected = pa.Table.from_batches([batch1, batch2])
1606+
batches = table.to_batches()
16071607

1608-
assert table.equals(expected)
1608+
assert len(batches) == 2
1609+
assert batches[0].equals(batch1)
1610+
assert batches[1].equals(batch2)
16091611
assert table.schema == df.schema()
16101612
assert table.column("a").num_chunks == 2
16111613

src/dataframe.rs

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -383,50 +383,59 @@ impl PyDataFrame {
383383
Ok(html_str)
384384
}
385385
}
386-
/// Synchronous wrapper around a [`SendableRecordBatchStream`] used for
387-
/// the `__arrow_c_stream__` implementation.
386+
387+
/// Synchronous wrapper around partitioned [`SendableRecordBatchStream`]s used
388+
/// for the `__arrow_c_stream__` implementation.
388389
///
389-
/// It uses `runtime.block_on` to consume the underlying async stream,
390-
/// providing synchronous iteration. When a `projection` is set, each
391-
/// batch is converted via `record_batch_into_schema` to apply schema
392-
/// changes per batch.
393-
struct DataFrameStreamReader {
394-
stream: SendableRecordBatchStream,
390+
/// It drains each partition's stream sequentially, yielding record batches in
391+
/// their original partition order. When a `projection` is set, each batch is
392+
/// converted via `record_batch_into_schema` to apply schema changes per batch.
393+
struct PartitionedDataFrameStreamReader {
394+
streams: Vec<SendableRecordBatchStream>,
395395
schema: SchemaRef,
396396
projection: Option<SchemaRef>,
397+
current: usize,
397398
}
398399

399-
impl Iterator for DataFrameStreamReader {
400+
impl Iterator for PartitionedDataFrameStreamReader {
400401
type Item = Result<RecordBatch, ArrowError>;
401402

402403
fn next(&mut self) -> Option<Self::Item> {
403-
// Use wait_for_future to poll the underlying async stream while
404-
// respecting Python signal handling (e.g. ``KeyboardInterrupt``).
405-
// This mirrors the behaviour of other synchronous wrappers and
406-
// prevents blocking indefinitely when a Python interrupt is raised.
407-
let fut = poll_next_batch(&mut self.stream);
408-
let result = Python::with_gil(|py| wait_for_future(py, fut));
409-
410-
match result {
411-
Ok(Ok(Some(batch))) => {
412-
let batch = if let Some(ref schema) = self.projection {
413-
match record_batch_into_schema(batch, schema.as_ref()) {
414-
Ok(b) => b,
415-
Err(e) => return Some(Err(e)),
416-
}
417-
} else {
418-
batch
419-
};
420-
Some(Ok(batch))
404+
while self.current < self.streams.len() {
405+
let stream = &mut self.streams[self.current];
406+
let fut = poll_next_batch(stream);
407+
let result = Python::with_gil(|py| wait_for_future(py, fut));
408+
409+
match result {
410+
Ok(Ok(Some(batch))) => {
411+
let batch = if let Some(ref schema) = self.projection {
412+
match record_batch_into_schema(batch, schema.as_ref()) {
413+
Ok(b) => b,
414+
Err(e) => return Some(Err(e)),
415+
}
416+
} else {
417+
batch
418+
};
419+
return Some(Ok(batch));
420+
}
421+
Ok(Ok(None)) => {
422+
self.current += 1;
423+
continue;
424+
}
425+
Ok(Err(e)) => {
426+
return Some(Err(ArrowError::ExternalError(Box::new(e))));
427+
}
428+
Err(e) => {
429+
return Some(Err(ArrowError::ExternalError(Box::new(e))));
430+
}
421431
}
422-
Ok(Ok(None)) => None,
423-
Ok(Err(e)) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
424-
Err(e) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
425432
}
433+
434+
None
426435
}
427436
}
428437

429-
impl RecordBatchReader for DataFrameStreamReader {
438+
impl RecordBatchReader for PartitionedDataFrameStreamReader {
430439
fn schema(&self) -> SchemaRef {
431440
self.schema.clone()
432441
}
@@ -958,7 +967,7 @@ impl PyDataFrame {
958967
requested_schema: Option<Bound<'py, PyCapsule>>,
959968
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
960969
let df = self.df.as_ref().clone();
961-
let stream = spawn_stream(py, async move { df.execute_stream().await })?;
970+
let streams = spawn_streams(py, async move { df.execute_stream_partitioned().await })?;
962971

963972
let mut schema: Schema = self.df.schema().to_owned().into();
964973
let mut projection: Option<SchemaRef> = None;
@@ -975,10 +984,11 @@ impl PyDataFrame {
975984

976985
let schema_ref = Arc::new(schema.clone());
977986

978-
let reader = DataFrameStreamReader {
979-
stream,
987+
let reader = PartitionedDataFrameStreamReader {
988+
streams,
980989
schema: schema_ref,
981990
projection,
991+
current: 0,
982992
};
983993
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
984994

0 commit comments

Comments
 (0)