22import logging
33from typing import Optional , List
44from urllib .parse import urlparse
5- from databricks .sql .common .http import DatabricksHttpClient , HttpMethod
65
76logger = logging .getLogger (__name__ )
87
@@ -36,6 +35,21 @@ def __init__(
3635 tls_client_cert_file : Optional [str ] = None ,
3736 oauth_persistence = None ,
3837 credentials_provider = None ,
38+ # HTTP client configuration parameters
39+ ssl_options = None , # SSLOptions type
40+ socket_timeout : Optional [float ] = None ,
41+ retry_stop_after_attempts_count : Optional [int ] = None ,
42+ retry_delay_min : Optional [float ] = None ,
43+ retry_delay_max : Optional [float ] = None ,
44+ retry_stop_after_attempts_duration : Optional [float ] = None ,
45+ retry_delay_default : Optional [float ] = None ,
46+ retry_dangerous_codes : Optional [List [int ]] = None ,
47+ http_proxy : Optional [str ] = None ,
48+ proxy_username : Optional [str ] = None ,
49+ proxy_password : Optional [str ] = None ,
50+ pool_connections : Optional [int ] = None ,
51+ pool_maxsize : Optional [int ] = None ,
52+ user_agent : Optional [str ] = None ,
3953 ):
4054 self .hostname = hostname
4155 self .access_token = access_token
@@ -51,6 +65,22 @@ def __init__(
5165 self .tls_client_cert_file = tls_client_cert_file
5266 self .oauth_persistence = oauth_persistence
5367 self .credentials_provider = credentials_provider
68+
69+ # HTTP client configuration
70+ self .ssl_options = ssl_options
71+ self .socket_timeout = socket_timeout
72+ self .retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30
73+ self .retry_delay_min = retry_delay_min or 1.0
74+ self .retry_delay_max = retry_delay_max or 60.0
75+ self .retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0
76+ self .retry_delay_default = retry_delay_default or 5.0
77+ self .retry_dangerous_codes = retry_dangerous_codes or []
78+ self .http_proxy = http_proxy
79+ self .proxy_username = proxy_username
80+ self .proxy_password = proxy_password
81+ self .pool_connections = pool_connections or 10
82+ self .pool_maxsize = pool_maxsize or 20
83+ self .user_agent = user_agent
5484
5585
5686def get_effective_azure_login_app_id (hostname ) -> str :
@@ -69,7 +99,7 @@ def get_effective_azure_login_app_id(hostname) -> str:
6999 return AzureAppId .PROD .value [1 ]
70100
71101
72- def get_azure_tenant_id_from_host (host : str , http_client = None ) -> str :
102+ def get_azure_tenant_id_from_host (host : str , http_client ) -> str :
73103 """
74104 Load the Azure tenant ID from the Azure Databricks login page.
75105
@@ -78,23 +108,22 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
78108 the Azure login page, and the tenant ID is extracted from the redirect URL.
79109 """
80110
81- if http_client is None :
82- http_client = DatabricksHttpClient .get_instance ()
83-
84111 login_url = f"{ host } /aad/auth"
85112 logger .debug ("Loading tenant ID from %s" , login_url )
86- with http_client .execute (HttpMethod .GET , login_url , allow_redirects = False ) as resp :
87- if resp .status_code // 100 != 3 :
113+
114+ with http_client .request_context ('GET' , login_url , allow_redirects = False ) as resp :
115+ if resp .status // 100 != 3 :
88116 raise ValueError (
89- f"Failed to get tenant ID from { login_url } : expected status code 3xx, got { resp .status_code } "
117+ f"Failed to get tenant ID from { login_url } : expected status code 3xx, got { resp .status } "
90118 )
91- entra_id_endpoint = resp .headers .get ("Location" )
119+ entra_id_endpoint = dict ( resp .headers ) .get ("Location" )
92120 if entra_id_endpoint is None :
93121 raise ValueError (f"No Location header in response from { login_url } " )
94- # The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
95- # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
96- url = urlparse (entra_id_endpoint )
97- path_segments = url .path .split ("/" )
98- if len (path_segments ) < 2 :
99- raise ValueError (f"Invalid path in Location header: { url .path } " )
100- return path_segments [1 ]
122+
123+ # The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
124+ # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
125+ url = urlparse (entra_id_endpoint )
126+ path_segments = url .path .split ("/" )
127+ if len (path_segments ) < 2 :
128+ raise ValueError (f"Invalid path in Location header: { url .path } " )
129+ return path_segments [1 ]
0 commit comments