Skip to content

Commit 3bfbac1

Browse files
committed
feat: Implement host-level telemetry batching to reduce rate limiting
Changes telemetry client architecture from per-session to per-host batching, matching the JDBC driver implementation. This reduces the number of HTTP requests to the telemetry endpoint and prevents rate limiting in test environments. Key changes: - Add _TelemetryClientHolder with reference counting for shared clients - Change TelemetryClientFactory to key clients by host_url instead of session_id - Add getHostUrlSafely() helper for defensive null handling - Update all callers (client.py, exc.py, latency_logger.py) to pass host_url Before: 100 connections to same host = 100 separate TelemetryClients After: 100 connections to same host = 1 shared TelemetryClient (refcount=100) This fixes rate limiting issues seen in e2e tests where 300+ parallel connections were overwhelming the telemetry endpoint with 429 errors.
1 parent d524f0e commit 3bfbac1

File tree

4 files changed

+126
-33
lines changed

4 files changed

+126
-33
lines changed

src/databricks/sql/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def read(self) -> Optional[OAuthToken]:
341341
)
342342

343343
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(
344-
session_id_hex=self.get_session_id_hex()
344+
host_url=self.session.host
345345
)
346346

347347
# Determine proxy usage
@@ -521,7 +521,7 @@ def _close(self, close_cursors=True) -> None:
521521
except Exception as e:
522522
logger.error(f"Attempt to close session raised a local exception: {e}")
523523

524-
TelemetryClientFactory.close(self.get_session_id_hex())
524+
TelemetryClientFactory.close(host_url=self.session.host)
525525

526526
# Close HTTP client that was created by this connection
527527
if self.http_client:

src/databricks/sql/exc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@ class Error(Exception):
1212
"""
1313

1414
def __init__(
15-
self, message=None, context=None, session_id_hex=None, *args, **kwargs
15+
self, message=None, context=None, host_url=None, *args, **kwargs
1616
):
1717
super().__init__(message, *args, **kwargs)
1818
self.message = message
1919
self.context = context or {}
2020

2121
error_name = self.__class__.__name__
22-
if session_id_hex:
22+
if host_url:
2323
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
2424

2525
telemetry_client = TelemetryClientFactory.get_telemetry_client(
26-
session_id_hex
26+
host_url=host_url
2727
)
2828
telemetry_client.export_failure_log(error_name, self.message)
2929

src/databricks/sql/telemetry/latency_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def wrapper(self, *args, **kwargs):
205205

206206
telemetry_client = (
207207
TelemetryClientFactory.get_telemetry_client(
208-
session_id_hex
208+
host_url=connection.session.host
209209
)
210210
)
211211
telemetry_client.export_latency_log(

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 120 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -409,15 +409,39 @@ def close(self):
409409
self._flush()
410410

411411

412+
class _TelemetryClientHolder:
413+
"""
414+
Holds a telemetry client with reference counting.
415+
Multiple connections to the same host share one client.
416+
"""
417+
418+
def __init__(self, client: BaseTelemetryClient):
419+
self.client = client
420+
self.refcount = 1
421+
422+
def increment(self):
423+
"""Increment reference count when a new connection uses this client"""
424+
self.refcount += 1
425+
426+
def decrement(self):
427+
"""Decrement reference count when a connection closes"""
428+
self.refcount -= 1
429+
return self.refcount
430+
431+
412432
class TelemetryClientFactory:
413433
"""
414434
Static factory class for creating and managing telemetry clients.
415435
It uses a thread pool to handle asynchronous operations and a single flush thread for all clients.
436+
437+
Clients are shared at the HOST level - multiple connections to the same host
438+
share a single TelemetryClient to enable efficient batching and reduce load
439+
on the telemetry endpoint.
416440
"""
417441

418442
_clients: Dict[
419-
str, BaseTelemetryClient
420-
] = {} # Map of session_id_hex -> BaseTelemetryClient
443+
str, _TelemetryClientHolder
444+
] = {} # Map of host_url -> TelemetryClientHolder
421445
_executor: Optional[ThreadPoolExecutor] = None
422446
_initialized: bool = False
423447
_lock = threading.RLock() # Thread safety for factory operations
@@ -431,6 +455,22 @@ class TelemetryClientFactory:
431455
_flush_interval_seconds = 300 # 5 minutes
432456

433457
DEFAULT_BATCH_SIZE = 100
458+
UNKNOWN_HOST = "unknown-host"
459+
460+
@staticmethod
461+
def getHostUrlSafely(host_url):
462+
"""
463+
Safely get host URL with fallback to UNKNOWN_HOST.
464+
465+
Args:
466+
host_url: The host URL to validate
467+
468+
Returns:
469+
The host_url if valid, otherwise UNKNOWN_HOST
470+
"""
471+
if not host_url or not isinstance(host_url, str) or not host_url.strip():
472+
return TelemetryClientFactory.UNKNOWN_HOST
473+
return host_url
434474

435475
@classmethod
436476
def _initialize(cls):
@@ -506,21 +546,38 @@ def initialize_telemetry_client(
506546
batch_size,
507547
client_context,
508548
):
509-
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
549+
"""
550+
Initialize a telemetry client for a specific connection if telemetry is enabled.
551+
552+
Clients are shared at the HOST level - multiple connections to the same host
553+
will share a single TelemetryClient with reference counting.
554+
"""
510555
try:
556+
# Safely get host_url with fallback to UNKNOWN_HOST
557+
host_url = TelemetryClientFactory.getHostUrlSafely(host_url)
511558

512559
with TelemetryClientFactory._lock:
513560
TelemetryClientFactory._initialize()
514561

515-
if session_id_hex not in TelemetryClientFactory._clients:
562+
if host_url in TelemetryClientFactory._clients:
563+
# Reuse existing client for this host
564+
holder = TelemetryClientFactory._clients[host_url]
565+
holder.increment()
566+
logger.debug(
567+
"Reusing TelemetryClient for host %s (session %s, refcount=%d)",
568+
host_url,
569+
session_id_hex,
570+
holder.refcount,
571+
)
572+
else:
573+
# Create new client for this host
516574
logger.debug(
517-
"Creating new TelemetryClient for connection %s",
575+
"Creating new TelemetryClient for host %s (session %s)",
576+
host_url,
518577
session_id_hex,
519578
)
520579
if telemetry_enabled:
521-
TelemetryClientFactory._clients[
522-
session_id_hex
523-
] = TelemetryClient(
580+
client = TelemetryClient(
524581
telemetry_enabled=telemetry_enabled,
525582
session_id_hex=session_id_hex,
526583
auth_provider=auth_provider,
@@ -529,36 +586,72 @@ def initialize_telemetry_client(
529586
batch_size=batch_size,
530587
client_context=client_context,
531588
)
589+
TelemetryClientFactory._clients[host_url] = _TelemetryClientHolder(
590+
client
591+
)
532592
else:
533-
TelemetryClientFactory._clients[
534-
session_id_hex
535-
] = NoopTelemetryClient()
593+
TelemetryClientFactory._clients[host_url] = _TelemetryClientHolder(
594+
NoopTelemetryClient()
595+
)
536596
except Exception as e:
537597
logger.debug("Failed to initialize telemetry client: %s", e)
538598
# Fallback to NoopTelemetryClient to ensure connection doesn't fail
539-
TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient()
599+
TelemetryClientFactory._clients[host_url] = _TelemetryClientHolder(
600+
NoopTelemetryClient()
601+
)
540602

541603
@staticmethod
542-
def get_telemetry_client(session_id_hex):
543-
"""Get the telemetry client for a specific connection"""
544-
return TelemetryClientFactory._clients.get(
545-
session_id_hex, NoopTelemetryClient()
546-
)
604+
def get_telemetry_client(host_url):
605+
"""
606+
Get the shared telemetry client for a specific host.
607+
608+
Args:
609+
host_url: The host URL to look up the client. If None/empty, uses UNKNOWN_HOST.
610+
611+
Returns:
612+
The shared TelemetryClient for this host, or NoopTelemetryClient if not found
613+
"""
614+
host_url = TelemetryClientFactory.getHostUrlSafely(host_url)
615+
616+
if host_url in TelemetryClientFactory._clients:
617+
return TelemetryClientFactory._clients[host_url].client
618+
return NoopTelemetryClient()
547619

548620
@staticmethod
549-
def close(session_id_hex):
550-
"""Close and remove the telemetry client for a specific connection"""
621+
def close(host_url):
622+
"""
623+
Close the telemetry client for a specific host.
624+
625+
Decrements the reference count for the host's client. Only actually closes
626+
the client when the reference count reaches zero (all connections to this host closed).
627+
628+
Args:
629+
host_url: The host URL whose client to close. If None/empty, uses UNKNOWN_HOST.
630+
"""
631+
host_url = TelemetryClientFactory.getHostUrlSafely(host_url)
551632

552633
with TelemetryClientFactory._lock:
553-
if (
554-
telemetry_client := TelemetryClientFactory._clients.pop(
555-
session_id_hex, None
556-
)
557-
) is not None:
634+
# Get the holder for this host
635+
holder = TelemetryClientFactory._clients.get(host_url)
636+
if holder is None:
637+
logger.debug("No telemetry client found for host %s", host_url)
638+
return
639+
640+
# Decrement refcount
641+
remaining_refs = holder.decrement()
642+
logger.debug(
643+
"Decremented refcount for host %s (refcount=%d)",
644+
host_url,
645+
remaining_refs,
646+
)
647+
648+
# Only close if no more references
649+
if remaining_refs <= 0:
558650
logger.debug(
559-
"Removing telemetry client for connection %s", session_id_hex
651+
"Closing telemetry client for host %s (no more references)", host_url
560652
)
561-
telemetry_client.close()
653+
TelemetryClientFactory._clients.pop(host_url, None)
654+
holder.client.close()
562655

563656
# Shutdown executor if no more clients
564657
if not TelemetryClientFactory._clients and TelemetryClientFactory._executor:
@@ -597,7 +690,7 @@ def connection_failure_log(
597690
)
598691

599692
telemetry_client = TelemetryClientFactory.get_telemetry_client(
600-
UNAUTH_DUMMY_SESSION_ID
693+
host_url=host_url
601694
)
602695
telemetry_client._driver_connection_params = DriverConnectionParameters(
603696
http_path=http_path,

0 commit comments

Comments
 (0)