1919 OperationalError ,
2020 SessionAlreadyClosedError ,
2121 CursorAlreadyClosedError ,
22+ Error ,
23+ NotSupportedError ,
2224)
2325from databricks .sql .thrift_api .TCLIService import ttypes
2426from databricks .sql .thrift_backend import ThriftBackend
4547from databricks .sql .types import Row , SSLOptions
4648from databricks .sql .auth .auth import get_python_sql_connector_auth_provider
4749from databricks .sql .experimental .oauth_persistence import OAuthPersistence
50+ from databricks .sql .session import Session
4851
4952from databricks .sql .thrift_api .TCLIService .ttypes import (
5053 TSparkParameter ,
@@ -218,66 +221,24 @@ def read(self) -> Optional[OAuthToken]:
218221 access_token_kv = {"access_token" : access_token }
219222 kwargs = {** kwargs , ** access_token_kv }
220223
221- self .open = False
222- self .host = server_hostname
223- self .port = kwargs .get ("_port" , 443 )
224224 self .disable_pandas = kwargs .get ("_disable_pandas" , False )
225225 self .lz4_compression = kwargs .get ("enable_query_result_lz4_compression" , True )
226+ self .use_cloud_fetch = kwargs .get ("use_cloud_fetch" , True )
227+ self ._cursors = [] # type: List[Cursor]
226228
227- auth_provider = get_python_sql_connector_auth_provider (
228- server_hostname , ** kwargs
229- )
230-
231- user_agent_entry = kwargs .get ("user_agent_entry" )
232- if user_agent_entry is None :
233- user_agent_entry = kwargs .get ("_user_agent_entry" )
234- if user_agent_entry is not None :
235- logger .warning (
236- "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
237- "This parameter will be removed in the upcoming releases."
238- )
239-
240- if user_agent_entry :
241- useragent_header = "{}/{} ({})" .format (
242- USER_AGENT_NAME , __version__ , user_agent_entry
243- )
244- else :
245- useragent_header = "{}/{}" .format (USER_AGENT_NAME , __version__ )
246-
247- base_headers = [("User-Agent" , useragent_header )]
248-
249- self ._ssl_options = SSLOptions (
250- # Double negation is generally a bad thing, but we have to keep backward compatibility
251- tls_verify = not kwargs .get (
252- "_tls_no_verify" , False
253- ), # by default - verify cert and host
254- tls_verify_hostname = kwargs .get ("_tls_verify_hostname" , True ),
255- tls_trusted_ca_file = kwargs .get ("_tls_trusted_ca_file" ),
256- tls_client_cert_file = kwargs .get ("_tls_client_cert_file" ),
257- tls_client_cert_key_file = kwargs .get ("_tls_client_cert_key_file" ),
258- tls_client_cert_key_password = kwargs .get ("_tls_client_cert_key_password" ),
259- )
260-
261- self .thrift_backend = ThriftBackend (
262- self .host ,
263- self .port ,
229+ # Create the session
230+ self .session = Session (
231+ server_hostname ,
264232 http_path ,
265- (http_headers or []) + base_headers ,
266- auth_provider ,
267- ssl_options = self ._ssl_options ,
268- _use_arrow_native_complex_types = _use_arrow_native_complex_types ,
269- ** kwargs ,
270- )
271-
272- self ._open_session_resp = self .thrift_backend .open_session (
273- session_configuration , catalog , schema
233+ http_headers ,
234+ session_configuration ,
235+ catalog ,
236+ schema ,
237+ _use_arrow_native_complex_types ,
238+ ** kwargs
274239 )
275- self ._session_handle = self ._open_session_resp .sessionHandle
276- self .protocol_version = self .get_protocol_version (self ._open_session_resp )
277- self .use_cloud_fetch = kwargs .get ("use_cloud_fetch" , True )
278- self .open = True
279- logger .info ("Successfully opened session " + str (self .get_session_id_hex ()))
280- self ._cursors = [] # type: List[Cursor]
240+
241+ logger .info ("Successfully opened connection with session " + str (self .get_session_id_hex ()))
281242
282243 self .use_inline_params = self ._set_use_inline_params_with_warning (
283244 kwargs .get ("use_inline_params" , False )
@@ -318,7 +279,7 @@ def __exit__(self, exc_type, exc_value, traceback):
318279 self .close ()
319280
320281 def __del__ (self ):
321- if self .open :
282+ if self .session . open :
322283 logger .debug (
323284 "Closing unclosed connection for session "
324285 "{}" .format (self .get_session_id_hex ())
@@ -330,34 +291,27 @@ def __del__(self):
330291 logger .debug ("Couldn't close unclosed connection: {}" .format (e .message ))
331292
332293 def get_session_id (self ):
333- return self .thrift_backend .handle_to_id (self ._session_handle )
294+ """Get the session ID from the Session object"""
295+ return self .session .get_session_id ()
334296
335- @staticmethod
336- def get_protocol_version (openSessionResp ):
337- """
338- Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
339- precedence over the serverProtocolVersion defined in the OpenSessionResponse.
340- """
341- if (
342- openSessionResp .sessionHandle
343- and hasattr (openSessionResp .sessionHandle , "serverProtocolVersion" )
344- and openSessionResp .sessionHandle .serverProtocolVersion
345- ):
346- return openSessionResp .sessionHandle .serverProtocolVersion
347- return openSessionResp .serverProtocolVersion
297+ def get_session_id_hex (self ):
298+ """Get the session ID in hex format from the Session object"""
299+ return self .session .get_session_id_hex ()
348300
349301 @staticmethod
350302 def server_parameterized_queries_enabled (protocolVersion ):
351- if (
352- protocolVersion
353- and protocolVersion >= ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V8
354- ):
355- return True
356- else :
357- return False
303+ """Delegate to Session class static method"""
304+ return Session .server_parameterized_queries_enabled (protocolVersion )
358305
359- def get_session_id_hex (self ):
360- return self .thrift_backend .handle_to_hex_id (self ._session_handle )
306+ @property
307+ def protocol_version (self ):
308+ """Get the protocol version from the Session object"""
309+ return self .session .protocol_version
310+
311+ @staticmethod
312+ def get_protocol_version (openSessionResp ):
313+ """Delegate to Session class static method"""
314+ return Session .get_protocol_version (openSessionResp )
361315
362316 def cursor (
363317 self ,
@@ -369,12 +323,12 @@ def cursor(
369323
370324 Will throw an Error if the connection has been closed.
371325 """
372- if not self .open :
326+ if not self .session . open :
373327 raise Error ("Cannot create cursor from closed connection" )
374328
375329 cursor = Cursor (
376330 self ,
377- self .thrift_backend ,
331+ self .session . thrift_backend ,
378332 arraysize = arraysize ,
379333 result_buffer_size_bytes = buffer_size_bytes ,
380334 )
@@ -390,28 +344,10 @@ def _close(self, close_cursors=True) -> None:
390344 for cursor in self ._cursors :
391345 cursor .close ()
392346
393- logger .info (f"Closing session { self .get_session_id_hex ()} " )
394- if not self .open :
395- logger .debug ("Session appears to have been closed already" )
396-
397347 try :
398- self .thrift_backend .close_session (self ._session_handle )
399- except RequestError as e :
400- if isinstance (e .args [1 ], SessionAlreadyClosedError ):
401- logger .info ("Session was closed by a prior request" )
402- except DatabaseError as e :
403- if "Invalid SessionHandle" in str (e ):
404- logger .warning (
405- f"Attempted to close session that was already closed: { e } "
406- )
407- else :
408- logger .warning (
409- f"Attempt to close session raised an exception at the server: { e } "
410- )
348+ self .session .close ()
411349 except Exception as e :
412- logger .error (f"Attempt to close session raised a local exception: { e } " )
413-
414- self .open = False
350+ logger .error (f"Attempt to close session raised an exception: { e } " )
415351
416352 def commit (self ):
417353 """No-op because Databricks does not support transactions"""
@@ -811,7 +747,7 @@ def execute(
811747 self ._close_and_clear_active_result_set ()
812748 execute_response = self .thrift_backend .execute_command (
813749 operation = prepared_operation ,
814- session_handle = self .connection ._session_handle ,
750+ session_handle = self .connection .session . _session_handle ,
815751 max_rows = self .arraysize ,
816752 max_bytes = self .buffer_size_bytes ,
817753 lz4_compression = self .connection .lz4_compression ,
@@ -874,7 +810,7 @@ def execute_async(
874810 self ._close_and_clear_active_result_set ()
875811 self .thrift_backend .execute_command (
876812 operation = prepared_operation ,
877- session_handle = self .connection ._session_handle ,
813+ session_handle = self .connection .session . _session_handle ,
878814 max_rows = self .arraysize ,
879815 max_bytes = self .buffer_size_bytes ,
880816 lz4_compression = self .connection .lz4_compression ,
@@ -970,7 +906,7 @@ def catalogs(self) -> "Cursor":
970906 self ._check_not_closed ()
971907 self ._close_and_clear_active_result_set ()
972908 execute_response = self .thrift_backend .get_catalogs (
973- session_handle = self .connection ._session_handle ,
909+ session_handle = self .connection .session . _session_handle ,
974910 max_rows = self .arraysize ,
975911 max_bytes = self .buffer_size_bytes ,
976912 cursor = self ,
@@ -996,7 +932,7 @@ def schemas(
996932 self ._check_not_closed ()
997933 self ._close_and_clear_active_result_set ()
998934 execute_response = self .thrift_backend .get_schemas (
999- session_handle = self .connection ._session_handle ,
935+ session_handle = self .connection .session . _session_handle ,
1000936 max_rows = self .arraysize ,
1001937 max_bytes = self .buffer_size_bytes ,
1002938 cursor = self ,
@@ -1029,7 +965,7 @@ def tables(
1029965 self ._close_and_clear_active_result_set ()
1030966
1031967 execute_response = self .thrift_backend .get_tables (
1032- session_handle = self .connection ._session_handle ,
968+ session_handle = self .connection .session . _session_handle ,
1033969 max_rows = self .arraysize ,
1034970 max_bytes = self .buffer_size_bytes ,
1035971 cursor = self ,
@@ -1064,7 +1000,7 @@ def columns(
10641000 self ._close_and_clear_active_result_set ()
10651001
10661002 execute_response = self .thrift_backend .get_columns (
1067- session_handle = self .connection ._session_handle ,
1003+ session_handle = self .connection .session . _session_handle ,
10681004 max_rows = self .arraysize ,
10691005 max_bytes = self .buffer_size_bytes ,
10701006 cursor = self ,
@@ -1493,7 +1429,7 @@ def close(self) -> None:
14931429 if (
14941430 self .op_state != self .thrift_backend .CLOSED_OP_STATE
14951431 and not self .has_been_closed_server_side
1496- and self .connection .open
1432+ and self .connection .session . open
14971433 ):
14981434 self .thrift_backend .close_command (self .command_id )
14991435 except RequestError as e :
0 commit comments