11import json
22import logging
33import ssl
4- import urllib .parse
54import urllib .request
65from typing import Dict , Any , Optional , List , Tuple , Union
76from urllib .parse import urljoin
87
9- from urllib3 import HTTPConnectionPool , HTTPSConnectionPool , ProxyManager
10- from urllib3 .util import make_headers
8+ import requests
9+ from requests .adapters import HTTPAdapter
10+ from requests .exceptions import RequestException , HTTPError , ConnectionError
1111from urllib3 .exceptions import MaxRetryError
1212
1313from databricks .sql .auth .authenticators import AuthProvider
2323logger = logging .getLogger (__name__ )
2424
2525
26- class SeaHttpClient :
26+ class SSLContextAdapter (HTTPAdapter ):
27+ """
28+ An HTTP adapter that uses a custom SSLContext to handle advanced SSL settings,
29+ including client certificate key passwords.
2730 """
28- HTTP client for Statement Execution API (SEA).
2931
30- This client uses urllib3 for robust HTTP communication with retry policies
31- and connection pooling, similar to the Thrift HTTP client but simplified.
32+ def __init__ (self , ssl_options : SSLOptions , ** kwargs ):
33+ self .ssl_context = self ._create_ssl_context (ssl_options )
34+ super ().__init__ (** kwargs )
35+
36+ def _create_ssl_context (self , ssl_options : SSLOptions ) -> ssl .SSLContext :
37+ """
38+ Build a custom SSLContext based on the provided SSLOptions.
39+ """
40+ context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
41+ if not ssl_options .tls_verify :
42+ context .check_hostname = False
43+ context .verify_mode = ssl .CERT_NONE
44+ elif ssl_options .tls_trusted_ca_file :
45+ context .load_verify_locations (cafile = ssl_options .tls_trusted_ca_file )
46+ if ssl_options .tls_client_cert_file :
47+ context .load_cert_chain (
48+ certfile = ssl_options .tls_client_cert_file ,
49+ keyfile = ssl_options .tls_client_cert_key_file ,
50+ password = ssl_options .tls_client_cert_key_password ,
51+ )
52+ return context
53+
54+ def init_poolmanager (self , * args , ** kwargs ):
55+ kwargs ["ssl_context" ] = self .ssl_context
56+ return super ().init_poolmanager (* args , ** kwargs )
57+
58+
59+ class SeaHttpClient :
60+ """
61+ HTTP client for Statement Execution API (SEA), using the requests library.
3262 """
3363
3464 retry_policy : Union [DatabricksRetryPolicy , int ]
35- _pool : Optional [Union [HTTPConnectionPool , HTTPSConnectionPool ]]
36- proxy_uri : Optional [str ]
37- realhost : Optional [str ]
38- realport : Optional [int ]
39- proxy_auth : Optional [Dict [str , str ]]
65+ _session : requests .Session
4066
4167 def __init__ (
4268 self ,
@@ -48,39 +74,16 @@ def __init__(
4874 ssl_options : SSLOptions ,
4975 ** kwargs ,
5076 ):
51- """
52- Initialize the SEA HTTP client.
53-
54- Args:
55- server_hostname: Hostname of the Databricks server
56- port: Port number for the connection
57- http_path: HTTP path for the connection
58- http_headers: List of HTTP headers to include in requests
59- auth_provider: Authentication provider
60- ssl_options: SSL configuration options
61- **kwargs: Additional keyword arguments including retry policy settings
62- """
63-
6477 self .server_hostname = server_hostname
6578 self .port = port or 443
66- self .http_path = http_path
6779 self .auth_provider = auth_provider
6880 self .ssl_options = ssl_options
69-
70- # Build base URL
71- self .base_url = f"https://{ server_hostname } :{ self .port } "
72-
73- # Parse URL for proxy handling
74- parsed_url = urllib .parse .urlparse (self .base_url )
75- self .scheme = parsed_url .scheme
76- self .host = parsed_url .hostname
77- self .port = parsed_url .port or (443 if self .scheme == "https" else 80 )
78-
79- # Setup headers
81+ self .scheme = "https"
82+ self .base_url = f"{ self .scheme } ://{ server_hostname } :{ self .port } "
83+ self ._session = requests .Session ()
8084 self .headers : Dict [str , str ] = dict (http_headers )
8185 self .headers .update ({"Content-Type" : "application/json" })
82-
83- # Extract retry policy settings
86+ self ._session .headers .update (self .headers )
8487 self ._retry_delay_min = kwargs .get ("_retry_delay_min" , 1.0 )
8588 self ._retry_delay_max = kwargs .get ("_retry_delay_max" , 60.0 )
8689 self ._retry_stop_after_attempts_count = kwargs .get (
@@ -91,23 +94,36 @@ def __init__(
9194 )
9295 self ._retry_delay_default = kwargs .get ("_retry_delay_default" , 5.0 )
9396 self .force_dangerous_codes = kwargs .get ("_retry_dangerous_codes" , [])
94-
95- # Connection pooling settings
9697 self .max_connections = kwargs .get ("max_connections" , 10 )
97-
98- # Setup retry policy
98+ self ._configure_proxies ()
9999 self .enable_v3_retries = kwargs .get ("_enable_v3_retries" , True )
100+ self ._configure_retries_and_ssl (** kwargs )
101+
102+ def _configure_proxies (self ):
103+ try :
104+ proxy = urllib .request .getproxies ().get (self .scheme )
105+ except (KeyError , AttributeError ):
106+ proxy = None
107+ else :
108+ if self .server_hostname and urllib .request .proxy_bypass (
109+ self .server_hostname
110+ ):
111+ proxy = None
112+ if proxy :
113+ self ._session .proxies = {"http" : proxy , "https" : proxy }
100114
115+ def _configure_retries_and_ssl (self , ** kwargs ):
101116 if self .enable_v3_retries :
102117 urllib3_kwargs = {"allowed_methods" : ["GET" , "POST" , "DELETE" ]}
103118 _max_redirects = kwargs .get ("_retry_max_redirects" )
104119 if _max_redirects :
105120 if _max_redirects > self ._retry_stop_after_attempts_count :
106121 logger .warning (
107- "_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!"
122+ "_retry_max_redirects > _retry_stop_after_attempts_count "
123+ "so it will have no effect!"
108124 )
109125 urllib3_kwargs ["redirect" ] = _max_redirects
110-
126+ self . _session . max_redirects = _max_redirects
111127 self .retry_policy = DatabricksRetryPolicy (
112128 delay_min = self ._retry_delay_min ,
113129 delay_max = self ._retry_delay_max ,
@@ -117,104 +133,34 @@ def __init__(
117133 force_dangerous_codes = self .force_dangerous_codes ,
118134 urllib3_kwargs = urllib3_kwargs ,
119135 )
136+ retry_strategy = self .retry_policy
120137 else :
121- # Legacy behavior - no automatic retries
122138 logger .warning (
123139 "Legacy retry behavior is enabled for this connection."
124140 " This behaviour is not supported for the SEA backend."
125141 )
126142 self .retry_policy = 0
127-
128- # Handle proxy settings
129- try :
130- proxy = urllib .request .getproxies ().get (self .scheme )
131- except (KeyError , AttributeError ):
132- proxy = None
133- else :
134- if self .host and urllib .request .proxy_bypass (self .host ):
135- proxy = None
136-
137- if proxy :
138- parsed_proxy = urllib .parse .urlparse (proxy )
139- self .realhost = self .host
140- self .realport = self .port
141- self .proxy_uri = proxy
142- self .host = parsed_proxy .hostname
143- self .port = parsed_proxy .port or (443 if self .scheme == "https" else 80 )
144- self .proxy_auth = self ._basic_proxy_auth_headers (parsed_proxy )
145- else :
146- self .realhost = None
147- self .realport = None
148- self .proxy_auth = None
149- self .proxy_uri = None
150-
151- # Initialize connection pool
152- self ._pool = None
153- self ._open ()
154-
155- def _basic_proxy_auth_headers (self , proxy_parsed ) -> Optional [Dict [str , str ]]:
156- """Create basic auth headers for proxy if credentials are provided."""
157- if proxy_parsed is None or not proxy_parsed .username :
158- return None
159- ap = f"{ urllib .parse .unquote (proxy_parsed .username )} :{ urllib .parse .unquote (proxy_parsed .password )} "
160- return make_headers (proxy_basic_auth = ap )
161-
162- def _open (self ):
163- """Initialize the connection pool."""
164- pool_kwargs = {"maxsize" : self .max_connections }
165-
166- if self .scheme == "http" :
167- pool_class = HTTPConnectionPool
168- else : # https
169- pool_class = HTTPSConnectionPool
170- pool_kwargs .update (
171- {
172- "cert_reqs" : ssl .CERT_REQUIRED
173- if self .ssl_options .tls_verify
174- else ssl .CERT_NONE ,
175- "ca_certs" : self .ssl_options .tls_trusted_ca_file ,
176- "cert_file" : self .ssl_options .tls_client_cert_file ,
177- "key_file" : self .ssl_options .tls_client_cert_key_file ,
178- "key_password" : self .ssl_options .tls_client_cert_key_password ,
179- }
180- )
181-
182- if self .using_proxy ():
183- proxy_manager = ProxyManager (
184- self .proxy_uri ,
185- num_pools = 1 ,
186- proxy_headers = self .proxy_auth ,
187- )
188- self ._pool = proxy_manager .connection_from_host (
189- host = self .realhost ,
190- port = self .realport ,
191- scheme = self .scheme ,
192- pool_kwargs = pool_kwargs ,
193- )
194- else :
195- self ._pool = pool_class (self .host , self .port , ** pool_kwargs )
143+ retry_strategy = 0
144+ adapter = SSLContextAdapter (
145+ ssl_options = self .ssl_options ,
146+ pool_connections = self .max_connections ,
147+ max_retries = retry_strategy ,
148+ )
149+ self ._session .mount ("https://" , adapter )
150+ self ._session .mount ("http://" , adapter )
196151
197152 def close (self ):
198- """Close the connection pool."""
199- if self ._pool :
200- self ._pool .clear ()
201-
202- def using_proxy (self ) -> bool :
203- """Check if proxy is being used (for compatibility with Thrift client)."""
204- return self .realhost is not None
153+ self ._session .close ()
205154
206155 def set_retry_command_type (self , command_type : CommandType ):
207- """Set the command type for retry policy decision making."""
208156 if isinstance (self .retry_policy , DatabricksRetryPolicy ):
209157 self .retry_policy .command_type = command_type
210158
211159 def start_retry_timer (self ):
212- """Start the retry timer for duration-based retry limits."""
213160 if isinstance (self .retry_policy , DatabricksRetryPolicy ):
214161 self .retry_policy .start_retry_timer ()
215162
216163 def _get_auth_headers (self ) -> Dict [str , str ]:
217- """Get authentication headers from the auth provider."""
218164 headers : Dict [str , str ] = {}
219165 self .auth_provider .add_headers (headers )
220166 return headers
@@ -225,91 +171,51 @@ def _make_request(
225171 path : str ,
226172 data : Optional [Dict [str , Any ]] = None ,
227173 ) -> Dict [str , Any ]:
228- """
229- Make an HTTP request to the SEA endpoint.
230-
231- Args:
232- method: HTTP method (GET, POST, DELETE)
233- path: API endpoint path
234- data: Request payload data
235-
236- Returns:
237- Dict[str, Any]: Response data parsed from JSON
238-
239- Raises:
240- RequestError: If the request fails after retries
241- """
242-
243- # Prepare headers
244- headers = {** self .headers , ** self ._get_auth_headers ()}
245-
246- # Prepare request body
247- body = json .dumps (data ).encode ("utf-8" ) if data else b""
248- if body :
249- headers ["Content-Length" ] = str (len (body ))
250-
251- # Set command type for retry policy
174+ full_url = urljoin (self .base_url , path )
175+ auth_headers = self ._get_auth_headers ()
252176 command_type = self ._get_command_type_from_path (path , method )
253177 self .set_retry_command_type (command_type )
254178 self .start_retry_timer ()
255-
256- logger .debug (f"Making { method } request to { path } " )
257-
258- # When v3 retries are enabled, urllib3 handles retries internally via DatabricksRetryPolicy
259- # When disabled, we let exceptions bubble up (similar to Thrift backend approach)
260- if self ._pool is None :
261- raise RequestError ("Connection pool not initialized" , None )
262-
179+ logger .debug (f"Making { method } request to { full_url } " )
263180 try :
264- response = self ._pool .request (
181+ with self ._session .request (
265182 method = method .upper (),
266- url = path ,
267- body = body ,
268- headers = headers ,
269- preload_content = False ,
270- retries = self .retry_policy ,
271- )
272- except MaxRetryError as e :
273- # urllib3 MaxRetryError should bubble up for redirect tests to catch
274- logger .error (f"SEA HTTP request failed with MaxRetryError: { e } " )
275- raise
276- except Exception as e :
277- logger .error (f"SEA HTTP request failed with exception: { e } " )
278- error_message = f"Error during request to server. { e } "
279- # Construct RequestError with proper 3-argument format (message, context, error)
183+ url = full_url ,
184+ json = data ,
185+ headers = auth_headers ,
186+ ) as response :
187+ logger .debug (f"Response status: { response .status_code } " )
188+ response .raise_for_status ()
189+ return response .json ()
190+ except requests .exceptions .ConnectionError as e :
191+ # Check if the first argument of the ConnectionError is a MaxRetryError
192+ if e .args and isinstance (e .args [0 ], MaxRetryError ):
193+ # We want to raise the original MaxRetryError, not the wrapper
194+ original_error = e .args [0 ]
195+ logger .error (
196+ f"SEA HTTP request failed with MaxRetryError: { original_error } "
197+ )
198+ raise original_error
199+ else :
200+ logger .error (f"SEA HTTP request failed with ConnectionError: { e } " )
201+ raise RequestError ("Error during request to server." , None , None , e )
202+ except RequestException as e :
203+ error_message = f"Error during request to server: { e } "
280204 raise RequestError (error_message , None , None , e )
281205
282- logger .debug (f"Response status: { response .status } " )
283-
284- # Handle successful responses
285- if 200 <= response .status < 300 :
286- return response .json ()
287-
288- error_message = f"SEA HTTP request failed with status { response .status } "
289-
290- raise RequestError (error_message , None )
291-
292206 def _get_command_type_from_path (self , path : str , method : str ) -> CommandType :
293- """
294- Determine the command type based on the API path and method.
295-
296- This helps the retry policy make appropriate decisions for different
297- types of SEA operations.
298- """
299207 path = path .lower ()
300208 method = method .upper ()
301-
302209 if "/statements" in path :
303210 if method == "POST" and path .endswith ("/statements" ):
304211 return CommandType .EXECUTE_STATEMENT
305212 elif "/cancel" in path :
306- return CommandType .OTHER # Cancel operation
213+ return CommandType .OTHER
307214 elif method == "DELETE" :
308215 return CommandType .CLOSE_OPERATION
309216 elif method == "GET" :
310217 return CommandType .GET_OPERATION_STATUS
311218 elif "/sessions" in path :
312219 if method == "DELETE" :
313220 return CommandType .CLOSE_SESSION
314-
315221 return CommandType .OTHER
0 commit comments