diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index c8ac65ba2..e642159c0 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -17,7 +17,7 @@ * Implementation of TokenSource that handles OAuth token exchange for Databricks authentication. * This class manages the OAuth token exchange flow using ID tokens to obtain access tokens. */ -public class DatabricksOAuthTokenSource implements TokenSource { +public class DatabricksOAuthTokenSource extends RefreshableTokenSource { private static final Logger LOG = LoggerFactory.getLogger(DatabricksOAuthTokenSource.class); /** OAuth client ID used for token exchange. */ @@ -154,7 +154,7 @@ private static void validate(Object value, String fieldName) { * parameters are missing. */ @Override - public Token getToken() { + public Token refresh() { // Validate all required parameters validate(clientId, "ClientID"); validate(host, "Host"); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java index 1428b4a5f..5b098d076 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java @@ -41,12 +41,14 @@ public TokenSourceCredentialsProvider(TokenSource tokenSource, String authType) @Override public HeaderFactory configure(DatabricksConfig config) { try { - // Validate that we can get a token before returning the HeaderFactory - String accessToken = tokenSource.getToken().getAccessToken(); + // Validate that we can get a token before returning a HeaderFactory + tokenSource.getToken().getAccessToken(); return () -> { Map headers = new HashMap<>(); - headers.put("Authorization", "Bearer " + accessToken); + // Some TokenSource implementations cache tokens internally, so an additional getToken() + // call is not costly + headers.put("Authorization", "Bearer " + tokenSource.getToken().getAccessToken()); return headers; }; } catch (Exception e) { diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProviderTest.java index 0b7661409..14eb3fa40 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProviderTest.java @@ -40,7 +40,8 @@ void testTokenScenarios( Map headers = headerFactory.headers(); assertEquals(expectedAuthHeader, headers.get("Authorization")); } - verify(mockTokenSource).getToken(); + + verify(mockTokenSource, atLeastOnce()).getToken(); assertEquals(TEST_AUTH_TYPE, provider.authType()); }