Skip to content

Commit cb1d203

Browse files
committed
changed connection_uuid to thread local in thrift backend
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent e9d9ce4 commit cb1d203

File tree

1 file changed

+38
-39
lines changed

1 file changed

+38
-39
lines changed

src/databricks/sql/thrift_backend.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@
7272
"_retry_delay_default": (float, 5, 1, 60),
7373
}
7474

75+
# Add thread local storage
76+
_connection_uuid = threading.local()
77+
7578

7679
class ThriftBackend:
7780
CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE
@@ -223,7 +226,7 @@ def __init__(
223226
raise
224227

225228
self._request_lock = threading.RLock()
226-
self._connection_uuid = None
229+
_connection_uuid.value = None
227230

228231
# TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918)
229232
def _initialize_retry_args(self, kwargs):
@@ -256,13 +259,14 @@ def _initialize_retry_args(self, kwargs):
256259
)
257260

258261
@staticmethod
259-
def _check_response_for_error(response, connection_uuid=None):
262+
def _check_response_for_error(response):
260263
if response.status and response.status.statusCode in [
261264
ttypes.TStatusCode.ERROR_STATUS,
262265
ttypes.TStatusCode.INVALID_HANDLE_STATUS,
263266
]:
264267
raise DatabaseError(
265-
response.status.errorMessage, connection_uuid=connection_uuid
268+
response.status.errorMessage,
269+
connection_uuid=getattr(_connection_uuid, "value", None),
266270
)
267271

268272
@staticmethod
@@ -316,7 +320,7 @@ def _handle_request_error(self, error_info, attempt, elapsed):
316320
network_request_error = RequestError(
317321
user_friendly_error_message,
318322
full_error_info_context,
319-
self._connection_uuid,
323+
getattr(_connection_uuid, "value", None),
320324
error_info.error,
321325
)
322326
logger.info(network_request_error.message_with_context())
@@ -489,7 +493,7 @@ def attempt_request(attempt):
489493
if not isinstance(response_or_error_info, RequestErrorInfo):
490494
# log nothing here, presume that main request logging covers
491495
response = response_or_error_info
492-
ThriftBackend._check_response_for_error(response, self._connection_uuid)
496+
ThriftBackend._check_response_for_error(response)
493497
return response
494498

495499
error_info = response_or_error_info
@@ -504,7 +508,7 @@ def _check_protocol_version(self, t_open_session_resp):
504508
"Error: expected server to use a protocol version >= "
505509
"SPARK_CLI_SERVICE_PROTOCOL_V2, "
506510
"instead got: {}".format(protocol_version),
507-
connection_uuid=self._connection_uuid,
511+
connection_uuid=getattr(_connection_uuid, "value", None),
508512
)
509513

510514
def _check_initial_namespace(self, catalog, schema, response):
@@ -518,15 +522,15 @@ def _check_initial_namespace(self, catalog, schema, response):
518522
raise InvalidServerResponseError(
519523
"Setting initial namespace not supported by the DBR version, "
520524
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.",
521-
connection_uuid=self._connection_uuid,
525+
connection_uuid=getattr(_connection_uuid, "value", None),
522526
)
523527

524528
if catalog:
525529
if not response.canUseMultipleCatalogs:
526530
raise InvalidServerResponseError(
527531
"Unexpected response from server: Trying to set initial catalog to {}, "
528532
+ "but server does not support multiple catalogs.".format(catalog), # type: ignore
529-
connection_uuid=self._connection_uuid,
533+
connection_uuid=getattr(_connection_uuid, "value", None),
530534
)
531535

532536
def _check_session_configuration(self, session_configuration):
@@ -541,7 +545,7 @@ def _check_session_configuration(self, session_configuration):
541545
TIMESTAMP_AS_STRING_CONFIG,
542546
session_configuration[TIMESTAMP_AS_STRING_CONFIG],
543547
),
544-
connection_uuid=self._connection_uuid,
548+
connection_uuid=getattr(_connection_uuid, "value", None),
545549
)
546550

547551
def open_session(self, session_configuration, catalog, schema):
@@ -572,7 +576,7 @@ def open_session(self, session_configuration, catalog, schema):
572576
response = self.make_request(self._client.OpenSession, open_session_req)
573577
self._check_initial_namespace(catalog, schema, response)
574578
self._check_protocol_version(response)
575-
self._connection_uuid = (
579+
_connection_uuid.value = (
576580
self.handle_to_hex_id(response.sessionHandle)
577581
if response.sessionHandle
578582
else None
@@ -601,7 +605,7 @@ def _check_command_not_in_error_or_closed_state(
601605
and self.guid_to_hex_id(op_handle.operationId.guid),
602606
"diagnostic-info": get_operations_resp.diagnosticInfo,
603607
},
604-
connection_uuid=self._connection_uuid,
608+
connection_uuid=getattr(_connection_uuid, "value", None),
605609
)
606610
else:
607611
raise ServerOperationError(
@@ -611,7 +615,7 @@ def _check_command_not_in_error_or_closed_state(
611615
and self.guid_to_hex_id(op_handle.operationId.guid),
612616
"diagnostic-info": None,
613617
},
614-
connection_uuid=self._connection_uuid,
618+
connection_uuid=getattr(_connection_uuid, "value", None),
615619
)
616620
elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE:
617621
raise DatabaseError(
@@ -622,7 +626,7 @@ def _check_command_not_in_error_or_closed_state(
622626
"operation-id": op_handle
623627
and self.guid_to_hex_id(op_handle.operationId.guid)
624628
},
625-
connection_uuid=self._connection_uuid,
629+
connection_uuid=getattr(_connection_uuid, "value", None),
626630
)
627631

628632
def _poll_for_status(self, op_handle):
@@ -645,7 +649,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
645649
else:
646650
raise OperationalError(
647651
"Unsupported TRowSet instance {}".format(t_row_set),
648-
connection_uuid=self._connection_uuid,
652+
connection_uuid=getattr(_connection_uuid, "value", None),
649653
)
650654
return convert_decimals_in_arrow_table(arrow_table, description), num_rows
651655

@@ -654,7 +658,7 @@ def _get_metadata_resp(self, op_handle):
654658
return self.make_request(self._client.GetResultSetMetadata, req)
655659

656660
@staticmethod
657-
def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None):
661+
def _hive_schema_to_arrow_schema(t_table_schema):
658662
def map_type(t_type_entry):
659663
if t_type_entry.primitiveEntry:
660664
return {
@@ -686,7 +690,7 @@ def map_type(t_type_entry):
686690
# even for complex types
687691
raise OperationalError(
688692
"Thrift protocol error: t_type_entry not a primitiveEntry",
689-
connection_uuid=connection_uuid,
693+
connection_uuid=getattr(_connection_uuid, "value", None),
690694
)
691695

692696
def convert_col(t_column_desc):
@@ -697,7 +701,7 @@ def convert_col(t_column_desc):
697701
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])
698702

699703
@staticmethod
700-
def _col_to_description(col, connection_uuid=None):
704+
def _col_to_description(col):
701705
type_entry = col.typeDesc.types[0]
702706

703707
if type_entry.primitiveEntry:
@@ -707,7 +711,7 @@ def _col_to_description(col, connection_uuid=None):
707711
else:
708712
raise OperationalError(
709713
"Thrift protocol error: t_type_entry not a primitiveEntry",
710-
connection_uuid=connection_uuid,
714+
connection_uuid=getattr(_connection_uuid, "value", None),
711715
)
712716

713717
if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE:
@@ -721,18 +725,17 @@ def _col_to_description(col, connection_uuid=None):
721725
raise OperationalError(
722726
"Decimal type did not provide typeQualifier precision, scale in "
723727
"primitiveEntry {}".format(type_entry.primitiveEntry),
724-
connection_uuid=connection_uuid,
728+
connection_uuid=getattr(_connection_uuid, "value", None),
725729
)
726730
else:
727731
precision, scale = None, None
728732

729733
return col.columnName, cleaned_type, None, None, precision, scale, None
730734

731735
@staticmethod
732-
def _hive_schema_to_description(t_table_schema, connection_uuid=None):
736+
def _hive_schema_to_description(t_table_schema):
733737
return [
734-
ThriftBackend._col_to_description(col, connection_uuid)
735-
for col in t_table_schema.columns
738+
ThriftBackend._col_to_description(col) for col in t_table_schema.columns
736739
]
737740

738741
def _results_message_to_execute_response(self, resp, operation_state):
@@ -753,7 +756,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
753756
t_result_set_metadata_resp.resultFormat
754757
]
755758
),
756-
connection_uuid=self._connection_uuid,
759+
connection_uuid=getattr(_connection_uuid, "value", None),
757760
)
758761
direct_results = resp.directResults
759762
has_been_closed_server_side = direct_results and direct_results.closeOperation
@@ -763,15 +766,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
763766
or direct_results.resultSet.hasMoreRows
764767
)
765768
description = self._hive_schema_to_description(
766-
t_result_set_metadata_resp.schema, self._connection_uuid
769+
t_result_set_metadata_resp.schema
767770
)
768771

769772
if pyarrow:
770773
schema_bytes = (
771774
t_result_set_metadata_resp.arrowSchema
772-
or self._hive_schema_to_arrow_schema(
773-
t_result_set_metadata_resp.schema, self._connection_uuid
774-
)
775+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
775776
.serialize()
776777
.to_pybytes()
777778
)
@@ -832,15 +833,13 @@ def get_execution_result(self, op_handle, cursor):
832833
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
833834
has_more_rows = resp.hasMoreRows
834835
description = self._hive_schema_to_description(
835-
t_result_set_metadata_resp.schema, self._connection_uuid
836+
t_result_set_metadata_resp.schema
836837
)
837838

838839
if pyarrow:
839840
schema_bytes = (
840841
t_result_set_metadata_resp.arrowSchema
841-
or self._hive_schema_to_arrow_schema(
842-
t_result_set_metadata_resp.schema, self._connection_uuid
843-
)
842+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
844843
.serialize()
845844
.to_pybytes()
846845
)
@@ -894,23 +893,23 @@ def get_query_state(self, op_handle) -> "TOperationState":
894893
return operation_state
895894

896895
@staticmethod
897-
def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None):
896+
def _check_direct_results_for_error(t_spark_direct_results):
898897
if t_spark_direct_results:
899898
if t_spark_direct_results.operationStatus:
900899
ThriftBackend._check_response_for_error(
901-
t_spark_direct_results.operationStatus, connection_uuid
900+
t_spark_direct_results.operationStatus
902901
)
903902
if t_spark_direct_results.resultSetMetadata:
904903
ThriftBackend._check_response_for_error(
905-
t_spark_direct_results.resultSetMetadata, connection_uuid
904+
t_spark_direct_results.resultSetMetadata
906905
)
907906
if t_spark_direct_results.resultSet:
908907
ThriftBackend._check_response_for_error(
909-
t_spark_direct_results.resultSet, connection_uuid
908+
t_spark_direct_results.resultSet
910909
)
911910
if t_spark_direct_results.closeOperation:
912911
ThriftBackend._check_response_for_error(
913-
t_spark_direct_results.closeOperation, connection_uuid
912+
t_spark_direct_results.closeOperation
914913
)
915914

916915
def execute_command(
@@ -1059,7 +1058,7 @@ def get_columns(
10591058

10601059
def _handle_execute_response(self, resp, cursor):
10611060
cursor.active_op_handle = resp.operationHandle
1062-
self._check_direct_results_for_error(resp.directResults, self._connection_uuid)
1061+
self._check_direct_results_for_error(resp.directResults)
10631062

10641063
final_operation_state = self._wait_until_command_done(
10651064
resp.operationHandle,
@@ -1070,7 +1069,7 @@ def _handle_execute_response(self, resp, cursor):
10701069

10711070
def _handle_execute_response_async(self, resp, cursor):
10721071
cursor.active_op_handle = resp.operationHandle
1073-
self._check_direct_results_for_error(resp.directResults, self._connection_uuid)
1072+
self._check_direct_results_for_error(resp.directResults)
10741073

10751074
def fetch_results(
10761075
self,
@@ -1105,7 +1104,7 @@ def fetch_results(
11051104
"fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format(
11061105
expected_row_start_offset, resp.results.startRowOffset
11071106
),
1108-
connection_uuid=self._connection_uuid,
1107+
connection_uuid=getattr(_connection_uuid, "value", None),
11091108
)
11101109

11111110
queue = ResultSetQueueFactory.build_queue(

0 commit comments

Comments
 (0)