diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index a835be1fa..8c1b31ab6 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,7 +3,8 @@ ## Release v0.46.0 ### New Features and Improvements - + * Added `TokenCache` to `ExternalBrowserCredentialsProvider` to reduce number of authentications needed for U2M OAuth. + ### Bug Fixes ### Documentation diff --git a/databricks-sdk-java/pom.xml b/databricks-sdk-java/pom.xml index 69c5d2f28..2ef261200 100644 --- a/databricks-sdk-java/pom.xml +++ b/databricks-sdk-java/pom.xml @@ -97,5 +97,11 @@ google-auth-library-oauth2-http 1.20.0 + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + ${jackson.version} + diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index 2943ce82e..fa89f5041 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -669,4 +669,14 @@ public DatabricksConfig newWithWorkspaceHost(String host) { "headerFactory")); return clone(fieldsToSkip).setHost(host); } + + /** + * Gets the default OAuth redirect URL. If one is not provided explicitly, uses + * http://localhost:8080/callback + * + * @return The OAuth redirect URL to use + */ + public String getEffectiveOAuthRedirectUrl() { + return redirectUrl != null ? redirectUrl : "http://localhost:8080/callback"; + } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java index f054f06ed..b8aa4c66f 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java @@ -5,12 +5,38 @@ import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.HeaderFactory; import java.io.IOException; +import java.nio.file.Path; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A {@code CredentialsProvider} which implements the Authorization Code + PKCE flow by opening a - * browser for the user to authorize the application. + * browser for the user to authorize the application. Uses a specified TokenCache or creates a + * default one if none is provided. */ public class ExternalBrowserCredentialsProvider implements CredentialsProvider { + private static final Logger LOGGER = + LoggerFactory.getLogger(ExternalBrowserCredentialsProvider.class); + + private TokenCache tokenCache; + + /** + * Creates a new ExternalBrowserCredentialsProvider with the specified TokenCache. + * + * @param tokenCache the TokenCache to use for caching tokens + */ + public ExternalBrowserCredentialsProvider(TokenCache tokenCache) { + this.tokenCache = tokenCache; + } + + /** + * Creates a new ExternalBrowserCredentialsProvider with a default TokenCache. A FileTokenCache + * will be created when credentials are configured. + */ + public ExternalBrowserCredentialsProvider() { + this(null); + } @Override public String authType() { @@ -19,16 +45,87 @@ public String authType() { @Override public HeaderFactory configure(DatabricksConfig config) { - if (config.getHost() == null || config.getAuthType() != "external-browser") { + if (config.getHost() == null || !Objects.equals(config.getAuthType(), "external-browser")) { return null; } + + // Use the utility class to resolve client ID and client secret + String clientId = OAuthClientUtils.resolveClientId(config); + String clientSecret = OAuthClientUtils.resolveClientSecret(config); + try { - OAuthClient client = new OAuthClient(config); - Consent consent = client.initiateConsent(); - SessionCredentials creds = consent.launchExternalBrowser(); - return creds.configure(config); + if (tokenCache == null) { + // Create a default FileTokenCache based on config + Path cachePath = + TokenCacheUtils.getCacheFilePath(config.getHost(), clientId, config.getScopes()); + tokenCache = new FileTokenCache(cachePath); + } + + // First try to use the cached token if available (will return null if disabled) + Token cachedToken = tokenCache.load(); + if (cachedToken != null && cachedToken.getRefreshToken() != null) { + LOGGER.debug("Found cached token for {}:{}", config.getHost(), clientId); + + try { + // Create SessionCredentials with the cached token and try to refresh if needed + SessionCredentials cachedCreds = + new SessionCredentials.Builder() + .withToken(cachedToken) + .withHttpClient(config.getHttpClient()) + .withClientId(clientId) + .withClientSecret(clientSecret) + .withTokenUrl(config.getOidcEndpoints().getTokenEndpoint()) + .withRedirectUrl(config.getEffectiveOAuthRedirectUrl()) + .withTokenCache(tokenCache) + .build(); + + LOGGER.debug("Using cached token, will immediately refresh"); + cachedCreds.token = cachedCreds.refresh(); + return cachedCreds.configure(config); + } catch (Exception e) { + // If token refresh fails, log and continue to browser auth + LOGGER.info("Token refresh failed: {}, falling back to browser auth", e.getMessage()); + } + } + + // If no cached token or refresh failed, perform browser auth + SessionCredentials credentials = + performBrowserAuth(config, clientId, clientSecret, tokenCache); + tokenCache.save(credentials.getToken()); + return credentials.configure(config); } catch (IOException | DatabricksException e) { + LOGGER.error("Failed to authenticate: {}", e.getMessage()); return null; } } + + SessionCredentials performBrowserAuth( + DatabricksConfig config, String clientId, String clientSecret, TokenCache tokenCache) + throws IOException { + LOGGER.debug("Performing browser authentication"); + OAuthClient client = + new OAuthClient.Builder() + .withHttpClient(config.getHttpClient()) + .withClientId(clientId) + .withClientSecret(clientSecret) + .withHost(config.getHost()) + .withRedirectUrl(config.getEffectiveOAuthRedirectUrl()) + .withScopes(config.getScopes()) + .build(); + Consent consent = client.initiateConsent(); + + // Use the existing browser flow to get credentials + SessionCredentials credentials = consent.launchExternalBrowser(); + + // Create a new SessionCredentials with the same token but with our token cache + return new SessionCredentials.Builder() + .withToken(credentials.getToken()) + .withHttpClient(config.getHttpClient()) + .withClientId(config.getClientId()) + .withClientSecret(config.getClientSecret()) + .withTokenUrl(config.getOidcEndpoints().getTokenEndpoint()) + .withRedirectUrl(config.getEffectiveOAuthRedirectUrl()) + .withTokenCache(tokenCache) + .build(); + } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/FileTokenCache.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/FileTokenCache.java new file mode 100644 index 000000000..ff62ca835 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/FileTokenCache.java @@ -0,0 +1,77 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.utils.SerDeUtils; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A TokenCache implementation that stores tokens as plain files. */ +public class FileTokenCache implements TokenCache { + private static final Logger LOGGER = LoggerFactory.getLogger(FileTokenCache.class); + + private final Path cacheFile; + private final ObjectMapper mapper; + + /** + * Constructs a new SimpleFileTokenCache instance. + * + * @param cacheFilePath The path where the token cache will be stored + */ + public FileTokenCache(Path cacheFilePath) { + Objects.requireNonNull(cacheFilePath, "cacheFilePath must be defined"); + + this.cacheFile = cacheFilePath; + this.mapper = SerDeUtils.createMapper(); + } + + @Override + public void save(Token token) { + try { + Files.createDirectories(cacheFile.getParent()); + + // Serialize token to JSON + String json = mapper.writeValueAsString(token); + byte[] dataToWrite = json.getBytes(StandardCharsets.UTF_8); + + Files.write(cacheFile, dataToWrite); + // Set file permissions to be readable only by the owner (equivalent to 0600) + File file = cacheFile.toFile(); + file.setReadable(false, false); + file.setReadable(true, true); + file.setWritable(false, false); + file.setWritable(true, true); + + LOGGER.debug("Successfully saved token to cache: {}", cacheFile); + } catch (Exception e) { + LOGGER.warn("Failed to save token to cache: {}", cacheFile, e); + } + } + + @Override + public Token load() { + try { + if (!Files.exists(cacheFile)) { + LOGGER.debug("No token cache file found at: {}", cacheFile); + return null; + } + + byte[] fileContent = Files.readAllBytes(cacheFile); + + // Deserialize token from JSON + String json = new String(fileContent, StandardCharsets.UTF_8); + Token token = mapper.readValue(json, Token.class); + LOGGER.debug("Successfully loaded token from cache: {}", cacheFile); + return token; + } catch (Exception e) { + // If there's any issue loading the token, return null + // to allow a fresh token to be obtained + LOGGER.warn("Failed to load token from cache: {}", e.getMessage()); + return null; + } + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClient.java index 6f4b25996..cf65ba71a 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClient.java @@ -85,20 +85,6 @@ public OAuthClient build() throws IOException { private final boolean isAws; private final boolean isAzure; - public OAuthClient(DatabricksConfig config) throws IOException { - this( - new Builder() - .withHttpClient(config.getHttpClient()) - .withClientId(config.getClientId()) - .withClientSecret(config.getClientSecret()) - .withHost(config.getHost()) - .withRedirectUrl( - config.getOAuthRedirectUrl() != null - ? config.getOAuthRedirectUrl() - : "http://localhost:8080/callback") - .withScopes(config.getScopes())); - } - private OAuthClient(Builder b) throws IOException { this.clientId = Objects.requireNonNull(b.clientId); this.clientSecret = b.clientSecret; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClientUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClientUtils.java new file mode 100644 index 000000000..5908eff79 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthClientUtils.java @@ -0,0 +1,42 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksConfig; + +/** Utility methods for OAuth client credentials resolution. */ +public class OAuthClientUtils { + + /** Default client ID to use when no client ID is specified. */ + private static final String DEFAULT_CLIENT_ID = "databricks-cli"; + + /** + * Resolves the OAuth client ID from the configuration. Prioritizes regular OAuth client ID, then + * Azure client ID, and falls back to default client ID. + * + * @param config The Databricks configuration + * @return The resolved client ID + */ + public static String resolveClientId(DatabricksConfig config) { + if (config.getClientId() != null) { + return config.getClientId(); + } else if (config.getAzureClientId() != null) { + return config.getAzureClientId(); + } + return DEFAULT_CLIENT_ID; + } + + /** + * Resolves the OAuth client secret from the configuration. Prioritizes regular OAuth client + * secret, then Azure client secret. + * + * @param config The Databricks configuration + * @return The resolved client secret, or null if not present + */ + public static String resolveClientSecret(DatabricksConfig config) { + if (config.getClientSecret() != null) { + return config.getClientSecret(); + } else if (config.getAzureClientSecret() != null) { + return config.getAzureClientSecret(); + } + return null; + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java index aee173ea4..9114b6d6c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java @@ -9,6 +9,8 @@ import java.util.HashMap; import java.util.Map; import org.apache.http.HttpHeaders; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * An implementation of RefreshableTokenSource implementing the refresh_token OAuth grant type. @@ -20,6 +22,7 @@ public class SessionCredentials extends RefreshableTokenSource implements CredentialsProvider, Serializable { private static final long serialVersionUID = 3083941540130596650L; + private static final Logger LOGGER = LoggerFactory.getLogger(SessionCredentials.class); @Override public String authType() { @@ -43,6 +46,7 @@ static class Builder { private String redirectUrl; private String clientId; private String clientSecret; + private TokenCache tokenCache; public Builder withHttpClient(HttpClient hc) { this.hc = hc; @@ -74,6 +78,11 @@ public Builder withClientSecret(String clientSecret) { return this; } + public Builder withTokenCache(TokenCache tokenCache) { + this.tokenCache = tokenCache; + return this; + } + public SessionCredentials build() { return new SessionCredentials(this); } @@ -84,6 +93,7 @@ public SessionCredentials build() { private final String redirectUrl; private final String clientId; private final String clientSecret; + private final TokenCache tokenCache; private SessionCredentials(Builder b) { super(b.token); @@ -92,6 +102,7 @@ private SessionCredentials(Builder b) { this.redirectUrl = b.redirectUrl; this.clientId = b.clientId; this.clientSecret = b.clientSecret; + this.tokenCache = b.tokenCache; } @Override @@ -113,7 +124,15 @@ protected Token refresh() { // cross-origin requests headers.put("Origin", redirectUrl); } - return retrieveToken( - hc, clientId, clientSecret, tokenUrl, params, headers, AuthParameterPosition.BODY); + Token newToken = + retrieveToken( + hc, clientId, clientSecret, tokenUrl, params, headers, AuthParameterPosition.BODY); + + // Save the refreshed token directly to cache + if (tokenCache != null) { + tokenCache.save(newToken); + LOGGER.debug("Saved refreshed token to cache"); + } + return newToken; } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java index 153abe066..f0fd72f68 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/Token.java @@ -2,6 +2,7 @@ import com.databricks.sdk.core.utils.ClockSupplier; import com.databricks.sdk.core.utils.SystemClockSupplier; +import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import java.time.LocalDateTime; import java.time.temporal.ChronoUnit; @@ -37,7 +38,12 @@ public Token( } /** Constructor for refreshable tokens. */ - public Token(String accessToken, String tokenType, String refreshToken, LocalDateTime expiry) { + @JsonCreator + public Token( + @JsonProperty("accessToken") String accessToken, + @JsonProperty("tokenType") String tokenType, + @JsonProperty("refreshToken") String refreshToken, + @JsonProperty("expiry") LocalDateTime expiry) { this(accessToken, tokenType, refreshToken, expiry, new SystemClockSupplier()); } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenCache.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenCache.java new file mode 100644 index 000000000..ed8fb533b --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenCache.java @@ -0,0 +1,21 @@ +package com.databricks.sdk.core.oauth; + +/** + * TokenCache interface for storing and retrieving OAuth tokens. Implementations can use different + * storage mechanisms and security approaches. + */ +public interface TokenCache { + /** + * Saves a Token to the cache. + * + * @param token The Token to save + */ + void save(Token token); + + /** + * Loads a Token from the cache. + * + * @return The Token from the cache or null if the cache doesn't exist or is invalid + */ + Token load(); +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenCacheUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenCacheUtils.java new file mode 100644 index 000000000..aaaeca4c8 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenCacheUtils.java @@ -0,0 +1,49 @@ +package com.databricks.sdk.core.oauth; + +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.List; + +/** Utility methods for TokenCache implementations. */ +public class TokenCacheUtils { + // Base path for token cache files + private static final String BASE_PATH = ".config/databricks-sdk-java/oauth"; + + /** + * Returns the path to the cache file for the given configuration. The filename is based on a hash + * of the host, client ID, and scopes. + * + * @param host The Databricks host URL + * @param clientId The OAuth client ID + * @param scopes The OAuth scopes requested + * @return The path to the token cache file + */ + public static Path getCacheFilePath(String host, String clientId, List scopes) { + try { + // Create SHA-256 hash of host, client_id, and scopes + MessageDigest hash = MessageDigest.getInstance("SHA-256"); + for (String chunk : new String[] {host, clientId, String.join(",", scopes)}) { + hash.update(chunk.getBytes(StandardCharsets.UTF_8)); + } + + // Convert hash bytes to hexadecimal string + StringBuilder hexString = new StringBuilder(); + for (byte b : hash.digest()) { + String hex = Integer.toHexString(0xff & b); + if (hex.length() == 1) { + hexString.append('0'); + } + hexString.append(hex); + } + + String userHome = System.getProperty("user.home"); + Path basePath = Paths.get(userHome, BASE_PATH); + return basePath.resolve(hexString.toString()); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("Failed to create hash for token cache filename", e); + } + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SerDeUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SerDeUtils.java index 916eafd23..9730b135e 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SerDeUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/SerDeUtils.java @@ -4,12 +4,14 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; /** Utilities for serialization and deserialization in the Databricks Java SDK. */ public class SerDeUtils { public static ObjectMapper createMapper() { ObjectMapper mapper = new ObjectMapper(); mapper + .registerModule(new JavaTimeModule()) .configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false) .configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false) .configure(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT, true) diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java index 1551045a7..1714b731c 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProviderTest.java @@ -6,6 +6,7 @@ import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.FixtureServer; +import com.databricks.sdk.core.HeaderFactory; import com.databricks.sdk.core.commons.CommonsHttpClient; import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Request; @@ -17,6 +18,7 @@ import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; public class ExternalBrowserCredentialsProviderTest { @@ -40,7 +42,15 @@ void clientAndConsentTest() throws IOException { assertEquals("tokenEndPointFromServer", config.getOidcEndpoints().getTokenEndpoint()); - OAuthClient testClient = new OAuthClient(config); + OAuthClient testClient = + new OAuthClient.Builder() + .withHttpClient(config.getHttpClient()) + .withClientId(config.getClientId()) + .withClientSecret(config.getClientSecret()) + .withHost(config.getHost()) + .withRedirectUrl(config.getEffectiveOAuthRedirectUrl()) + .withScopes(config.getScopes()) + .build(); assertEquals("test-client-id", testClient.getClientId()); Consent testConsent = testClient.initiateConsent(); @@ -77,7 +87,15 @@ void clientAndConsentTestWithCustomRedirectUrl() throws IOException { assertEquals("tokenEndPointFromServer", config.getOidcEndpoints().getTokenEndpoint()); - OAuthClient testClient = new OAuthClient(config); + OAuthClient testClient = + new OAuthClient.Builder() + .withHttpClient(config.getHttpClient()) + .withClientId(config.getClientId()) + .withClientSecret(config.getClientSecret()) + .withHost(config.getHost()) + .withRedirectUrl(config.getEffectiveOAuthRedirectUrl()) + .withScopes(config.getScopes()) + .build(); assertEquals("test-client-id", testClient.getClientId()); Consent testConsent = testClient.initiateConsent(); @@ -201,4 +219,297 @@ void sessionCredentials() throws IOException { assertEquals("accessTokenFromServer", token.getAccessToken()); assertEquals("refreshTokenFromServer", token.getRefreshToken()); } + + // Token caching tests + + @Test + void cacheWithValidTokenTest() throws IOException { + // Create mock HTTP client for token refresh + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + String refreshResponse = + "{\"access_token\": \"refreshed_access_token\", \"token_type\": \"Bearer\", \"expires_in\": \"3600\", \"refresh_token\": \"new_refresh_token\"}"; + URL url = new URL("https://test.databricks.com/"); + Mockito.doReturn(new Response(refreshResponse, url)) + .when(mockHttpClient) + .execute(any(Request.class)); + + // Create an valid token with valid refresh token + LocalDateTime futureTime = LocalDateTime.now().plusHours(1); + Token validToken = new Token("valid_access_token", "Bearer", "valid_refresh_token", futureTime); + + // Create mock token cache that returns the valid token + TokenCache mockTokenCache = Mockito.mock(TokenCache.class); + Mockito.doReturn(validToken).when(mockTokenCache).load(); + + // Create config with HTTP client and mock token cache + DatabricksConfig config = + new DatabricksConfig() + .setAuthType("external-browser") + .setHost("https://test.databricks.com") + .setClientId("test-client-id") + .setHttpClient(mockHttpClient); + + // We need to provide OIDC endpoints for token refresh + OpenIDConnectEndpoints endpoints = + new OpenIDConnectEndpoints( + "https://test.databricks.com/token", "https://test.databricks.com/authorize"); + + // Create our provider with the mock token cache and mock the browser auth method + ExternalBrowserCredentialsProvider provider = + Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache)); + + // Spy on the config to inject the endpoints + DatabricksConfig spyConfig = Mockito.spy(config); + Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints(); + + // Configure provider + HeaderFactory headerFactory = provider.configure(spyConfig); + + // Verify headers contain the refreshed token even though the cached token is valid + Map headers = headerFactory.headers(); + assertEquals("Bearer refreshed_access_token", headers.get("Authorization")); + + // Verify token was loaded from cache + Mockito.verify(mockTokenCache, Mockito.times(1)).load(); + + // Verify HTTP call was made to refresh the token + Mockito.verify(mockHttpClient, Mockito.times(1)).execute(any(Request.class)); + + // Verify performBrowserAuth was NOT called since refresh succeeded + Mockito.verify(provider, Mockito.never()) + .performBrowserAuth( + any(DatabricksConfig.class), + any(String.class), + any(String.class), + any(TokenCache.class)); + + // Verify token was saved back to cache + Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class)); + + // Capture the token that was saved to cache to verify it's the refreshed token + ArgumentCaptor tokenCaptor = ArgumentCaptor.forClass(Token.class); + Mockito.verify(mockTokenCache).save(tokenCaptor.capture()); + Token savedToken = tokenCaptor.getValue(); + + // Verify the saved token contains the refreshed values from the HTTP response + assertEquals( + "refreshed_access_token", + savedToken.getAccessToken(), + "Should save refreshed access token to cache"); + assertEquals( + "new_refresh_token", + savedToken.getRefreshToken(), + "Should save new refresh token to cache"); + } + + @Test + void cacheWithInvalidAccessTokenValidRefreshTest() throws IOException { + // Create mock HTTP client for token refresh + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + String refreshResponse = + "{\"access_token\": \"refreshed_access_token\", \"token_type\": \"Bearer\", \"expires_in\": \"3600\", \"refresh_token\": \"new_refresh_token\"}"; + URL url = new URL("https://test.databricks.com/"); + Mockito.doReturn(new Response(refreshResponse, url)) + .when(mockHttpClient) + .execute(any(Request.class)); + + // Create an expired token with valid refresh token + LocalDateTime pastTime = LocalDateTime.now().minusHours(1); + Token expiredToken = + new Token("expired_access_token", "Bearer", "valid_refresh_token", pastTime); + + // Create mock token cache that returns the expired token + TokenCache mockTokenCache = Mockito.mock(TokenCache.class); + Mockito.doReturn(expiredToken).when(mockTokenCache).load(); + + // Create config with HTTP client and mock token cache + DatabricksConfig config = + new DatabricksConfig() + .setAuthType("external-browser") + .setHost("https://test.databricks.com") + .setClientId("test-client-id") + .setHttpClient(mockHttpClient); + + // We need to provide OIDC endpoints for token refresh + OpenIDConnectEndpoints endpoints = + new OpenIDConnectEndpoints( + "https://test.databricks.com/token", "https://test.databricks.com/authorize"); + + // Create our provider with the mock token cache + ExternalBrowserCredentialsProvider provider = + Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache)); + + // Spy on the config to inject the endpoints + DatabricksConfig spyConfig = Mockito.spy(config); + Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints(); + + // Configure provider + HeaderFactory headerFactory = provider.configure(spyConfig); + + // Verify headers contain the refreshed token, not the browser auth token or expired token + Map headers = headerFactory.headers(); + assertEquals("Bearer refreshed_access_token", headers.get("Authorization")); + + // Verify token was loaded from cache + Mockito.verify(mockTokenCache, Mockito.times(1)).load(); + + // Verify HTTP call was made to refresh the token + Mockito.verify(mockHttpClient, Mockito.times(1)).execute(any(Request.class)); + + // Verify performBrowserAuth was NOT called since refresh succeeded + Mockito.verify(provider, Mockito.never()) + .performBrowserAuth( + any(DatabricksConfig.class), + any(String.class), + any(String.class), + any(TokenCache.class)); + + // Verify token was saved back to cache + Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class)); + + // Capture the token that was saved to cache to verify it's the refreshed token + ArgumentCaptor tokenCaptor = ArgumentCaptor.forClass(Token.class); + Mockito.verify(mockTokenCache).save(tokenCaptor.capture()); + Token savedToken = tokenCaptor.getValue(); + + // Verify the saved token contains the refreshed values from the HTTP response + assertEquals( + "refreshed_access_token", + savedToken.getAccessToken(), + "Should save refreshed access token to cache"); + assertEquals( + "new_refresh_token", + savedToken.getRefreshToken(), + "Should save new refresh token to cache"); + } + + @Test + void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException { + // Create HTTP client that fails when refreshing token + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + Mockito.doThrow(new IOException("Failed to refresh token")) + .when(mockHttpClient) + .execute(any(Request.class)); + + // Create an expired token with invalid refresh token + LocalDateTime pastTime = LocalDateTime.now().minusHours(1); + Token expiredToken = + new Token("expired_access_token", "Bearer", "invalid_refresh_token", pastTime); + + // Create mock token cache that returns the expired token + TokenCache mockTokenCache = Mockito.mock(TokenCache.class); + Mockito.doReturn(expiredToken).when(mockTokenCache).load(); + + // Setup browser auth result (should be used as fallback) + Token browserAuthToken = + new Token( + "browser_access_token", + "Bearer", + "browser_refresh_token", + LocalDateTime.now().plusHours(1)); + + SessionCredentials browserAuthCreds = + new SessionCredentials.Builder() + .withToken(browserAuthToken) + .withClientId("test-client-id") + .withTokenUrl("https://test-token-url") + .build(); + + // Create config with failing HTTP client and mock token cache + DatabricksConfig config = + new DatabricksConfig() + .setAuthType("external-browser") + .setHost("https://test.databricks.com") + .setClientId("test-client-id") + .setHttpClient(mockHttpClient); + + // We need to provide OIDC endpoints for token refresh attempt + OpenIDConnectEndpoints endpoints = + new OpenIDConnectEndpoints( + "https://test.databricks.com/token", "https://test.databricks.com/authorize"); + + // Create our provider and mock the browser auth method + ExternalBrowserCredentialsProvider provider = + Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache)); + Mockito.doReturn(browserAuthCreds) + .when(provider) + .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); + + // Spy on the config to inject the endpoints + DatabricksConfig spyConfig = Mockito.spy(config); + Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints(); + + // Configure provider + HeaderFactory headerFactory = provider.configure(spyConfig); + + // Verify headers contain the browser auth token (fallback) + Map headers = headerFactory.headers(); + assertEquals("Bearer browser_access_token", headers.get("Authorization")); + + // Verify token was loaded from cache + Mockito.verify(mockTokenCache, Mockito.times(1)).load(); + + // Verify performBrowserAuth was called since refresh failed + Mockito.verify(provider, Mockito.times(1)) + .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); + + // Verify token was saved after browser auth (for the new token) + Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class)); + } + + @Test + void cacheWithInvalidTokensTest() throws IOException { + // Create completely invalid token (no refresh token) + LocalDateTime pastTime = LocalDateTime.now().minusHours(1); + Token invalidToken = new Token("expired_access_token", "Bearer", null, pastTime); + + // Create mock token cache that returns the invalid token + TokenCache mockTokenCache = Mockito.mock(TokenCache.class); + Mockito.doReturn(invalidToken).when(mockTokenCache).load(); + + // Setup browser auth result (should be used as fallback) + Token browserAuthToken = + new Token( + "browser_access_token", + "Bearer", + "browser_refresh_token", + LocalDateTime.now().plusHours(1)); + + SessionCredentials browserAuthCreds = + new SessionCredentials.Builder() + .withToken(browserAuthToken) + .withClientId("test-client-id") + .withTokenUrl("https://test-token-url") + .build(); + + // Create simple config + DatabricksConfig config = + new DatabricksConfig() + .setAuthType("external-browser") + .setHost("https://test.databricks.com") + .setClientId("test-client-id"); + + // Create our provider and mock the browser auth method + ExternalBrowserCredentialsProvider provider = + Mockito.spy(new ExternalBrowserCredentialsProvider(mockTokenCache)); + Mockito.doReturn(browserAuthCreds) + .when(provider) + .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); + + // Configure provider + HeaderFactory headerFactory = provider.configure(config); + // Verify headers contain the browser auth token (fallback) + Map headers = headerFactory.headers(); + assertEquals("Bearer browser_access_token", headers.get("Authorization")); + + // Verify token was loaded from cache + Mockito.verify(mockTokenCache, Mockito.times(1)).load(); + + // Verify performBrowserAuth was called since we had an invalid token + Mockito.verify(provider, Mockito.times(1)) + .performBrowserAuth(any(DatabricksConfig.class), any(), any(), any(TokenCache.class)); + + // Verify token was saved after browser auth (for the new token) + Mockito.verify(mockTokenCache, Mockito.times(1)).save(any(Token.class)); + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java new file mode 100644 index 000000000..ede6cfd11 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/FileTokenCacheTest.java @@ -0,0 +1,128 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +/** Tests for the FileTokenCache implementation of TokenCache. */ +public class FileTokenCacheTest { + private static final String TEST_HOST = "https://test-host.cloud.databricks.com"; + private static final String TEST_CLIENT_ID = "test-client-id"; + private static final List TEST_SCOPES = + Arrays.asList("offline_access", "clusters", "sql"); + private Path cacheFile; + private FileTokenCache tokenCache; + + @BeforeEach + void setUp() { + cacheFile = TokenCacheUtils.getCacheFilePath(TEST_HOST, TEST_CLIENT_ID, TEST_SCOPES); + tokenCache = new FileTokenCache(cacheFile); + } + + @AfterEach + void tearDown() throws IOException { + Files.deleteIfExists(cacheFile); + } + + @Test + void testEmptyCache() { + // When no cache file exists + assertNull(tokenCache.load(), "Loading from non-existent cache should return null"); + } + + @Test + void testSaveAndLoadToken() { + // Given a token + LocalDateTime expiry = LocalDateTime.now().plusHours(1); + Token token = new Token("access-token", "Bearer", "refresh-token", expiry); + + // When saving and loading the token + tokenCache.save(token); + Token loadedToken = tokenCache.load(); + + // Then the loaded token should match the original + assertNotNull(loadedToken, "Loaded token should not be null"); + assertEquals("access-token", loadedToken.getAccessToken()); + assertEquals("Bearer", loadedToken.getTokenType()); + assertEquals("refresh-token", loadedToken.getRefreshToken()); + assertFalse(loadedToken.isExpired(), "Token should not be expired"); + } + + @Test + void testTokenExpiry() { + // Create an expired token + LocalDateTime pastTime = LocalDateTime.now().minusHours(1); + Token expiredToken = new Token("access-token", "Bearer", "refresh-token", pastTime); + + // Verify it's marked as expired + assertTrue(expiredToken.isExpired(), "Token should be expired"); + + // Create a valid token + LocalDateTime futureTime = LocalDateTime.now().plusMinutes(30); + Token validToken = new Token("access-token", "Bearer", "refresh-token", futureTime); + + // Verify it's not marked as expired + assertFalse(validToken.isExpired(), "Token should not be expired"); + } + + @Test + void testNullPathRejection() { + // FileTokenCache should reject null path + assertThrows( + NullPointerException.class, + () -> new FileTokenCache(null), + "Should throw NullPointerException for null path"); + } + + @Test + void testOverwriteToken() { + // Given two tokens saved in sequence + Token token1 = new Token("token1", "Bearer", "refresh1", LocalDateTime.now().plusHours(1)); + Token token2 = new Token("token2", "Bearer", "refresh2", LocalDateTime.now().plusHours(2)); + + tokenCache.save(token1); + tokenCache.save(token2); + + // When loading from cache + Token loadedToken = tokenCache.load(); + + // Then the second token should be loaded + assertNotNull(loadedToken, "Loaded token should not be null"); + assertEquals("token2", loadedToken.getAccessToken()); + assertEquals("refresh2", loadedToken.getRefreshToken()); + } + + @Test + void testWithCustomPath(@TempDir Path tempDir) { + // Given a token cache with a custom path + Path tempPath = tempDir.resolve("custom-token-cache"); + FileTokenCache cache = new FileTokenCache(tempPath); + + // And a token + Token testToken = + new Token( + "test-access-token", "Bearer", "test-refresh-token", LocalDateTime.now().plusHours(1)); + + // When saving and loading + cache.save(testToken); + Token loadedToken = cache.load(); + + // Then the token should be loaded from the custom path + assertNotNull(loadedToken, "Should load token from custom cache path"); + assertEquals("test-access-token", loadedToken.getAccessToken()); + assertEquals("Bearer", loadedToken.getTokenType()); + assertEquals("test-refresh-token", loadedToken.getRefreshToken()); + + // Verify the file exists + assertTrue(Files.exists(tempPath), "Cache file should exist at custom path"); + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthClientUtilsTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthClientUtilsTest.java new file mode 100644 index 000000000..b9af3928f --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthClientUtilsTest.java @@ -0,0 +1,45 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import com.databricks.sdk.core.DatabricksConfig; +import org.junit.jupiter.api.Test; + +public class OAuthClientUtilsTest { + + @Test + void resolveClientIdTest() { + // Test with regular client ID + DatabricksConfig config = + new DatabricksConfig().setClientId("test-client-id").setAzureClientId("azure-client-id"); + assertEquals("test-client-id", OAuthClientUtils.resolveClientId(config)); + + // Test with only Azure client ID + config = new DatabricksConfig().setClientId(null).setAzureClientId("azure-client-id"); + assertEquals("azure-client-id", OAuthClientUtils.resolveClientId(config)); + + // Test with no client IDs + config = new DatabricksConfig().setClientId(null).setAzureClientId(null); + assertEquals("databricks-cli", OAuthClientUtils.resolveClientId(config)); + } + + @Test + void resolveClientSecretTest() { + // Test with regular client secret + DatabricksConfig config = + new DatabricksConfig() + .setClientSecret("test-client-secret") + .setAzureClientSecret("azure-client-secret"); + assertEquals("test-client-secret", OAuthClientUtils.resolveClientSecret(config)); + + // Test with only Azure client secret + config = + new DatabricksConfig().setClientSecret(null).setAzureClientSecret("azure-client-secret"); + assertEquals("azure-client-secret", OAuthClientUtils.resolveClientSecret(config)); + + // Test with no client secrets + config = new DatabricksConfig().setClientSecret(null).setAzureClientSecret(null); + assertNull(OAuthClientUtils.resolveClientSecret(config)); + } +}