Skip to content

Commit b7a4677

Browse files
cleaner HTTP client using requests.sessions
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 112958d commit b7a4677

File tree

1 file changed

+97
-191
lines changed

1 file changed

+97
-191
lines changed
Lines changed: 97 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import json
22
import logging
33
import ssl
4-
import urllib.parse
54
import urllib.request
65
from typing import Dict, Any, Optional, List, Tuple, Union
76
from urllib.parse import urljoin
87

9-
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
10-
from urllib3.util import make_headers
8+
import requests
9+
from requests.adapters import HTTPAdapter
10+
from requests.exceptions import RequestException, HTTPError, ConnectionError
1111
from urllib3.exceptions import MaxRetryError
1212

1313
from databricks.sql.auth.authenticators import AuthProvider
@@ -23,20 +23,46 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26-
class SeaHttpClient:
26+
class SSLContextAdapter(HTTPAdapter):
27+
"""
28+
An HTTP adapter that uses a custom SSLContext to handle advanced SSL settings,
29+
including client certificate key passwords.
2730
"""
28-
HTTP client for Statement Execution API (SEA).
2931

30-
This client uses urllib3 for robust HTTP communication with retry policies
31-
and connection pooling, similar to the Thrift HTTP client but simplified.
32+
def __init__(self, ssl_options: SSLOptions, **kwargs):
33+
self.ssl_context = self._create_ssl_context(ssl_options)
34+
super().__init__(**kwargs)
35+
36+
def _create_ssl_context(self, ssl_options: SSLOptions) -> ssl.SSLContext:
37+
"""
38+
Build a custom SSLContext based on the provided SSLOptions.
39+
"""
40+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
41+
if not ssl_options.tls_verify:
42+
context.check_hostname = False
43+
context.verify_mode = ssl.CERT_NONE
44+
elif ssl_options.tls_trusted_ca_file:
45+
context.load_verify_locations(cafile=ssl_options.tls_trusted_ca_file)
46+
if ssl_options.tls_client_cert_file:
47+
context.load_cert_chain(
48+
certfile=ssl_options.tls_client_cert_file,
49+
keyfile=ssl_options.tls_client_cert_key_file,
50+
password=ssl_options.tls_client_cert_key_password,
51+
)
52+
return context
53+
54+
def init_poolmanager(self, *args, **kwargs):
55+
kwargs["ssl_context"] = self.ssl_context
56+
return super().init_poolmanager(*args, **kwargs)
57+
58+
59+
class SeaHttpClient:
60+
"""
61+
HTTP client for Statement Execution API (SEA), using the requests library.
3262
"""
3363

3464
retry_policy: Union[DatabricksRetryPolicy, int]
35-
_pool: Optional[Union[HTTPConnectionPool, HTTPSConnectionPool]]
36-
proxy_uri: Optional[str]
37-
realhost: Optional[str]
38-
realport: Optional[int]
39-
proxy_auth: Optional[Dict[str, str]]
65+
_session: requests.Session
4066

4167
def __init__(
4268
self,
@@ -48,39 +74,16 @@ def __init__(
4874
ssl_options: SSLOptions,
4975
**kwargs,
5076
):
51-
"""
52-
Initialize the SEA HTTP client.
53-
54-
Args:
55-
server_hostname: Hostname of the Databricks server
56-
port: Port number for the connection
57-
http_path: HTTP path for the connection
58-
http_headers: List of HTTP headers to include in requests
59-
auth_provider: Authentication provider
60-
ssl_options: SSL configuration options
61-
**kwargs: Additional keyword arguments including retry policy settings
62-
"""
63-
6477
self.server_hostname = server_hostname
6578
self.port = port or 443
66-
self.http_path = http_path
6779
self.auth_provider = auth_provider
6880
self.ssl_options = ssl_options
69-
70-
# Build base URL
71-
self.base_url = f"https://{server_hostname}:{self.port}"
72-
73-
# Parse URL for proxy handling
74-
parsed_url = urllib.parse.urlparse(self.base_url)
75-
self.scheme = parsed_url.scheme
76-
self.host = parsed_url.hostname
77-
self.port = parsed_url.port or (443 if self.scheme == "https" else 80)
78-
79-
# Setup headers
81+
self.scheme = "https"
82+
self.base_url = f"{self.scheme}://{server_hostname}:{self.port}"
83+
self._session = requests.Session()
8084
self.headers: Dict[str, str] = dict(http_headers)
8185
self.headers.update({"Content-Type": "application/json"})
82-
83-
# Extract retry policy settings
86+
self._session.headers.update(self.headers)
8487
self._retry_delay_min = kwargs.get("_retry_delay_min", 1.0)
8588
self._retry_delay_max = kwargs.get("_retry_delay_max", 60.0)
8689
self._retry_stop_after_attempts_count = kwargs.get(
@@ -91,23 +94,36 @@ def __init__(
9194
)
9295
self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0)
9396
self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", [])
94-
95-
# Connection pooling settings
9697
self.max_connections = kwargs.get("max_connections", 10)
97-
98-
# Setup retry policy
98+
self._configure_proxies()
9999
self.enable_v3_retries = kwargs.get("_enable_v3_retries", True)
100+
self._configure_retries_and_ssl(**kwargs)
101+
102+
def _configure_proxies(self):
103+
try:
104+
proxy = urllib.request.getproxies().get(self.scheme)
105+
except (KeyError, AttributeError):
106+
proxy = None
107+
else:
108+
if self.server_hostname and urllib.request.proxy_bypass(
109+
self.server_hostname
110+
):
111+
proxy = None
112+
if proxy:
113+
self._session.proxies = {"http": proxy, "https": proxy}
100114

115+
def _configure_retries_and_ssl(self, **kwargs):
101116
if self.enable_v3_retries:
102117
urllib3_kwargs = {"allowed_methods": ["GET", "POST", "DELETE"]}
103118
_max_redirects = kwargs.get("_retry_max_redirects")
104119
if _max_redirects:
105120
if _max_redirects > self._retry_stop_after_attempts_count:
106121
logger.warning(
107-
"_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!"
122+
"_retry_max_redirects > _retry_stop_after_attempts_count "
123+
"so it will have no effect!"
108124
)
109125
urllib3_kwargs["redirect"] = _max_redirects
110-
126+
self._session.max_redirects = _max_redirects
111127
self.retry_policy = DatabricksRetryPolicy(
112128
delay_min=self._retry_delay_min,
113129
delay_max=self._retry_delay_max,
@@ -117,104 +133,34 @@ def __init__(
117133
force_dangerous_codes=self.force_dangerous_codes,
118134
urllib3_kwargs=urllib3_kwargs,
119135
)
136+
retry_strategy = self.retry_policy
120137
else:
121-
# Legacy behavior - no automatic retries
122138
logger.warning(
123139
"Legacy retry behavior is enabled for this connection."
124140
" This behaviour is not supported for the SEA backend."
125141
)
126142
self.retry_policy = 0
127-
128-
# Handle proxy settings
129-
try:
130-
proxy = urllib.request.getproxies().get(self.scheme)
131-
except (KeyError, AttributeError):
132-
proxy = None
133-
else:
134-
if self.host and urllib.request.proxy_bypass(self.host):
135-
proxy = None
136-
137-
if proxy:
138-
parsed_proxy = urllib.parse.urlparse(proxy)
139-
self.realhost = self.host
140-
self.realport = self.port
141-
self.proxy_uri = proxy
142-
self.host = parsed_proxy.hostname
143-
self.port = parsed_proxy.port or (443 if self.scheme == "https" else 80)
144-
self.proxy_auth = self._basic_proxy_auth_headers(parsed_proxy)
145-
else:
146-
self.realhost = None
147-
self.realport = None
148-
self.proxy_auth = None
149-
self.proxy_uri = None
150-
151-
# Initialize connection pool
152-
self._pool = None
153-
self._open()
154-
155-
def _basic_proxy_auth_headers(self, proxy_parsed) -> Optional[Dict[str, str]]:
156-
"""Create basic auth headers for proxy if credentials are provided."""
157-
if proxy_parsed is None or not proxy_parsed.username:
158-
return None
159-
ap = f"{urllib.parse.unquote(proxy_parsed.username)}:{urllib.parse.unquote(proxy_parsed.password)}"
160-
return make_headers(proxy_basic_auth=ap)
161-
162-
def _open(self):
163-
"""Initialize the connection pool."""
164-
pool_kwargs = {"maxsize": self.max_connections}
165-
166-
if self.scheme == "http":
167-
pool_class = HTTPConnectionPool
168-
else: # https
169-
pool_class = HTTPSConnectionPool
170-
pool_kwargs.update(
171-
{
172-
"cert_reqs": ssl.CERT_REQUIRED
173-
if self.ssl_options.tls_verify
174-
else ssl.CERT_NONE,
175-
"ca_certs": self.ssl_options.tls_trusted_ca_file,
176-
"cert_file": self.ssl_options.tls_client_cert_file,
177-
"key_file": self.ssl_options.tls_client_cert_key_file,
178-
"key_password": self.ssl_options.tls_client_cert_key_password,
179-
}
180-
)
181-
182-
if self.using_proxy():
183-
proxy_manager = ProxyManager(
184-
self.proxy_uri,
185-
num_pools=1,
186-
proxy_headers=self.proxy_auth,
187-
)
188-
self._pool = proxy_manager.connection_from_host(
189-
host=self.realhost,
190-
port=self.realport,
191-
scheme=self.scheme,
192-
pool_kwargs=pool_kwargs,
193-
)
194-
else:
195-
self._pool = pool_class(self.host, self.port, **pool_kwargs)
143+
retry_strategy = 0
144+
adapter = SSLContextAdapter(
145+
ssl_options=self.ssl_options,
146+
pool_connections=self.max_connections,
147+
max_retries=retry_strategy,
148+
)
149+
self._session.mount("https://", adapter)
150+
self._session.mount("http://", adapter)
196151

197152
def close(self):
198-
"""Close the connection pool."""
199-
if self._pool:
200-
self._pool.clear()
201-
202-
def using_proxy(self) -> bool:
203-
"""Check if proxy is being used (for compatibility with Thrift client)."""
204-
return self.realhost is not None
153+
self._session.close()
205154

206155
def set_retry_command_type(self, command_type: CommandType):
207-
"""Set the command type for retry policy decision making."""
208156
if isinstance(self.retry_policy, DatabricksRetryPolicy):
209157
self.retry_policy.command_type = command_type
210158

211159
def start_retry_timer(self):
212-
"""Start the retry timer for duration-based retry limits."""
213160
if isinstance(self.retry_policy, DatabricksRetryPolicy):
214161
self.retry_policy.start_retry_timer()
215162

216163
def _get_auth_headers(self) -> Dict[str, str]:
217-
"""Get authentication headers from the auth provider."""
218164
headers: Dict[str, str] = {}
219165
self.auth_provider.add_headers(headers)
220166
return headers
@@ -225,91 +171,51 @@ def _make_request(
225171
path: str,
226172
data: Optional[Dict[str, Any]] = None,
227173
) -> Dict[str, Any]:
228-
"""
229-
Make an HTTP request to the SEA endpoint.
230-
231-
Args:
232-
method: HTTP method (GET, POST, DELETE)
233-
path: API endpoint path
234-
data: Request payload data
235-
236-
Returns:
237-
Dict[str, Any]: Response data parsed from JSON
238-
239-
Raises:
240-
RequestError: If the request fails after retries
241-
"""
242-
243-
# Prepare headers
244-
headers = {**self.headers, **self._get_auth_headers()}
245-
246-
# Prepare request body
247-
body = json.dumps(data).encode("utf-8") if data else b""
248-
if body:
249-
headers["Content-Length"] = str(len(body))
250-
251-
# Set command type for retry policy
174+
full_url = urljoin(self.base_url, path)
175+
auth_headers = self._get_auth_headers()
252176
command_type = self._get_command_type_from_path(path, method)
253177
self.set_retry_command_type(command_type)
254178
self.start_retry_timer()
255-
256-
logger.debug(f"Making {method} request to {path}")
257-
258-
# When v3 retries are enabled, urllib3 handles retries internally via DatabricksRetryPolicy
259-
# When disabled, we let exceptions bubble up (similar to Thrift backend approach)
260-
if self._pool is None:
261-
raise RequestError("Connection pool not initialized", None)
262-
179+
logger.debug(f"Making {method} request to {full_url}")
263180
try:
264-
response = self._pool.request(
181+
with self._session.request(
265182
method=method.upper(),
266-
url=path,
267-
body=body,
268-
headers=headers,
269-
preload_content=False,
270-
retries=self.retry_policy,
271-
)
272-
except MaxRetryError as e:
273-
# urllib3 MaxRetryError should bubble up for redirect tests to catch
274-
logger.error(f"SEA HTTP request failed with MaxRetryError: {e}")
275-
raise
276-
except Exception as e:
277-
logger.error(f"SEA HTTP request failed with exception: {e}")
278-
error_message = f"Error during request to server. {e}"
279-
# Construct RequestError with proper 3-argument format (message, context, error)
183+
url=full_url,
184+
json=data,
185+
headers=auth_headers,
186+
) as response:
187+
logger.debug(f"Response status: {response.status_code}")
188+
response.raise_for_status()
189+
return response.json()
190+
except requests.exceptions.ConnectionError as e:
191+
# Check if the first argument of the ConnectionError is a MaxRetryError
192+
if e.args and isinstance(e.args[0], MaxRetryError):
193+
# We want to raise the original MaxRetryError, not the wrapper
194+
original_error = e.args[0]
195+
logger.error(
196+
f"SEA HTTP request failed with MaxRetryError: {original_error}"
197+
)
198+
raise original_error
199+
else:
200+
logger.error(f"SEA HTTP request failed with ConnectionError: {e}")
201+
raise RequestError("Error during request to server.", None, None, e)
202+
except RequestException as e:
203+
error_message = f"Error during request to server: {e}"
280204
raise RequestError(error_message, None, None, e)
281205

282-
logger.debug(f"Response status: {response.status}")
283-
284-
# Handle successful responses
285-
if 200 <= response.status < 300:
286-
return response.json()
287-
288-
error_message = f"SEA HTTP request failed with status {response.status}"
289-
290-
raise RequestError(error_message, None)
291-
292206
def _get_command_type_from_path(self, path: str, method: str) -> CommandType:
293-
"""
294-
Determine the command type based on the API path and method.
295-
296-
This helps the retry policy make appropriate decisions for different
297-
types of SEA operations.
298-
"""
299207
path = path.lower()
300208
method = method.upper()
301-
302209
if "/statements" in path:
303210
if method == "POST" and path.endswith("/statements"):
304211
return CommandType.EXECUTE_STATEMENT
305212
elif "/cancel" in path:
306-
return CommandType.OTHER # Cancel operation
213+
return CommandType.OTHER
307214
elif method == "DELETE":
308215
return CommandType.CLOSE_OPERATION
309216
elif method == "GET":
310217
return CommandType.GET_OPERATION_STATUS
311218
elif "/sessions" in path:
312219
if method == "DELETE":
313220
return CommandType.CLOSE_SESSION
314-
315221
return CommandType.OTHER

0 commit comments

Comments
 (0)