1+ import random
12import threading
3+ import time
24from unittest .mock import patch
35import pytest
46
@@ -47,35 +49,59 @@ def test_concurrent_queries_sends_telemetry(self):
4749 An E2E test where concurrent threads execute real queries against
4850 the staging endpoint, while we capture and verify the generated telemetry.
4951 """
50- num_threads = 5
52+ num_threads = 30
53+ capture_lock = threading .Lock ()
5154 captured_telemetry = []
52- captured_telemetry_lock = threading .Lock ()
5355 captured_session_ids = []
5456 captured_statement_ids = []
55- capture_info_lock = threading .Lock ()
57+ captured_responses = []
58+ captured_exceptions = []
5659
5760 original_send_telemetry = TelemetryClient ._send_telemetry
61+ original_callback = TelemetryClient ._telemetry_request_callback
5862
5963 def send_telemetry_wrapper (self_client , events ):
60- with captured_telemetry_lock :
64+ with capture_lock :
6165 captured_telemetry .extend (events )
6266 original_send_telemetry (self_client , events )
6367
64- with patch .object (TelemetryClient , "_send_telemetry" , send_telemetry_wrapper ):
68+ def callback_wrapper (self_client , future , sent_count ):
69+ """
70+ Wraps the original callback to capture the server's response
71+ or any exceptions from the async network call.
72+ """
73+ try :
74+ original_callback (self_client , future , sent_count )
75+
76+ # Now, capture the result for our assertions
77+ response = future .result ()
78+ response .raise_for_status () # Raise an exception for 4xx/5xx errors
79+ telemetry_response = response .json ()
80+ with capture_lock :
81+ captured_responses .append (telemetry_response )
82+ except Exception as e :
83+ with capture_lock :
84+ captured_exceptions .append (e )
85+
86+ with patch .object (TelemetryClient , "_send_telemetry" , send_telemetry_wrapper ), \
87+ patch .object (TelemetryClient , "_telemetry_request_callback" , callback_wrapper ):
6588
6689 def execute_query_worker (thread_id ):
6790 """Each thread creates a connection and executes a query."""
91+
92+ time .sleep (random .uniform (0 , 0.05 ))
93+
6894 with self .connection (extra_params = {"enable_telemetry" : True }) as conn :
6995 # Capture the session ID from the connection before executing the query
7096 session_id_hex = conn .get_session_id_hex ()
71- with capture_info_lock :
97+ with capture_lock :
7298 captured_session_ids .append (session_id_hex )
7399
74100 with conn .cursor () as cursor :
75101 cursor .execute (f"SELECT { thread_id } " )
76102 # Capture the statement ID after executing the query
77103 statement_id = cursor .query_id
78- with capture_info_lock :
104+ with capture_lock :
79105 captured_statement_ids .append (statement_id )
80106 cursor .fetchall ()
81107
@@ -86,6 +112,16 @@ def execute_query_worker(thread_id):
86112 TelemetryClientFactory ._executor .shutdown (wait = True )
87113
88114 # --- VERIFICATION ---
115+ assert not captured_exceptions
116+ assert len (captured_responses ) > 0
117+
118+ total_successful_events = 0
119+ for response in captured_responses :
120+ assert "errors" not in response or not response ["errors" ]
121+ if "numProtoSuccess" in response :
122+ total_successful_events += response ["numProtoSuccess" ]
123+ assert total_successful_events == num_threads * 2
124+
89125 assert len (captured_telemetry ) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
90126 assert len (captured_session_ids ) == num_threads # One session ID per thread
91127 assert len (captured_statement_ids ) == num_threads # One statement ID per thread (per query)
0 commit comments