Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* Implementation of TokenSource that handles OAuth token exchange for Databricks authentication.
* This class manages the OAuth token exchange flow using ID tokens to obtain access tokens.
*/
public class DatabricksOAuthTokenSource implements TokenSource {
public class DatabricksOAuthTokenSource extends RefreshableTokenSource {
private static final Logger LOG = LoggerFactory.getLogger(DatabricksOAuthTokenSource.class);

/** OAuth client ID used for token exchange. */
Expand Down Expand Up @@ -154,7 +154,7 @@ private static void validate(Object value, String fieldName) {
* parameters are missing.
*/
@Override
public Token getToken() {
public Token refresh() {
// Validate all required parameters
validate(clientId, "ClientID");
validate(host, "Host");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ public TokenSourceCredentialsProvider(TokenSource tokenSource, String authType)
@Override
public HeaderFactory configure(DatabricksConfig config) {
try {
// Validate that we can get a token before returning the HeaderFactory
String accessToken = tokenSource.getToken().getAccessToken();
// Validate that we can get a token before returning a HeaderFactory
tokenSource.getToken().getAccessToken();

return () -> {
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", "Bearer " + accessToken);
// Some TokenSource implementations cache tokens internally, so an additional getToken()
// call is not costly
headers.put("Authorization", "Bearer " + tokenSource.getToken().getAccessToken());
return headers;
};
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ void testTokenScenarios(
Map<String, String> headers = headerFactory.headers();
assertEquals(expectedAuthHeader, headers.get("Authorization"));
}
verify(mockTokenSource).getToken();

verify(mockTokenSource, atLeastOnce()).getToken();
assertEquals(TEST_AUTH_TYPE, provider.authType());
}

Expand Down
Loading