Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from mssql_python import get_settings

if TYPE_CHECKING:
import pyarrow # type: ignore
from mssql_python.connection import Connection
else:
pyarrow = None

# Constants for string handling
MAX_INLINE_CHAR: int = (
Expand Down Expand Up @@ -2198,6 +2201,89 @@ def fetchall(self) -> List[Row]:
# On error, don't increment rownumber - rethrow the error
raise e

def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch":
"""
Fetch a single pyarrow Record Batch of the specified size from the
query result set.
Args:
batch_size: Maximum number of rows to fetch in the Record Batch.
Returns:
A pyarrow RecordBatch object containing up to batch_size rows.
"""
self._check_closed() # Check if the cursor is closed
if not self._has_result_set and self.description:
self._reset_rownumber()

try:
import pyarrow
except ImportError as e:
raise ImportError(
"pyarrow is required for arrow_batch(). Please install pyarrow."
) from e

capsules = []
ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0))
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules)
return batch

def arrow(self, batch_size: int = 8192) -> "pyarrow.Table":
"""
Fetch the entire result as a pyarrow Table.
Args:
batch_size: Size of the Record Batches which make up the Table.
Returns:
A pyarrow Table containing all remaining rows from the result set.
"""
try:
import pyarrow
except ImportError as e:
raise ImportError("pyarrow is required for arrow(). Please install pyarrow.") from e

batches: list["pyarrow.RecordBatch"] = []
while True:
batch = self.arrow_batch(batch_size)
if batch.num_rows < batch_size or batch_size <= 0:
if not batches or batch.num_rows > 0:
batches.append(batch)
break
batches.append(batch)
return pyarrow.Table.from_batches(batches, schema=batches[0].schema)

def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader":
"""
Fetch the result as a pyarrow RecordBatchReader, which yields Record
Batches of the specified size until the current result set is
exhausted.
Args:
batch_size: Size of the Record Batches produced by the reader.
Returns:
A pyarrow RecordBatchReader for the result set.
"""
try:
import pyarrow
except ImportError as e:
raise ImportError(
"pyarrow is required for arrow_reader(). Please install pyarrow."
) from e

# Fetch schema without advancing cursor
schema_batch = self.arrow_batch(0)
schema = schema_batch.schema

def batch_generator():
while (batch := self.arrow_batch(batch_size)).num_rows > 0:
yield batch

return pyarrow.RecordBatchReader.from_batches(schema, batch_generator())

def nextset(self) -> Union[bool, None]:
"""
Skip to the next available result set.
Expand Down
Loading
Loading