Skip to content

Commit 317a471

Browse files
author
Jesse
authored
[PECO-1134] v3 Retries: allow users to bound the number of redirects to follow (#244)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
1 parent 910bb5c commit 317a471

File tree

4 files changed

+138
-13
lines changed

4 files changed

+138
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
- Other: Introduce SQLAlchemy dialect compliance test suite and enumerate all excluded tests
66
- Add integration tests for Databricks UC Volumes ingestion queries
7+
- Add `_retry_max_redirects` config
78

89
## 2.9.3 (2023-08-24)
910

examples/v3_retries_query_execute.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,20 @@
2020
# for 502 (Bad Gateway) codes etc. In these cases, there is a possibility that the initial command _did_ reach
2121
# Databricks compute and retrying it could result in additional executions. Retrying under these conditions uses
2222
# an exponential back-off since a Retry-After header is not present.
23+
#
24+
# This new retry behaviour allows you to configure the maximum number of redirects that the connector will follow.
25+
# Just set `_retry_max_redirects` to the integer number of redirects you want to allow. The default is None,
26+
# which means all redirects will be followed. In this case, a redirect will count toward the
27+
# _retry_stop_after_attempts_count which means that by default the connector will not enter an endless retry loop.
28+
#
29+
# For complete information about configuring retries, see the docstring for databricks.sql.thrift_backend.ThriftBackend
2330

2431
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
2532
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
2633
access_token = os.getenv("DATABRICKS_TOKEN"),
2734
_enable_v3_retries = True,
28-
_retry_dangerous_codes=[502,400]) as connection:
35+
_retry_dangerous_codes=[502,400],
36+
_retry_max_redirects=2) as connection:
2937

3038
with connection.cursor() as cursor:
3139
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")

src/databricks/sql/thrift_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def __init__(
130130
# _enable_v3_retries
131131
# Whether to use the DatabricksRetryPolicy implemented in urllib3
132132
# (defaults to False)
133+
# _retry_max_redirects
134+
# An integer representing the maximum number of redirects to follow for a request.
135+
# This number must be <= _retry_stop_after_attempts_count.
136+
# (defaults to None)
133137
# max_download_threads
134138
# Number of threads for handling cloud fetch downloads. Defaults to 10
135139

@@ -185,6 +189,16 @@ def __init__(
185189
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
186190

187191
additional_transport_args = {}
192+
_max_redirects: Union[None, int] = kwargs.get("_retry_max_redirects")
193+
194+
if _max_redirects:
195+
if _max_redirects > self._retry_stop_after_attempts_count:
196+
logger.warn(
197+
"_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!"
198+
)
199+
urllib3_kwargs = {"redirect": _max_redirects}
200+
else:
201+
urllib3_kwargs = {}
188202
if self.enable_v3_retries:
189203
self.retry_policy = databricks.sql.auth.thrift_http_client.DatabricksRetryPolicy(
190204
delay_min=self._retry_delay_min,
@@ -193,6 +207,7 @@ def __init__(
193207
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
194208
delay_default=self._retry_delay_default,
195209
force_dangerous_codes=self.force_dangerous_codes,
210+
urllib3_kwargs=urllib3_kwargs,
196211
)
197212

198213
additional_transport_args["retry_policy"] = self.retry_policy

tests/e2e/common/retry_test_mixins.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,21 @@ def _test_retry_disabled_with_message(self, error_msg_substring, exception_type)
5858

5959

6060
@contextmanager
61-
def mocked_server_response(status: int = 200, headers: dict = {}):
61+
def mocked_server_response(
62+
status: int = 200, headers: dict = {}, redirect_location: str = None
63+
):
6264
"""Context manager for patching urllib3 responses"""
6365

6466
# When mocking mocking a BaseHTTPResponse for urllib3 the mock must include
6567
# 1. A status code
6668
# 2. A headers dict
67-
# 3. mock.get_redirect_location() return falsy
69+
# 3. mock.get_redirect_location() return falsy by default
6870

6971
# `msg` is included for testing when urllib3~=1.0.0 is installed
7072
mock_response = MagicMock(headers=headers, msg=headers, status=status)
71-
mock_response.get_redirect_location.return_value = False
73+
mock_response.get_redirect_location.return_value = (
74+
False if redirect_location is None else redirect_location
75+
)
7276

7377
with patch("urllib3.connectionpool.HTTPSConnectionPool._get_conn") as getconn_mock:
7478
getconn_mock.return_value.getresponse.return_value = mock_response
@@ -86,6 +90,7 @@ def mock_sequential_server_responses(responses: List[dict]):
8690
`responses` should be a list of dictionaries containing these members:
8791
- status: int
8892
- headers: dict
93+
- redirect_location: str
8994
"""
9095

9196
mock_responses = []
@@ -96,7 +101,9 @@ def mock_sequential_server_responses(responses: List[dict]):
96101
_mock = MagicMock(
97102
headers=resp["headers"], msg=resp["headers"], status=resp["status"]
98103
)
99-
_mock.get_redirect_location.return_value = False
104+
_mock.get_redirect_location.return_value = (
105+
False if resp["redirect_location"] is None else resp["redirect_location"]
106+
)
100107
mock_responses.append(_mock)
101108

102109
with patch("urllib3.connectionpool.HTTPSConnectionPool._get_conn") as getconn_mock:
@@ -220,7 +227,7 @@ def test_retry_dangerous_codes(self):
220227
with self.connection(extra_params={**self._retry_policy}) as conn:
221228
with conn.cursor() as cursor:
222229
for dangerous_code in DANGEROUS_CODES:
223-
with mocked_server_response(status=dangerous_code) as mock_obj:
230+
with mocked_server_response(status=dangerous_code):
224231
with self.assertRaises(RequestError) as cm:
225232
cursor.execute("Not a real query")
226233
assert isinstance(cm.exception.args[1], UnsafeToRetryError)
@@ -231,7 +238,7 @@ def test_retry_dangerous_codes(self):
231238
) as conn:
232239
with conn.cursor() as cursor:
233240
for dangerous_code in DANGEROUS_CODES:
234-
with mocked_server_response(status=dangerous_code) as mock_obj:
241+
with mocked_server_response(status=dangerous_code):
235242
with pytest.raises(MaxRetryError) as cm:
236243
cursor.execute("Not a real query")
237244

@@ -242,8 +249,8 @@ def test_retry_safe_execute_statement_retry_condition(self):
242249
"""
243250

244251
responses = [
245-
{"status": 429, "headers": {"Retry-After": "1"}},
246-
{"status": 503, "headers": {}},
252+
{"status": 429, "headers": {"Retry-After": "1"}, "redirect_location": None},
253+
{"status": 503, "headers": {}, "redirect_location": None},
247254
]
248255

249256
with self.connection(
@@ -265,8 +272,8 @@ def test_retry_abort_close_session_on_404(self):
265272
# First response is a Bad Gateway -> Result is the command actually goes through
266273
# Second response is a 404 because the session is no longer found
267274
responses = [
268-
{"status": 502, "headers": {"Retry-After": "1"}},
269-
{"status": 404, "headers": {}},
275+
{"status": 502, "headers": {"Retry-After": "1"}, "redirect_location": None},
276+
{"status": 404, "headers": {}, "redirect_location": None},
270277
]
271278

272279
with self.connection(extra_params={**self._retry_policy}) as conn:
@@ -295,8 +302,8 @@ def test_retry_abort_close_operation_on_404(self):
295302
# First response is a Bad Gateway -> Result is the command actually goes through
296303
# Second response is a 404 because the session is no longer found
297304
responses = [
298-
{"status": 502, "headers": {"Retry-After": "1"}},
299-
{"status": 404, "headers": {}},
305+
{"status": 502, "headers": {"Retry-After": "1"}, "redirect_location": None},
306+
{"status": 404, "headers": {}, "redirect_location": None},
300307
]
301308

302309
with self.connection(extra_params={**self._retry_policy}) as conn:
@@ -323,3 +330,97 @@ def test_retry_abort_close_operation_on_404(self):
323330
self.assertTrue(
324331
expected_message_was_found, "Did not find expected log messages"
325332
)
333+
334+
def test_retry_max_redirects_raises_too_many_redirects_exception(self):
335+
"""GIVEN the connector is configured with a custom max_redirects
336+
WHEN the DatabricksRetryPolicy is created
337+
THEN the connector raises a MaxRedirectsError if that number is exceeded
338+
"""
339+
340+
max_redirects, expected_call_count = 1, 2
341+
342+
# Code 302 is a redirect
343+
with mocked_server_response(
344+
status=302, redirect_location="/foo.bar"
345+
) as mock_obj:
346+
with self.assertRaises(MaxRetryError) as cm:
347+
with self.connection(
348+
extra_params={
349+
**self._retry_policy,
350+
"_retry_max_redirects": max_redirects,
351+
}
352+
):
353+
pass
354+
assert "too many redirects" == str(cm.exception.reason)
355+
# Total call count should be 2 (original + 1 retry)
356+
assert mock_obj.return_value.getresponse.call_count == expected_call_count
357+
358+
def test_retry_max_redirects_unset_doesnt_redirect_forever(self):
359+
"""GIVEN the connector is configured without a custom max_redirects
360+
WHEN the DatabricksRetryPolicy is used
361+
THEN the connector raises a MaxRedirectsError if that number is exceeded
362+
363+
This test effectively guarantees that regardless of _retry_max_redirects,
364+
_stop_after_attempts_count is enforced.
365+
"""
366+
# Code 302 is a redirect
367+
with mocked_server_response(
368+
status=302, redirect_location="/foo.bar/"
369+
) as mock_obj:
370+
with self.assertRaises(MaxRetryError) as cm:
371+
with self.connection(
372+
extra_params={
373+
**self._retry_policy,
374+
}
375+
):
376+
pass
377+
378+
# Total call count should be 6 (original + _retry_stop_after_attempts_count)
379+
assert mock_obj.return_value.getresponse.call_count == 6
380+
381+
def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self):
382+
# If I add another 503 or 302 here the test will fail with a MaxRetryError
383+
responses = [
384+
{"status": 302, "headers": {}, "redirect_location": "/foo.bar"},
385+
{"status": 500, "headers": {}, "redirect_location": None},
386+
]
387+
388+
additional_settings = {
389+
"_retry_max_redirects": 1,
390+
"_retry_stop_after_attempts_count": 2,
391+
}
392+
393+
with pytest.raises(RequestError) as cm:
394+
with mock_sequential_server_responses(responses):
395+
with self.connection(
396+
extra_params={**self._retry_policy, **additional_settings}
397+
):
398+
pass
399+
400+
# The error should be the result of the 500, not because of too many requests.
401+
assert "too many redirects" not in str(cm.value.message)
402+
assert "Error during request to server" in str(cm.value.message)
403+
404+
def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self):
405+
with self.assertLogs(
406+
"databricks.sql",
407+
level="WARN",
408+
) as cm:
409+
with self.connection(
410+
extra_params={
411+
**self._retry_policy,
412+
**{
413+
"_retry_max_redirects": 100,
414+
"_retry_stop_after_attempts_count": 1,
415+
},
416+
}
417+
):
418+
pass
419+
expected_message_was_found = False
420+
for log in cm.output:
421+
if expected_message_was_found:
422+
break
423+
target = "it will have no affect!"
424+
expected_message_was_found = target in log
425+
426+
assert expected_message_was_found, "Did not find expected log messages"

0 commit comments

Comments
 (0)