7272 "_retry_delay_default" : (float , 5 , 1 , 60 ),
7373}
7474
75+ # Add thread local storage
76+ _connection_uuid = threading .local ()
77+
7578
7679class ThriftBackend :
7780 CLOSED_OP_STATE = ttypes .TOperationState .CLOSED_STATE
@@ -223,7 +226,7 @@ def __init__(
223226 raise
224227
225228 self ._request_lock = threading .RLock ()
226- self . _connection_uuid = None
229+ _connection_uuid . value = None
227230
228231 # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918)
229232 def _initialize_retry_args (self , kwargs ):
@@ -256,13 +259,14 @@ def _initialize_retry_args(self, kwargs):
256259 )
257260
258261 @staticmethod
259- def _check_response_for_error (response , connection_uuid = None ):
262+ def _check_response_for_error (response ):
260263 if response .status and response .status .statusCode in [
261264 ttypes .TStatusCode .ERROR_STATUS ,
262265 ttypes .TStatusCode .INVALID_HANDLE_STATUS ,
263266 ]:
264267 raise DatabaseError (
265- response .status .errorMessage , connection_uuid = connection_uuid
268+ response .status .errorMessage ,
269+ connection_uuid = getattr (_connection_uuid , "value" , None ),
266270 )
267271
268272 @staticmethod
@@ -316,7 +320,7 @@ def _handle_request_error(self, error_info, attempt, elapsed):
316320 network_request_error = RequestError (
317321 user_friendly_error_message ,
318322 full_error_info_context ,
319- self . _connection_uuid ,
323+ getattr ( _connection_uuid , "value" , None ) ,
320324 error_info .error ,
321325 )
322326 logger .info (network_request_error .message_with_context ())
@@ -489,7 +493,7 @@ def attempt_request(attempt):
489493 if not isinstance (response_or_error_info , RequestErrorInfo ):
490494 # log nothing here, presume that main request logging covers
491495 response = response_or_error_info
492- ThriftBackend ._check_response_for_error (response , self . _connection_uuid )
496+ ThriftBackend ._check_response_for_error (response )
493497 return response
494498
495499 error_info = response_or_error_info
@@ -504,7 +508,7 @@ def _check_protocol_version(self, t_open_session_resp):
504508 "Error: expected server to use a protocol version >= "
505509 "SPARK_CLI_SERVICE_PROTOCOL_V2, "
506510 "instead got: {}" .format (protocol_version ),
507- connection_uuid = self . _connection_uuid ,
511+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
508512 )
509513
510514 def _check_initial_namespace (self , catalog , schema , response ):
@@ -518,15 +522,15 @@ def _check_initial_namespace(self, catalog, schema, response):
518522 raise InvalidServerResponseError (
519523 "Setting initial namespace not supported by the DBR version, "
520524 "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0." ,
521- connection_uuid = self . _connection_uuid ,
525+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
522526 )
523527
524528 if catalog :
525529 if not response .canUseMultipleCatalogs :
526530 raise InvalidServerResponseError (
527531 "Unexpected response from server: Trying to set initial catalog to {}, "
528532 + "but server does not support multiple catalogs." .format (catalog ), # type: ignore
529- connection_uuid = self . _connection_uuid ,
533+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
530534 )
531535
532536 def _check_session_configuration (self , session_configuration ):
@@ -541,7 +545,7 @@ def _check_session_configuration(self, session_configuration):
541545 TIMESTAMP_AS_STRING_CONFIG ,
542546 session_configuration [TIMESTAMP_AS_STRING_CONFIG ],
543547 ),
544- connection_uuid = self . _connection_uuid ,
548+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
545549 )
546550
547551 def open_session (self , session_configuration , catalog , schema ):
@@ -572,7 +576,7 @@ def open_session(self, session_configuration, catalog, schema):
572576 response = self .make_request (self ._client .OpenSession , open_session_req )
573577 self ._check_initial_namespace (catalog , schema , response )
574578 self ._check_protocol_version (response )
575- self . _connection_uuid = (
579+ _connection_uuid . value = (
576580 self .handle_to_hex_id (response .sessionHandle )
577581 if response .sessionHandle
578582 else None
@@ -601,7 +605,7 @@ def _check_command_not_in_error_or_closed_state(
601605 and self .guid_to_hex_id (op_handle .operationId .guid ),
602606 "diagnostic-info" : get_operations_resp .diagnosticInfo ,
603607 },
604- connection_uuid = self . _connection_uuid ,
608+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
605609 )
606610 else :
607611 raise ServerOperationError (
@@ -611,7 +615,7 @@ def _check_command_not_in_error_or_closed_state(
611615 and self .guid_to_hex_id (op_handle .operationId .guid ),
612616 "diagnostic-info" : None ,
613617 },
614- connection_uuid = self . _connection_uuid ,
618+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
615619 )
616620 elif get_operations_resp .operationState == ttypes .TOperationState .CLOSED_STATE :
617621 raise DatabaseError (
@@ -622,7 +626,7 @@ def _check_command_not_in_error_or_closed_state(
622626 "operation-id" : op_handle
623627 and self .guid_to_hex_id (op_handle .operationId .guid )
624628 },
625- connection_uuid = self . _connection_uuid ,
629+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
626630 )
627631
628632 def _poll_for_status (self , op_handle ):
@@ -645,7 +649,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
645649 else :
646650 raise OperationalError (
647651 "Unsupported TRowSet instance {}" .format (t_row_set ),
648- connection_uuid = self . _connection_uuid ,
652+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
649653 )
650654 return convert_decimals_in_arrow_table (arrow_table , description ), num_rows
651655
@@ -654,7 +658,7 @@ def _get_metadata_resp(self, op_handle):
654658 return self .make_request (self ._client .GetResultSetMetadata , req )
655659
656660 @staticmethod
657- def _hive_schema_to_arrow_schema (t_table_schema , connection_uuid = None ):
661+ def _hive_schema_to_arrow_schema (t_table_schema ):
658662 def map_type (t_type_entry ):
659663 if t_type_entry .primitiveEntry :
660664 return {
@@ -686,7 +690,7 @@ def map_type(t_type_entry):
686690 # even for complex types
687691 raise OperationalError (
688692 "Thrift protocol error: t_type_entry not a primitiveEntry" ,
689- connection_uuid = connection_uuid ,
693+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
690694 )
691695
692696 def convert_col (t_column_desc ):
@@ -697,7 +701,7 @@ def convert_col(t_column_desc):
697701 return pyarrow .schema ([convert_col (col ) for col in t_table_schema .columns ])
698702
699703 @staticmethod
700- def _col_to_description (col , connection_uuid = None ):
704+ def _col_to_description (col ):
701705 type_entry = col .typeDesc .types [0 ]
702706
703707 if type_entry .primitiveEntry :
@@ -707,7 +711,7 @@ def _col_to_description(col, connection_uuid=None):
707711 else :
708712 raise OperationalError (
709713 "Thrift protocol error: t_type_entry not a primitiveEntry" ,
710- connection_uuid = connection_uuid ,
714+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
711715 )
712716
713717 if type_entry .primitiveEntry .type == ttypes .TTypeId .DECIMAL_TYPE :
@@ -721,18 +725,17 @@ def _col_to_description(col, connection_uuid=None):
721725 raise OperationalError (
722726 "Decimal type did not provide typeQualifier precision, scale in "
723727 "primitiveEntry {}" .format (type_entry .primitiveEntry ),
724- connection_uuid = connection_uuid ,
728+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
725729 )
726730 else :
727731 precision , scale = None , None
728732
729733 return col .columnName , cleaned_type , None , None , precision , scale , None
730734
731735 @staticmethod
732- def _hive_schema_to_description (t_table_schema , connection_uuid = None ):
736+ def _hive_schema_to_description (t_table_schema ):
733737 return [
734- ThriftBackend ._col_to_description (col , connection_uuid )
735- for col in t_table_schema .columns
738+ ThriftBackend ._col_to_description (col ) for col in t_table_schema .columns
736739 ]
737740
738741 def _results_message_to_execute_response (self , resp , operation_state ):
@@ -753,7 +756,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
753756 t_result_set_metadata_resp .resultFormat
754757 ]
755758 ),
756- connection_uuid = self . _connection_uuid ,
759+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
757760 )
758761 direct_results = resp .directResults
759762 has_been_closed_server_side = direct_results and direct_results .closeOperation
@@ -763,15 +766,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
763766 or direct_results .resultSet .hasMoreRows
764767 )
765768 description = self ._hive_schema_to_description (
766- t_result_set_metadata_resp .schema , self . _connection_uuid
769+ t_result_set_metadata_resp .schema
767770 )
768771
769772 if pyarrow :
770773 schema_bytes = (
771774 t_result_set_metadata_resp .arrowSchema
772- or self ._hive_schema_to_arrow_schema (
773- t_result_set_metadata_resp .schema , self ._connection_uuid
774- )
775+ or self ._hive_schema_to_arrow_schema (t_result_set_metadata_resp .schema )
775776 .serialize ()
776777 .to_pybytes ()
777778 )
@@ -832,15 +833,13 @@ def get_execution_result(self, op_handle, cursor):
832833 is_staging_operation = t_result_set_metadata_resp .isStagingOperation
833834 has_more_rows = resp .hasMoreRows
834835 description = self ._hive_schema_to_description (
835- t_result_set_metadata_resp .schema , self . _connection_uuid
836+ t_result_set_metadata_resp .schema
836837 )
837838
838839 if pyarrow :
839840 schema_bytes = (
840841 t_result_set_metadata_resp .arrowSchema
841- or self ._hive_schema_to_arrow_schema (
842- t_result_set_metadata_resp .schema , self ._connection_uuid
843- )
842+ or self ._hive_schema_to_arrow_schema (t_result_set_metadata_resp .schema )
844843 .serialize ()
845844 .to_pybytes ()
846845 )
@@ -894,23 +893,23 @@ def get_query_state(self, op_handle) -> "TOperationState":
894893 return operation_state
895894
896895 @staticmethod
897- def _check_direct_results_for_error (t_spark_direct_results , connection_uuid = None ):
896+ def _check_direct_results_for_error (t_spark_direct_results ):
898897 if t_spark_direct_results :
899898 if t_spark_direct_results .operationStatus :
900899 ThriftBackend ._check_response_for_error (
901- t_spark_direct_results .operationStatus , connection_uuid
900+ t_spark_direct_results .operationStatus
902901 )
903902 if t_spark_direct_results .resultSetMetadata :
904903 ThriftBackend ._check_response_for_error (
905- t_spark_direct_results .resultSetMetadata , connection_uuid
904+ t_spark_direct_results .resultSetMetadata
906905 )
907906 if t_spark_direct_results .resultSet :
908907 ThriftBackend ._check_response_for_error (
909- t_spark_direct_results .resultSet , connection_uuid
908+ t_spark_direct_results .resultSet
910909 )
911910 if t_spark_direct_results .closeOperation :
912911 ThriftBackend ._check_response_for_error (
913- t_spark_direct_results .closeOperation , connection_uuid
912+ t_spark_direct_results .closeOperation
914913 )
915914
916915 def execute_command (
@@ -1059,7 +1058,7 @@ def get_columns(
10591058
10601059 def _handle_execute_response (self , resp , cursor ):
10611060 cursor .active_op_handle = resp .operationHandle
1062- self ._check_direct_results_for_error (resp .directResults , self . _connection_uuid )
1061+ self ._check_direct_results_for_error (resp .directResults )
10631062
10641063 final_operation_state = self ._wait_until_command_done (
10651064 resp .operationHandle ,
@@ -1070,7 +1069,7 @@ def _handle_execute_response(self, resp, cursor):
10701069
10711070 def _handle_execute_response_async (self , resp , cursor ):
10721071 cursor .active_op_handle = resp .operationHandle
1073- self ._check_direct_results_for_error (resp .directResults , self . _connection_uuid )
1072+ self ._check_direct_results_for_error (resp .directResults )
10741073
10751074 def fetch_results (
10761075 self ,
@@ -1105,7 +1104,7 @@ def fetch_results(
11051104 "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped" .format (
11061105 expected_row_start_offset , resp .results .startRowOffset
11071106 ),
1108- connection_uuid = self . _connection_uuid ,
1107+ connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
11091108 )
11101109
11111110 queue = ResultSetQueueFactory .build_queue (
0 commit comments