22import logging
33from typing import Optional , List
44from urllib .parse import urlparse
5- from databricks .sql .common .http import DatabricksHttpClient , HttpMethod
5+ from databricks .sql .auth .retry import DatabricksRetryPolicy
6+ from databricks .sql .common .http import HttpMethod
67
78logger = logging .getLogger (__name__ )
89
@@ -36,6 +37,21 @@ def __init__(
3637 tls_client_cert_file : Optional [str ] = None ,
3738 oauth_persistence = None ,
3839 credentials_provider = None ,
40+ # HTTP client configuration parameters
41+ ssl_options = None , # SSLOptions type
42+ socket_timeout : Optional [float ] = None ,
43+ retry_stop_after_attempts_count : Optional [int ] = None ,
44+ retry_delay_min : Optional [float ] = None ,
45+ retry_delay_max : Optional [float ] = None ,
46+ retry_stop_after_attempts_duration : Optional [float ] = None ,
47+ retry_delay_default : Optional [float ] = None ,
48+ retry_dangerous_codes : Optional [List [int ]] = None ,
49+ http_proxy : Optional [str ] = None ,
50+ proxy_username : Optional [str ] = None ,
51+ proxy_password : Optional [str ] = None ,
52+ pool_connections : Optional [int ] = None ,
53+ pool_maxsize : Optional [int ] = None ,
54+ user_agent : Optional [str ] = None ,
3955 ):
4056 self .hostname = hostname
4157 self .access_token = access_token
@@ -52,6 +68,24 @@ def __init__(
5268 self .oauth_persistence = oauth_persistence
5369 self .credentials_provider = credentials_provider
5470
71+ # HTTP client configuration
72+ self .ssl_options = ssl_options
73+ self .socket_timeout = socket_timeout
74+ self .retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5
75+ self .retry_delay_min = retry_delay_min or 1.0
76+ self .retry_delay_max = retry_delay_max or 10.0
77+ self .retry_stop_after_attempts_duration = (
78+ retry_stop_after_attempts_duration or 300.0
79+ )
80+ self .retry_delay_default = retry_delay_default or 5.0
81+ self .retry_dangerous_codes = retry_dangerous_codes or []
82+ self .http_proxy = http_proxy
83+ self .proxy_username = proxy_username
84+ self .proxy_password = proxy_password
85+ self .pool_connections = pool_connections or 10
86+ self .pool_maxsize = pool_maxsize or 20
87+ self .user_agent = user_agent
88+
5589
5690def get_effective_azure_login_app_id (hostname ) -> str :
5791 """
@@ -69,7 +103,7 @@ def get_effective_azure_login_app_id(hostname) -> str:
69103 return AzureAppId .PROD .value [1 ]
70104
71105
72- def get_azure_tenant_id_from_host (host : str , http_client = None ) -> str :
106+ def get_azure_tenant_id_from_host (host : str , http_client ) -> str :
73107 """
74108 Load the Azure tenant ID from the Azure Databricks login page.
75109
@@ -78,23 +112,24 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
78112 the Azure login page, and the tenant ID is extracted from the redirect URL.
79113 """
80114
81- if http_client is None :
82- http_client = DatabricksHttpClient .get_instance ()
83-
84115 login_url = f"{ host } /aad/auth"
85116 logger .debug ("Loading tenant ID from %s" , login_url )
86- with http_client .execute (HttpMethod .GET , login_url , allow_redirects = False ) as resp :
87- if resp .status_code // 100 != 3 :
117+
118+ with http_client .request_context (
119+ HttpMethod .GET , login_url , allow_redirects = False
120+ ) as resp :
121+ if resp .status // 100 != 3 :
88122 raise ValueError (
89- f"Failed to get tenant ID from { login_url } : expected status code 3xx, got { resp .status_code } "
123+ f"Failed to get tenant ID from { login_url } : expected status code 3xx, got { resp .status } "
90124 )
91- entra_id_endpoint = resp .headers .get ("Location" )
125+ entra_id_endpoint = dict ( resp .headers ) .get ("Location" )
92126 if entra_id_endpoint is None :
93127 raise ValueError (f"No Location header in response from { login_url } " )
94- # The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
95- # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
96- url = urlparse (entra_id_endpoint )
97- path_segments = url .path .split ("/" )
98- if len (path_segments ) < 2 :
99- raise ValueError (f"Invalid path in Location header: { url .path } " )
100- return path_segments [1 ]
128+
129+ # The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
130+ # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
131+ url = urlparse (entra_id_endpoint )
132+ path_segments = url .path .split ("/" )
133+ if len (path_segments ) < 2 :
134+ raise ValueError (f"Invalid path in Location header: { url .path } " )
135+ return path_segments [1 ]
0 commit comments