Skip to content

Commit 583d4d1

Browse files
committed
minor changes, added checks on server response
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent f6281fe commit 583d4d1

File tree

1 file changed

+43
-7
lines changed

1 file changed

+43
-7
lines changed

tests/e2e/test_concurrent_telemetry.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import random
12
import threading
3+
import time
24
from unittest.mock import patch
35
import pytest
46

@@ -47,35 +49,59 @@ def test_concurrent_queries_sends_telemetry(self):
4749
An E2E test where concurrent threads execute real queries against
4850
the staging endpoint, while we capture and verify the generated telemetry.
4951
"""
50-
num_threads = 5
52+
num_threads = 30
53+
capture_lock = threading.Lock()
5154
captured_telemetry = []
52-
captured_telemetry_lock = threading.Lock()
5355
captured_session_ids = []
5456
captured_statement_ids = []
55-
capture_info_lock = threading.Lock()
57+
captured_responses = []
58+
captured_exceptions = []
5659

5760
original_send_telemetry = TelemetryClient._send_telemetry
61+
original_callback = TelemetryClient._telemetry_request_callback
5862

5963
def send_telemetry_wrapper(self_client, events):
60-
with captured_telemetry_lock:
64+
with capture_lock:
6165
captured_telemetry.extend(events)
6266
original_send_telemetry(self_client, events)
6367

64-
with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper):
68+
def callback_wrapper(self_client, future, sent_count):
69+
"""
70+
Wraps the original callback to capture the server's response
71+
or any exceptions from the async network call.
72+
"""
73+
try:
74+
original_callback(self_client, future, sent_count)
75+
76+
# Now, capture the result for our assertions
77+
response = future.result()
78+
response.raise_for_status() # Raise an exception for 4xx/5xx errors
79+
telemetry_response = response.json()
80+
with capture_lock:
81+
captured_responses.append(telemetry_response)
82+
except Exception as e:
83+
with capture_lock:
84+
captured_exceptions.append(e)
85+
86+
with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper), \
87+
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
6588

6689
def execute_query_worker(thread_id):
6790
"""Each thread creates a connection and executes a query."""
91+
92+
time.sleep(random.uniform(0, 0.05))
93+
6894
with self.connection(extra_params={"enable_telemetry": True}) as conn:
6995
# Capture the session ID from the connection before executing the query
7096
session_id_hex = conn.get_session_id_hex()
71-
with capture_info_lock:
97+
with capture_lock:
7298
captured_session_ids.append(session_id_hex)
7399

74100
with conn.cursor() as cursor:
75101
cursor.execute(f"SELECT {thread_id}")
76102
# Capture the statement ID after executing the query
77103
statement_id = cursor.query_id
78-
with capture_info_lock:
104+
with capture_lock:
79105
captured_statement_ids.append(statement_id)
80106
cursor.fetchall()
81107

@@ -86,6 +112,16 @@ def execute_query_worker(thread_id):
86112
TelemetryClientFactory._executor.shutdown(wait=True)
87113

88114
# --- VERIFICATION ---
115+
assert not captured_exceptions
116+
assert len(captured_responses) > 0
117+
118+
total_successful_events = 0
119+
for response in captured_responses:
120+
assert "errors" not in response or not response["errors"]
121+
if "numProtoSuccess" in response:
122+
total_successful_events += response["numProtoSuccess"]
123+
assert total_successful_events == num_threads * 2
124+
89125
assert len(captured_telemetry) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
90126
assert len(captured_session_ids) == num_threads # One session ID per thread
91127
assert len(captured_statement_ids) == num_threads # One statement ID per thread (per query)

0 commit comments

Comments
 (0)