Skip to content

Commit 5a6352e

Browse files
NiallEgansusodapop
authored andcommitted
429, 503 Retries in Thrift Client
This PR introduces changes to retry if there is a `Retry-After` header and the response is 429 and 503 * New unit tests
1 parent 3758583 commit 5a6352e

File tree

4 files changed

+128
-8
lines changed

4 files changed

+128
-8
lines changed

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __init__(self,
8686
# Set client SSL certificate.
8787
# _session_id
8888
# Specify the session id of the connection. For Redash use only.
89+
# _max_number_of_retries
90+
# The maximum number of times we should retry retriable requests (defaults to 25)
8991

9092
self.host = server_hostname
9193
self.port = kwargs.get("_port", 443)
@@ -117,7 +119,8 @@ def __init__(self,
117119
_tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
118120
_tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
119121
_tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
120-
_connection_uri=kwargs.get("_connection_uri"))
122+
_connection_uri=kwargs.get("_connection_uri"),
123+
_max_number_of_retries=kwargs.get("_max_number_of_retries", 25))
121124

122125
self._session_handle = self.thrift_backend.open_session(
123126
session_id=kwargs.get("_session_id"))

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
import time
3+
import threading
24
from uuid import uuid4
35
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
46

@@ -20,6 +22,7 @@ class ThriftBackend:
2022
CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE
2123
ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE
2224
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
25+
ERROR_MSG_HEADER = "X-Thriftserver-Error-Message"
2326

2427
def __init__(self, server_hostname: str, port, http_path: str, http_headers, **kwargs):
2528
# Internal arguments in **kwargs:
@@ -39,6 +42,8 @@ def __init__(self, server_hostname: str, port, http_path: str, http_headers, **k
3942
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
4043
# _connection_uri
4144
# Overrides server_hostname and http_path.
45+
# _max_number_of_retries
46+
# The maximum number of times we should retry retryable requests (defaults to 25)
4247

4348
port = port or 443
4449
if kwargs.get("_connection_uri"):
@@ -49,6 +54,8 @@ def __init__(self, server_hostname: str, port, http_path: str, http_headers, **k
4954
else:
5055
raise ValueError("No valid connection settings.")
5156

57+
self._max_number_of_retries = kwargs.get("_max_number_of_retries", 25)
58+
5259
# Configure tls context
5360
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
5461
if kwargs.get("_tls_no_verify") is True:
@@ -85,21 +92,53 @@ def __init__(self, server_hostname: str, port, http_path: str, http_headers, **k
8592
self._transport.close()
8693
raise
8794

95+
self._request_lock = threading.RLock()
96+
8897
@staticmethod
8998
def _check_response_for_error(response):
9099
if response.status and response.status.statusCode in \
91100
[ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS]:
92101
raise DatabaseError(response.status.errorMessage)
93102

94-
@staticmethod
95-
def make_request(method, request):
103+
def make_request(self, method, request, attempt_number=1):
96104
try:
105+
# We have a lock here because .cancel can be called from a separate thread.
106+
# We do not want threads to be simultaneously sharing the Thrift Transport
107+
# because we use its state to determine retries
108+
self._request_lock.acquire()
97109
response = method(request)
98110
logger.debug("Received response: {}".format(response))
99111
ThriftBackend._check_response_for_error(response)
100112
return response
101-
except TException as error:
102-
raise OperationalError("Error during Thrift request", error)
113+
except Exception as error:
114+
# _transport.code isn't necessarily set :(
115+
code_and_headers_is_set = hasattr(self._transport, 'code') \
116+
and hasattr(self._transport, 'headers')
117+
# We only retry if a Retry-After header is set
118+
if code_and_headers_is_set and self._transport.code in [503, 429] and \
119+
"Retry-After" in self._transport.headers and \
120+
attempt_number <= self._max_number_of_retries:
121+
retry_time_seconds = int(self._transport.headers["Retry-After"])
122+
if self.ERROR_MSG_HEADER in self._transport.headers:
123+
error_message = self._transport.headers[self.ERROR_MSG_HEADER]
124+
else:
125+
error_message = str(error)
126+
logger.warning("Received retryable error during {}. Request: {} Error: {}".format(
127+
method, request, error_message))
128+
logger.warning("Retrying in {} seconds. This is attempt number {}".format(
129+
retry_time_seconds, attempt_number))
130+
time.sleep(retry_time_seconds)
131+
return self.make_request(method, request, attempt_number + 1)
132+
else:
133+
logger.error("Received error when issuing: {}".format(request))
134+
if hasattr(self._transport, "headers") and \
135+
self.ERROR_MSG_HEADER in self._transport.headers:
136+
error_message = self._transport.headers[self.ERROR_MSG_HEADER]
137+
raise OperationalError("Error during Thrift request: {}".format(error_message))
138+
else:
139+
raise OperationalError("Error during Thrift request", error)
140+
finally:
141+
self._request_lock.release()
103142

104143
def _check_protocol_version(self, t_open_session_resp):
105144
protocol_version = t_open_session_resp.serverProtocolVersion

cmdexec/clients/python/tests/test_thrift_backend.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def test_make_request_checks_thrift_status_code(self):
4343
mock_response = Mock()
4444
mock_response.status.statusCode = ttypes.TStatusCode.ERROR_STATUS
4545
mock_method = lambda _: mock_response
46+
thrift_backend = ThriftBackend("foo", 123, "bar", [])
4647
with self.assertRaises(DatabaseError):
47-
ThriftBackend.make_request(mock_method, Mock())
48+
thrift_backend.make_request(mock_method, Mock())
4849

4950
def _make_type_desc(self, type):
5051
return ttypes.TTypeDesc(types=[ttypes.TTypeEntry(ttypes.TPrimitiveTypeEntry(type=type))])
@@ -174,22 +175,25 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self):
174175

175176
def test_make_request_checks_status_code(self):
176177
error_codes = [ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS]
178+
thrift_backend = ThriftBackend("foo", 123, "bar", [])
179+
177180
for code in error_codes:
178181
mock_error_response = Mock()
179182
mock_error_response.status.statusCode = code
180183
mock_error_response.status.errorMessage = "a detailed error message"
181184
with self.assertRaises(DatabaseError) as cm:
182-
ThriftBackend.make_request(lambda _: mock_error_response, Mock())
185+
thrift_backend.make_request(lambda _: mock_error_response, Mock())
183186
self.assertIn("a detailed error message", str(cm.exception))
184187

185188
success_codes = [
186189
ttypes.TStatusCode.SUCCESS_STATUS, ttypes.TStatusCode.SUCCESS_WITH_INFO_STATUS,
187190
ttypes.TStatusCode.STILL_EXECUTING_STATUS
188191
]
192+
189193
for code in success_codes:
190194
mock_response = Mock()
191195
mock_response.status.statusCode = code
192-
ThriftBackend.make_request(lambda _: mock_response, Mock())
196+
thrift_backend.make_request(lambda _: mock_response, Mock())
193197

194198
def test_handle_execute_response_checks_operation_state_in_direct_results(self):
195199
for resp_type in self.execute_response_types:
@@ -746,6 +750,74 @@ def test_handle_execute_response_sets_active_op_handle(self):
746750

747751
self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle)
748752

753+
@patch("thrift.transport.THttpClient.THttpClient")
754+
def test_make_request_wont_retry_if_headers_not_present(self, t_transport_class):
755+
t_transport_instance = t_transport_class.return_value
756+
t_transport_instance.code = 429
757+
t_transport_instance.headers = {"foo": "bar"}
758+
mock_method = Mock()
759+
mock_method.side_effect = Exception("This method fails")
760+
761+
thrift_backend = ThriftBackend("foobar", 443, "path", [])
762+
763+
with self.assertRaises(OperationalError) as cm:
764+
thrift_backend.make_request(mock_method, Mock())
765+
766+
self.assertIn("This method fails", str(cm.exception))
767+
768+
@patch("thrift.transport.THttpClient.THttpClient")
769+
def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_class):
770+
t_transport_instance = t_transport_class.return_value
771+
t_transport_instance.code = 430
772+
t_transport_instance.headers = {"Retry-After": "1"}
773+
mock_method = Mock()
774+
mock_method.side_effect = Exception("This method fails")
775+
776+
thrift_backend = ThriftBackend("foobar", 443, "path", [])
777+
778+
with self.assertRaises(OperationalError) as cm:
779+
thrift_backend.make_request(mock_method, Mock())
780+
781+
self.assertIn("This method fails", str(cm.exception))
782+
783+
@patch("thrift.transport.THttpClient.THttpClient")
784+
def test_make_request_will_retry_max_number_of_retries_times_if_retryable(
785+
self, t_transport_class):
786+
t_transport_instance = t_transport_class.return_value
787+
t_transport_instance.code = 429
788+
t_transport_instance.headers = {"Retry-After": "0"}
789+
mock_method = Mock()
790+
mock_method.side_effect = Exception("This method fails")
791+
792+
thrift_backend = ThriftBackend("foobar", 443, "path", [], _max_number_of_retries=13)
793+
794+
with self.assertRaises(OperationalError) as cm:
795+
thrift_backend.make_request(mock_method, Mock())
796+
797+
self.assertIn("This method fails", str(cm.exception))
798+
799+
self.assertEqual(mock_method.call_count, 13 + 1)
800+
801+
@patch("thrift.transport.THttpClient.THttpClient")
802+
def test_make_request_will_read_X_Thriftserver_Error_Message_if_set(self, t_transport_class):
803+
t_transport_instance = t_transport_class.return_value
804+
t_transport_instance.code = 429
805+
t_transport_instance.headers = {
806+
"Retry-After": "0",
807+
"X-Thriftserver-Error-Message": "message2"
808+
}
809+
mock_method = Mock()
810+
mock_method.side_effect = Exception("This method fails")
811+
812+
thrift_backend = ThriftBackend("foobar", 443, "path", [], _max_number_of_retries=13)
813+
814+
with self.assertRaises(OperationalError) as cm:
815+
thrift_backend.make_request(mock_method, Mock())
816+
817+
self.assertIn("message2", str(cm.exception))
818+
819+
self.assertEqual(mock_method.call_count, 13 + 1)
820+
749821

750822
if __name__ == '__main__':
751823
unittest.main()

cmdexec/clients/python/tests/tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,12 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command(
325325
self.assertTrue(logger_instance.warning.called)
326326
self.assertFalse(mock_thrift_backend.cancel_command.called)
327327

328+
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
329+
def test_max_number_of_retries_passthrough(self, mock_client_class):
330+
databricks.sql.connect(_max_number_of_retries=53, **self.DUMMY_CONNECTION_ARGS)
331+
332+
self.assertEqual(mock_client_class.call_args[1]["_max_number_of_retries"], 53)
333+
328334

329335
class ResultSetTests(unittest.TestCase):
330336
def test_parse_type_converts_decimal(self):

0 commit comments

Comments
 (0)