From 73832906b769feeea8865c0ad1311032d01de026 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 09:55:05 +0530 Subject: [PATCH 1/8] decouple session class from existing Connection ensure maintenance of current APIs of Connection while delegating responsibility Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 152 ++++++++++------------------------ src/databricks/sql/session.py | 146 ++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 108 deletions(-) create mode 100644 src/databricks/sql/session.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index f24a6584a..0fbd20df5 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -19,6 +19,8 @@ OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError, + Error, + NotSupportedError, ) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend @@ -45,6 +47,7 @@ from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence +from databricks.sql.session import Session from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -218,66 +221,24 @@ def read(self) -> Optional[OAuthToken]: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} - self.open = False - self.host = server_hostname - self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) + self._cursors = [] # type: List[Cursor] - auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - - user_agent_entry = kwargs.get("user_agent_entry") - if user_agent_entry is None: - user_agent_entry = kwargs.get("_user_agent_entry") - if user_agent_entry is not None: - logger.warning( - "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " - "This parameter will be removed in the upcoming releases." - ) - - if user_agent_entry: - useragent_header = "{}/{} ({})".format( - USER_AGENT_NAME, __version__, user_agent_entry - ) - else: - useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) - - base_headers = [("User-Agent", useragent_header)] - - self._ssl_options = SSLOptions( - # Double negation is generally a bad thing, but we have to keep backward compatibility - tls_verify=not kwargs.get( - "_tls_no_verify", False - ), # by default - verify cert and host - tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), - tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), - tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), - ) - - self.thrift_backend = ThriftBackend( - self.host, - self.port, + # Create the session + self.session = Session( + server_hostname, http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, - ) - - self._open_session_resp = self.thrift_backend.open_session( - session_configuration, catalog, schema + http_headers, + session_configuration, + catalog, + schema, + _use_arrow_native_complex_types, + **kwargs ) - self._session_handle = self._open_session_resp.sessionHandle - self.protocol_version = self.get_protocol_version(self._open_session_resp) - self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) - self.open = True - logger.info("Successfully opened session " + str(self.get_session_id_hex())) - self._cursors = [] # type: List[Cursor] + + logger.info("Successfully opened connection with session " + str(self.get_session_id_hex())) self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) @@ -318,7 +279,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): - if self.open: + if self.session.open: logger.debug( "Closing unclosed connection for session " "{}".format(self.get_session_id_hex()) @@ -330,34 +291,27 @@ def __del__(self): logger.debug("Couldn't close unclosed connection: {}".format(e.message)) def get_session_id(self): - return self.thrift_backend.handle_to_id(self._session_handle) + """Get the session ID from the Session object""" + return self.session.get_session_id() - @staticmethod - def get_protocol_version(openSessionResp): - """ - Since the sessionHandle will sometimes have a serverProtocolVersion, it takes - precedence over the serverProtocolVersion defined in the OpenSessionResponse. - """ - if ( - openSessionResp.sessionHandle - and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") - and openSessionResp.sessionHandle.serverProtocolVersion - ): - return openSessionResp.sessionHandle.serverProtocolVersion - return openSessionResp.serverProtocolVersion + def get_session_id_hex(self): + """Get the session ID in hex format from the Session object""" + return self.session.get_session_id_hex() @staticmethod def server_parameterized_queries_enabled(protocolVersion): - if ( - protocolVersion - and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 - ): - return True - else: - return False + """Delegate to Session class static method""" + return Session.server_parameterized_queries_enabled(protocolVersion) - def get_session_id_hex(self): - return self.thrift_backend.handle_to_hex_id(self._session_handle) + @property + def protocol_version(self): + """Get the protocol version from the Session object""" + return self.session.protocol_version + + @staticmethod + def get_protocol_version(openSessionResp): + """Delegate to Session class static method""" + return Session.get_protocol_version(openSessionResp) def cursor( self, @@ -369,12 +323,12 @@ def cursor( Will throw an Error if the connection has been closed. """ - if not self.open: + if not self.session.open: raise Error("Cannot create cursor from closed connection") cursor = Cursor( self, - self.thrift_backend, + self.session.thrift_backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, ) @@ -390,28 +344,10 @@ def _close(self, close_cursors=True) -> None: for cursor in self._cursors: cursor.close() - logger.info(f"Closing session {self.get_session_id_hex()}") - if not self.open: - logger.debug("Session appears to have been closed already") - try: - self.thrift_backend.close_session(self._session_handle) - except RequestError as e: - if isinstance(e.args[1], SessionAlreadyClosedError): - logger.info("Session was closed by a prior request") - except DatabaseError as e: - if "Invalid SessionHandle" in str(e): - logger.warning( - f"Attempted to close session that was already closed: {e}" - ) - else: - logger.warning( - f"Attempt to close session raised an exception at the server: {e}" - ) + self.session.close() except Exception as e: - logger.error(f"Attempt to close session raised a local exception: {e}") - - self.open = False + logger.error(f"Attempt to close session raised an exception: {e}") def commit(self): """No-op because Databricks does not support transactions""" @@ -811,7 +747,7 @@ def execute( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -874,7 +810,7 @@ def execute_async( self._close_and_clear_active_result_set() self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -970,7 +906,7 @@ def catalogs(self) -> "Cursor": self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -996,7 +932,7 @@ def schemas( self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1029,7 +965,7 @@ def tables( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_tables( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1064,7 +1000,7 @@ def columns( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_columns( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1493,7 +1429,7 @@ def close(self) -> None: if ( self.op_state != self.thrift_backend.CLOSED_OP_STATE and not self.has_been_closed_server_side - and self.connection.open + and self.connection.session.open ): self.thrift_backend.close_command(self.command_id) except RequestError as e: diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py new file mode 100644 index 000000000..4920550e7 --- /dev/null +++ b/src/databricks/sql/session.py @@ -0,0 +1,146 @@ +import logging +from typing import Dict, Tuple, List, Optional, Any + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions +from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError +from databricks.sql import __version__ +from databricks.sql import USER_AGENT_NAME +from databricks.sql.thrift_backend import ThriftBackend + +logger = logging.getLogger(__name__) + + +class Session: + def __init__( + self, + server_hostname: str, + http_path: str, + http_headers: Optional[List[Tuple[str, str]]] = None, + session_configuration: Optional[Dict[str, Any]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + _use_arrow_native_complex_types: Optional[bool] = True, + **kwargs, + ) -> None: + """ + Create a session to a Databricks SQL endpoint or a Databricks cluster. + + This class handles all session-related behavior and communication with the backend. + """ + self.open = False + self.host = server_hostname + self.port = kwargs.get("_port", 443) + + auth_provider = get_python_sql_connector_auth_provider( + server_hostname, **kwargs + ) + + user_agent_entry = kwargs.get("user_agent_entry") + if user_agent_entry is None: + user_agent_entry = kwargs.get("_user_agent_entry") + if user_agent_entry is not None: + logger.warning( + "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " + "This parameter will be removed in the upcoming releases." + ) + + if user_agent_entry: + useragent_header = "{}/{} ({})".format( + USER_AGENT_NAME, __version__, user_agent_entry + ) + else: + useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) + + base_headers = [("User-Agent", useragent_header)] + + self._ssl_options = SSLOptions( + # Double negation is generally a bad thing, but we have to keep backward compatibility + tls_verify=not kwargs.get( + "_tls_no_verify", False + ), # by default - verify cert and host + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + self.thrift_backend = ThriftBackend( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) + + self._open_session_resp = self.thrift_backend.open_session( + session_configuration, catalog, schema + ) + self._session_handle = self._open_session_resp.sessionHandle + self.protocol_version = self.get_protocol_version(self._open_session_resp) + self.open = True + logger.info("Successfully opened session " + str(self.get_session_id_hex())) + + @staticmethod + def get_protocol_version(openSessionResp): + """ + Since the sessionHandle will sometimes have a serverProtocolVersion, it takes + precedence over the serverProtocolVersion defined in the OpenSessionResponse. + """ + if ( + openSessionResp.sessionHandle + and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") + and openSessionResp.sessionHandle.serverProtocolVersion + ): + return openSessionResp.sessionHandle.serverProtocolVersion + return openSessionResp.serverProtocolVersion + + @staticmethod + def server_parameterized_queries_enabled(protocolVersion): + if ( + protocolVersion + and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + ): + return True + else: + return False + + def get_session_handle(self): + return self._session_handle + + def get_session_id(self): + return self.thrift_backend.handle_to_id(self._session_handle) + + def get_session_id_hex(self): + return self.thrift_backend.handle_to_hex_id(self._session_handle) + + def close(self) -> None: + """Close the underlying session.""" + logger.info(f"Closing session {self.get_session_id_hex()}") + if not self.open: + logger.debug("Session appears to have been closed already") + return + + try: + self.thrift_backend.close_session(self._session_handle) + except RequestError as e: + if isinstance(e.args[1], SessionAlreadyClosedError): + logger.info("Session was closed by a prior request") + except DatabaseError as e: + if "Invalid SessionHandle" in str(e): + logger.warning( + f"Attempted to close session that was already closed: {e}" + ) + else: + logger.warning( + f"Attempt to close session raised an exception at the server: {e}" + ) + except Exception as e: + logger.error(f"Attempt to close session raised a local exception: {e}") + + self.open = False \ No newline at end of file From 0e6efd88af09da9c4e5783fc77c14bcc744fb339 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 10:35:55 +0530 Subject: [PATCH 2/8] add open property to Connection to ensure maintenance of existing API Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0fbd20df5..6c89ef0a1 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -313,6 +313,11 @@ def get_protocol_version(openSessionResp): """Delegate to Session class static method""" return Session.get_protocol_version(openSessionResp) + @property + def open(self) -> bool: + """Return whether the connection is open by checking if the session is open.""" + return self.session.open + def cursor( self, arraysize: int = DEFAULT_ARRAY_SIZE, From e7ebe2bb5584501225a79cc2ac6bdfdb44696c59 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 10:36:36 +0530 Subject: [PATCH 3/8] update unit tests to address ThriftBackend through session instead of through Connection Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 49 ++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index c39aeb524..58607cf49 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -80,7 +80,7 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value @@ -95,7 +95,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class): close_session_id = instance.close_session.call_args[0][0].sessionId self.assertEqual(close_session_id, b"\x22") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_auth_args(self, mock_client_class): # Test that the following auth args work: # token = foo, @@ -122,7 +122,7 @@ def test_auth_args(self, mock_client_class): self.assertEqual(args["http_path"], http_path) connection.close() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) @@ -130,7 +130,7 @@ def test_http_header_passthrough(self, mock_client_class): call_args = mock_client_class.call_args[0][3] self.assertIn(("foo", "bar"), call_args) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, @@ -146,7 +146,7 @@ def test_tls_arg_passthrough(self, mock_client_class): self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -167,7 +167,7 @@ def test_useragent_header(self, mock_client_class): http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): # Test once with has_been_closed_server side, once without @@ -184,7 +184,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): ) mock_result_set_class.return_value.close.assert_called_once_with() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) @@ -194,7 +194,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection.cursor() self.assertIn("closed", str(cm.exception)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough( self, mock_cursor_class, mock_client_class @@ -214,7 +214,10 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): thrift_backend=mock_backend, execute_response=Mock(), ) - mock_connection.open = False + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = False + type(mock_connection).session = PropertyMock(return_value=mock_session) result_set.close() @@ -226,7 +229,11 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.has_been_closed_server_side = False mock_connection = Mock() mock_thrift_backend = Mock() - mock_connection.open = True + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = True + type(mock_connection).session = PropertyMock(return_value=mock_session) + result_set = client.ResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -283,7 +290,7 @@ def test_context_manager_closes_cursor(self): cursor.close = mock_close mock_close.assert_called_once_with() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value @@ -396,7 +403,7 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self.assertTrue(logger_instance.warning.called) self.assertFalse(mock_thrift_backend.cancel_command.called) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): databricks.sql.connect( _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS @@ -406,7 +413,7 @@ def test_max_number_of_retries_passthrough(self, mock_client_class): mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) @@ -419,7 +426,7 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() databricks.sql.connect( @@ -431,7 +438,7 @@ def test_configuration_passthrough(self, mock_client_class): mock_session_config, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() @@ -505,7 +512,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( "last operation", ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() @@ -518,7 +525,7 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): @@ -603,7 +610,7 @@ def test_column_name_api(self): }, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_finalizer_closes_abandoned_connection(self, mock_client_class): instance = mock_client_class.return_value @@ -620,7 +627,7 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): close_session_id = instance.close_session.call_args[0][0].sessionId self.assertEqual(close_session_id, b"\x22") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value @@ -639,7 +646,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_staging_operation_response_is_handled( self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): @@ -658,7 +665,7 @@ def test_staging_operation_response_is_handled( mock_handle_staging_operation.call_count == 1 - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) def test_access_current_query_id(self): operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" From c63f6fd5fd254e7c4b97761583e621b9b0c65ecc Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 10:43:24 +0530 Subject: [PATCH 4/8] chore: move session specific tests from test_client to test_session Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 161 ------------------------------- tests/unit/test_session.py | 187 +++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 161 deletions(-) create mode 100644 tests/unit/test_session.py diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 58607cf49..ecbf3493b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -80,93 +80,6 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_close_uses_the_correct_session_id(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_auth_args(self, mock_client_class): - # Test that the following auth args work: - # token = foo, - # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True - connection_args = [ - { - "server_hostname": "foo", - "http_path": None, - "access_token": "tok", - }, - { - "server_hostname": "foo", - "http_path": None, - "_tls_client_cert_file": "something", - "_use_cert_as_auth": True, - "access_token": None, - }, - ] - - for args in connection_args: - connection = databricks.sql.connect(**args) - host, port, http_path, *_ = mock_client_class.call_args[0] - self.assertEqual(args["server_hostname"], host) - self.assertEqual(args["http_path"], http_path) - connection.close() - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_http_header_passthrough(self, mock_client_class): - http_headers = [("foo", "bar")] - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) - - call_args = mock_client_class.call_args[0][3] - self.assertIn(("foo", "bar"), call_args) - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, - _tls_verify_hostname="hostname", - _tls_trusted_ca_file="trusted ca file", - _tls_client_cert_key_file="trusted client cert", - _tls_client_cert_key_password="key password", - ) - - kwargs = mock_client_class.call_args[1] - self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") - self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") - self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") - self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_useragent_header(self, mock_client_class): - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - http_headers = mock_client_class.call_args[0][3] - user_agent_header = ( - "User-Agent", - "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), - ) - self.assertIn(user_agent_header, http_headers) - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") - user_agent_header_with_entry = ( - "User-Agent", - "{}/{} ({})".format( - databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" - ), - ) - http_headers = mock_client_class.call_args[0][3] - self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): @@ -290,21 +203,6 @@ def test_context_manager_closes_cursor(self): cursor.close = mock_close mock_close.assert_called_once_with() - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - pass - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - def dict_product(self, dicts): """ Generate cartesion product of values in input dictionary, outputting a dictionary @@ -403,21 +301,6 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self.assertTrue(logger_instance.warning.called) self.assertFalse(mock_thrift_backend.cancel_command.called) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect( - _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 - ) - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_socket_timeout_passthrough(self, mock_client_class): - databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) - self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) - def test_version_is_canonical(self): version = databricks.sql.__version__ canonical_version_re = ( @@ -426,33 +309,6 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() - databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_initial_namespace_passthrough(self, mock_client_class): - mock_cat = Mock() - mock_schem = Mock() - - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) - def test_execute_parameter_passthrough(self): mock_thrift_backend = ThriftBackendMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) @@ -610,23 +466,6 @@ def test_column_name_api(self): }, ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_finalizer_closes_abandoned_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - # not strictly necessary as the refcount is 0, but just to be sure - gc.collect() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000..6a49abef6 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,187 @@ +import unittest +from unittest.mock import patch, MagicMock, Mock, PropertyMock +import gc + +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, +) + +import databricks.sql + + +class SessionTestSuite(unittest.TestCase): + """ + Unit tests for Session functionality + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_close_uses_the_correct_session_id(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_auth_args(self, mock_client_class): + # Test that the following auth args work: + # token = foo, + # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True + connection_args = [ + { + "server_hostname": "foo", + "http_path": None, + "access_token": "tok", + }, + { + "server_hostname": "foo", + "http_path": None, + "_tls_client_cert_file": "something", + "_use_cert_as_auth": True, + "access_token": None, + }, + ] + + for args in connection_args: + connection = databricks.sql.connect(**args) + host, port, http_path, *_ = mock_client_class.call_args[0] + self.assertEqual(args["server_hostname"], host) + self.assertEqual(args["http_path"], http_path) + connection.close() + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_http_header_passthrough(self, mock_client_class): + http_headers = [("foo", "bar")] + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) + + call_args = mock_client_class.call_args[0][3] + self.assertIn(("foo", "bar"), call_args) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_tls_arg_passthrough(self, mock_client_class): + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, + _tls_verify_hostname="hostname", + _tls_trusted_ca_file="trusted ca file", + _tls_client_cert_key_file="trusted client cert", + _tls_client_cert_key_password="key password", + ) + + kwargs = mock_client_class.call_args[1] + self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") + self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") + self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") + self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_useragent_header(self, mock_client_class): + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + http_headers = mock_client_class.call_args[0][3] + user_agent_header = ( + "User-Agent", + "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), + ) + self.assertIn(user_agent_header, http_headers) + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") + user_agent_header_with_entry = ( + "User-Agent", + "{}/{} ({})".format( + databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" + ), + ) + http_headers = mock_client_class.call_args[0][3] + self.assertIn(user_agent_header_with_entry, http_headers) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_context_manager_closes_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + pass + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_max_number_of_retries_passthrough(self, mock_client_class): + databricks.sql.connect( + _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_socket_timeout_passthrough(self, mock_client_class): + databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) + self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_configuration_passthrough(self, mock_client_class): + mock_session_config = Mock() + databricks.sql.connect( + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][0], + mock_session_config, + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_initial_namespace_passthrough(self, mock_client_class): + mock_cat = Mock() + mock_schem = Mock() + + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][1], mock_cat + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][2], mock_schem + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_finalizer_closes_abandoned_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + # not strictly necessary as the refcount is 0, but just to be sure + gc.collect() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 23cc72a0e6243e63c26b6887b7ea799c5adccca0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 11:14:16 +0530 Subject: [PATCH 5/8] formatting (black) as in CONTRIBUTING.md Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 9 ++++++--- src/databricks/sql/session.py | 6 +++--- tests/unit/test_client.py | 2 +- tests/unit/test_session.py | 4 ++-- tests/unit/test_thrift_backend.py | 4 +++- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 6c89ef0a1..098aa0548 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -235,10 +235,13 @@ def read(self) -> Optional[OAuthToken]: catalog, schema, _use_arrow_native_complex_types, - **kwargs + **kwargs, + ) + + logger.info( + "Successfully opened connection with session " + + str(self.get_session_id_hex()) ) - - logger.info("Successfully opened connection with session " + str(self.get_session_id_hex())) self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 4920550e7..a308b71d5 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -26,13 +26,13 @@ def __init__( ) -> None: """ Create a session to a Databricks SQL endpoint or a Databricks cluster. - + This class handles all session-related behavior and communication with the backend. """ self.open = False self.host = server_hostname self.port = kwargs.get("_port", 443) - + auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs ) @@ -143,4 +143,4 @@ def close(self) -> None: except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") - self.open = False \ No newline at end of file + self.open = False diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index ecbf3493b..b67101943 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -146,7 +146,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session = Mock() mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - + result_set = client.ResultSet( mock_connection, mock_results_response, mock_thrift_backend ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6a49abef6..eb392a229 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -11,7 +11,7 @@ class SessionTestSuite(unittest.TestCase): """ - Unit tests for Session functionality + Unit tests for Session functionality """ PACKAGE_NAME = "databricks.sql" @@ -184,4 +184,4 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7fe318446..458ea9a82 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -86,7 +86,9 @@ def test_make_request_checks_thrift_status_code(self): def _make_type_desc(self, type): return ttypes.TTypeDesc( - types=[ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type))] + types=[ + ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type)) + ] ) def _make_fake_thrift_backend(self): From 49d2bcd97b8af11a4135749a74913c7d67eee828 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 16:18:44 +0530 Subject: [PATCH 6/8] use connection open property instead of long chain through session Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 098aa0548..54a097641 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -282,7 +282,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): - if self.session.open: + if self.open: logger.debug( "Closing unclosed connection for session " "{}".format(self.get_session_id_hex()) @@ -331,7 +331,7 @@ def cursor( Will throw an Error if the connection has been closed. """ - if not self.session.open: + if not self.open: raise Error("Cannot create cursor from closed connection") cursor = Cursor( @@ -1437,7 +1437,7 @@ def close(self) -> None: if ( self.op_state != self.thrift_backend.CLOSED_OP_STATE and not self.has_been_closed_server_side - and self.connection.session.open + and self.connection.open ): self.thrift_backend.close_command(self.command_id) except RequestError as e: From 078f91aa101778db773ccf170f6cada552da275f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 16:47:02 +0530 Subject: [PATCH 7/8] trigger integration workflow Signed-off-by: varun-edachali-dbx From c676f9b0281cc3e4fe9c6d8216cc62fc75eade3b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 21 May 2025 10:36:31 +0530 Subject: [PATCH 8/8] fix: ensure open attribute of Connection never fails in case the openSession takes long, the initialisation of the session will not complete immediately. This could make the session attribute inaccessible. If the Connection is deleted in this time, the open() check will throw because the session attribute does not exist. Thus, we default to the Connection being closed in this case. This was not an issue before because open was a direct attribute of the Connection class. Caught in the integration tests. Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 54a097641..4ddc5069b 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -19,8 +19,6 @@ OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError, - Error, - NotSupportedError, ) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend @@ -319,7 +317,9 @@ def get_protocol_version(openSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return self.session.open + # NOTE: we have to check for the existence of session in case the __del__ is called + # before the session is instantiated + return hasattr(self, "session") and self.session.open def cursor( self,