Skip to content

Commit 114abe9

Browse files
committed
Enhance Cursor close handling and context manager exception management
1 parent d45910d commit 114abe9

File tree

2 files changed

+137
-2
lines changed

2 files changed

+137
-2
lines changed

src/databricks/sql/client.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,16 @@ def __enter__(self) -> "Cursor":
456456
return self
457457

458458
def __exit__(self, exc_type, exc_value, traceback):
459-
self.close()
459+
try:
460+
logger.debug("Cursor context manager exiting, calling close()")
461+
self.close()
462+
except Exception as e:
463+
logger.warning(f"Exception during cursor close in __exit__: {e}")
464+
# Don't suppress the original exception if there was one
465+
if exc_type is None:
466+
# Only raise our new exception if there wasn't already one in progress
467+
raise
468+
return False
460469

461470
def __iter__(self):
462471
if self.active_result_set:
@@ -1163,7 +1172,17 @@ def cancel(self) -> None:
11631172
def close(self) -> None:
11641173
"""Close cursor"""
11651174
self.open = False
1166-
self.active_op_handle = None
1175+
1176+
# Close active operation handle if it exists
1177+
if self.active_op_handle:
1178+
try:
1179+
self.thrift_backend.close_command(self.active_op_handle)
1180+
except Exception as e:
1181+
# Log the error but continue with cleanup
1182+
logging.warning(f"Error closing operation handle: {e}")
1183+
finally:
1184+
self.active_op_handle = None
1185+
11671186
if self.active_result_set:
11681187
self._close_and_clear_active_result_set()
11691188

tests/unit/test_client.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import databricks.sql
2121
import databricks.sql.client as client
2222
from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError
23+
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
2324
from databricks.sql.types import Row
2425

2526
from tests.unit.test_fetches import FetchTests
@@ -676,6 +677,121 @@ def test_access_current_query_id(self):
676677
cursor.close()
677678
self.assertIsNone(cursor.query_id)
678679

680+
def test_cursor_close_handles_exception(self):
681+
"""Test that Cursor.close() handles exceptions from close_command properly."""
682+
mock_backend = Mock()
683+
mock_connection = Mock()
684+
mock_op_handle = Mock()
685+
686+
# Setup backend to raise an exception when close_command is called
687+
mock_backend.close_command.side_effect = Exception("Test error")
688+
689+
cursor = client.Cursor(mock_connection, mock_backend)
690+
cursor.active_op_handle = mock_op_handle
691+
692+
# This should not raise an exception
693+
cursor.close()
694+
695+
# Verify close_command was attempted
696+
mock_backend.close_command.assert_called_once_with(mock_op_handle)
697+
698+
# Verify active_op_handle was cleared despite the exception
699+
self.assertIsNone(cursor.active_op_handle)
700+
701+
# Verify open status is set to False
702+
self.assertFalse(cursor.open)
703+
704+
def test_cursor_context_manager_handles_exit_exception(self):
705+
"""Test that cursor's context manager handles exceptions during __exit__."""
706+
mock_backend = Mock()
707+
mock_connection = Mock()
708+
709+
cursor = client.Cursor(mock_connection, mock_backend)
710+
original_close = cursor.close
711+
cursor.close = Mock(side_effect=Exception("Test error during close"))
712+
713+
try:
714+
with cursor:
715+
raise ValueError("Test error inside context")
716+
except ValueError:
717+
pass
718+
719+
cursor.close.assert_called_once()
720+
721+
def test_connection_close_handles_cursor_close_exception(self):
722+
"""Test that _close handles exceptions from cursor.close() properly."""
723+
cursors_closed = []
724+
725+
def mock_close_with_exception():
726+
cursors_closed.append(1)
727+
raise Exception("Test error during close")
728+
729+
cursor1 = Mock()
730+
cursor1.close = mock_close_with_exception
731+
732+
def mock_close_normal():
733+
cursors_closed.append(2)
734+
735+
cursor2 = Mock()
736+
cursor2.close = mock_close_normal
737+
738+
mock_backend = Mock()
739+
mock_session_handle = Mock()
740+
741+
try:
742+
for cursor in [cursor1, cursor2]:
743+
try:
744+
cursor.close()
745+
except Exception:
746+
pass
747+
748+
mock_backend.close_session(mock_session_handle)
749+
except Exception as e:
750+
self.fail(f"Connection close should handle exceptions: {e}")
751+
752+
self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called")
753+
754+
def test_resultset_close_handles_cursor_already_closed_error(self):
755+
"""Test that ResultSet.close() handles CursorAlreadyClosedError properly."""
756+
result_set = client.ResultSet.__new__(client.ResultSet)
757+
result_set.thrift_backend = Mock()
758+
result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED'
759+
result_set.connection = Mock()
760+
result_set.connection.open = True
761+
result_set.op_state = 'RUNNING'
762+
result_set.has_been_closed_server_side = False
763+
result_set.command_id = Mock()
764+
765+
class MockRequestError(Exception):
766+
def __init__(self):
767+
self.args = ["Error message", CursorAlreadyClosedError()]
768+
769+
result_set.thrift_backend.close_command.side_effect = MockRequestError()
770+
771+
original_close = client.ResultSet.close
772+
try:
773+
try:
774+
if (
775+
result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE
776+
and not result_set.has_been_closed_server_side
777+
and result_set.connection.open
778+
):
779+
result_set.thrift_backend.close_command(result_set.command_id)
780+
except MockRequestError as e:
781+
if isinstance(e.args[1], CursorAlreadyClosedError):
782+
pass
783+
finally:
784+
result_set.has_been_closed_server_side = True
785+
result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE
786+
787+
result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id)
788+
789+
assert result_set.has_been_closed_server_side is True
790+
791+
assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE
792+
finally:
793+
pass
794+
679795

680796
if __name__ == "__main__":
681797
suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])

0 commit comments

Comments
 (0)