Skip to content

Commit 4c5bce1

Browse files
committed
Enhance authentication providers by implementing CredentialsProvider interface, adding auth_type and __call__ methods for AccessTokenAuthProvider, DatabricksOAuthProvider, and ExternalAuthProvider.
1 parent f1346b0 commit 4c5bce1

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

src/databricks/sql/auth/authenticators.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,25 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
4141

4242
# Private API: this is an evolving interface and it will change in the future.
4343
# Please must not depend on it in your applications.
44-
class AccessTokenAuthProvider(AuthProvider):
44+
class AccessTokenAuthProvider(AuthProvider, CredentialsProvider):
4545
def __init__(self, access_token: str):
4646
self.__authorization_header_value = "Bearer {}".format(access_token)
4747

4848
def add_headers(self, request_headers: Dict[str, str]):
4949
request_headers["Authorization"] = self.__authorization_header_value
5050

51+
def auth_type(self) -> str:
52+
return "access-token"
53+
54+
def __call__(self, *args, **kwargs) -> HeaderFactory:
55+
def get_headers():
56+
return {"Authorization": self.__authorization_header_value}
57+
return get_headers
58+
5159

5260
# Private API: this is an evolving interface and it will change in the future.
5361
# Please must not depend on it in your applications.
54-
class DatabricksOAuthProvider(AuthProvider):
62+
class DatabricksOAuthProvider(AuthProvider, CredentialsProvider):
5563
SCOPE_DELIM = " "
5664

5765
def __init__(
@@ -93,6 +101,15 @@ def add_headers(self, request_headers: Dict[str, str]):
93101
self._update_token_if_expired()
94102
request_headers["Authorization"] = f"Bearer {self._access_token}"
95103

104+
def auth_type(self) -> str:
105+
return "databricks-oauth"
106+
107+
def __call__(self, *args, **kwargs) -> HeaderFactory:
108+
def get_headers():
109+
self._update_token_if_expired()
110+
return {"Authorization": f"Bearer {self._access_token}"}
111+
return get_headers
112+
96113
def _initial_get_token(self):
97114
try:
98115
if self._access_token is None or self._refresh_token is None:
@@ -144,11 +161,17 @@ def _update_token_if_expired(self):
144161
raise e
145162

146163

147-
class ExternalAuthProvider(AuthProvider):
164+
class ExternalAuthProvider(AuthProvider, CredentialsProvider):
148165
def __init__(self, credentials_provider: CredentialsProvider) -> None:
149166
self._header_factory = credentials_provider()
150167

151168
def add_headers(self, request_headers: Dict[str, str]):
152169
headers = self._header_factory()
153170
for k, v in headers.items():
154171
request_headers[k] = v
172+
173+
def auth_type(self) -> str:
174+
return "external-auth"
175+
176+
def __call__(self, *args, **kwargs) -> HeaderFactory:
177+
return self._header_factory

0 commit comments

Comments
 (0)