Skip to content

Commit 68d83a9

Browse files
NiallEgansusodapop
authored andcommitted
Add socket timeout param
1 parent d00fc9d commit 68d83a9

File tree

4 files changed

+26
-0
lines changed

4 files changed

+26
-0
lines changed

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def __init__(self,
6060
# Set client SSL certificate.
6161
# _retry_stop_after_attempts_count
6262
# The maximum number of attempts during a request retry sequence (defaults to 24)
63+
# _socket_timeout
64+
# The timeout in seconds for socket send, recv and connect operations. Defaults to None for
65+
# no timeout. Should be a positive float or integer.
6366

6467
self.host = server_hostname
6568
self.port = kwargs.get("_port", 443)

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def __init__(self, server_hostname: str, port, http_path: str, http_headers, **k
7171
# next calculated pre-retry delay would go past
7272
# _retry_stop_after_attempts_duration, stop now.)
7373
#
74+
# _retry_stop_after_attempts_count
75+
# The maximum number of times we should retry retryable requests (defaults to 24)
76+
# _socket_timeout
77+
# The timeout in seconds for socket send, recv and connect operations. Defaults to None for
78+
# no timeout. Should be a positive float or integer.
7479

7580
port = port or 443
7681
if kwargs.get("_connection_uri"):
@@ -109,6 +114,10 @@ def __init__(self, server_hostname: str, port, http_path: str, http_headers, **k
109114
ssl_context=ssl_context,
110115
)
111116

117+
timeout = kwargs.get("_socket_timeout")
118+
# setTimeout defaults to None (i.e. no timeout), and is expected in ms
119+
self._transport.setTimeout(timeout and (float(timeout) * 1000.0))
120+
112121
self._transport.setCustomHeaders(dict(http_headers))
113122
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
114123
self._client = TCLIService.Client(protocol)

cmdexec/clients/python/tests/test_thrift_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,15 @@ def test_port_and_host_are_respected(self, t_http_client_class):
193193
self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"],
194194
"https://hostname:123/path_value")
195195

196+
@patch("thrift.transport.THttpClient.THttpClient")
197+
def test_socket_timeout_is_propagated(self, t_http_client_class):
198+
ThriftBackend("hostname", 123, "path_value", [], _socket_timeout=129)
199+
self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000)
200+
ThriftBackend("hostname", 123, "path_value", [], _socket_timeout=0)
201+
self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0)
202+
ThriftBackend("hostname", 123, "path_value", [], _socket_timeout=None)
203+
self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None)
204+
196205
def test_non_primitive_types_raise_error(self):
197206
columns = [
198207
ttypes.TColumnDesc(

cmdexec/clients/python/tests/tests.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ def test_max_number_of_retries_passthrough(self, mock_client_class):
333333

334334
self.assertEqual(mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54)
335335

336+
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
337+
def test_socket_timeout_passthrough(self, mock_client_class):
338+
databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS)
339+
self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234)
340+
336341
def test_version_is_canonical(self):
337342
version = databricks.sql.__version__
338343
canonical_version_re = r'^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)' \

0 commit comments

Comments
 (0)