Skip to content

Commit f6bd9ad

Browse files
committed
added lock in TelemetryClientFactory and specified dict type in _client
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 906c187 commit f6bd9ad

File tree

1 file changed

+46
-35
lines changed

1 file changed

+46
-35
lines changed

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import requests
55
import logging
66
from concurrent.futures import ThreadPoolExecutor
7+
from typing import Dict
78
from databricks.sql.telemetry.models.event import (
89
TelemetryEvent,
910
DriverSystemConfiguration,
@@ -248,23 +249,27 @@ class TelemetryClientFactory:
248249
It uses a thread pool to handle asynchronous operations.
249250
"""
250251

251-
_clients = {} # Map of connection_uuid -> TelemetryClient
252-
_executor = None
253-
_initialized = False
252+
_clients: Dict[
253+
str, TelemetryClient
254+
] = {} # Map of connection_uuid -> TelemetryClient
255+
_executor: ThreadPoolExecutor = None
256+
_initialized: bool = False
257+
_lock = threading.Lock() # Thread safety for factory operations
254258

255259
@classmethod
256260
def _initialize(cls):
257261
"""Initialize the factory if not already initialized"""
258-
if not cls._initialized:
259-
logger.info("Initializing TelemetryClientFactory")
260-
cls._clients = {}
261-
cls._executor = ThreadPoolExecutor(
262-
max_workers=10
263-
) # Thread pool for async operations TODO: Decide on max workers
264-
cls._initialized = True
265-
logger.debug(
266-
"TelemetryClientFactory initialized with thread pool (max_workers=10)"
267-
)
262+
with cls._lock:
263+
if not cls._initialized:
264+
logger.info("Initializing TelemetryClientFactory")
265+
cls._clients = {}
266+
cls._executor = ThreadPoolExecutor(
267+
max_workers=10
268+
) # Thread pool for async operations TODO: Decide on max workers
269+
cls._initialized = True
270+
logger.debug(
271+
"TelemetryClientFactory initialized with thread pool (max_workers=10)"
272+
)
268273

269274
@staticmethod
270275
def initialize_telemetry_client(
@@ -277,18 +282,19 @@ def initialize_telemetry_client(
277282
TelemetryClientFactory._initialize()
278283

279284
if telemetry_enabled:
280-
if connection_uuid not in TelemetryClientFactory._clients:
281-
logger.info(
282-
f"Creating new TelemetryClient for connection {connection_uuid}"
283-
)
284-
TelemetryClientFactory._clients[connection_uuid] = TelemetryClient(
285-
telemetry_enabled=telemetry_enabled,
286-
connection_uuid=connection_uuid,
287-
auth_provider=auth_provider,
288-
host_url=host_url,
289-
executor=TelemetryClientFactory._executor,
290-
)
291-
return TelemetryClientFactory._clients[connection_uuid]
285+
with TelemetryClientFactory._lock:
286+
if connection_uuid not in TelemetryClientFactory._clients:
287+
logger.info(
288+
f"Creating new TelemetryClient for connection {connection_uuid}"
289+
)
290+
TelemetryClientFactory._clients[connection_uuid] = TelemetryClient(
291+
telemetry_enabled=telemetry_enabled,
292+
connection_uuid=connection_uuid,
293+
auth_provider=auth_provider,
294+
host_url=host_url,
295+
executor=TelemetryClientFactory._executor,
296+
)
297+
return TelemetryClientFactory._clients[connection_uuid]
292298
else:
293299
return NoopTelemetryClient()
294300

@@ -304,13 +310,18 @@ def get_telemetry_client(connection_uuid):
304310
def close(connection_uuid):
305311
"""Close and remove the telemetry client for a specific connection"""
306312

307-
if connection_uuid in TelemetryClientFactory._clients:
308-
logger.debug(f"Removing telemetry client for connection {connection_uuid}")
309-
del TelemetryClientFactory._clients[connection_uuid]
310-
311-
# Shutdown executor if no more clients
312-
if not TelemetryClientFactory._clients and TelemetryClientFactory._executor:
313-
logger.info("No more telemetry clients, shutting down thread pool executor")
314-
TelemetryClientFactory._executor.shutdown(wait=True)
315-
TelemetryClientFactory._executor = None
316-
TelemetryClientFactory._initialized = False
313+
with TelemetryClientFactory._lock:
314+
if connection_uuid in TelemetryClientFactory._clients:
315+
logger.debug(
316+
f"Removing telemetry client for connection {connection_uuid}"
317+
)
318+
del TelemetryClientFactory._clients[connection_uuid]
319+
320+
# Shutdown executor if no more clients
321+
if not TelemetryClientFactory._clients and TelemetryClientFactory._executor:
322+
logger.info(
323+
"No more telemetry clients, shutting down thread pool executor"
324+
)
325+
TelemetryClientFactory._executor.shutdown(wait=True)
326+
TelemetryClientFactory._executor = None
327+
TelemetryClientFactory._initialized = False

0 commit comments

Comments
 (0)