From 4abf7808a21f3292d240d424b21f93ba2362aada Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 13 May 2025 09:08:37 +0000 Subject: [PATCH 1/5] Add CachedTokenSource --- .../sdk/core/oauth/CachedTokenSource.java | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java new file mode 100644 index 000000000..99db022d6 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java @@ -0,0 +1,18 @@ +package com.databricks.sdk.core.oauth; + +public class CachedTokenSource implements TokenSource { + private final TokenSource tokenSource; + private Token cachedToken; + + public CachedTokenSource(TokenSource tokenSource) { + this.tokenSource = tokenSource; + } + + @Override + public synchronized Token getToken() { + if (cachedToken == null || !cachedToken.isValid()) { + cachedToken = tokenSource.getToken(); + } + return cachedToken; + } +} From e0a9622dce354922b6b626db2439c143097b912f Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 14 May 2025 11:59:13 +0000 Subject: [PATCH 2/5] Add TokenEndpointClient --- .../sdk/core/oauth/DataPlaneTokenSource.java | 17 ++++ .../sdk/core/oauth/EndpointTokenSource.java | 81 +++++++++++++++++++ .../sdk/core/oauth/TokenEndpointClient.java | 78 ++++++++++++++++++ 3 files changed, 176 insertions(+) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java new file mode 100644 index 000000000..5b5363a57 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java @@ -0,0 +1,17 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.http.HttpClient; +import java.util.concurrent.ConcurrentHashMap; + +public class DataPlaneTokenSource { + private final HttpClient httpClient; + private DatabricksOAuthTokenSource cpTokenSource; + private ConcurrentHashMap tokenCache; + + public DataPlaneTokenSource(HttpClient httpClient) { + this.httpClient = httpClient; + // It's good practice to initialize collections, even if empty initially. + this.tokenCache = new ConcurrentHashMap<>(); + } + +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java new file mode 100644 index 000000000..9f42a492b --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -0,0 +1,81 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class EndpointTokenSource extends RefreshableTokenSource { + private static final Logger LOG = LoggerFactory.getLogger(EndpointTokenSource.class); + private static final String JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"; + private static final String GRANT_TYPE_PARAM = "grant_type"; + private static final String AUTHORIZATION_DETAILS_PARAM = "authorization_details"; + private static final String ASSERTION_PARAM = "assertion"; + + private final DatabricksOAuthTokenSource cpTokenSource; + private final String authDetails; + private final String tokenEndpoint; + private final HttpClient httpClient; + + public EndpointTokenSource( + DatabricksOAuthTokenSource cpTokenSource, + String authDetails, + HttpClient httpClient, + String tokenEndpoint) { + + validate(cpTokenSource, "ControlPlaneTokenSource"); + validate(authDetails, "AuthDetails"); + validate(httpClient, "HttpClient"); + validate(tokenEndpoint, "TokenEndpoint"); + + this.cpTokenSource = cpTokenSource; + this.authDetails = authDetails; + this.tokenEndpoint = tokenEndpoint; + this.httpClient = httpClient; + } + + private static void validate(Object value, String fieldName) { + if (value == null) { + LOG.error("Required parameter '{}' is null", fieldName); + throw new IllegalArgumentException( + String.format("Required parameter '%s' cannot be null", fieldName)); + } + if (value instanceof String && ((String) value).isEmpty()) { + LOG.error("Required parameter '{}' is empty", fieldName); + throw new IllegalArgumentException( + String.format("Required parameter '%s' cannot be empty", fieldName)); + } + } + + @Override + protected Token refresh() { + Token cpToken = cpTokenSource.getToken(); + + Map params = new HashMap<>(); + params.put(GRANT_TYPE_PARAM, JWT_GRANT_TYPE); + params.put(AUTHORIZATION_DETAILS_PARAM, authDetails); + params.put(ASSERTION_PARAM, cpToken.getAccessToken()); + + OAuthResponse oauthResponse; + try { + oauthResponse = TokenEndpointClient.requestToken(this.httpClient, this.tokenEndpoint, params); + } catch (DatabricksException e) { + LOG.error( + "Failed to fetch token for endpoint source using {}: {}", + this.tokenEndpoint, + e.getMessage(), + e); + throw e; + } + + LocalDateTime expiry = LocalDateTime.now().plusSeconds(oauthResponse.getExpiresIn()); + return new Token( + oauthResponse.getAccessToken(), + oauthResponse.getTokenType(), + oauthResponse.getRefreshToken(), + expiry); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java new file mode 100644 index 000000000..dc09246da --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java @@ -0,0 +1,78 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class TokenEndpointClient { + private static final Logger LOG = LoggerFactory.getLogger(TokenEndpointClient.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private TokenEndpointClient() {} + + public static OAuthResponse requestToken( + HttpClient httpClient, String tokenEndpointUrl, Map params) + throws DatabricksException { + if (httpClient == null) { + LOG.error("HttpClient cannot be null for requestToken"); + throw new IllegalArgumentException("HttpClient cannot be null"); + } + if (tokenEndpointUrl == null || tokenEndpointUrl.isEmpty()) { + LOG.error("Token endpoint URL cannot be null or empty"); + throw new IllegalArgumentException("Token endpoint URL cannot be null or empty"); + } + if (params == null) { + LOG.error("Request parameters map cannot be null"); + throw new IllegalArgumentException("Request parameters map cannot be null"); + } + + Response rawResponse; + try { + LOG.debug("Requesting token from endpoint: {} via static client method", tokenEndpointUrl); + rawResponse = httpClient.execute(new FormRequest(tokenEndpointUrl, params)); + } catch (IOException e) { + LOG.error( + "Failed to request token from {}: {}", tokenEndpointUrl, e.getMessage(), e); + throw new DatabricksException( + String.format("Failed to request token from %s: %s", tokenEndpointUrl, e.getMessage()), + e); + } + + OAuthResponse response; + try { + response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); + } catch (IOException e) { + LOG.error( + "Failed to parse OAuth response from token endpoint {}: {}", + tokenEndpointUrl, + e.getMessage(), + e); + throw new DatabricksException( + String.format( + "Failed to parse OAuth response from token endpoint %s: %s", + tokenEndpointUrl, e.getMessage()), + e); + } + + if (response.getErrorCode() != null) { + String errorSummary = response.getErrorSummary() != null ? response.getErrorSummary() : "No summary provided."; + LOG.error( + "Token request to {} failed with error: {} - {}", + tokenEndpointUrl, + response.getErrorCode(), + errorSummary); + throw new DatabricksException( + String.format( + "Token request failed with error: %s - %s", + response.getErrorCode(), errorSummary)); + } + LOG.debug("Successfully obtained token response from {}", tokenEndpointUrl); + return response; + } +} \ No newline at end of file From e12cb62a8ebd81cdec35eb2fedd7a2f1666f9b05 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 14 May 2025 15:52:43 +0000 Subject: [PATCH 3/5] Add DataPlaneTokenSource and EndpointTokenSource --- .../sdk/core/DefaultCredentialsProvider.java | 13 +- .../sdk/core/oauth/CachedTokenSource.java | 18 --- .../sdk/core/oauth/DataPlaneTokenSource.java | 85 +++++++++- .../oauth/DatabricksOAuthTokenSource.java | 47 +----- .../sdk/core/oauth/EndpointTokenSource.java | 151 +++++++++++------- .../sdk/core/oauth/TokenEndpointClient.java | 134 +++++++++------- .../oauth/DatabricksOAuthTokenSourceTest.java | 2 +- 7 files changed, 254 insertions(+), 196 deletions(-) delete mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index c3fc3b1e4..2824b026d 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -33,14 +33,6 @@ public NamedIDTokenSource(String name, IDTokenSource idTokenSource) { this.name = name; this.idTokenSource = idTokenSource; } - - public String getName() { - return name; - } - - public IDTokenSource getIdTokenSource() { - return idTokenSource; - } } public DefaultCredentialsProvider() {} @@ -128,14 +120,13 @@ private void addOIDCCredentialsProviders(DatabricksConfig config) { config.getClientId(), config.getHost(), endpoints, - namedIdTokenSource.getIdTokenSource(), + namedIdTokenSource.idTokenSource, config.getHttpClient()) .audience(config.getTokenAudience()) .accountId(config.isAccountClient() ? config.getAccountId() : null) .build(); - providers.add( - new TokenSourceCredentialsProvider(oauthTokenSource, namedIdTokenSource.getName())); + providers.add(new TokenSourceCredentialsProvider(oauthTokenSource, namedIdTokenSource.name)); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java deleted file mode 100644 index 99db022d6..000000000 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/CachedTokenSource.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.databricks.sdk.core.oauth; - -public class CachedTokenSource implements TokenSource { - private final TokenSource tokenSource; - private Token cachedToken; - - public CachedTokenSource(TokenSource tokenSource) { - this.tokenSource = tokenSource; - } - - @Override - public synchronized Token getToken() { - if (cachedToken == null || !cachedToken.isValid()) { - cachedToken = tokenSource.getToken(); - } - return cachedToken; - } -} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java index 5b5363a57..fee01b7dc 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java @@ -1,17 +1,86 @@ package com.databricks.sdk.core.oauth; import com.databricks.sdk.core.http.HttpClient; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +/** + * Manages and provides Databricks data plane tokens. This class is responsible for acquiring and + * caching OAuth tokens that are specific to a particular Databricks data plane service endpoint and + * a set of authorization details. It utilizes a {@link DatabricksOAuthTokenSource} for obtaining + * control plane tokens, which may then be exchanged or used to authorize requests for data plane + * tokens. Cached {@link EndpointTokenSource} instances are used to efficiently reuse tokens for + * repeated requests to the same endpoint with the same authorization context. + */ public class DataPlaneTokenSource { - private final HttpClient httpClient; - private DatabricksOAuthTokenSource cpTokenSource; - private ConcurrentHashMap tokenCache; - - public DataPlaneTokenSource(HttpClient httpClient) { - this.httpClient = httpClient; - // It's good practice to initialize collections, even if empty initially. - this.tokenCache = new ConcurrentHashMap<>(); + private final HttpClient httpClient; + private final DatabricksOAuthTokenSource cpTokenSource; + private ConcurrentHashMap sourcesCache; + + /** Caching key for {@link EndpointTokenSource}, based on endpoint and authorization details. */ + private static final class TokenSourceKey { + /** The target service endpoint URL. */ + private final String endpoint; + /** Specific authorization details (e.g., scope) for the endpoint. */ + private final String authDetails; + + /** + * Constructs a TokenSourceKey. + * + * @param endpoint The target service endpoint URL. + * @param authDetails Specific authorization details. + */ + public TokenSourceKey(String endpoint, String authDetails) { + this.endpoint = endpoint; + this.authDetails = authDetails; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TokenSourceKey that = (TokenSourceKey) o; + return Objects.equals(endpoint, that.endpoint) + && Objects.equals(authDetails, that.authDetails); + } + + @Override + public int hashCode() { + return Objects.hash(endpoint, authDetails); + } + } + + /** + * Constructs a DataPlaneTokenSource. + * + * @param httpClient The {@link HttpClient} for token requests. + * @param cpTokenSource The {@link DatabricksOAuthTokenSource} for control plane tokens. + */ + public DataPlaneTokenSource(HttpClient httpClient, DatabricksOAuthTokenSource cpTokenSource) { + this.httpClient = httpClient; + this.cpTokenSource = cpTokenSource; + this.sourcesCache = new ConcurrentHashMap<>(); + } + + /** + * Retrieves a token for the specified endpoint and authorization details. It uses a cached {@link + * EndpointTokenSource} if available, otherwise creates and caches a new one. + * + * @param endpoint The target data plane service endpoint. + * @param authDetails Authorization details for the endpoint. + * @return The dataplane {@link Token}. + */ + public Token getToken(String endpoint, String authDetails) { + TokenSourceKey key = new TokenSourceKey(endpoint, authDetails); + + EndpointTokenSource specificSource = + sourcesCache.computeIfAbsent( + key, k -> new EndpointTokenSource(this.cpTokenSource, k.authDetails, this.httpClient)); + + return specificSource.getToken(); + } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index e642159c0..2a3192b39 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -1,12 +1,8 @@ package com.databricks.sdk.core.oauth; import com.databricks.sdk.core.DatabricksException; -import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; -import com.databricks.sdk.core.http.Response; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Strings; -import java.io.IOException; import java.time.LocalDateTime; import java.util.HashMap; import java.util.Map; @@ -44,8 +40,6 @@ public class DatabricksOAuthTokenSource extends RefreshableTokenSource { private static final String SCOPE_PARAM = "scope"; private static final String CLIENT_ID_PARAM = "client_id"; - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - private DatabricksOAuthTokenSource(Builder builder) { this.clientId = builder.clientId; this.host = builder.host; @@ -172,47 +166,22 @@ public Token refresh() { params.put(SCOPE_PARAM, SCOPE); params.put(CLIENT_ID_PARAM, clientId); - Response rawResponse; - try { - rawResponse = httpClient.execute(new FormRequest(endpoints.getTokenEndpoint(), params)); - } catch (IOException e) { - LOG.error( - "Failed to exchange ID token for access token at {}: {}", - endpoints.getTokenEndpoint(), - e.getMessage(), - e); - throw new DatabricksException( - String.format( - "Failed to exchange ID token for access token at %s: %s", - endpoints.getTokenEndpoint(), e.getMessage()), - e); - } - OAuthResponse response; try { - response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); - } catch (IOException e) { + response = + TokenEndpointClient.requestToken(this.httpClient, endpoints.getTokenEndpoint(), params); + } catch (DatabricksException e) { + // Log specifically for DatabricksOAuthTokenSource context if needed, + // or rely on TokenEndpointClient's logging and just rethrow. LOG.error( - "Failed to parse OAuth response from token endpoint {}: {}", + "OAuth token exchange failed for client ID '{}' at {}: {}", + this.clientId, endpoints.getTokenEndpoint(), e.getMessage(), e); - throw new DatabricksException( - String.format( - "Failed to parse OAuth response from token endpoint %s: %s", - endpoints.getTokenEndpoint(), e.getMessage())); + throw e; // Re-throw the exception from TokenEndpointClient } - if (response.getErrorCode() != null) { - LOG.error( - "Token exchange failed with error: {} - {}", - response.getErrorCode(), - response.getErrorSummary()); - throw new IllegalArgumentException( - String.format( - "Token exchange failed with error: %s - %s", - response.getErrorCode(), response.getErrorSummary())); - } LocalDateTime expiry = LocalDateTime.now().plusSeconds(response.getExpiresIn()); return new Token( response.getAccessToken(), response.getTokenType(), response.getRefreshToken(), expiry); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java index 9f42a492b..7e8702021 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -8,74 +8,105 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * Represents a token source that exchanges a control plane token for an endpoint-specific dataplane + * token. It utilizes an underlying {@link DatabricksOAuthTokenSource} to obtain the initial control + * plane token. + */ public class EndpointTokenSource extends RefreshableTokenSource { - private static final Logger LOG = LoggerFactory.getLogger(EndpointTokenSource.class); - private static final String JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"; - private static final String GRANT_TYPE_PARAM = "grant_type"; - private static final String AUTHORIZATION_DETAILS_PARAM = "authorization_details"; - private static final String ASSERTION_PARAM = "assertion"; + private static final Logger LOG = LoggerFactory.getLogger(EndpointTokenSource.class); + private static final String JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"; + private static final String GRANT_TYPE_PARAM = "grant_type"; + private static final String AUTHORIZATION_DETAILS_PARAM = "authorization_details"; + private static final String ASSERTION_PARAM = "assertion"; + private static final String TOKEN_ENDPOINT = "/oidc/v1/token"; - private final DatabricksOAuthTokenSource cpTokenSource; - private final String authDetails; - private final String tokenEndpoint; - private final HttpClient httpClient; + private final DatabricksOAuthTokenSource cpTokenSource; + private final String authDetails; + private final HttpClient httpClient; - public EndpointTokenSource( - DatabricksOAuthTokenSource cpTokenSource, - String authDetails, - HttpClient httpClient, - String tokenEndpoint) { + /** + * Constructs a new EndpointTokenSource. + * + * @param cpTokenSource The {@link DatabricksOAuthTokenSource} used to obtain the control plane + * token. Cannot be null. + * @param authDetails The authorization details required for the token exchange. Cannot be null or + * empty. + * @param httpClient The {@link HttpClient} used to make the token exchange request. Cannot be + * null. + * @throws IllegalArgumentException if any of the parameters are null or invalid. + */ + public EndpointTokenSource( + DatabricksOAuthTokenSource cpTokenSource, String authDetails, HttpClient httpClient) { - validate(cpTokenSource, "ControlPlaneTokenSource"); - validate(authDetails, "AuthDetails"); - validate(httpClient, "HttpClient"); - validate(tokenEndpoint, "TokenEndpoint"); + validate(cpTokenSource, "ControlPlaneTokenSource"); + validate(authDetails, "AuthDetails"); + validate(httpClient, "HttpClient"); - this.cpTokenSource = cpTokenSource; - this.authDetails = authDetails; - this.tokenEndpoint = tokenEndpoint; - this.httpClient = httpClient; - } + this.cpTokenSource = cpTokenSource; + this.authDetails = authDetails; + this.httpClient = httpClient; + } - private static void validate(Object value, String fieldName) { - if (value == null) { - LOG.error("Required parameter '{}' is null", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be null", fieldName)); - } - if (value instanceof String && ((String) value).isEmpty()) { - LOG.error("Required parameter '{}' is empty", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be empty", fieldName)); - } + /** + * Validates that a given parameter value is not null, and if it's a String, not empty. This is a + * helper method used to ensure constructor and method arguments meet preconditions. + * + * @param value The parameter value to validate. + * @param fieldName The name of the parameter, used for logging and error messages. + * @throws IllegalArgumentException if the value is null, or if it is a String and is empty. + */ + private static void validate(Object value, String fieldName) { + if (value == null) { + LOG.error("Required parameter '{}' is null", fieldName); + throw new IllegalArgumentException( + String.format("Required parameter '%s' cannot be null", fieldName)); + } + if (value instanceof String && ((String) value).isEmpty()) { + LOG.error("Required parameter '{}' is empty", fieldName); + throw new IllegalArgumentException( + String.format("Required parameter '%s' cannot be empty", fieldName)); } + } - @Override - protected Token refresh() { - Token cpToken = cpTokenSource.getToken(); - - Map params = new HashMap<>(); - params.put(GRANT_TYPE_PARAM, JWT_GRANT_TYPE); - params.put(AUTHORIZATION_DETAILS_PARAM, authDetails); - params.put(ASSERTION_PARAM, cpToken.getAccessToken()); - - OAuthResponse oauthResponse; - try { - oauthResponse = TokenEndpointClient.requestToken(this.httpClient, this.tokenEndpoint, params); - } catch (DatabricksException e) { - LOG.error( - "Failed to fetch token for endpoint source using {}: {}", - this.tokenEndpoint, - e.getMessage(), - e); - throw e; - } + /** + * Fetches an endpoint-specific dataplane token by exchanging a control plane token. + * + *

This method first obtains a control plane token from the configured {@code cpTokenSource}. + * It then uses this token (as an assertion) along with the provided {@code authDetails} to + * request a new, more scoped dataplane token from the Databricks OAuth token endpoint ({@value + * #TOKEN_ENDPOINT}). + * + * @return A new {@link Token} containing the exchanged dataplane access token, its type, any + * accompanying refresh token, and its expiry time. + * @throws DatabricksException if the token exchange with the OAuth endpoint fails. + */ + @Override + protected Token refresh() { + Token cpToken = cpTokenSource.getToken(); - LocalDateTime expiry = LocalDateTime.now().plusSeconds(oauthResponse.getExpiresIn()); - return new Token( - oauthResponse.getAccessToken(), - oauthResponse.getTokenType(), - oauthResponse.getRefreshToken(), - expiry); + Map params = new HashMap<>(); + params.put(GRANT_TYPE_PARAM, JWT_GRANT_TYPE); + params.put(AUTHORIZATION_DETAILS_PARAM, authDetails); + params.put(ASSERTION_PARAM, cpToken.getAccessToken()); + + OAuthResponse oauthResponse; + try { + oauthResponse = TokenEndpointClient.requestToken(this.httpClient, TOKEN_ENDPOINT, params); + } catch (DatabricksException e) { + LOG.error( + "Failed to fetch token for endpoint source using {}: {}", + TOKEN_ENDPOINT, + e.getMessage(), + e); + throw e; } + + LocalDateTime expiry = LocalDateTime.now().plusSeconds(oauthResponse.getExpiresIn()); + return new Token( + oauthResponse.getAccessToken(), + oauthResponse.getTokenType(), + oauthResponse.getRefreshToken(), + expiry); + } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java index dc09246da..42dfe4210 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java @@ -10,69 +10,85 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * Client for interacting with an OAuth token endpoint. + * + *

This class provides a method to request an OAuth token from a specified token endpoint URL + * using the provided HTTP client and request parameters. It handles the HTTP request and parses the + * JSON response into an {@link OAuthResponse} object. + */ public final class TokenEndpointClient { - private static final Logger LOG = LoggerFactory.getLogger(TokenEndpointClient.class); - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final Logger LOG = LoggerFactory.getLogger(TokenEndpointClient.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - private TokenEndpointClient() {} + private TokenEndpointClient() {} - public static OAuthResponse requestToken( - HttpClient httpClient, String tokenEndpointUrl, Map params) - throws DatabricksException { - if (httpClient == null) { - LOG.error("HttpClient cannot be null for requestToken"); - throw new IllegalArgumentException("HttpClient cannot be null"); - } - if (tokenEndpointUrl == null || tokenEndpointUrl.isEmpty()) { - LOG.error("Token endpoint URL cannot be null or empty"); - throw new IllegalArgumentException("Token endpoint URL cannot be null or empty"); - } - if (params == null) { - LOG.error("Request parameters map cannot be null"); - throw new IllegalArgumentException("Request parameters map cannot be null"); - } + /** + * Requests an OAuth token from the specified token endpoint. + * + * @param httpClient The {@link HttpClient} to use for making the request. + * @param tokenEndpointUrl The URL of the token endpoint. + * @param params A map of parameters to include in the token request. + * @return An {@link OAuthResponse} containing the token information. + * @throws DatabricksException if an error occurs during the token request or response parsing. + * @throws IllegalArgumentException if any of the required parameters are null or empty. + */ + public static OAuthResponse requestToken( + HttpClient httpClient, String tokenEndpointUrl, Map params) + throws DatabricksException { + if (httpClient == null) { + LOG.error("HttpClient cannot be null for requestToken"); + throw new IllegalArgumentException("HttpClient cannot be null"); + } + if (tokenEndpointUrl == null || tokenEndpointUrl.isEmpty()) { + LOG.error("Token endpoint URL cannot be null or empty"); + throw new IllegalArgumentException("Token endpoint URL cannot be null or empty"); + } + if (params == null) { + LOG.error("Request parameters map cannot be null"); + throw new IllegalArgumentException("Request parameters map cannot be null"); + } - Response rawResponse; - try { - LOG.debug("Requesting token from endpoint: {} via static client method", tokenEndpointUrl); - rawResponse = httpClient.execute(new FormRequest(tokenEndpointUrl, params)); - } catch (IOException e) { - LOG.error( - "Failed to request token from {}: {}", tokenEndpointUrl, e.getMessage(), e); - throw new DatabricksException( - String.format("Failed to request token from %s: %s", tokenEndpointUrl, e.getMessage()), - e); - } + Response rawResponse; + try { + LOG.debug("Requesting token from endpoint: {} via static client method", tokenEndpointUrl); + rawResponse = httpClient.execute(new FormRequest(tokenEndpointUrl, params)); + } catch (IOException e) { + LOG.error("Failed to request token from {}: {}", tokenEndpointUrl, e.getMessage(), e); + throw new DatabricksException( + String.format("Failed to request token from %s: %s", tokenEndpointUrl, e.getMessage()), + e); + } - OAuthResponse response; - try { - response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); - } catch (IOException e) { - LOG.error( - "Failed to parse OAuth response from token endpoint {}: {}", - tokenEndpointUrl, - e.getMessage(), - e); - throw new DatabricksException( - String.format( - "Failed to parse OAuth response from token endpoint %s: %s", - tokenEndpointUrl, e.getMessage()), - e); - } + OAuthResponse response; + try { + response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); + } catch (IOException e) { + LOG.error( + "Failed to parse OAuth response from token endpoint {}: {}", + tokenEndpointUrl, + e.getMessage(), + e); + throw new DatabricksException( + String.format( + "Failed to parse OAuth response from token endpoint %s: %s", + tokenEndpointUrl, e.getMessage()), + e); + } - if (response.getErrorCode() != null) { - String errorSummary = response.getErrorSummary() != null ? response.getErrorSummary() : "No summary provided."; - LOG.error( - "Token request to {} failed with error: {} - {}", - tokenEndpointUrl, - response.getErrorCode(), - errorSummary); - throw new DatabricksException( - String.format( - "Token request failed with error: %s - %s", - response.getErrorCode(), errorSummary)); - } - LOG.debug("Successfully obtained token response from {}", tokenEndpointUrl); - return response; + if (response.getErrorCode() != null) { + String errorSummary = + response.getErrorSummary() != null ? response.getErrorSummary() : "No summary provided."; + LOG.error( + "Token request to {} failed with error: {} - {}", + tokenEndpointUrl, + response.getErrorCode(), + errorSummary); + throw new DatabricksException( + String.format( + "Token request failed with error: %s - %s", response.getErrorCode(), errorSummary)); } -} \ No newline at end of file + LOG.debug("Successfully obtained token response from {}", tokenEndpointUrl); + return response; + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java index 8d7da8d3a..6f1897fda 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -161,7 +161,7 @@ private static Stream provideTestCases() { 400, errorJson, createMockHttpClient(expectedRequest, 400, errorJson), - IllegalArgumentException.class), + DatabricksException.class), new TestCase( "Network error during token exchange", null, From dedebecc208aadba2477191c42d704ad4727990c Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 16 May 2025 18:11:15 +0000 Subject: [PATCH 4/5] Add unit tests --- .../sdk/core/oauth/DataPlaneTokenSource.java | 7 + .../sdk/core/oauth/EndpointTokenSource.java | 49 ++--- .../sdk/core/oauth/TokenEndpointClient.java | 22 +-- .../core/oauth/DataPlaneTokenSourceTest.java | 182 +++++++++++++++++ .../core/oauth/EndpointTokenSourceTest.java | 187 ++++++++++++++++++ .../core/oauth/TokenEndpointClientTest.java | 175 ++++++++++++++++ 6 files changed, 575 insertions(+), 47 deletions(-) create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java index fee01b7dc..31d67149c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java @@ -1,6 +1,7 @@ package com.databricks.sdk.core.oauth; import com.databricks.sdk.core.http.HttpClient; +import com.google.common.base.Strings; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; @@ -75,6 +76,12 @@ public DataPlaneTokenSource(HttpClient httpClient, DatabricksOAuthTokenSource cp * @return The dataplane {@link Token}. */ public Token getToken(String endpoint, String authDetails) { + if (Strings.isNullOrEmpty(endpoint)) { + throw new IllegalArgumentException("Endpoint must not be null or empty"); + } + if (Strings.isNullOrEmpty(authDetails)) { + throw new IllegalArgumentException("AuthDetails must not be null or empty"); + } TokenSourceKey key = new TokenSourceKey(endpoint, authDetails); EndpointTokenSource specificSource = diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java index 7e8702021..0eed6a43c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -5,6 +5,7 @@ import java.time.LocalDateTime; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,44 +30,21 @@ public class EndpointTokenSource extends RefreshableTokenSource { * Constructs a new EndpointTokenSource. * * @param cpTokenSource The {@link DatabricksOAuthTokenSource} used to obtain the control plane - * token. Cannot be null. - * @param authDetails The authorization details required for the token exchange. Cannot be null or - * empty. - * @param httpClient The {@link HttpClient} used to make the token exchange request. Cannot be - * null. - * @throws IllegalArgumentException if any of the parameters are null or invalid. + * token. + * @param authDetails The authorization details required for the token exchange. + * @param httpClient The {@link HttpClient} used to make the token exchange request. + * @throws IllegalArgumentException if authDetails is empty. + * @throws NullPointerException if any of the parameters are null. */ public EndpointTokenSource( DatabricksOAuthTokenSource cpTokenSource, String authDetails, HttpClient httpClient) { - validate(cpTokenSource, "ControlPlaneTokenSource"); - validate(authDetails, "AuthDetails"); - validate(httpClient, "HttpClient"); - - this.cpTokenSource = cpTokenSource; - this.authDetails = authDetails; - this.httpClient = httpClient; - } - - /** - * Validates that a given parameter value is not null, and if it's a String, not empty. This is a - * helper method used to ensure constructor and method arguments meet preconditions. - * - * @param value The parameter value to validate. - * @param fieldName The name of the parameter, used for logging and error messages. - * @throws IllegalArgumentException if the value is null, or if it is a String and is empty. - */ - private static void validate(Object value, String fieldName) { - if (value == null) { - LOG.error("Required parameter '{}' is null", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be null", fieldName)); - } - if (value instanceof String && ((String) value).isEmpty()) { - LOG.error("Required parameter '{}' is empty", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be empty", fieldName)); + this.cpTokenSource = Objects.requireNonNull(cpTokenSource, "ControlPlaneTokenSource cannot be null"); + this.authDetails = Objects.requireNonNull(authDetails, "AuthDetails cannot be null"); + if (authDetails.isEmpty()) { + throw new IllegalArgumentException("AuthDetails cannot be empty"); } + this.httpClient = Objects.requireNonNull(httpClient, "HttpClient cannot be null"); } /** @@ -80,6 +58,7 @@ private static void validate(Object value, String fieldName) { * @return A new {@link Token} containing the exchanged dataplane access token, its type, any * accompanying refresh token, and its expiry time. * @throws DatabricksException if the token exchange with the OAuth endpoint fails. + * @throws IllegalArgumentException if the control pl */ @Override protected Token refresh() { @@ -93,9 +72,9 @@ protected Token refresh() { OAuthResponse oauthResponse; try { oauthResponse = TokenEndpointClient.requestToken(this.httpClient, TOKEN_ENDPOINT, params); - } catch (DatabricksException e) { + } catch (DatabricksException | IllegalArgumentException e) { LOG.error( - "Failed to fetch token for endpoint source using {}: {}", + "Failed to fetch dataplane token for endpoint source using {}: {}", TOKEN_ENDPOINT, e.getMessage(), e); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java index 42dfe4210..1e3425b4c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java @@ -5,8 +5,10 @@ import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Response; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Strings; import java.io.IOException; import java.util.Map; +import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,22 +38,18 @@ private TokenEndpointClient() {} public static OAuthResponse requestToken( HttpClient httpClient, String tokenEndpointUrl, Map params) throws DatabricksException { - if (httpClient == null) { - LOG.error("HttpClient cannot be null for requestToken"); - throw new IllegalArgumentException("HttpClient cannot be null"); - } - if (tokenEndpointUrl == null || tokenEndpointUrl.isEmpty()) { - LOG.error("Token endpoint URL cannot be null or empty"); - throw new IllegalArgumentException("Token endpoint URL cannot be null or empty"); - } - if (params == null) { - LOG.error("Request parameters map cannot be null"); - throw new IllegalArgumentException("Request parameters map cannot be null"); + Objects.requireNonNull(httpClient, "HttpClient cannot be null"); + Objects.requireNonNull(params, "Request parameters map cannot be null"); + Objects.requireNonNull(tokenEndpointUrl, "Token endpoint URL cannot be null"); + + if (tokenEndpointUrl.isEmpty()) { + LOG.error("Token endpoint URL cannot be empty"); + throw new IllegalArgumentException("Token endpoint URL cannot be empty"); } Response rawResponse; try { - LOG.debug("Requesting token from endpoint: {} via static client method", tokenEndpointUrl); + LOG.debug("Requesting token from endpoint: {}", tokenEndpointUrl); rawResponse = httpClient.execute(new FormRequest(tokenEndpointUrl, params)); } catch (IOException e) { LOG.error("Failed to request token from {}: {}", tokenEndpointUrl, e.getMessage(), e); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java new file mode 100644 index 000000000..90e9a1ee6 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java @@ -0,0 +1,182 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; +import java.net.URL; +import java.time.LocalDateTime; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; + +public class DataPlaneTokenSourceTest { + private static final String TEST_ENDPOINT_1 = "https://endpoint1.databricks.com/"; + private static final String TEST_ENDPOINT_2 = "https://endpoint2.databricks.com/"; + private static final String TEST_AUTH_DETAILS_1 = "{\"aud\":\"aud1\"}"; + private static final String TEST_AUTH_DETAILS_2 = "{\"aud\":\"aud2\"}"; + private static final String TEST_CP_TOKEN = "cp-access-token"; + private static final String TEST_TOKEN_TYPE = "Bearer"; + private static final String TEST_REFRESH_TOKEN = "refresh-token"; + private static final int TEST_EXPIRES_IN = 3600; + + private static Stream provideDataPlaneTokenScenarios() throws Exception { + // Mock DatabricksOAuthTokenSource for control plane token + Token cpToken = new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, LocalDateTime.now().plusSeconds(600)); + DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); + when(mockCpTokenSource.getToken()).thenReturn(cpToken); + + // --- Mock HttpClient for different scenarios --- + // Success JSON for endpoint1/auth1 + String successJson1 = "{" + + "\"access_token\":\"dp-access-token1\"," + + "\"token_type\":\"Bearer\"," + + "\"refresh_token\":\"refresh-token\"," + + "\"expires_in\":3600" + + "}"; + HttpClient mockSuccessClient1 = mock(HttpClient.class); + when(mockSuccessClient1.execute(any())).thenReturn( + new Response(successJson1, 200, "OK", new URL(TEST_ENDPOINT_1)) + ); + + // Success JSON for endpoint2/auth2 + String successJson2 = "{" + + "\"access_token\":\"dp-access-token2\"," + + "\"token_type\":\"Bearer\"," + + "\"refresh_token\":\"refresh-token\"," + + "\"expires_in\":3600" + + "}"; + HttpClient mockSuccessClient2 = mock(HttpClient.class); + when(mockSuccessClient2.execute(any())).thenReturn( + new Response(successJson2, 200, "OK", new URL(TEST_ENDPOINT_2)) + ); + + // Error response JSON + String errorJson = "{" + + "\"error\":\"invalid_request\"," + + "\"error_description\":\"Bad request\"" + + "}"; + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any())).thenReturn( + new Response(errorJson, 400, "Bad Request", new URL(TEST_ENDPOINT_1)) + ); + + // IOException scenario + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any())).thenThrow(new IOException("Network error")); + + // For null cpTokenSource + DatabricksOAuthTokenSource nullCpTokenSource = null; + + // For null httpClient + HttpClient nullHttpClient = null; + + // For null/empty endpoint or authDetails + return Stream.of( + Arguments.of( + "Success: endpoint1/auth1", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + new Token("dp-access-token1", TEST_TOKEN_TYPE, TEST_REFRESH_TOKEN, LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), + null // No exception + ), + Arguments.of( + "Success: endpoint2/auth2 (different cache key)", + TEST_ENDPOINT_2, + TEST_AUTH_DETAILS_2, + mockSuccessClient2, + mockCpTokenSource, + new Token("dp-access-token2", TEST_TOKEN_TYPE, TEST_REFRESH_TOKEN, LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), + null + ), + Arguments.of( + "Error response from endpoint", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockErrorClient, + mockCpTokenSource, + null, + com.databricks.sdk.core.DatabricksException.class + ), + Arguments.of( + "IOException from HttpClient", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockIOExceptionClient, + mockCpTokenSource, + null, + com.databricks.sdk.core.DatabricksException.class + ), + Arguments.of( + "Null cpTokenSource", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + nullCpTokenSource, + null, + IllegalArgumentException.class + ), + Arguments.of( + "Null httpClient", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + nullHttpClient, + mockCpTokenSource, + null, + IllegalArgumentException.class + ), + Arguments.of( + "Null endpoint", + null, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + null, + IllegalArgumentException.class + ), + Arguments.of( + "Null authDetails", + TEST_ENDPOINT_1, + null, + mockSuccessClient1, + mockCpTokenSource, + null, + IllegalArgumentException.class + ) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideDataPlaneTokenScenarios") + void testDataPlaneTokenSource( + String testName, + String endpoint, + String authDetails, + HttpClient httpClient, + DatabricksOAuthTokenSource cpTokenSource, + Token expectedToken, + Class expectedException + ) { + if (expectedException != null) { + assertThrows(expectedException, () -> { + DataPlaneTokenSource source = new DataPlaneTokenSource(httpClient, cpTokenSource); + source.getToken(endpoint, authDetails); + }); + } else { + DataPlaneTokenSource source = new DataPlaneTokenSource(httpClient, cpTokenSource); + Token token = source.getToken(endpoint, authDetails); + assertNotNull(token); + assertEquals(expectedToken.getAccessToken(), token.getAccessToken()); + assertEquals(expectedToken.getTokenType(), token.getTokenType()); + assertEquals(expectedToken.getRefreshToken(), token.getRefreshToken()); + assertTrue(token.isValid()); + } + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java new file mode 100644 index 000000000..3e842fd37 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java @@ -0,0 +1,187 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; +import java.net.URL; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; + +class EndpointTokenSourceTest { + private static final String TEST_AUTH_DETAILS = "{\"aud\":\"test-audience\"}"; + private static final String TEST_CP_TOKEN = "cp-access-token"; + private static final String TEST_DP_TOKEN = "dp-access-token"; + private static final String TEST_TOKEN_TYPE = "Bearer"; + private static final String TEST_REFRESH_TOKEN = "refresh-token"; + private static final int TEST_EXPIRES_IN = 3600; + private static final String TOKEN_ENDPOINT = "/oidc/v1/token"; + + private static Stream provideEndpointTokenScenarios() throws Exception { + // Success response JSON + String successJson = "{" + + "\"access_token\":\"" + TEST_DP_TOKEN + "\"," + + "\"token_type\":\"" + TEST_TOKEN_TYPE + "\"," + + "\"expires_in\":" + TEST_EXPIRES_IN + "," + + "\"refresh_token\":\"" + TEST_REFRESH_TOKEN + "\"}"; + // Error response JSON + String errorJson = "{" + + "\"error\":\"invalid_client\"," + + "\"error_description\":\"Client authentication failed\"}"; + // Malformed JSON + String malformedJson = "{not valid json}"; + + // Mock DatabricksOAuthTokenSource for control plane token + Token cpToken = new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, LocalDateTime.now().plusMinutes(10)); + DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); + when(mockCpTokenSource.getToken()).thenReturn(cpToken); + + // Mock HttpClient for success + HttpClient mockSuccessClient = mock(HttpClient.class); + when(mockSuccessClient.execute(any())).thenReturn(new Response(successJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for error response + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any())).thenReturn(new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for malformed JSON + HttpClient mockMalformedClient = mock(HttpClient.class); + when(mockMalformedClient.execute(any())).thenReturn(new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for IOException + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any())).thenThrow(new IOException("Network error")); + + return Stream.of( + Arguments.of( + "Success response", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockSuccessClient, + null, // No exception expected + TEST_DP_TOKEN, + TEST_TOKEN_TYPE, + TEST_REFRESH_TOKEN, + TEST_EXPIRES_IN + ), + Arguments.of( + "OAuth error response", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockErrorClient, + DatabricksException.class, + null, + null, + null, + 0 + ), + Arguments.of( + "Malformed JSON response", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockMalformedClient, + DatabricksException.class, + null, + null, + null, + 0 + ), + Arguments.of( + "IOException from HttpClient", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockIOExceptionClient, + DatabricksException.class, + null, + null, + null, + 0 + ), + Arguments.of( + "Null cpTokenSource", + null, + TEST_AUTH_DETAILS, + mockSuccessClient, + IllegalArgumentException.class, + null, + null, + null, + 0 + ), + Arguments.of( + "Null authDetails", + mockCpTokenSource, + null, + mockSuccessClient, + IllegalArgumentException.class, + null, + null, + null, + 0 + ), + Arguments.of( + "Empty authDetails", + mockCpTokenSource, + "", + mockSuccessClient, + IllegalArgumentException.class, + null, + null, + null, + 0 + ), + Arguments.of( + "Null httpClient", + mockCpTokenSource, + TEST_AUTH_DETAILS, + null, + IllegalArgumentException.class, + null, + null, + null, + 0 + ) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideEndpointTokenScenarios") + void testEndpointTokenSource( + String testName, + DatabricksOAuthTokenSource cpTokenSource, + String authDetails, + HttpClient httpClient, + Class expectedException, + String expectedAccessToken, + String expectedTokenType, + String expectedRefreshToken, + int expectedExpiresIn + ) { + if (expectedException != null) { + assertThrows(expectedException, () -> { + EndpointTokenSource source = new EndpointTokenSource(cpTokenSource, authDetails, httpClient); + source.getToken(); + }); + } else { + EndpointTokenSource source = new EndpointTokenSource(cpTokenSource, authDetails, httpClient); + Token token = source.getToken(); + assertNotNull(token); + assertEquals(expectedAccessToken, token.getAccessToken()); + assertEquals(expectedTokenType, token.getTokenType()); + assertEquals(expectedRefreshToken, token.getRefreshToken()); + // Allow a few seconds of clock skew for expiry + assertTrue(token.isValid()); + assertTrue(token.getAccessToken().length() > 0); + } + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java new file mode 100644 index 000000000..0b07d117b --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java @@ -0,0 +1,175 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class TokenEndpointClientTest { + private static final String TOKEN_ENDPOINT_URL = "https://test.databricks.com/oauth/token"; + private static final Map PARAMS = new HashMap<>(); + + private static Stream provideTokenScenarios() throws Exception { + // Success response JSON + String successJson = "{" + + "\"access_token\":\"test-access-token\"," + + "\"token_type\":\"Bearer\"," + + "\"expires_in\":3600," + + "\"refresh_token\":\"test-refresh-token\"}"; + // Error response JSON + String errorJson = "{" + + "\"error\":\"invalid_client\"," + + "\"error_description\":\"Client authentication failed\"}"; + // Malformed JSON + String malformedJson = "{not valid json}"; + + // Mock HttpClient for success + HttpClient mockSuccessClient = mock(HttpClient.class); + when(mockSuccessClient.execute(any(FormRequest.class))) + .thenReturn(new Response(successJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for error response + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any(FormRequest.class))) + .thenReturn(new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for malformed JSON + HttpClient mockMalformedClient = mock(HttpClient.class); + when(mockMalformedClient.execute(any(FormRequest.class))) + .thenReturn(new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for IOException + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any(FormRequest.class))) + .thenThrow(new IOException("Network error")); + + return Stream.of( + Arguments.of( + "Success response", + mockSuccessClient, + TOKEN_ENDPOINT_URL, + PARAMS, + null, // No exception expected + "test-access-token", + "Bearer", + 3600, + "test-refresh-token" + ), + Arguments.of( + "OAuth error response", + mockErrorClient, + TOKEN_ENDPOINT_URL, + PARAMS, + DatabricksException.class, + null, + null, + 0, + null + ), + Arguments.of( + "Malformed JSON response", + mockMalformedClient, + TOKEN_ENDPOINT_URL, + PARAMS, + DatabricksException.class, + null, + null, + 0, + null + ), + Arguments.of( + "IOException from HttpClient", + mockIOExceptionClient, + TOKEN_ENDPOINT_URL, + PARAMS, + DatabricksException.class, + null, + null, + 0, + null + ), + Arguments.of( + "Null HttpClient", + null, + TOKEN_ENDPOINT_URL, + PARAMS, + IllegalArgumentException.class, + null, + null, + 0, + null + ), + Arguments.of( + "Null tokenEndpointUrl", + mockSuccessClient, + null, + PARAMS, + IllegalArgumentException.class, + null, + null, + 0, + null + ), + Arguments.of( + "Empty tokenEndpointUrl", + mockSuccessClient, + "", + PARAMS, + IllegalArgumentException.class, + null, + null, + 0, + null + ), + Arguments.of( + "Null params", + mockSuccessClient, + TOKEN_ENDPOINT_URL, + null, + IllegalArgumentException.class, + null, + null, + 0, + null + ) + ); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideTokenScenarios") + void testRequestToken( + String testName, + HttpClient httpClient, + String tokenEndpointUrl, + Map params, + Class expectedException, + String expectedAccessToken, + String expectedTokenType, + int expectedExpiresIn, + String expectedRefreshToken + ) { + if (expectedException != null) { + assertThrows(expectedException, () -> + TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params)); + } else { + OAuthResponse response = TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params); + assertNotNull(response); + assertEquals(expectedAccessToken, response.getAccessToken()); + assertEquals(expectedTokenType, response.getTokenType()); + assertEquals(expectedExpiresIn, response.getExpiresIn()); + assertEquals(expectedRefreshToken, response.getRefreshToken()); + } + } +} From f43141f6114a5d7bd4d5dcee4976b819f51a8af6 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 19 May 2025 11:10:31 +0000 Subject: [PATCH 5/5] Add unit tests and refactor --- .../sdk/core/oauth/DataPlaneTokenSource.java | 29 +- .../oauth/DatabricksOAuthTokenSource.java | 48 +-- .../sdk/core/oauth/EndpointTokenSource.java | 21 +- .../sdk/core/oauth/TokenEndpointClient.java | 7 +- .../core/oauth/DataPlaneTokenSourceTest.java | 302 +++++++------- .../oauth/DatabricksOAuthTokenSourceTest.java | 393 ++++++++---------- .../core/oauth/EndpointTokenSourceTest.java | 86 ++-- .../core/oauth/TokenEndpointClientTest.java | 66 ++- 8 files changed, 446 insertions(+), 506 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java index 31d67149c..b12a92dd2 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java @@ -1,7 +1,6 @@ package com.databricks.sdk.core.oauth; import com.databricks.sdk.core.http.HttpClient; -import com.google.common.base.Strings; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; @@ -16,13 +15,17 @@ public class DataPlaneTokenSource { private final HttpClient httpClient; private final DatabricksOAuthTokenSource cpTokenSource; - private ConcurrentHashMap sourcesCache; + private final ConcurrentHashMap sourcesCache; - /** Caching key for {@link EndpointTokenSource}, based on endpoint and authorization details. */ + /** + * Caching key for {@link EndpointTokenSource}, based on endpoint and authorization details. This + * is a value object that uniquely identifies a token source configuration. + */ private static final class TokenSourceKey { /** The target service endpoint URL. */ private final String endpoint; - /** Specific authorization details (e.g., scope) for the endpoint. */ + + /** Specific authorization details for the endpoint. */ private final String authDetails; /** @@ -60,10 +63,12 @@ public int hashCode() { * * @param httpClient The {@link HttpClient} for token requests. * @param cpTokenSource The {@link DatabricksOAuthTokenSource} for control plane tokens. + * @throws NullPointerException if either parameter is null */ public DataPlaneTokenSource(HttpClient httpClient, DatabricksOAuthTokenSource cpTokenSource) { - this.httpClient = httpClient; - this.cpTokenSource = cpTokenSource; + this.httpClient = Objects.requireNonNull(httpClient, "HTTP client cannot be null"); + this.cpTokenSource = + Objects.requireNonNull(cpTokenSource, "Control plane token source cannot be null"); this.sourcesCache = new ConcurrentHashMap<>(); } @@ -74,13 +79,17 @@ public DataPlaneTokenSource(HttpClient httpClient, DatabricksOAuthTokenSource cp * @param endpoint The target data plane service endpoint. * @param authDetails Authorization details for the endpoint. * @return The dataplane {@link Token}. + * @throws NullPointerException if either parameter is null + * @throws IllegalArgumentException if either parameter is empty */ public Token getToken(String endpoint, String authDetails) { - if (Strings.isNullOrEmpty(endpoint)) { - throw new IllegalArgumentException("Endpoint must not be null or empty"); + Objects.requireNonNull(endpoint, "Data plane endpoint URL cannot be null"); + Objects.requireNonNull(authDetails, "Authorization details cannot be null"); + if (endpoint.isEmpty()) { + throw new IllegalArgumentException("Data plane endpoint URL cannot be empty"); } - if (Strings.isNullOrEmpty(authDetails)) { - throw new IllegalArgumentException("AuthDetails must not be null or empty"); + if (authDetails.isEmpty()) { + throw new IllegalArgumentException("Authorization details cannot be empty"); } TokenSourceKey key = new TokenSourceKey(endpoint, authDetails); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index 2a3192b39..f16ae2aed 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -6,6 +6,7 @@ import java.time.LocalDateTime; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -117,44 +118,29 @@ public DatabricksOAuthTokenSource build() { } } - /** - * Validates that a value is non-null for required fields. If the value is a string, it also - * checks that it is non-empty. - * - * @param value The value to validate. - * @param fieldName The name of the field being validated. - * @throws IllegalArgumentException when the value is null or an empty string. - */ - private static void validate(Object value, String fieldName) { - if (value == null) { - LOG.error("Required parameter '{}' is null", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be null", fieldName)); - } - if (value instanceof String && ((String) value).isEmpty()) { - LOG.error("Required parameter '{}' is empty", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be empty", fieldName)); - } - } - /** * Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to * obtain an access token. * * @return A Token containing the access token and related information. * @throws DatabricksException when the token exchange fails. - * @throws IllegalArgumentException when there is an error code in the response or when required - * parameters are missing. + * @throws IllegalArgumentException when the required string parameters are empty. + * @throws NullPointerException when any of the required parameters are null. */ @Override public Token refresh() { - // Validate all required parameters - validate(clientId, "ClientID"); - validate(host, "Host"); - validate(endpoints, "Endpoints"); - validate(idTokenSource, "IDTokenSource"); - validate(httpClient, "HttpClient"); + Objects.requireNonNull(clientId, "ClientID cannot be null"); + Objects.requireNonNull(host, "Host cannot be null"); + Objects.requireNonNull(endpoints, "Endpoints cannot be null"); + Objects.requireNonNull(idTokenSource, "IDTokenSource cannot be null"); + Objects.requireNonNull(httpClient, "HttpClient cannot be null"); + + if (clientId.isEmpty()) { + throw new IllegalArgumentException("ClientID cannot be empty"); + } + if (host.isEmpty()) { + throw new IllegalArgumentException("Host cannot be empty"); + } String effectiveAudience = determineAudience(); IDToken idToken = idTokenSource.getIDToken(effectiveAudience); @@ -171,15 +157,13 @@ public Token refresh() { response = TokenEndpointClient.requestToken(this.httpClient, endpoints.getTokenEndpoint(), params); } catch (DatabricksException e) { - // Log specifically for DatabricksOAuthTokenSource context if needed, - // or rely on TokenEndpointClient's logging and just rethrow. LOG.error( "OAuth token exchange failed for client ID '{}' at {}: {}", this.clientId, endpoints.getTokenEndpoint(), e.getMessage(), e); - throw e; // Re-throw the exception from TokenEndpointClient + throw e; } LocalDateTime expiry = LocalDateTime.now().plusSeconds(response.getExpiresIn()); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java index 0eed6a43c..c54e7f6c0 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -38,27 +38,28 @@ public class EndpointTokenSource extends RefreshableTokenSource { */ public EndpointTokenSource( DatabricksOAuthTokenSource cpTokenSource, String authDetails, HttpClient httpClient) { - - this.cpTokenSource = Objects.requireNonNull(cpTokenSource, "ControlPlaneTokenSource cannot be null"); - this.authDetails = Objects.requireNonNull(authDetails, "AuthDetails cannot be null"); + this.cpTokenSource = + Objects.requireNonNull(cpTokenSource, "Control plane token source cannot be null"); + this.authDetails = Objects.requireNonNull(authDetails, "Authorization details cannot be null"); if (authDetails.isEmpty()) { - throw new IllegalArgumentException("AuthDetails cannot be empty"); + throw new IllegalArgumentException("Authorization details cannot be empty"); } - this.httpClient = Objects.requireNonNull(httpClient, "HttpClient cannot be null"); + this.httpClient = Objects.requireNonNull(httpClient, "HTTP client cannot be null"); } /** * Fetches an endpoint-specific dataplane token by exchanging a control plane token. * *

This method first obtains a control plane token from the configured {@code cpTokenSource}. - * It then uses this token (as an assertion) along with the provided {@code authDetails} to - * request a new, more scoped dataplane token from the Databricks OAuth token endpoint ({@value + * It then uses this token as an assertion along with the provided {@code authDetails} to request + * a new, more scoped dataplane token from the Databricks OAuth token endpoint ({@value * #TOKEN_ENDPOINT}). * * @return A new {@link Token} containing the exchanged dataplane access token, its type, any * accompanying refresh token, and its expiry time. * @throws DatabricksException if the token exchange with the OAuth endpoint fails. - * @throws IllegalArgumentException if the control pl + * @throws IllegalArgumentException if the token endpoint url is empty. + * @throws NullPointerException if any of the parameters are null. */ @Override protected Token refresh() { @@ -72,9 +73,9 @@ protected Token refresh() { OAuthResponse oauthResponse; try { oauthResponse = TokenEndpointClient.requestToken(this.httpClient, TOKEN_ENDPOINT, params); - } catch (DatabricksException | IllegalArgumentException e) { + } catch (DatabricksException | IllegalArgumentException | NullPointerException e) { LOG.error( - "Failed to fetch dataplane token for endpoint source using {}: {}", + "Failed to exchange control plane token for dataplane token at endpoint {}: {}", TOKEN_ENDPOINT, e.getMessage(), e); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java index 1e3425b4c..69883dd24 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java @@ -5,7 +5,6 @@ import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Response; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Strings; import java.io.IOException; import java.util.Map; import java.util.Objects; @@ -33,7 +32,8 @@ private TokenEndpointClient() {} * @param params A map of parameters to include in the token request. * @return An {@link OAuthResponse} containing the token information. * @throws DatabricksException if an error occurs during the token request or response parsing. - * @throws IllegalArgumentException if any of the required parameters are null or empty. + * @throws IllegalArgumentException if the token endpoint URL is empty. + * @throws NullPointerException if any of the parameters are null. */ public static OAuthResponse requestToken( HttpClient httpClient, String tokenEndpointUrl, Map params) @@ -41,9 +41,8 @@ public static OAuthResponse requestToken( Objects.requireNonNull(httpClient, "HttpClient cannot be null"); Objects.requireNonNull(params, "Request parameters map cannot be null"); Objects.requireNonNull(tokenEndpointUrl, "Token endpoint URL cannot be null"); - + if (tokenEndpointUrl.isEmpty()) { - LOG.error("Token endpoint URL cannot be empty"); throw new IllegalArgumentException("Token endpoint URL cannot be empty"); } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java index 90e9a1ee6..91418798e 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java @@ -13,170 +13,168 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Mockito; public class DataPlaneTokenSourceTest { - private static final String TEST_ENDPOINT_1 = "https://endpoint1.databricks.com/"; - private static final String TEST_ENDPOINT_2 = "https://endpoint2.databricks.com/"; - private static final String TEST_AUTH_DETAILS_1 = "{\"aud\":\"aud1\"}"; - private static final String TEST_AUTH_DETAILS_2 = "{\"aud\":\"aud2\"}"; - private static final String TEST_CP_TOKEN = "cp-access-token"; - private static final String TEST_TOKEN_TYPE = "Bearer"; - private static final String TEST_REFRESH_TOKEN = "refresh-token"; - private static final int TEST_EXPIRES_IN = 3600; + private static final String TEST_ENDPOINT_1 = "https://endpoint1.databricks.com/"; + private static final String TEST_ENDPOINT_2 = "https://endpoint2.databricks.com/"; + private static final String TEST_AUTH_DETAILS_1 = "{\"aud\":\"aud1\"}"; + private static final String TEST_AUTH_DETAILS_2 = "{\"aud\":\"aud2\"}"; + private static final String TEST_CP_TOKEN = "cp-access-token"; + private static final String TEST_TOKEN_TYPE = "Bearer"; + private static final String TEST_REFRESH_TOKEN = "refresh-token"; + private static final int TEST_EXPIRES_IN = 3600; - private static Stream provideDataPlaneTokenScenarios() throws Exception { - // Mock DatabricksOAuthTokenSource for control plane token - Token cpToken = new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, LocalDateTime.now().plusSeconds(600)); - DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); - when(mockCpTokenSource.getToken()).thenReturn(cpToken); + private static Stream provideDataPlaneTokenScenarios() throws Exception { + // Mock DatabricksOAuthTokenSource for control plane token + Token cpToken = + new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, LocalDateTime.now().plusSeconds(600)); + DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); + when(mockCpTokenSource.getToken()).thenReturn(cpToken); - // --- Mock HttpClient for different scenarios --- - // Success JSON for endpoint1/auth1 - String successJson1 = "{" + - "\"access_token\":\"dp-access-token1\"," + - "\"token_type\":\"Bearer\"," + - "\"refresh_token\":\"refresh-token\"," + - "\"expires_in\":3600" + - "}"; - HttpClient mockSuccessClient1 = mock(HttpClient.class); - when(mockSuccessClient1.execute(any())).thenReturn( - new Response(successJson1, 200, "OK", new URL(TEST_ENDPOINT_1)) - ); + // --- Mock HttpClient for different scenarios --- + // Success JSON for endpoint1/auth1 + String successJson1 = + "{" + + "\"access_token\":\"dp-access-token1\"," + + "\"token_type\":\"Bearer\"," + + "\"refresh_token\":\"refresh-token\"," + + "\"expires_in\":3600" + + "}"; + HttpClient mockSuccessClient1 = mock(HttpClient.class); + when(mockSuccessClient1.execute(any())) + .thenReturn(new Response(successJson1, 200, "OK", new URL(TEST_ENDPOINT_1))); - // Success JSON for endpoint2/auth2 - String successJson2 = "{" + - "\"access_token\":\"dp-access-token2\"," + - "\"token_type\":\"Bearer\"," + - "\"refresh_token\":\"refresh-token\"," + - "\"expires_in\":3600" + - "}"; - HttpClient mockSuccessClient2 = mock(HttpClient.class); - when(mockSuccessClient2.execute(any())).thenReturn( - new Response(successJson2, 200, "OK", new URL(TEST_ENDPOINT_2)) - ); + // Success JSON for endpoint2/auth2 + String successJson2 = + "{" + + "\"access_token\":\"dp-access-token2\"," + + "\"token_type\":\"Bearer\"," + + "\"refresh_token\":\"refresh-token\"," + + "\"expires_in\":3600" + + "}"; + HttpClient mockSuccessClient2 = mock(HttpClient.class); + when(mockSuccessClient2.execute(any())) + .thenReturn(new Response(successJson2, 200, "OK", new URL(TEST_ENDPOINT_2))); - // Error response JSON - String errorJson = "{" + - "\"error\":\"invalid_request\"," + - "\"error_description\":\"Bad request\"" + - "}"; - HttpClient mockErrorClient = mock(HttpClient.class); - when(mockErrorClient.execute(any())).thenReturn( - new Response(errorJson, 400, "Bad Request", new URL(TEST_ENDPOINT_1)) - ); + // Error response JSON + String errorJson = + "{" + "\"error\":\"invalid_request\"," + "\"error_description\":\"Bad request\"" + "}"; + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any())) + .thenReturn(new Response(errorJson, 400, "Bad Request", new URL(TEST_ENDPOINT_1))); - // IOException scenario - HttpClient mockIOExceptionClient = mock(HttpClient.class); - when(mockIOExceptionClient.execute(any())).thenThrow(new IOException("Network error")); + // IOException scenario + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any())).thenThrow(new IOException("Network error")); - // For null cpTokenSource - DatabricksOAuthTokenSource nullCpTokenSource = null; + // For null cpTokenSource + DatabricksOAuthTokenSource nullCpTokenSource = null; - // For null httpClient - HttpClient nullHttpClient = null; + // For null httpClient + HttpClient nullHttpClient = null; - // For null/empty endpoint or authDetails - return Stream.of( - Arguments.of( - "Success: endpoint1/auth1", - TEST_ENDPOINT_1, - TEST_AUTH_DETAILS_1, - mockSuccessClient1, - mockCpTokenSource, - new Token("dp-access-token1", TEST_TOKEN_TYPE, TEST_REFRESH_TOKEN, LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), - null // No exception + // For null/empty endpoint or authDetails + return Stream.of( + Arguments.of( + "Success: endpoint1/auth1", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + new Token( + "dp-access-token1", + TEST_TOKEN_TYPE, + TEST_REFRESH_TOKEN, + LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), + null // No exception ), - Arguments.of( - "Success: endpoint2/auth2 (different cache key)", - TEST_ENDPOINT_2, - TEST_AUTH_DETAILS_2, - mockSuccessClient2, - mockCpTokenSource, - new Token("dp-access-token2", TEST_TOKEN_TYPE, TEST_REFRESH_TOKEN, LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), - null - ), - Arguments.of( - "Error response from endpoint", - TEST_ENDPOINT_1, - TEST_AUTH_DETAILS_1, - mockErrorClient, - mockCpTokenSource, - null, - com.databricks.sdk.core.DatabricksException.class - ), - Arguments.of( - "IOException from HttpClient", - TEST_ENDPOINT_1, - TEST_AUTH_DETAILS_1, - mockIOExceptionClient, - mockCpTokenSource, - null, - com.databricks.sdk.core.DatabricksException.class - ), - Arguments.of( - "Null cpTokenSource", - TEST_ENDPOINT_1, - TEST_AUTH_DETAILS_1, - mockSuccessClient1, - nullCpTokenSource, - null, - IllegalArgumentException.class - ), - Arguments.of( - "Null httpClient", - TEST_ENDPOINT_1, - TEST_AUTH_DETAILS_1, - nullHttpClient, - mockCpTokenSource, - null, - IllegalArgumentException.class - ), - Arguments.of( - "Null endpoint", - null, - TEST_AUTH_DETAILS_1, - mockSuccessClient1, - mockCpTokenSource, - null, - IllegalArgumentException.class - ), - Arguments.of( - "Null authDetails", - TEST_ENDPOINT_1, - null, - mockSuccessClient1, - mockCpTokenSource, - null, - IllegalArgumentException.class - ) - ); - } + Arguments.of( + "Success: endpoint2/auth2 (different cache key)", + TEST_ENDPOINT_2, + TEST_AUTH_DETAILS_2, + mockSuccessClient2, + mockCpTokenSource, + new Token( + "dp-access-token2", + TEST_TOKEN_TYPE, + TEST_REFRESH_TOKEN, + LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), + null), + Arguments.of( + "Error response from endpoint", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockErrorClient, + mockCpTokenSource, + null, + com.databricks.sdk.core.DatabricksException.class), + Arguments.of( + "IOException from HttpClient", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockIOExceptionClient, + mockCpTokenSource, + null, + com.databricks.sdk.core.DatabricksException.class), + Arguments.of( + "Null cpTokenSource", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + nullCpTokenSource, + null, + NullPointerException.class), + Arguments.of( + "Null httpClient", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + nullHttpClient, + mockCpTokenSource, + null, + NullPointerException.class), + Arguments.of( + "Null endpoint", + null, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + null, + NullPointerException.class), + Arguments.of( + "Null authDetails", + TEST_ENDPOINT_1, + null, + mockSuccessClient1, + mockCpTokenSource, + null, + NullPointerException.class)); + } - @ParameterizedTest(name = "{0}") - @MethodSource("provideDataPlaneTokenScenarios") - void testDataPlaneTokenSource( - String testName, - String endpoint, - String authDetails, - HttpClient httpClient, - DatabricksOAuthTokenSource cpTokenSource, - Token expectedToken, - Class expectedException - ) { - if (expectedException != null) { - assertThrows(expectedException, () -> { - DataPlaneTokenSource source = new DataPlaneTokenSource(httpClient, cpTokenSource); - source.getToken(endpoint, authDetails); - }); - } else { + @ParameterizedTest(name = "{0}") + @MethodSource("provideDataPlaneTokenScenarios") + void testDataPlaneTokenSource( + String testName, + String endpoint, + String authDetails, + HttpClient httpClient, + DatabricksOAuthTokenSource cpTokenSource, + Token expectedToken, + Class expectedException) { + if (expectedException != null) { + assertThrows( + expectedException, + () -> { DataPlaneTokenSource source = new DataPlaneTokenSource(httpClient, cpTokenSource); - Token token = source.getToken(endpoint, authDetails); - assertNotNull(token); - assertEquals(expectedToken.getAccessToken(), token.getAccessToken()); - assertEquals(expectedToken.getTokenType(), token.getTokenType()); - assertEquals(expectedToken.getRefreshToken(), token.getRefreshToken()); - assertTrue(token.isValid()); - } + source.getToken(endpoint, authDetails); + }); + } else { + DataPlaneTokenSource source = new DataPlaneTokenSource(httpClient, cpTokenSource); + Token token = source.getToken(endpoint, authDetails); + assertNotNull(token); + assertEquals(expectedToken.getAccessToken(), token.getAccessToken()); + assertEquals(expectedToken.getTokenType(), token.getTokenType()); + assertEquals(expectedToken.getRefreshToken(), token.getRefreshToken()); + assertTrue(token.isValid()); } + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java index 6f1897fda..8217179f2 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -15,7 +15,6 @@ import java.util.HashMap; import java.util.Map; import java.util.stream.Stream; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; @@ -35,45 +34,42 @@ class DatabricksOAuthTokenSourceTest { private static final String TEST_AUDIENCE = "test-audience"; private static final String TEST_ACCOUNT_ID = "test-account-id"; - // Error message constants - private static final String ERROR_NULL = "Required parameter '%s' cannot be null"; - private static final String ERROR_EMPTY = "Required parameter '%s' cannot be empty"; - - private IDTokenSource mockIdTokenSource; - - @BeforeEach - void setUp() { - mockIdTokenSource = Mockito.mock(IDTokenSource.class); - IDToken idToken = new IDToken(TEST_ID_TOKEN); - when(mockIdTokenSource.getIDToken(any())).thenReturn(idToken); - } - /** * Test case data for parameterized token source tests. Each case defines a specific OAuth token * exchange scenario. */ private static class TestCase { final String name; // Descriptive name of the test case + final String clientId; // Client ID to use + final String host; // Host to use + final OpenIDConnectEndpoints endpoints; // OIDC endpoints + final IDTokenSource idTokenSource; // ID token source + final HttpClient httpClient; // HTTP client final String audience; // Custom audience value if provided final String accountId; // Account ID if provided final String expectedAudience; // Expected audience used in token exchange - final HttpClient mockHttpClient; // Pre-configured mock HTTP client final Class expectedException; // Expected exception type if any TestCase( String name, + String clientId, + String host, + OpenIDConnectEndpoints endpoints, + IDTokenSource idTokenSource, + HttpClient httpClient, String audience, String accountId, String expectedAudience, - int statusCode, - Object responseBody, - HttpClient mockHttpClient, Class expectedException) { this.name = name; + this.clientId = clientId; + this.host = host; + this.endpoints = endpoints; + this.idTokenSource = idTokenSource; + this.httpClient = httpClient; this.audience = audience; this.accountId = accountId; this.expectedAudience = expectedAudience; - this.mockHttpClient = mockHttpClient; this.expectedException = expectedException; } @@ -87,20 +83,27 @@ public String toString() { * Provides test cases for OAuth token exchange scenarios. Includes success cases with different * audience configurations and various error cases. */ - private static Stream provideTestCases() { - try { - // Success response with valid token data - Map successResponse = new HashMap<>(); - successResponse.put("access_token", TOKEN); - successResponse.put("token_type", TOKEN_TYPE); - successResponse.put("refresh_token", REFRESH_TOKEN); - successResponse.put("expires_in", EXPIRES_IN); + private static Stream provideTestCases() throws MalformedURLException { + // Create valid components for reuse + OpenIDConnectEndpoints testEndpoints = + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + IDTokenSource testIdTokenSource = Mockito.mock(IDTokenSource.class); + IDToken idToken = new IDToken(TEST_ID_TOKEN); + when(testIdTokenSource.getIDToken(any())).thenReturn(idToken); + + // Create success response for token exchange tests + Map successResponse = new HashMap<>(); + successResponse.put("access_token", TOKEN); + successResponse.put("token_type", TOKEN_TYPE); + successResponse.put("refresh_token", REFRESH_TOKEN); + successResponse.put("expires_in", EXPIRES_IN); - // Error response for invalid requests - Map errorResponse = new HashMap<>(); - errorResponse.put("error", "invalid_request"); - errorResponse.put("error_description", "Invalid client ID"); + // Create error response for invalid requests + Map errorResponse = new HashMap<>(); + errorResponse.put("error", "invalid_request"); + errorResponse.put("error_description", "Invalid client ID"); + try { ObjectMapper mapper = new ObjectMapper(); final String errorJson = mapper.writeValueAsString(errorResponse); final String successJson = mapper.writeValueAsString(successResponse); @@ -115,71 +118,162 @@ private static Stream provideTestCases() { FormRequest expectedRequest = new FormRequest(TEST_TOKEN_ENDPOINT, formParams); return Stream.of( - // Success cases with different audience configurations + // Token exchange test cases new TestCase( "Default audience from token endpoint", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), null, null, TEST_TOKEN_ENDPOINT, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), new TestCase( "Custom audience provided", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), TEST_AUDIENCE, null, TEST_AUDIENCE, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), new TestCase( "Custom audience takes precedence over account ID", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), TEST_AUDIENCE, TEST_ACCOUNT_ID, TEST_AUDIENCE, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), new TestCase( "Account ID used as audience when no custom audience", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), null, TEST_ACCOUNT_ID, TEST_ACCOUNT_ID, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), - // Error cases new TestCase( "Invalid request returns 400", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 400, errorJson), null, null, TEST_TOKEN_ENDPOINT, - 400, - errorJson, - createMockHttpClient(expectedRequest, 400, errorJson), DatabricksException.class), new TestCase( "Network error during token exchange", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClientWithError(expectedRequest), null, null, TEST_TOKEN_ENDPOINT, - 0, - null, - createMockHttpClientWithError(expectedRequest), DatabricksException.class), new TestCase( "Invalid JSON response from server", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, "invalid json"), null, null, TEST_TOKEN_ENDPOINT, - 200, - "invalid json", - createMockHttpClient(expectedRequest, 200, "invalid json"), - DatabricksException.class)); + DatabricksException.class), + // Parameter validation test cases + new TestCase( + "Null client ID", + null, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Empty client ID", + "", + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + IllegalArgumentException.class), + new TestCase( + "Null host", + TEST_CLIENT_ID, + null, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Empty host", + TEST_CLIENT_ID, + "", + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + IllegalArgumentException.class), + new TestCase( + "Null endpoints", + TEST_CLIENT_ID, + TEST_HOST, + null, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Null IDTokenSource", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + null, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Null HttpClient", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + null, + null, + null, + null, + NullPointerException.class)); } catch (IOException e) { throw new RuntimeException("Failed to create test cases", e); } @@ -212,179 +306,34 @@ private static HttpClient createMockHttpClientWithError(FormRequest expectedRequ * Tests OAuth token exchange with various configurations and error scenarios. Verifies correct * audience selection, token exchange, and error handling. */ - @ParameterizedTest(name = "testTokenSource: {arguments}") + @ParameterizedTest(name = "{0}") @MethodSource("provideTestCases") void testTokenSource(TestCase testCase) { - try { - // Create token source with test configuration - OpenIDConnectEndpoints endpoints = - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); - - DatabricksOAuthTokenSource.Builder builder = - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, TEST_HOST, endpoints, mockIdTokenSource, testCase.mockHttpClient); - - builder.audience(testCase.audience).accountId(testCase.accountId); - - DatabricksOAuthTokenSource tokenSource = builder.build(); + DatabricksOAuthTokenSource.Builder builder = + new DatabricksOAuthTokenSource.Builder( + testCase.clientId, + testCase.host, + testCase.endpoints, + testCase.idTokenSource, + testCase.httpClient); - if (testCase.expectedException != null) { - assertThrows(testCase.expectedException, () -> tokenSource.getToken()); - } else { - // Verify successful token exchange - Token token = tokenSource.getToken(); - assertEquals(TOKEN, token.getAccessToken()); - assertEquals(TOKEN_TYPE, token.getTokenType()); - assertEquals(REFRESH_TOKEN, token.getRefreshToken()); - assertFalse(token.isExpired()); + builder.audience(testCase.audience); + builder.accountId(testCase.accountId); - // Verify correct audience was used - verify(mockIdTokenSource).getIDToken(testCase.expectedAudience); - } - } catch (IOException e) { - throw new RuntimeException("Test failed", e); - } - } - - /** - * Test case data for parameter validation tests. Each case defines a specific validation - * scenario. - */ - private static class ValidationTestCase { - final String name; - final String clientId; - final String host; - final OpenIDConnectEndpoints endpoints; - final IDTokenSource idTokenSource; - final HttpClient httpClient; - final String expectedFieldName; - final boolean isNullTest; + DatabricksOAuthTokenSource tokenSource = builder.build(); - ValidationTestCase( - String name, - String clientId, - String host, - OpenIDConnectEndpoints endpoints, - IDTokenSource idTokenSource, - HttpClient httpClient, - String expectedFieldName, - boolean isNullTest) { - this.name = name; - this.clientId = clientId; - this.host = host; - this.endpoints = endpoints; - this.idTokenSource = idTokenSource; - this.httpClient = httpClient; - this.expectedFieldName = expectedFieldName; - this.isNullTest = isNullTest; - } + if (testCase.expectedException != null) { + assertThrows(testCase.expectedException, () -> tokenSource.getToken()); + } else { + // Verify successful token exchange + Token token = tokenSource.getToken(); + assertEquals(TOKEN, token.getAccessToken()); + assertEquals(TOKEN_TYPE, token.getTokenType()); + assertEquals(REFRESH_TOKEN, token.getRefreshToken()); + assertFalse(token.isExpired()); - @Override - public String toString() { - return name; + // Verify correct audience was used + verify(testCase.idTokenSource, atLeastOnce()).getIDToken(testCase.expectedAudience); } } - - private static Stream provideValidationTestCases() - throws MalformedURLException { - OpenIDConnectEndpoints validEndpoints = - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); - HttpClient validHttpClient = Mockito.mock(HttpClient.class); - IDTokenSource validIdTokenSource = Mockito.mock(IDTokenSource.class); - - return Stream.of( - // Client ID validation - new ValidationTestCase( - "Null client ID", - null, - TEST_HOST, - validEndpoints, - validIdTokenSource, - validHttpClient, - "ClientID", - true), - new ValidationTestCase( - "Empty client ID", - "", - TEST_HOST, - validEndpoints, - validIdTokenSource, - validHttpClient, - "ClientID", - false), - // Host validation - new ValidationTestCase( - "Null host", - TEST_CLIENT_ID, - null, - validEndpoints, - validIdTokenSource, - validHttpClient, - "Host", - true), - new ValidationTestCase( - "Empty host", - TEST_CLIENT_ID, - "", - validEndpoints, - validIdTokenSource, - validHttpClient, - "Host", - false), - // Endpoints validation - new ValidationTestCase( - "Null endpoints", - TEST_CLIENT_ID, - TEST_HOST, - null, - validIdTokenSource, - validHttpClient, - "Endpoints", - true), - // IDTokenSource validation - new ValidationTestCase( - "Null IDTokenSource", - TEST_CLIENT_ID, - TEST_HOST, - validEndpoints, - null, - validHttpClient, - "IDTokenSource", - true), - // HttpClient validation - new ValidationTestCase( - "Null HttpClient", - TEST_CLIENT_ID, - TEST_HOST, - validEndpoints, - validIdTokenSource, - null, - "HttpClient", - true)); - } - - /** - * Tests validation of required fields in the token source using parameterized test cases. - * Verifies that null or empty values for required fields cause getToken() to throw - * IllegalArgumentException with specific error messages. - */ - @ParameterizedTest(name = "testParameterValidation: {0}") - @MethodSource("provideValidationTestCases") - void testParameterValidation(ValidationTestCase testCase) { - DatabricksOAuthTokenSource tokenSource = - new DatabricksOAuthTokenSource.Builder( - testCase.clientId, - testCase.host, - testCase.endpoints, - testCase.idTokenSource, - testCase.httpClient) - .build(); - - IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> tokenSource.getToken()); - - String expectedMessage = - String.format(testCase.isNullTest ? ERROR_NULL : ERROR_EMPTY, testCase.expectedFieldName); - assertEquals(expectedMessage, exception.getMessage()); - } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java index 3e842fd37..549077690 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java @@ -10,13 +10,10 @@ import java.io.IOException; import java.net.URL; import java.time.LocalDateTime; -import java.util.HashMap; -import java.util.Map; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.Mockito; class EndpointTokenSourceTest { private static final String TEST_AUTH_DETAILS = "{\"aud\":\"test-audience\"}"; @@ -25,19 +22,28 @@ class EndpointTokenSourceTest { private static final String TEST_TOKEN_TYPE = "Bearer"; private static final String TEST_REFRESH_TOKEN = "refresh-token"; private static final int TEST_EXPIRES_IN = 3600; - private static final String TOKEN_ENDPOINT = "/oidc/v1/token"; private static Stream provideEndpointTokenScenarios() throws Exception { // Success response JSON - String successJson = "{" + - "\"access_token\":\"" + TEST_DP_TOKEN + "\"," + - "\"token_type\":\"" + TEST_TOKEN_TYPE + "\"," + - "\"expires_in\":" + TEST_EXPIRES_IN + "," + - "\"refresh_token\":\"" + TEST_REFRESH_TOKEN + "\"}"; + String successJson = + "{" + + "\"access_token\":\"" + + TEST_DP_TOKEN + + "\"," + + "\"token_type\":\"" + + TEST_TOKEN_TYPE + + "\"," + + "\"expires_in\":" + + TEST_EXPIRES_IN + + "," + + "\"refresh_token\":\"" + + TEST_REFRESH_TOKEN + + "\"}"; // Error response JSON - String errorJson = "{" + - "\"error\":\"invalid_client\"," + - "\"error_description\":\"Client authentication failed\"}"; + String errorJson = + "{" + + "\"error\":\"invalid_client\"," + + "\"error_description\":\"Client authentication failed\"}"; // Malformed JSON String malformedJson = "{not valid json}"; @@ -48,15 +54,20 @@ private static Stream provideEndpointTokenScenarios() throws Exceptio // Mock HttpClient for success HttpClient mockSuccessClient = mock(HttpClient.class); - when(mockSuccessClient.execute(any())).thenReturn(new Response(successJson, 200, "OK", new URL("https://test.databricks.com/"))); + when(mockSuccessClient.execute(any())) + .thenReturn(new Response(successJson, 200, "OK", new URL("https://test.databricks.com/"))); // Mock HttpClient for error response HttpClient mockErrorClient = mock(HttpClient.class); - when(mockErrorClient.execute(any())).thenReturn(new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); + when(mockErrorClient.execute(any())) + .thenReturn( + new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); // Mock HttpClient for malformed JSON HttpClient mockMalformedClient = mock(HttpClient.class); - when(mockMalformedClient.execute(any())).thenReturn(new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); + when(mockMalformedClient.execute(any())) + .thenReturn( + new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); // Mock HttpClient for IOException HttpClient mockIOExceptionClient = mock(HttpClient.class); @@ -72,8 +83,7 @@ private static Stream provideEndpointTokenScenarios() throws Exceptio TEST_DP_TOKEN, TEST_TOKEN_TYPE, TEST_REFRESH_TOKEN, - TEST_EXPIRES_IN - ), + TEST_EXPIRES_IN), Arguments.of( "OAuth error response", mockCpTokenSource, @@ -83,8 +93,7 @@ private static Stream provideEndpointTokenScenarios() throws Exceptio null, null, null, - 0 - ), + 0), Arguments.of( "Malformed JSON response", mockCpTokenSource, @@ -94,8 +103,7 @@ private static Stream provideEndpointTokenScenarios() throws Exceptio null, null, null, - 0 - ), + 0), Arguments.of( "IOException from HttpClient", mockCpTokenSource, @@ -105,30 +113,27 @@ private static Stream provideEndpointTokenScenarios() throws Exceptio null, null, null, - 0 - ), + 0), Arguments.of( "Null cpTokenSource", null, TEST_AUTH_DETAILS, mockSuccessClient, - IllegalArgumentException.class, + NullPointerException.class, null, null, null, - 0 - ), + 0), Arguments.of( "Null authDetails", mockCpTokenSource, null, mockSuccessClient, - IllegalArgumentException.class, + NullPointerException.class, null, null, null, - 0 - ), + 0), Arguments.of( "Empty authDetails", mockCpTokenSource, @@ -138,20 +143,17 @@ private static Stream provideEndpointTokenScenarios() throws Exceptio null, null, null, - 0 - ), + 0), Arguments.of( "Null httpClient", mockCpTokenSource, TEST_AUTH_DETAILS, null, - IllegalArgumentException.class, + NullPointerException.class, null, null, null, - 0 - ) - ); + 0)); } @ParameterizedTest(name = "{0}") @@ -165,13 +167,15 @@ void testEndpointTokenSource( String expectedAccessToken, String expectedTokenType, String expectedRefreshToken, - int expectedExpiresIn - ) { + int expectedExpiresIn) { if (expectedException != null) { - assertThrows(expectedException, () -> { - EndpointTokenSource source = new EndpointTokenSource(cpTokenSource, authDetails, httpClient); - source.getToken(); - }); + assertThrows( + expectedException, + () -> { + EndpointTokenSource source = + new EndpointTokenSource(cpTokenSource, authDetails, httpClient); + source.getToken(); + }); } else { EndpointTokenSource source = new EndpointTokenSource(cpTokenSource, authDetails, httpClient); Token token = source.getToken(); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java index 0b07d117b..581c90143 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java @@ -23,15 +23,17 @@ class TokenEndpointClientTest { private static Stream provideTokenScenarios() throws Exception { // Success response JSON - String successJson = "{" + - "\"access_token\":\"test-access-token\"," + - "\"token_type\":\"Bearer\"," + - "\"expires_in\":3600," + - "\"refresh_token\":\"test-refresh-token\"}"; + String successJson = + "{" + + "\"access_token\":\"test-access-token\"," + + "\"token_type\":\"Bearer\"," + + "\"expires_in\":3600," + + "\"refresh_token\":\"test-refresh-token\"}"; // Error response JSON - String errorJson = "{" + - "\"error\":\"invalid_client\"," + - "\"error_description\":\"Client authentication failed\"}"; + String errorJson = + "{" + + "\"error\":\"invalid_client\"," + + "\"error_description\":\"Client authentication failed\"}"; // Malformed JSON String malformedJson = "{not valid json}"; @@ -43,12 +45,14 @@ private static Stream provideTokenScenarios() throws Exception { // Mock HttpClient for error response HttpClient mockErrorClient = mock(HttpClient.class); when(mockErrorClient.execute(any(FormRequest.class))) - .thenReturn(new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); + .thenReturn( + new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); // Mock HttpClient for malformed JSON HttpClient mockMalformedClient = mock(HttpClient.class); when(mockMalformedClient.execute(any(FormRequest.class))) - .thenReturn(new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); + .thenReturn( + new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); // Mock HttpClient for IOException HttpClient mockIOExceptionClient = mock(HttpClient.class); @@ -65,8 +69,7 @@ private static Stream provideTokenScenarios() throws Exception { "test-access-token", "Bearer", 3600, - "test-refresh-token" - ), + "test-refresh-token"), Arguments.of( "OAuth error response", mockErrorClient, @@ -76,8 +79,7 @@ private static Stream provideTokenScenarios() throws Exception { null, null, 0, - null - ), + null), Arguments.of( "Malformed JSON response", mockMalformedClient, @@ -87,8 +89,7 @@ private static Stream provideTokenScenarios() throws Exception { null, null, 0, - null - ), + null), Arguments.of( "IOException from HttpClient", mockIOExceptionClient, @@ -98,30 +99,27 @@ private static Stream provideTokenScenarios() throws Exception { null, null, 0, - null - ), + null), Arguments.of( "Null HttpClient", null, TOKEN_ENDPOINT_URL, PARAMS, - IllegalArgumentException.class, + NullPointerException.class, null, null, 0, - null - ), + null), Arguments.of( "Null tokenEndpointUrl", mockSuccessClient, null, PARAMS, - IllegalArgumentException.class, + NullPointerException.class, null, null, 0, - null - ), + null), Arguments.of( "Empty tokenEndpointUrl", mockSuccessClient, @@ -131,20 +129,17 @@ private static Stream provideTokenScenarios() throws Exception { null, null, 0, - null - ), + null), Arguments.of( "Null params", mockSuccessClient, TOKEN_ENDPOINT_URL, null, - IllegalArgumentException.class, + NullPointerException.class, null, null, 0, - null - ) - ); + null)); } @ParameterizedTest(name = "{0}") @@ -158,13 +153,14 @@ void testRequestToken( String expectedAccessToken, String expectedTokenType, int expectedExpiresIn, - String expectedRefreshToken - ) { + String expectedRefreshToken) { if (expectedException != null) { - assertThrows(expectedException, () -> - TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params)); + assertThrows( + expectedException, + () -> TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params)); } else { - OAuthResponse response = TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params); + OAuthResponse response = + TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params); assertNotNull(response); assertEquals(expectedAccessToken, response.getAccessToken()); assertEquals(expectedTokenType, response.getTokenType());