1- # tests/e2e/test_telemetry_retry.py
2-
31import pytest
4- import logging
52from unittest .mock import patch , MagicMock
6- from functools import wraps
7- import time
8- from concurrent .futures import Future
3+ import io
94
10- # Imports for the code being tested
115from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory
126from databricks .sql .telemetry .models .event import DriverConnectionParameters , HostDetails , DatabricksClientType
137from databricks .sql .telemetry .models .enums import AuthMech
14- from databricks .sql .auth .retry import DatabricksRetryPolicy , CommandType
15-
16- # Imports for mocking the network layer correctly
17- from urllib3 .connectionpool import HTTPSConnectionPool
18- from urllib3 .exceptions import MaxRetryError
19- from requests .exceptions import ConnectionError as RequestsConnectionError
8+ from databricks .sql .auth .retry import DatabricksRetryPolicy
209
2110PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
2211
23- # Helper to create a mock that looks and acts like a urllib3.response.HTTPResponse.
24- def create_urllib3_response (status , headers = None , body = b'{}' ):
25- """Create a proper mock response that simulates urllib3's HTTPResponse"""
26- mock_response = MagicMock ()
27- mock_response .status = status
28- mock_response .headers = headers or {}
29- mock_response .msg = headers or {} # For urllib3~=1.0 compatibility
30- mock_response .data = body
31- mock_response .read .return_value = body
32- mock_response .get_redirect_location .return_value = False
33- mock_response .closed = False
34- mock_response .isclosed .return_value = False
35- return mock_response
12+ def create_mock_conn (responses ):
13+ """Creates a mock connection object whose getresponse() method yields a series of responses."""
14+ mock_conn = MagicMock ()
15+ mock_http_responses = []
16+ for resp in responses :
17+ mock_http_response = MagicMock ()
18+ mock_http_response .status = resp .get ("status" )
19+ mock_http_response .headers = resp .get ("headers" , {})
20+ body = resp .get ("body" , b'{}' )
21+ mock_http_response .fp = io .BytesIO (body )
22+ def release ():
23+ mock_http_response .fp .close ()
24+ mock_http_response .release_conn = release
25+ mock_http_responses .append (mock_http_response )
26+ mock_conn .getresponse .side_effect = mock_http_responses
27+ return mock_conn
3628
37- @pytest .mark .usefixtures ("caplog" )
3829class TestTelemetryClientRetries :
39- """
40- Test suite for verifying the retry mechanism of the TelemetryClient.
41- This suite patches the low-level urllib3 connection to correctly
42- trigger and test the retry logic configured in the requests adapter.
43- """
44-
4530 @pytest .fixture (autouse = True )
46- def setup_and_teardown (self , caplog ):
47- caplog .set_level (logging .DEBUG )
31+ def setup_and_teardown (self ):
4832 TelemetryClientFactory ._initialized = False
4933 TelemetryClientFactory ._clients = {}
5034 TelemetryClientFactory ._executor = None
@@ -55,159 +39,94 @@ def setup_and_teardown(self, caplog):
5539 TelemetryClientFactory ._clients = {}
5640 TelemetryClientFactory ._executor = None
5741
58- def get_client (self , session_id , total_retries = 3 ):
42+ def get_client (self , session_id , num_retries = 3 ):
43+ """
44+ Configures a client with a specific number of retries.
45+ """
5946 TelemetryClientFactory .initialize_telemetry_client (
6047 telemetry_enabled = True ,
6148 session_id_hex = session_id ,
6249 auth_provider = None ,
6350 host_url = "test.databricks.com" ,
6451 )
6552 client = TelemetryClientFactory .get_telemetry_client (session_id )
66-
53+
6754 retry_policy = DatabricksRetryPolicy (
6855 delay_min = 0.01 ,
6956 delay_max = 0.02 ,
7057 stop_after_attempts_duration = 2.0 ,
71- stop_after_attempts_count = total_retries ,
58+ stop_after_attempts_count = num_retries ,
7259 delay_default = 0.1 ,
7360 force_dangerous_codes = [],
74- urllib3_kwargs = {'total' : total_retries }
61+ urllib3_kwargs = {'total' : num_retries }
7562 )
7663 adapter = client ._session .adapters .get ("https://" )
7764 adapter .max_retries = retry_policy
7865 return client , adapter
7966
80- def wait_for_async_request (self , timeout = 2.0 ):
81- """Wait for async telemetry request to complete"""
82- start_time = time .time ()
83- while time .time () - start_time < timeout :
84- if TelemetryClientFactory ._executor and TelemetryClientFactory ._executor ._threads :
85- # Wait a bit more for threads to complete
86- time .sleep (0.1 )
87- else :
88- break
89- time .sleep (0.1 ) # Extra buffer for completion
90-
9167 def test_success_no_retry (self ):
9268 client , _ = self .get_client ("session-success" )
9369 params = DriverConnectionParameters (
94- http_path = "test-path" ,
95- mode = DatabricksClientType .THRIFT ,
70+ http_path = "test-path" , mode = DatabricksClientType .THRIFT ,
9671 host_info = HostDetails (host_url = "test.databricks.com" , port = 443 ),
9772 auth_mech = AuthMech .PAT
9873 )
99- with patch ( PATCH_TARGET ) as mock_get_conn :
100- mock_get_conn . return_value . getresponse . return_value = create_urllib3_response ( 200 )
101-
74+ mock_responses = [{ "status" : 200 }]
75+
76+ with patch ( PATCH_TARGET , return_value = create_mock_conn ( mock_responses )) as mock_get_conn :
10277 client .export_initial_telemetry_log (params , "test-agent" )
103- self .wait_for_async_request ()
10478 TelemetryClientFactory .close (client ._session_id_hex )
10579
10680 mock_get_conn .return_value .getresponse .assert_called_once ()
107-
108- def test_retry_on_503_then_succeeds (self ):
109- client , _ = self .get_client ("session-retry-once" )
110- with patch (PATCH_TARGET ) as mock_get_conn :
111- mock_get_conn .return_value .getresponse .side_effect = [
112- create_urllib3_response (503 ),
113- create_urllib3_response (200 ),
114- ]
115-
116- client .export_failure_log ("TestError" , "Test message" )
117- self .wait_for_async_request ()
118- TelemetryClientFactory .close (client ._session_id_hex )
119-
120- assert mock_get_conn .return_value .getresponse .call_count == 2
121-
122- def test_respects_retry_after_header (self , caplog ):
123- client , _ = self .get_client ("session-retry-after" )
124- with patch (PATCH_TARGET ) as mock_get_conn :
125- mock_get_conn .return_value .getresponse .side_effect = [
126- create_urllib3_response (429 , headers = {'Retry-After' : '1' }), # Use integer seconds to avoid parsing issues
127- create_urllib3_response (200 )
128- ]
129-
81+ client , _ = self .get_client ("session-retry-once" , num_retries = 1 )
82+ mock_responses = [{"status" : 503 }, {"status" : 200 }]
83+
84+ with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
13085 client .export_failure_log ("TestError" , "Test message" )
131- self .wait_for_async_request ()
13286 TelemetryClientFactory .close (client ._session_id_hex )
13387
134- # Check that the request was retried (should be 2 calls: initial + 1 retry)
13588 assert mock_get_conn .return_value .getresponse .call_count == 2
136- assert "Retrying after" in caplog .text
13789
138- def test_exceeds_retry_count_limit (self , caplog ):
139- client , _ = self .get_client ("session-exceed-limit" , total_retries = 3 )
140- expected_call_count = 4
141- with patch (PATCH_TARGET ) as mock_get_conn :
142- mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (503 )
143-
90+ @pytest .mark .parametrize (
91+ "status_code, description" ,
92+ [
93+ (401 , "Unauthorized" ),
94+ (403 , "Forbidden" ),
95+ (501 , "Not Implemented" ),
96+ ],
97+ )
98+ def test_non_retryable_status_codes_are_not_retried (self , status_code , description ):
99+ """
100+ Verifies that terminal error codes (401, 403, 501, etc.) are not retried.
101+ """
102+ # Use the status code in the session ID for easier debugging if it fails
103+ client , _ = self .get_client (f"session-{ status_code } " )
104+ mock_responses = [{"status" : status_code }]
105+
106+ with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
144107 client .export_failure_log ("TestError" , "Test message" )
145- self .wait_for_async_request ()
146108 TelemetryClientFactory .close (client ._session_id_hex )
147-
148- assert mock_get_conn .return_value .getresponse .call_count == expected_call_count
149- assert "Telemetry request failed with exception" in caplog .text
150- assert "Max retries exceeded" in caplog .text
151109
152- def test_no_retry_on_401_unauthorized (self , caplog ):
153- """Test that 401 responses are not retried (per retry policy)"""
154- client , _ = self .get_client ("session-401" )
155- with patch (PATCH_TARGET ) as mock_get_conn :
156- mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (401 )
157-
158- client .export_failure_log ("TestError" , "Test message" )
159- self .wait_for_async_request ()
160- TelemetryClientFactory .close (client ._session_id_hex )
161-
162- # 401 should not be retried based on the retry policy
163110 mock_get_conn .return_value .getresponse .assert_called_once ()
164- assert "Telemetry request failed with status code: 401" in caplog .text
165111
166- def test_retries_on_400_bad_request (self , caplog ):
167- """Test that 400 responses are retried (this is the current behavior for telemetry)"""
168- client , _ = self .get_client ("session-400" )
169- with patch (PATCH_TARGET ) as mock_get_conn :
170- mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (400 )
171-
172- client .export_failure_log ("TestError" , "Test message" )
173- self .wait_for_async_request ()
174- TelemetryClientFactory .close (client ._session_id_hex )
175-
176- # Based on the logs, 400 IS being retried (this is the actual behavior for CommandType.OTHER)
177- expected_call_count = 4 # total + 1 (initial + 3 retries)
178- assert mock_get_conn .return_value .getresponse .call_count == expected_call_count
179- assert "Telemetry request failed with exception" in caplog .text
180- assert "Max retries exceeded" in caplog .text
181-
182- def test_no_retry_on_403_forbidden (self , caplog ):
183- """Test that 403 responses are not retried (per retry policy)"""
184- client , _ = self .get_client ("session-403" )
185- with patch (PATCH_TARGET ) as mock_get_conn :
186- mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (403 )
187-
112+ def test_respects_retry_after_header (self ):
113+ client , _ = self .get_client ("session-retry-after" , num_retries = 1 )
114+ mock_responses = [{"status" : 429 , "headers" : {'Retry-After' : '1' }}, {"status" : 200 }]
115+
116+ with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
188117 client .export_failure_log ("TestError" , "Test message" )
189- self .wait_for_async_request ()
190118 TelemetryClientFactory .close (client ._session_id_hex )
191119
192- # 403 should not be retried based on the retry policy
193- mock_get_conn .return_value .getresponse .assert_called_once ()
194- assert "Telemetry request failed with status code: 403" in caplog .text
120+ assert mock_get_conn .return_value .getresponse .call_count == 2
195121
196- def test_retry_policy_command_type_is_set_to_other (self ):
197- client , adapter = self .get_client ("session-command-type" )
122+ def test_exceeds_retry_count_limit (self ):
123+ num_retries = 3
124+ expected_total_calls = num_retries + 1
125+ client , _ = self .get_client ("session-exceed-limit" , num_retries = num_retries )
126+ mock_responses = [{"status" : 503 }] * expected_total_calls
198127
199- original_send = adapter .send
200- @wraps (original_send )
201- def wrapper (request , ** kwargs ):
202- assert adapter .max_retries .command_type == CommandType .OTHER
203- return original_send (request , ** kwargs )
204-
205- with patch .object (adapter , 'send' , side_effect = wrapper , autospec = True ), \
206- patch (PATCH_TARGET ) as mock_get_conn :
207- mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (200 )
208-
128+ with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
209129 client .export_failure_log ("TestError" , "Test message" )
210- self .wait_for_async_request ()
211130 TelemetryClientFactory .close (client ._session_id_hex )
212131
213- assert adapter . send . call_count == 1
132+ assert mock_get_conn . return_value . getresponse . call_count == expected_total_calls
0 commit comments