Skip to content

Commit ed4d7ab

Browse files
Merge branch 'sea-migration' into ext-links-sea
2 parents dfbbf79 + 4f11ff0 commit ed4d7ab

File tree

5 files changed

+85
-12
lines changed

5 files changed

+85
-12
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def execute_command(
9494
parameters: List,
9595
async_op: bool,
9696
enforce_embedded_schema_correctness: bool,
97+
row_limit: Optional[int] = None,
9798
) -> Union["ResultSet", None]:
9899
"""
99100
Executes a SQL command or query within the specified session.
@@ -112,6 +113,7 @@ def execute_command(
112113
parameters: List of parameters to bind to the query
113114
async_op: Whether to execute the command asynchronously
114115
enforce_embedded_schema_correctness: Whether to enforce schema correctness
116+
row_limit: Maximum number of rows in the operation result.
115117
116118
Returns:
117119
If async_op is False, returns a ResultSet object containing the

src/databricks/sql/backend/sea/backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def execute_command(
407407
parameters: List[Dict[str, Any]],
408408
async_op: bool,
409409
enforce_embedded_schema_correctness: bool,
410+
row_limit: Optional[int] = None,
410411
) -> Union[SeaResultSet, None]:
411412
"""
412413
Execute a SQL command using the SEA backend.
@@ -464,7 +465,7 @@ def execute_command(
464465
format=format,
465466
wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value,
466467
on_wait_timeout="CONTINUE",
467-
row_limit=max_rows,
468+
row_limit=row_limit,
468469
parameters=sea_parameters if sea_parameters else None,
469470
result_compression=result_compression,
470471
)

src/databricks/sql/backend/thrift_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
import time
66
import threading
7-
from typing import List, Union, Any, TYPE_CHECKING
7+
from typing import List, Optional, Union, Any, TYPE_CHECKING
88

99
if TYPE_CHECKING:
1010
from databricks.sql.client import Cursor
@@ -925,6 +925,7 @@ def execute_command(
925925
parameters=[],
926926
async_op=False,
927927
enforce_embedded_schema_correctness=False,
928+
row_limit: Optional[int] = None,
928929
) -> Union["ResultSet", None]:
929930
thrift_handle = session_id.to_thrift_handle()
930931
if not thrift_handle:
@@ -965,6 +966,7 @@ def execute_command(
965966
useArrowNativeTypes=spark_arrow_types,
966967
parameters=parameters,
967968
enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness,
969+
resultRowLimit=row_limit,
968970
)
969971
resp = self.make_request(self._client.ExecuteStatement, req)
970972

src/databricks/sql/client.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,14 @@ def cursor(
335335
self,
336336
arraysize: int = DEFAULT_ARRAY_SIZE,
337337
buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
338+
row_limit: Optional[int] = None,
338339
) -> "Cursor":
339340
"""
341+
Args:
342+
arraysize: The maximum number of rows in direct results.
343+
buffer_size_bytes: The maximum number of bytes in direct results.
344+
row_limit: The maximum number of rows in the result.
345+
340346
Return a new Cursor object using the connection.
341347
342348
Will throw an Error if the connection has been closed.
@@ -349,6 +355,7 @@ def cursor(
349355
self.session.backend,
350356
arraysize=arraysize,
351357
result_buffer_size_bytes=buffer_size_bytes,
358+
row_limit=row_limit,
352359
)
353360
self._cursors.append(cursor)
354361
return cursor
@@ -382,6 +389,7 @@ def __init__(
382389
backend: DatabricksClient,
383390
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
384391
arraysize: int = DEFAULT_ARRAY_SIZE,
392+
row_limit: Optional[int] = None,
385393
) -> None:
386394
"""
387395
These objects represent a database cursor, which is used to manage the context of a fetch
@@ -391,16 +399,18 @@ def __init__(
391399
visible by other cursors or connections.
392400
"""
393401

394-
self.connection = connection
395-
self.rowcount = -1 # Return -1 as this is not supported
396-
self.buffer_size_bytes = result_buffer_size_bytes
402+
self.connection: Connection = connection
403+
404+
self.rowcount: int = -1 # Return -1 as this is not supported
405+
self.buffer_size_bytes: int = result_buffer_size_bytes
397406
self.active_result_set: Union[ResultSet, None] = None
398-
self.arraysize = arraysize
407+
self.arraysize: int = arraysize
408+
self.row_limit: Optional[int] = row_limit
399409
# Note that Cursor closed => active result set closed, but not vice versa
400-
self.open = True
401-
self.executing_command_id = None
402-
self.backend = backend
403-
self.active_command_id = None
410+
self.open: bool = True
411+
self.executing_command_id: Optional[CommandId] = None
412+
self.backend: DatabricksClient = backend
413+
self.active_command_id: Optional[CommandId] = None
404414
self.escaper = ParamEscaper()
405415
self.lastrowid = None
406416

@@ -779,6 +789,7 @@ def execute(
779789
parameters=prepared_params,
780790
async_op=False,
781791
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
792+
row_limit=self.row_limit,
782793
)
783794

784795
if self.active_result_set and self.active_result_set.is_staging_operation:
@@ -835,6 +846,7 @@ def execute_async(
835846
parameters=prepared_params,
836847
async_op=True,
837848
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
849+
row_limit=self.row_limit,
838850
)
839851

840852
return self

tests/e2e/test_driver.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ def connection(self, extra_params=()):
113113
conn.close()
114114

115115
@contextmanager
116-
def cursor(self, extra_params=()):
116+
def cursor(self, extra_params=(), extra_cursor_params=()):
117117
with self.connection(extra_params) as conn:
118118
cursor = conn.cursor(
119-
arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes
119+
arraysize=self.arraysize,
120+
buffer_size_bytes=self.buffer_size_bytes,
121+
**dict(extra_cursor_params),
120122
)
121123
try:
122124
yield cursor
@@ -988,6 +990,60 @@ def test_catalogs_returns_arrow_table(self):
988990
results = cursor.fetchall_arrow()
989991
assert isinstance(results, pyarrow.Table)
990992

993+
def test_row_limit_with_larger_result(self):
994+
"""Test that row_limit properly constrains results when query would return more rows"""
995+
row_limit = 1000
996+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
997+
# Execute a query that returns more than row_limit rows
998+
cursor.execute("SELECT * FROM range(2000)")
999+
rows = cursor.fetchall()
1000+
1001+
# Check if the number of rows is limited to row_limit
1002+
assert len(rows) == row_limit, f"Expected {row_limit} rows, got {len(rows)}"
1003+
1004+
def test_row_limit_with_smaller_result(self):
1005+
"""Test that row_limit doesn't affect results when query returns fewer rows than limit"""
1006+
row_limit = 100
1007+
expected_rows = 50
1008+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
1009+
# Execute a query that returns fewer than row_limit rows
1010+
cursor.execute(f"SELECT * FROM range({expected_rows})")
1011+
rows = cursor.fetchall()
1012+
1013+
# Check if all rows are returned (not limited by row_limit)
1014+
assert (
1015+
len(rows) == expected_rows
1016+
), f"Expected {expected_rows} rows, got {len(rows)}"
1017+
1018+
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
1019+
def test_row_limit_with_arrow_larger_result(self):
1020+
"""Test that row_limit properly constrains arrow results when query would return more rows"""
1021+
row_limit = 800
1022+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
1023+
# Execute a query that returns more than row_limit rows
1024+
cursor.execute("SELECT * FROM range(1500)")
1025+
arrow_table = cursor.fetchall_arrow()
1026+
1027+
# Check if the number of rows in the arrow table is limited to row_limit
1028+
assert (
1029+
arrow_table.num_rows == row_limit
1030+
), f"Expected {row_limit} rows, got {arrow_table.num_rows}"
1031+
1032+
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
1033+
def test_row_limit_with_arrow_smaller_result(self):
1034+
"""Test that row_limit doesn't affect arrow results when query returns fewer rows than limit"""
1035+
row_limit = 200
1036+
expected_rows = 100
1037+
with self.cursor(extra_cursor_params={"row_limit": row_limit}) as cursor:
1038+
# Execute a query that returns fewer than row_limit rows
1039+
cursor.execute(f"SELECT * FROM range({expected_rows})")
1040+
arrow_table = cursor.fetchall_arrow()
1041+
1042+
# Check if all rows are returned (not limited by row_limit)
1043+
assert (
1044+
arrow_table.num_rows == expected_rows
1045+
), f"Expected {expected_rows} rows, got {arrow_table.num_rows}"
1046+
9911047

9921048
# use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep
9931049
# the 429/503 subsuites separate since they execute under different circumstances.

0 commit comments

Comments
 (0)