|
4 | 4 | import json |
5 | 5 | from concurrent.futures import ThreadPoolExecutor |
6 | 6 | from concurrent.futures import Future |
| 7 | +from concurrent.futures import wait |
7 | 8 | from datetime import datetime, timezone |
8 | 9 | from typing import List, Dict, Any, Optional, TYPE_CHECKING |
9 | 10 | from databricks.sql.telemetry.models.event import ( |
@@ -182,6 +183,7 @@ def __init__( |
182 | 183 | self._user_agent = None |
183 | 184 | self._events_batch = [] |
184 | 185 | self._lock = threading.RLock() |
| 186 | + self._pending_futures = set() |
185 | 187 | self._driver_connection_params = None |
186 | 188 | self._host_url = host_url |
187 | 189 | self._executor = executor |
@@ -245,6 +247,9 @@ def _send_telemetry(self, events): |
245 | 247 | timeout=900, |
246 | 248 | ) |
247 | 249 |
|
| 250 | + with self._lock: |
| 251 | + self._pending_futures.add(future) |
| 252 | + |
248 | 253 | future.add_done_callback( |
249 | 254 | lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) |
250 | 255 | ) |
@@ -303,6 +308,9 @@ def _telemetry_request_callback(self, future, sent_count: int): |
303 | 308 |
|
304 | 309 | except Exception as e: |
305 | 310 | logger.debug("Telemetry request failed with exception: %s", e) |
| 311 | + finally: |
| 312 | + with self._lock: |
| 313 | + self._pending_futures.discard(future) |
306 | 314 |
|
307 | 315 | def _export_telemetry_log(self, **telemetry_event_kwargs): |
308 | 316 | """ |
@@ -356,9 +364,20 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): |
356 | 364 | ) |
357 | 365 |
|
358 | 366 | def close(self): |
359 | | - """Flush remaining events before closing""" |
| 367 | + """Flush remaining events and wait for them to complete before closing""" |
360 | 368 | logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) |
361 | 369 | self._flush() |
| 370 | + |
| 371 | + with self._lock: |
| 372 | + futures_to_wait_on = list(self._pending_futures) |
| 373 | + |
| 374 | + if futures_to_wait_on: |
| 375 | + logger.debug( |
| 376 | + "Waiting for %s pending telemetry requests to complete.", |
| 377 | + len(futures_to_wait_on), |
| 378 | + ) |
| 379 | + wait(futures_to_wait_on) |
| 380 | + |
362 | 381 | self._http_client.close() |
363 | 382 |
|
364 | 383 |
|
|
0 commit comments