Skip to content

Commit facd588

Browse files
committed
shifted thread pool executor to telemetry manager
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 05b93fe commit facd588

File tree

2 files changed

+86
-135
lines changed

2 files changed

+86
-135
lines changed

src/databricks/sql/client.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
TSparkParameter,
5050
TOperationState,
5151
)
52-
from databricks.sql.telemetry.telemetry_client import telemetry_client
52+
from databricks.sql.telemetry.telemetry_client import telemetry_manager
5353

5454

5555
logger = logging.getLogger(__name__)
@@ -297,23 +297,29 @@ def read(self) -> Optional[OAuthToken]:
297297
self.use_inline_params = self._set_use_inline_params_with_warning(
298298
kwargs.get("use_inline_params", False)
299299
)
300+
301+
telemetry_kwargs = {
302+
"auth_provider": auth_provider,
303+
"is_authenticated": True, # TODO: Add authentication logic later
304+
"user_agent": useragent_header,
305+
"host_url": server_hostname
306+
}
307+
telemetry_manager.initialize_telemetry_client(
308+
telemetry_enabled=self.telemetry_enabled,
309+
batch_size=telemetry_batch_size,
310+
connection_uuid=self.get_session_id_hex(),
311+
**telemetry_kwargs
312+
)
300313

301-
if self.telemetry_enabled:
302-
telemetry_client.initialize(
303-
host=self.host,
304-
connection_uuid=self.get_session_id_hex(),
305-
batch_size=telemetry_batch_size,
306-
auth_provider=auth_provider,
307-
is_authenticated=True, # TODO: Add authentication logic later
308-
user_agent=useragent_header,
309-
)
310-
311-
telemetry_client.export_initial_telemetry_log(
312-
http_path,
313-
self.port,
314-
kwargs.get("_socket_timeout", None),
315-
self.get_session_id_hex(),
316-
)
314+
intial_telmetry_kwargs = {
315+
"http_path": http_path,
316+
"port": self.port,
317+
"socket_timeout": kwargs.get("_socket_timeout", None),
318+
}
319+
telemetry_manager.export_initial_telemetry_log(
320+
connection_uuid=self.get_session_id_hex(),
321+
**intial_telmetry_kwargs
322+
)
317323

318324
def _set_use_inline_params_with_warning(self, value: Union[bool, str]):
319325
"""Valid values are True, False, and "silent"
@@ -451,8 +457,7 @@ def _close(self, close_cursors=True) -> None:
451457

452458
self.open = False
453459

454-
if self.telemetry_enabled:
455-
telemetry_client.close(self.get_session_id_hex())
460+
telemetry_manager.close_telemetry_client(self.get_session_id_hex())
456461

457462
def commit(self):
458463
"""No-op because Databricks does not support transactions"""

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 62 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,20 @@
2525
class TelemetryClient:
2626
def __init__(
2727
self,
28-
host,
29-
connection_uuid,
28+
telemetry_enabled,
3029
batch_size,
31-
auth_provider=None,
32-
is_authenticated=False,
33-
user_agent=None,
30+
connection_uuid,
31+
**kwargs
3432
):
35-
self.host_url = host
36-
self.connection_uuid = connection_uuid
37-
self.auth_provider = auth_provider
38-
self.is_authenticated = is_authenticated
33+
self.telemetry_enabled = telemetry_enabled
3934
self.batch_size = batch_size
40-
self.user_agent = user_agent
35+
self.connection_uuid = connection_uuid
36+
self.host_url = kwargs.get("host_url", None)
37+
self.auth_provider = kwargs.get("auth_provider", None)
38+
self.is_authenticated = kwargs.get("is_authenticated", False)
39+
self.user_agent = kwargs.get("user_agent", None)
4140
self.events_batch = []
4241
self.lock = threading.Lock()
43-
self.executor = ThreadPoolExecutor(
44-
max_workers=10 # TODO: Decide on max workers
45-
) # Thread pool for async operations
4642
self.DriverConnectionParameters = None
4743

4844
def export_event(self, event):
@@ -59,58 +55,17 @@ def flush(self):
5955
self.events_batch = []
6056

6157
if events_to_flush:
62-
self.executor.submit(self._send_telemetry, events_to_flush)
63-
64-
def _send_telemetry(self, events):
65-
"""Send telemetry events to the server"""
66-
request = {
67-
"uploadTime": int(time.time() * 1000),
68-
"items": [],
69-
"protoLogs": [event.to_json() for event in events],
70-
}
71-
72-
path = "/telemetry-ext" if self.is_authenticated else "/telemetry-unauth"
73-
url = f"https://{self.host_url}{path}"
74-
75-
headers = {"Accept": "application/json", "Content-Type": "application/json"}
76-
77-
if self.is_authenticated and self.auth_provider:
78-
self.auth_provider.add_headers(headers)
79-
80-
# print("\n=== Request Details ===", flush=True)
81-
# print(f"URL: {url}", flush=True)
82-
# print("\nHeaders:", flush=True)
83-
# for key, value in headers.items():
84-
# print(f" {key}: {value}", flush=True)
85-
86-
# print("\nRequest Body:", flush=True)
87-
# print(json.dumps(request, indent=2), flush=True)
88-
# sys.stdout.flush()
89-
90-
response = requests.post(
91-
url, data=json.dumps(request), headers=headers, timeout=10
92-
)
93-
94-
# print("\n=== Response Details ===", flush=True)
95-
# print(f"Status Code: {response.status_code}", flush=True)
96-
# print("\nResponse Headers:", flush=True)
97-
# for key, value in response.headers.items():
98-
# print(f" {key}: {value}", flush=True)
99-
100-
# print("\nResponse Body:", flush=True)
101-
# try:
102-
# response_json = response.json()
103-
# print(json.dumps(response_json, indent=2), flush=True)
104-
# except json.JSONDecodeError:
105-
# print(response.text, flush=True)
106-
# sys.stdout.flush()
58+
telemetry_manager._send_telemetry(events_to_flush, self.host_url, self.is_authenticated, self.auth_provider)
10759

10860
def close(self):
109-
"""Flush remaining events and shut down executor"""
61+
"""Flush remaining events before closing"""
11062
self.flush()
111-
self.executor.shutdown(wait=True)
11263

113-
def export_initial_telemetry_log(self, http_path, port, socket_timeout):
64+
def export_initial_telemetry_log(self, **kwargs):
65+
http_path = kwargs.get("http_path", None)
66+
port = kwargs.get("port", None)
67+
socket_timeout = kwargs.get("socket_timeout", None)
68+
11469
discovery_url = None
11570
if hasattr(self.auth_provider, "oauth_manager") and hasattr(
11671
self.auth_provider.oauth_manager, "idp_endpoint"
@@ -147,19 +102,6 @@ def export_initial_telemetry_log(self, http_path, port, socket_timeout):
147102

148103
self.export_event(telemetry_frontend_log)
149104

150-
def export_failure_log(self, errorName, errorMessage):
151-
pass
152-
153-
def export_sql_latency_log(
154-
self, latency_ms, sql_execution_event, sql_statement_id=None
155-
):
156-
"""Export telemetry for sql execution"""
157-
pass
158-
159-
def export_volume_latency_log(self, latency_ms, volume_operation):
160-
"""Export telemetry for volume operation"""
161-
pass
162-
163105

164106
class TelemetryManager:
165107
"""A singleton manager class that handles telemetry operations for SQL connections.
@@ -189,56 +131,56 @@ def __init__(self):
189131
return
190132

191133
self._clients = {} # Map of connection_uuid -> TelemetryClient
134+
self.executor = ThreadPoolExecutor(max_workers=10) # Thread pool for async operations TODO: Decide on max workers
192135
self._initialized = True
193136

194-
def initialize(
137+
def initialize_telemetry_client(
195138
self,
196-
host,
197-
connection_uuid,
139+
telemetry_enabled,
198140
batch_size,
199-
auth_provider=None,
200-
is_authenticated=False,
201-
user_agent=None,
141+
connection_uuid,
142+
**kwargs
202143
):
203-
"""Initialize a telemetry client for a specific connection"""
204-
if connection_uuid not in self._clients:
205-
self._clients[connection_uuid] = TelemetryClient(
206-
host=host,
207-
connection_uuid=connection_uuid,
208-
batch_size=batch_size,
209-
auth_provider=auth_provider,
210-
is_authenticated=is_authenticated,
211-
user_agent=user_agent,
212-
)
144+
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
145+
if telemetry_enabled:
146+
if connection_uuid not in self._clients:
147+
self._clients[connection_uuid] = TelemetryClient(
148+
telemetry_enabled=telemetry_enabled,
149+
batch_size=batch_size,
150+
connection_uuid=connection_uuid,
151+
**kwargs
152+
)
153+
154+
def _send_telemetry(self, events, host_url, is_authenticated, auth_provider):
155+
"""Send telemetry events to the server"""
156+
request = {
157+
"uploadTime": int(time.time() * 1000),
158+
"items": [],
159+
"protoLogs": [event.to_json() for event in events],
160+
}
161+
162+
path = "/telemetry-ext" if is_authenticated else "/telemetry-unauth"
163+
url = f"https://{host_url}{path}"
164+
165+
headers = {"Accept": "application/json", "Content-Type": "application/json"}
166+
167+
if is_authenticated and auth_provider:
168+
auth_provider.add_headers(headers)
213169

214-
def export_failure_log(self, error_name, error_message, connection_uuid):
215-
"""Export error logs for a specific connection or all connections if connection_uuid is None"""
216-
pass
170+
self.executor.submit(
171+
requests.post,
172+
url,
173+
data=json.dumps(request),
174+
headers=headers,
175+
timeout=10
176+
)
217177

218178
def export_initial_telemetry_log(
219-
self, http_path, port, socket_timeout, connection_uuid
179+
self, connection_uuid, **kwargs
220180
):
221181
"""Export initial telemetry for a specific connection"""
222182
if connection_uuid in self._clients:
223-
self._clients[connection_uuid].export_initial_telemetry_log(
224-
http_path, port, socket_timeout
225-
)
226-
227-
def export_sql_latency_log(
228-
self,
229-
latency_ms,
230-
sql_execution_event,
231-
sql_statement_id=None,
232-
connection_uuid=None,
233-
):
234-
"""Export latency logs for sql execution for a specific connection"""
235-
pass
236-
237-
def export_volume_latency_log(
238-
self, latency_ms, volume_operation, connection_uuid=None
239-
):
240-
"""Export latency logs for volume operation for a specific connection"""
241-
pass
183+
self._clients[connection_uuid].export_initial_telemetry_log(**kwargs)
242184

243185
@classmethod
244186
def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration:
@@ -260,13 +202,17 @@ def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration:
260202
)
261203
return cls._DRIVER_SYSTEM_CONFIGURATION
262204

263-
def close(self, connection_uuid):
264-
"""Close telemetry client(s)"""
205+
def close_telemetry_client(self, connection_uuid):
206+
"""Close telemetry client"""
265207
if connection_uuid:
266208
if connection_uuid in self._clients:
267209
self._clients[connection_uuid].close()
268210
del self._clients[connection_uuid]
211+
212+
# Shutdown executor if no more clients
213+
if not self._clients:
214+
self.executor.shutdown(wait=True)
269215

270216

271217
# Create a global instance
272-
telemetry_client = TelemetryManager()
218+
telemetry_manager = TelemetryManager()

0 commit comments

Comments
 (0)