Skip to content

Commit 7c733ee

Browse files
fixed (most) tests by accounting for normalised Session interface
have not considered how to fix the protocol version check yet Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent b8f9146 commit 7c733ee

File tree

5 files changed

+110
-76
lines changed

5 files changed

+110
-76
lines changed

tests/e2e/test_driver.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -824,9 +824,7 @@ def test_close_connection_closes_cursors(self):
824824
status_request = ttypes.TGetOperationStatusReq(
825825
operationHandle=ars.command_id, getProgressUpdate=False
826826
)
827-
op_status_at_server = ars.backend._client.GetOperationStatus(
828-
status_request
829-
)
827+
op_status_at_server = ars.backend._client.GetOperationStatus(status_request)
830828
assert (
831829
op_status_at_server.operationState
832830
!= ttypes.TOperationState.CLOSED_STATE
@@ -856,17 +854,19 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog):
856854
raise KeyboardInterrupt("Simulated interrupt")
857855
finally:
858856
if conn is not None:
859-
assert not conn.open, "Connection should be closed after KeyboardInterrupt"
857+
assert (
858+
not conn.open
859+
), "Connection should be closed after KeyboardInterrupt"
860860

861861
def test_cursor_close_properly_closes_operation(self):
862862
"""Test that Cursor.close() properly closes the active operation handle on the server."""
863863
with self.connection() as conn:
864864
cursor = conn.cursor()
865865
try:
866866
cursor.execute("SELECT 1 AS test")
867-
assert cursor.active_op_handle is not None
867+
assert cursor.active_command_id is not None
868868
cursor.close()
869-
assert cursor.active_op_handle is None
869+
assert cursor.active_command_id is None
870870
assert not cursor.open
871871
finally:
872872
if cursor.open:
@@ -883,26 +883,28 @@ def test_cursor_close_properly_closes_operation(self):
883883
raise KeyboardInterrupt("Simulated interrupt")
884884
finally:
885885
if cursor is not None:
886-
assert not cursor.open, "Cursor should be closed after KeyboardInterrupt"
886+
assert (
887+
not cursor.open
888+
), "Cursor should be closed after KeyboardInterrupt"
887889

888890
def test_nested_cursor_context_managers(self):
889891
"""Test that nested cursor context managers properly close operations on the server."""
890892
with self.connection() as conn:
891893
with conn.cursor() as cursor1:
892894
cursor1.execute("SELECT 1 AS test1")
893-
assert cursor1.active_op_handle is not None
895+
assert cursor1.active_command_id is not None
894896

895897
with conn.cursor() as cursor2:
896898
cursor2.execute("SELECT 2 AS test2")
897-
assert cursor2.active_op_handle is not None
899+
assert cursor2.active_command_id is not None
898900

899901
# After inner context manager exit, cursor2 should be not open
900902
assert not cursor2.open
901-
assert cursor2.active_op_handle is None
903+
assert cursor2.active_command_id is None
902904

903905
# After outer context manager exit, cursor1 should be not open
904906
assert not cursor1.open
905-
assert cursor1.active_op_handle is None
907+
assert cursor1.active_command_id is None
906908

907909
def test_cursor_error_handling(self):
908910
"""Test that cursor close handles errors properly to prevent orphaned operations."""
@@ -911,7 +913,7 @@ def test_cursor_error_handling(self):
911913

912914
cursor.execute("SELECT 1 AS test")
913915

914-
op_handle = cursor.active_op_handle
916+
op_handle = cursor.active_command_id
915917

916918
assert op_handle is not None
917919

tests/unit/test_client.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ class ClientTestSuite(unittest.TestCase):
8181
"access_token": "tok",
8282
}
8383

84-
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME, ThriftDatabricksClientMockFactory.new())
84+
@patch(
85+
"%s.session.ThriftDatabricksClient" % PACKAGE_NAME,
86+
ThriftDatabricksClientMockFactory.new(),
87+
)
8588
@patch("%s.client.ResultSet" % PACKAGE_NAME)
8689
def test_closing_connection_closes_commands(self, mock_result_set_class):
8790
# Test once with has_been_closed_server side, once without
@@ -294,10 +297,10 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe
294297
def test_cancel_command_calls_the_backend(self):
295298
mock_thrift_backend = Mock()
296299
cursor = client.Cursor(Mock(), mock_thrift_backend)
297-
mock_op_handle = Mock()
298-
cursor.active_op_handle = mock_op_handle
300+
mock_command_id = Mock()
301+
cursor.active_command_id = mock_command_id
299302
cursor.cancel()
300-
mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle)
303+
mock_thrift_backend.cancel_command.assert_called_with(mock_command_id)
301304

302305
@patch("databricks.sql.client.logger")
303306
def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command(
@@ -531,9 +534,13 @@ def test_access_current_query_id(self):
531534

532535
self.assertIsNone(cursor.query_id)
533536

534-
cursor.active_op_handle = TOperationHandle(
535-
operationId=THandleIdentifier(guid=UUID(operation_id).bytes, secret=0x00),
536-
operationType=TOperationType.EXECUTE_STATEMENT,
537+
cursor.active_command_id = CommandId.from_thrift_handle(
538+
TOperationHandle(
539+
operationId=THandleIdentifier(
540+
guid=UUID(operation_id).bytes, secret=0x00
541+
),
542+
operationType=TOperationType.EXECUTE_STATEMENT,
543+
)
537544
)
538545
self.assertEqual(cursor.query_id.upper(), operation_id.upper())
539546

@@ -544,70 +551,72 @@ def test_cursor_close_handles_exception(self):
544551
"""Test that Cursor.close() handles exceptions from close_command properly."""
545552
mock_backend = Mock()
546553
mock_connection = Mock()
547-
mock_op_handle = Mock()
548-
554+
mock_command_id = Mock()
555+
549556
mock_backend.close_command.side_effect = Exception("Test error")
550557

551558
cursor = client.Cursor(mock_connection, mock_backend)
552-
cursor.active_op_handle = mock_op_handle
559+
cursor.active_command_id = mock_command_id
553560

554561
cursor.close()
555562

556-
mock_backend.close_command.assert_called_once_with(mock_op_handle)
557-
558-
self.assertIsNone(cursor.active_op_handle)
559-
563+
mock_backend.close_command.assert_called_once_with(mock_command_id)
564+
565+
self.assertIsNone(cursor.active_command_id)
566+
560567
self.assertFalse(cursor.open)
561568

562569
def test_cursor_context_manager_handles_exit_exception(self):
563570
"""Test that cursor's context manager handles exceptions during __exit__."""
564571
mock_backend = Mock()
565572
mock_connection = Mock()
566-
573+
567574
cursor = client.Cursor(mock_connection, mock_backend)
568575
original_close = cursor.close
569576
cursor.close = Mock(side_effect=Exception("Test error during close"))
570-
577+
571578
try:
572579
with cursor:
573580
raise ValueError("Test error inside context")
574581
except ValueError:
575582
pass
576-
583+
577584
cursor.close.assert_called_once()
578585

579586
def test_connection_close_handles_cursor_close_exception(self):
580587
"""Test that _close handles exceptions from cursor.close() properly."""
581588
cursors_closed = []
582-
589+
583590
def mock_close_with_exception():
584591
cursors_closed.append(1)
585592
raise Exception("Test error during close")
586-
593+
587594
cursor1 = Mock()
588595
cursor1.close = mock_close_with_exception
589-
596+
590597
def mock_close_normal():
591598
cursors_closed.append(2)
592-
599+
593600
cursor2 = Mock()
594601
cursor2.close = mock_close_normal
595-
602+
596603
mock_backend = Mock()
597604
mock_session_handle = Mock()
598-
605+
599606
try:
600607
for cursor in [cursor1, cursor2]:
601608
try:
602609
cursor.close()
603610
except Exception:
604611
pass
605-
612+
606613
mock_backend.close_session(mock_session_handle)
607614
except Exception as e:
608615
self.fail(f"Connection close should handle exceptions: {e}")
609-
610-
self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called")
616+
617+
self.assertEqual(
618+
cursors_closed, [1, 2], "Both cursors should have close called"
619+
)
611620

612621
def test_resultset_close_handles_cursor_already_closed_error(self):
613622
"""Test that ResultSet.close() handles CursorAlreadyClosedError properly."""
@@ -616,7 +625,7 @@ def test_resultset_close_handles_cursor_already_closed_error(self):
616625
result_set.backend.CLOSED_OP_STATE = 'CLOSED'
617626
result_set.connection = Mock()
618627
result_set.connection.open = True
619-
result_set.op_state = 'RUNNING'
628+
result_set.op_state = "RUNNING"
620629
result_set.has_been_closed_server_side = False
621630
result_set.command_id = Mock()
622631

tests/unit/test_fetches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def make_dummy_result_set_from_batch_list(batch_list):
6666
batch_index = 0
6767

6868
def fetch_results(
69-
op_handle,
69+
command_id,
7070
max_rows,
7171
max_bytes,
7272
expected_row_start_offset,

tests/unit/test_session.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
from databricks.sql.thrift_api.TCLIService.ttypes import (
66
TOpenSessionResp,
7+
TSessionHandle,
8+
THandleIdentifier,
79
)
10+
from databricks.sql.ids import SessionId, BackendType
811

912
import databricks.sql
1013

@@ -25,16 +28,17 @@ class SessionTestSuite(unittest.TestCase):
2528
def test_close_uses_the_correct_session_id(self, mock_client_class):
2629
instance = mock_client_class.return_value
2730

28-
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
29-
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
30-
instance.open_session.return_value = mock_open_session_resp
31+
# Create a mock SessionId that will be returned by open_session
32+
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
33+
instance.open_session.return_value = mock_session_id
3134

3235
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
3336
connection.close()
3437

35-
# Check the close session request has an id of x22
36-
close_session_id = instance.close_session.call_args[0][0].sessionId
37-
self.assertEqual(close_session_id, b"\x22")
38+
# Check that close_session was called with the correct SessionId
39+
close_session_call_args = instance.close_session.call_args[0][0]
40+
self.assertEqual(close_session_call_args.guid, b"\x22")
41+
self.assertEqual(close_session_call_args.secret, b"\x33")
3842

3943
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
4044
def test_auth_args(self, mock_client_class):
@@ -112,16 +116,17 @@ def test_useragent_header(self, mock_client_class):
112116
def test_context_manager_closes_connection(self, mock_client_class):
113117
instance = mock_client_class.return_value
114118

115-
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
116-
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
117-
instance.open_session.return_value = mock_open_session_resp
119+
# Create a mock SessionId that will be returned by open_session
120+
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
121+
instance.open_session.return_value = mock_session_id
118122

119123
with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection:
120124
pass
121125

122-
# Check the close session request has an id of x22
123-
close_session_id = instance.close_session.call_args[0][0].sessionId
124-
self.assertEqual(close_session_id, b"\x22")
126+
# Check that close_session was called with the correct SessionId
127+
close_session_call_args = instance.close_session.call_args[0][0]
128+
self.assertEqual(close_session_call_args.guid, b"\x22")
129+
self.assertEqual(close_session_call_args.secret, b"\x33")
125130

126131
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
127132
def test_max_number_of_retries_passthrough(self, mock_client_class):
@@ -141,46 +146,54 @@ def test_socket_timeout_passthrough(self, mock_client_class):
141146
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
142147
def test_configuration_passthrough(self, mock_client_class):
143148
mock_session_config = Mock()
149+
150+
# Create a mock SessionId that will be returned by open_session
151+
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
152+
mock_client_class.return_value.open_session.return_value = mock_session_id
153+
144154
databricks.sql.connect(
145155
session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS
146156
)
147157

148-
self.assertEqual(
149-
mock_client_class.return_value.open_session.call_args[0][0],
150-
mock_session_config,
151-
)
158+
# Check that open_session was called with the correct session_configuration
159+
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
160+
self.assertEqual(call_kwargs["session_configuration"], mock_session_config)
152161

153162
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
154163
def test_initial_namespace_passthrough(self, mock_client_class):
155164
mock_cat = Mock()
156165
mock_schem = Mock()
157166

167+
# Create a mock SessionId that will be returned by open_session
168+
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
169+
mock_client_class.return_value.open_session.return_value = mock_session_id
170+
158171
databricks.sql.connect(
159172
**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem
160173
)
161-
self.assertEqual(
162-
mock_client_class.return_value.open_session.call_args[0][1], mock_cat
163-
)
164-
self.assertEqual(
165-
mock_client_class.return_value.open_session.call_args[0][2], mock_schem
166-
)
174+
175+
# Check that open_session was called with the correct catalog and schema
176+
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
177+
self.assertEqual(call_kwargs["catalog"], mock_cat)
178+
self.assertEqual(call_kwargs["schema"], mock_schem)
167179

168180
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
169181
def test_finalizer_closes_abandoned_connection(self, mock_client_class):
170182
instance = mock_client_class.return_value
171183

172-
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
173-
mock_open_session_resp.sessionHandle.sessionId = b"\x22"
174-
instance.open_session.return_value = mock_open_session_resp
184+
# Create a mock SessionId that will be returned by open_session
185+
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
186+
instance.open_session.return_value = mock_session_id
175187

176188
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
177189

178190
# not strictly necessary as the refcount is 0, but just to be sure
179191
gc.collect()
180192

181-
# Check the close session request has an id of x22
182-
close_session_id = instance.close_session.call_args[0][0].sessionId
183-
self.assertEqual(close_session_id, b"\x22")
193+
# Check that close_session was called with the correct SessionId
194+
close_session_call_args = instance.close_session.call_args[0][0]
195+
self.assertEqual(close_session_call_args.guid, b"\x22")
196+
self.assertEqual(close_session_call_args.secret, b"\x33")
184197

185198

186199
if __name__ == "__main__":

0 commit comments

Comments
 (0)