Skip to content

Commit 5da8efe

Browse files
committed
feat: enhance DataFrame and RecordBatchStream iteration to yield pyarrow.RecordBatch objects
1 parent f459c60 commit 5da8efe

File tree

3 files changed

+25
-27
lines changed

3 files changed

+25
-27
lines changed

python/datafusion/dataframe.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,22 @@ def to_arrow_table(self) -> pa.Table:
10281028
"""
10291029
return self.df.to_arrow_table()
10301030

1031+
def __iter__(self) -> Iterator[pa.RecordBatch]:
1032+
"""Iterate over :py:class:`pyarrow.RecordBatch` objects.
1033+
1034+
This executes the DataFrame and yields each partition as a native
1035+
:py:class:`pyarrow.RecordBatch`.
1036+
1037+
Yields:
1038+
pyarrow.RecordBatch: the next batch in the result stream.
1039+
"""
1040+
for batch in self.execute_stream():
1041+
# ``execute_stream`` yields batches that may be ``RecordBatch``
1042+
# wrappers or ``pyarrow.RecordBatch`` objects directly. Convert
1043+
# to native PyArrow batches when necessary to provide a consistent
1044+
# iterator interface.
1045+
yield batch.to_pyarrow() if hasattr(batch, "to_pyarrow") else batch
1046+
10311047
def execute_stream(self) -> RecordBatchStream:
10321048
"""Executes this DataFrame and returns a stream over a single partition.
10331049

python/datafusion/record_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ async def __anext__(self) -> RecordBatch:
6969
next_batch = await self.rbs.__anext__()
7070
return RecordBatch(next_batch)
7171

72-
def __next__(self) -> RecordBatch:
72+
def __next__(self) -> pa.RecordBatch:
7373
"""Iterator function."""
7474
next_batch = next(self.rbs)
75-
return RecordBatch(next_batch)
75+
return next_batch.to_pyarrow()
7676

7777
def __aiter__(self) -> typing_extensions.Self:
7878
"""Async iterator function."""

python/tests/test_dataframe.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,7 +1505,8 @@ def test_to_arrow_table(df):
15051505

15061506
def test_execute_stream(df):
15071507
stream = df.execute_stream()
1508-
assert all(batch is not None for batch in stream)
1508+
batches = list(stream)
1509+
assert all(isinstance(batch, pa.RecordBatch) for batch in batches)
15091510
assert not list(stream) # after one iteration the generator must be exhausted
15101511

15111512

@@ -1526,30 +1527,9 @@ def test_execute_stream_to_arrow_table(df, schema):
15261527
stream = df.execute_stream()
15271528

15281529
if schema:
1529-
pyarrow_table = pa.Table.from_batches(
1530-
(batch.to_pyarrow() for batch in stream), schema=df.schema()
1531-
)
1530+
pyarrow_table = pa.Table.from_batches(stream, schema=df.schema())
15321531
else:
1533-
pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream)
1534-
1535-
assert isinstance(pyarrow_table, pa.Table)
1536-
assert pyarrow_table.shape == (3, 3)
1537-
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
1538-
1539-
1540-
@pytest.mark.asyncio
1541-
@pytest.mark.parametrize("schema", [True, False])
1542-
async def test_execute_stream_to_arrow_table_async(df, schema):
1543-
stream = df.execute_stream()
1544-
1545-
if schema:
1546-
pyarrow_table = pa.Table.from_batches(
1547-
[batch.to_pyarrow() async for batch in stream], schema=df.schema()
1548-
)
1549-
else:
1550-
pyarrow_table = pa.Table.from_batches(
1551-
[batch.to_pyarrow() async for batch in stream]
1552-
)
1532+
pyarrow_table = pa.Table.from_batches(stream)
15531533

15541534
assert isinstance(pyarrow_table, pa.Table)
15551535
assert pyarrow_table.shape == (3, 3)
@@ -1558,7 +1538,9 @@ async def test_execute_stream_to_arrow_table_async(df, schema):
15581538

15591539
def test_execute_stream_partitioned(df):
15601540
streams = df.execute_stream_partitioned()
1561-
assert all(batch is not None for stream in streams for batch in stream)
1541+
assert all(
1542+
isinstance(batch, pa.RecordBatch) for stream in streams for batch in stream
1543+
)
15621544
assert all(
15631545
not list(stream) for stream in streams
15641546
) # after one iteration all generators must be exhausted

0 commit comments

Comments
 (0)