From 3afdfa8e3d1f087a887e7514607f876fea06ebb6 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 19 May 2025 12:10:22 +0000 Subject: [PATCH 1/2] Implement direct dataplane access --- .../sdk/core/oauth/DataPlaneTokenSource.java | 102 ++++++++++ .../sdk/core/oauth/EndpointTokenSource.java | 92 +++++++++ .../sdk/core/oauth/TokenEndpointClient.java | 91 +++++++++ .../core/oauth/DataPlaneTokenSourceTest.java | 180 +++++++++++++++++ .../core/oauth/EndpointTokenSourceTest.java | 191 ++++++++++++++++++ .../core/oauth/TokenEndpointClientTest.java | 171 ++++++++++++++++ 6 files changed, 827 insertions(+) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java new file mode 100644 index 000000000..b12a92dd2 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java @@ -0,0 +1,102 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.http.HttpClient; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Manages and provides Databricks data plane tokens. This class is responsible for acquiring and + * caching OAuth tokens that are specific to a particular Databricks data plane service endpoint and + * a set of authorization details. It utilizes a {@link DatabricksOAuthTokenSource} for obtaining + * control plane tokens, which may then be exchanged or used to authorize requests for data plane + * tokens. Cached {@link EndpointTokenSource} instances are used to efficiently reuse tokens for + * repeated requests to the same endpoint with the same authorization context. + */ +public class DataPlaneTokenSource { + private final HttpClient httpClient; + private final DatabricksOAuthTokenSource cpTokenSource; + private final ConcurrentHashMap sourcesCache; + + /** + * Caching key for {@link EndpointTokenSource}, based on endpoint and authorization details. This + * is a value object that uniquely identifies a token source configuration. + */ + private static final class TokenSourceKey { + /** The target service endpoint URL. */ + private final String endpoint; + + /** Specific authorization details for the endpoint. */ + private final String authDetails; + + /** + * Constructs a TokenSourceKey. + * + * @param endpoint The target service endpoint URL. + * @param authDetails Specific authorization details. + */ + public TokenSourceKey(String endpoint, String authDetails) { + this.endpoint = endpoint; + this.authDetails = authDetails; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TokenSourceKey that = (TokenSourceKey) o; + return Objects.equals(endpoint, that.endpoint) + && Objects.equals(authDetails, that.authDetails); + } + + @Override + public int hashCode() { + return Objects.hash(endpoint, authDetails); + } + } + + /** + * Constructs a DataPlaneTokenSource. + * + * @param httpClient The {@link HttpClient} for token requests. + * @param cpTokenSource The {@link DatabricksOAuthTokenSource} for control plane tokens. + * @throws NullPointerException if either parameter is null + */ + public DataPlaneTokenSource(HttpClient httpClient, DatabricksOAuthTokenSource cpTokenSource) { + this.httpClient = Objects.requireNonNull(httpClient, "HTTP client cannot be null"); + this.cpTokenSource = + Objects.requireNonNull(cpTokenSource, "Control plane token source cannot be null"); + this.sourcesCache = new ConcurrentHashMap<>(); + } + + /** + * Retrieves a token for the specified endpoint and authorization details. It uses a cached {@link + * EndpointTokenSource} if available, otherwise creates and caches a new one. + * + * @param endpoint The target data plane service endpoint. + * @param authDetails Authorization details for the endpoint. + * @return The dataplane {@link Token}. + * @throws NullPointerException if either parameter is null + * @throws IllegalArgumentException if either parameter is empty + */ + public Token getToken(String endpoint, String authDetails) { + Objects.requireNonNull(endpoint, "Data plane endpoint URL cannot be null"); + Objects.requireNonNull(authDetails, "Authorization details cannot be null"); + if (endpoint.isEmpty()) { + throw new IllegalArgumentException("Data plane endpoint URL cannot be empty"); + } + if (authDetails.isEmpty()) { + throw new IllegalArgumentException("Authorization details cannot be empty"); + } + TokenSourceKey key = new TokenSourceKey(endpoint, authDetails); + + EndpointTokenSource specificSource = + sourcesCache.computeIfAbsent( + key, k -> new EndpointTokenSource(this.cpTokenSource, k.authDetails, this.httpClient)); + + return specificSource.getToken(); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java new file mode 100644 index 000000000..c54e7f6c0 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -0,0 +1,92 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Represents a token source that exchanges a control plane token for an endpoint-specific dataplane + * token. It utilizes an underlying {@link DatabricksOAuthTokenSource} to obtain the initial control + * plane token. + */ +public class EndpointTokenSource extends RefreshableTokenSource { + private static final Logger LOG = LoggerFactory.getLogger(EndpointTokenSource.class); + private static final String JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"; + private static final String GRANT_TYPE_PARAM = "grant_type"; + private static final String AUTHORIZATION_DETAILS_PARAM = "authorization_details"; + private static final String ASSERTION_PARAM = "assertion"; + private static final String TOKEN_ENDPOINT = "/oidc/v1/token"; + + private final DatabricksOAuthTokenSource cpTokenSource; + private final String authDetails; + private final HttpClient httpClient; + + /** + * Constructs a new EndpointTokenSource. + * + * @param cpTokenSource The {@link DatabricksOAuthTokenSource} used to obtain the control plane + * token. + * @param authDetails The authorization details required for the token exchange. + * @param httpClient The {@link HttpClient} used to make the token exchange request. + * @throws IllegalArgumentException if authDetails is empty. + * @throws NullPointerException if any of the parameters are null. + */ + public EndpointTokenSource( + DatabricksOAuthTokenSource cpTokenSource, String authDetails, HttpClient httpClient) { + this.cpTokenSource = + Objects.requireNonNull(cpTokenSource, "Control plane token source cannot be null"); + this.authDetails = Objects.requireNonNull(authDetails, "Authorization details cannot be null"); + if (authDetails.isEmpty()) { + throw new IllegalArgumentException("Authorization details cannot be empty"); + } + this.httpClient = Objects.requireNonNull(httpClient, "HTTP client cannot be null"); + } + + /** + * Fetches an endpoint-specific dataplane token by exchanging a control plane token. + * + *

This method first obtains a control plane token from the configured {@code cpTokenSource}. + * It then uses this token as an assertion along with the provided {@code authDetails} to request + * a new, more scoped dataplane token from the Databricks OAuth token endpoint ({@value + * #TOKEN_ENDPOINT}). + * + * @return A new {@link Token} containing the exchanged dataplane access token, its type, any + * accompanying refresh token, and its expiry time. + * @throws DatabricksException if the token exchange with the OAuth endpoint fails. + * @throws IllegalArgumentException if the token endpoint url is empty. + * @throws NullPointerException if any of the parameters are null. + */ + @Override + protected Token refresh() { + Token cpToken = cpTokenSource.getToken(); + + Map params = new HashMap<>(); + params.put(GRANT_TYPE_PARAM, JWT_GRANT_TYPE); + params.put(AUTHORIZATION_DETAILS_PARAM, authDetails); + params.put(ASSERTION_PARAM, cpToken.getAccessToken()); + + OAuthResponse oauthResponse; + try { + oauthResponse = TokenEndpointClient.requestToken(this.httpClient, TOKEN_ENDPOINT, params); + } catch (DatabricksException | IllegalArgumentException | NullPointerException e) { + LOG.error( + "Failed to exchange control plane token for dataplane token at endpoint {}: {}", + TOKEN_ENDPOINT, + e.getMessage(), + e); + throw e; + } + + LocalDateTime expiry = LocalDateTime.now().plusSeconds(oauthResponse.getExpiresIn()); + return new Token( + oauthResponse.getAccessToken(), + oauthResponse.getTokenType(), + oauthResponse.getRefreshToken(), + expiry); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java new file mode 100644 index 000000000..69883dd24 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java @@ -0,0 +1,91 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Client for interacting with an OAuth token endpoint. + * + *

This class provides a method to request an OAuth token from a specified token endpoint URL + * using the provided HTTP client and request parameters. It handles the HTTP request and parses the + * JSON response into an {@link OAuthResponse} object. + */ +public final class TokenEndpointClient { + private static final Logger LOG = LoggerFactory.getLogger(TokenEndpointClient.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private TokenEndpointClient() {} + + /** + * Requests an OAuth token from the specified token endpoint. + * + * @param httpClient The {@link HttpClient} to use for making the request. + * @param tokenEndpointUrl The URL of the token endpoint. + * @param params A map of parameters to include in the token request. + * @return An {@link OAuthResponse} containing the token information. + * @throws DatabricksException if an error occurs during the token request or response parsing. + * @throws IllegalArgumentException if the token endpoint URL is empty. + * @throws NullPointerException if any of the parameters are null. + */ + public static OAuthResponse requestToken( + HttpClient httpClient, String tokenEndpointUrl, Map params) + throws DatabricksException { + Objects.requireNonNull(httpClient, "HttpClient cannot be null"); + Objects.requireNonNull(params, "Request parameters map cannot be null"); + Objects.requireNonNull(tokenEndpointUrl, "Token endpoint URL cannot be null"); + + if (tokenEndpointUrl.isEmpty()) { + throw new IllegalArgumentException("Token endpoint URL cannot be empty"); + } + + Response rawResponse; + try { + LOG.debug("Requesting token from endpoint: {}", tokenEndpointUrl); + rawResponse = httpClient.execute(new FormRequest(tokenEndpointUrl, params)); + } catch (IOException e) { + LOG.error("Failed to request token from {}: {}", tokenEndpointUrl, e.getMessage(), e); + throw new DatabricksException( + String.format("Failed to request token from %s: %s", tokenEndpointUrl, e.getMessage()), + e); + } + + OAuthResponse response; + try { + response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); + } catch (IOException e) { + LOG.error( + "Failed to parse OAuth response from token endpoint {}: {}", + tokenEndpointUrl, + e.getMessage(), + e); + throw new DatabricksException( + String.format( + "Failed to parse OAuth response from token endpoint %s: %s", + tokenEndpointUrl, e.getMessage()), + e); + } + + if (response.getErrorCode() != null) { + String errorSummary = + response.getErrorSummary() != null ? response.getErrorSummary() : "No summary provided."; + LOG.error( + "Token request to {} failed with error: {} - {}", + tokenEndpointUrl, + response.getErrorCode(), + errorSummary); + throw new DatabricksException( + String.format( + "Token request failed with error: %s - %s", response.getErrorCode(), errorSummary)); + } + LOG.debug("Successfully obtained token response from {}", tokenEndpointUrl); + return response; + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java new file mode 100644 index 000000000..91418798e --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java @@ -0,0 +1,180 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; +import java.net.URL; +import java.time.LocalDateTime; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +public class DataPlaneTokenSourceTest { + private static final String TEST_ENDPOINT_1 = "https://endpoint1.databricks.com/"; + private static final String TEST_ENDPOINT_2 = "https://endpoint2.databricks.com/"; + private static final String TEST_AUTH_DETAILS_1 = "{\"aud\":\"aud1\"}"; + private static final String TEST_AUTH_DETAILS_2 = "{\"aud\":\"aud2\"}"; + private static final String TEST_CP_TOKEN = "cp-access-token"; + private static final String TEST_TOKEN_TYPE = "Bearer"; + private static final String TEST_REFRESH_TOKEN = "refresh-token"; + private static final int TEST_EXPIRES_IN = 3600; + + private static Stream provideDataPlaneTokenScenarios() throws Exception { + // Mock DatabricksOAuthTokenSource for control plane token + Token cpToken = + new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, LocalDateTime.now().plusSeconds(600)); + DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); + when(mockCpTokenSource.getToken()).thenReturn(cpToken); + + // --- Mock HttpClient for different scenarios --- + // Success JSON for endpoint1/auth1 + String successJson1 = + "{" + + "\"access_token\":\"dp-access-token1\"," + + "\"token_type\":\"Bearer\"," + + "\"refresh_token\":\"refresh-token\"," + + "\"expires_in\":3600" + + "}"; + HttpClient mockSuccessClient1 = mock(HttpClient.class); + when(mockSuccessClient1.execute(any())) + .thenReturn(new Response(successJson1, 200, "OK", new URL(TEST_ENDPOINT_1))); + + // Success JSON for endpoint2/auth2 + String successJson2 = + "{" + + "\"access_token\":\"dp-access-token2\"," + + "\"token_type\":\"Bearer\"," + + "\"refresh_token\":\"refresh-token\"," + + "\"expires_in\":3600" + + "}"; + HttpClient mockSuccessClient2 = mock(HttpClient.class); + when(mockSuccessClient2.execute(any())) + .thenReturn(new Response(successJson2, 200, "OK", new URL(TEST_ENDPOINT_2))); + + // Error response JSON + String errorJson = + "{" + "\"error\":\"invalid_request\"," + "\"error_description\":\"Bad request\"" + "}"; + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any())) + .thenReturn(new Response(errorJson, 400, "Bad Request", new URL(TEST_ENDPOINT_1))); + + // IOException scenario + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any())).thenThrow(new IOException("Network error")); + + // For null cpTokenSource + DatabricksOAuthTokenSource nullCpTokenSource = null; + + // For null httpClient + HttpClient nullHttpClient = null; + + // For null/empty endpoint or authDetails + return Stream.of( + Arguments.of( + "Success: endpoint1/auth1", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + new Token( + "dp-access-token1", + TEST_TOKEN_TYPE, + TEST_REFRESH_TOKEN, + LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), + null // No exception + ), + Arguments.of( + "Success: endpoint2/auth2 (different cache key)", + TEST_ENDPOINT_2, + TEST_AUTH_DETAILS_2, + mockSuccessClient2, + mockCpTokenSource, + new Token( + "dp-access-token2", + TEST_TOKEN_TYPE, + TEST_REFRESH_TOKEN, + LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), + null), + Arguments.of( + "Error response from endpoint", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockErrorClient, + mockCpTokenSource, + null, + com.databricks.sdk.core.DatabricksException.class), + Arguments.of( + "IOException from HttpClient", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockIOExceptionClient, + mockCpTokenSource, + null, + com.databricks.sdk.core.DatabricksException.class), + Arguments.of( + "Null cpTokenSource", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + nullCpTokenSource, + null, + NullPointerException.class), + Arguments.of( + "Null httpClient", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + nullHttpClient, + mockCpTokenSource, + null, + NullPointerException.class), + Arguments.of( + "Null endpoint", + null, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + null, + NullPointerException.class), + Arguments.of( + "Null authDetails", + TEST_ENDPOINT_1, + null, + mockSuccessClient1, + mockCpTokenSource, + null, + NullPointerException.class)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideDataPlaneTokenScenarios") + void testDataPlaneTokenSource( + String testName, + String endpoint, + String authDetails, + HttpClient httpClient, + DatabricksOAuthTokenSource cpTokenSource, + Token expectedToken, + Class expectedException) { + if (expectedException != null) { + assertThrows( + expectedException, + () -> { + DataPlaneTokenSource source = new DataPlaneTokenSource(httpClient, cpTokenSource); + source.getToken(endpoint, authDetails); + }); + } else { + DataPlaneTokenSource source = new DataPlaneTokenSource(httpClient, cpTokenSource); + Token token = source.getToken(endpoint, authDetails); + assertNotNull(token); + assertEquals(expectedToken.getAccessToken(), token.getAccessToken()); + assertEquals(expectedToken.getTokenType(), token.getTokenType()); + assertEquals(expectedToken.getRefreshToken(), token.getRefreshToken()); + assertTrue(token.isValid()); + } + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java new file mode 100644 index 000000000..549077690 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java @@ -0,0 +1,191 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; +import java.net.URL; +import java.time.LocalDateTime; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class EndpointTokenSourceTest { + private static final String TEST_AUTH_DETAILS = "{\"aud\":\"test-audience\"}"; + private static final String TEST_CP_TOKEN = "cp-access-token"; + private static final String TEST_DP_TOKEN = "dp-access-token"; + private static final String TEST_TOKEN_TYPE = "Bearer"; + private static final String TEST_REFRESH_TOKEN = "refresh-token"; + private static final int TEST_EXPIRES_IN = 3600; + + private static Stream provideEndpointTokenScenarios() throws Exception { + // Success response JSON + String successJson = + "{" + + "\"access_token\":\"" + + TEST_DP_TOKEN + + "\"," + + "\"token_type\":\"" + + TEST_TOKEN_TYPE + + "\"," + + "\"expires_in\":" + + TEST_EXPIRES_IN + + "," + + "\"refresh_token\":\"" + + TEST_REFRESH_TOKEN + + "\"}"; + // Error response JSON + String errorJson = + "{" + + "\"error\":\"invalid_client\"," + + "\"error_description\":\"Client authentication failed\"}"; + // Malformed JSON + String malformedJson = "{not valid json}"; + + // Mock DatabricksOAuthTokenSource for control plane token + Token cpToken = new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, LocalDateTime.now().plusMinutes(10)); + DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); + when(mockCpTokenSource.getToken()).thenReturn(cpToken); + + // Mock HttpClient for success + HttpClient mockSuccessClient = mock(HttpClient.class); + when(mockSuccessClient.execute(any())) + .thenReturn(new Response(successJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for error response + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any())) + .thenReturn( + new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for malformed JSON + HttpClient mockMalformedClient = mock(HttpClient.class); + when(mockMalformedClient.execute(any())) + .thenReturn( + new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for IOException + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any())).thenThrow(new IOException("Network error")); + + return Stream.of( + Arguments.of( + "Success response", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockSuccessClient, + null, // No exception expected + TEST_DP_TOKEN, + TEST_TOKEN_TYPE, + TEST_REFRESH_TOKEN, + TEST_EXPIRES_IN), + Arguments.of( + "OAuth error response", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockErrorClient, + DatabricksException.class, + null, + null, + null, + 0), + Arguments.of( + "Malformed JSON response", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockMalformedClient, + DatabricksException.class, + null, + null, + null, + 0), + Arguments.of( + "IOException from HttpClient", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockIOExceptionClient, + DatabricksException.class, + null, + null, + null, + 0), + Arguments.of( + "Null cpTokenSource", + null, + TEST_AUTH_DETAILS, + mockSuccessClient, + NullPointerException.class, + null, + null, + null, + 0), + Arguments.of( + "Null authDetails", + mockCpTokenSource, + null, + mockSuccessClient, + NullPointerException.class, + null, + null, + null, + 0), + Arguments.of( + "Empty authDetails", + mockCpTokenSource, + "", + mockSuccessClient, + IllegalArgumentException.class, + null, + null, + null, + 0), + Arguments.of( + "Null httpClient", + mockCpTokenSource, + TEST_AUTH_DETAILS, + null, + NullPointerException.class, + null, + null, + null, + 0)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideEndpointTokenScenarios") + void testEndpointTokenSource( + String testName, + DatabricksOAuthTokenSource cpTokenSource, + String authDetails, + HttpClient httpClient, + Class expectedException, + String expectedAccessToken, + String expectedTokenType, + String expectedRefreshToken, + int expectedExpiresIn) { + if (expectedException != null) { + assertThrows( + expectedException, + () -> { + EndpointTokenSource source = + new EndpointTokenSource(cpTokenSource, authDetails, httpClient); + source.getToken(); + }); + } else { + EndpointTokenSource source = new EndpointTokenSource(cpTokenSource, authDetails, httpClient); + Token token = source.getToken(); + assertNotNull(token); + assertEquals(expectedAccessToken, token.getAccessToken()); + assertEquals(expectedTokenType, token.getTokenType()); + assertEquals(expectedRefreshToken, token.getRefreshToken()); + // Allow a few seconds of clock skew for expiry + assertTrue(token.isValid()); + assertTrue(token.getAccessToken().length() > 0); + } + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java new file mode 100644 index 000000000..581c90143 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java @@ -0,0 +1,171 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class TokenEndpointClientTest { + private static final String TOKEN_ENDPOINT_URL = "https://test.databricks.com/oauth/token"; + private static final Map PARAMS = new HashMap<>(); + + private static Stream provideTokenScenarios() throws Exception { + // Success response JSON + String successJson = + "{" + + "\"access_token\":\"test-access-token\"," + + "\"token_type\":\"Bearer\"," + + "\"expires_in\":3600," + + "\"refresh_token\":\"test-refresh-token\"}"; + // Error response JSON + String errorJson = + "{" + + "\"error\":\"invalid_client\"," + + "\"error_description\":\"Client authentication failed\"}"; + // Malformed JSON + String malformedJson = "{not valid json}"; + + // Mock HttpClient for success + HttpClient mockSuccessClient = mock(HttpClient.class); + when(mockSuccessClient.execute(any(FormRequest.class))) + .thenReturn(new Response(successJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for error response + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any(FormRequest.class))) + .thenReturn( + new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for malformed JSON + HttpClient mockMalformedClient = mock(HttpClient.class); + when(mockMalformedClient.execute(any(FormRequest.class))) + .thenReturn( + new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for IOException + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any(FormRequest.class))) + .thenThrow(new IOException("Network error")); + + return Stream.of( + Arguments.of( + "Success response", + mockSuccessClient, + TOKEN_ENDPOINT_URL, + PARAMS, + null, // No exception expected + "test-access-token", + "Bearer", + 3600, + "test-refresh-token"), + Arguments.of( + "OAuth error response", + mockErrorClient, + TOKEN_ENDPOINT_URL, + PARAMS, + DatabricksException.class, + null, + null, + 0, + null), + Arguments.of( + "Malformed JSON response", + mockMalformedClient, + TOKEN_ENDPOINT_URL, + PARAMS, + DatabricksException.class, + null, + null, + 0, + null), + Arguments.of( + "IOException from HttpClient", + mockIOExceptionClient, + TOKEN_ENDPOINT_URL, + PARAMS, + DatabricksException.class, + null, + null, + 0, + null), + Arguments.of( + "Null HttpClient", + null, + TOKEN_ENDPOINT_URL, + PARAMS, + NullPointerException.class, + null, + null, + 0, + null), + Arguments.of( + "Null tokenEndpointUrl", + mockSuccessClient, + null, + PARAMS, + NullPointerException.class, + null, + null, + 0, + null), + Arguments.of( + "Empty tokenEndpointUrl", + mockSuccessClient, + "", + PARAMS, + IllegalArgumentException.class, + null, + null, + 0, + null), + Arguments.of( + "Null params", + mockSuccessClient, + TOKEN_ENDPOINT_URL, + null, + NullPointerException.class, + null, + null, + 0, + null)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideTokenScenarios") + void testRequestToken( + String testName, + HttpClient httpClient, + String tokenEndpointUrl, + Map params, + Class expectedException, + String expectedAccessToken, + String expectedTokenType, + int expectedExpiresIn, + String expectedRefreshToken) { + if (expectedException != null) { + assertThrows( + expectedException, + () -> TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params)); + } else { + OAuthResponse response = + TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params); + assertNotNull(response); + assertEquals(expectedAccessToken, response.getAccessToken()); + assertEquals(expectedTokenType, response.getTokenType()); + assertEquals(expectedExpiresIn, response.getExpiresIn()); + assertEquals(expectedRefreshToken, response.getRefreshToken()); + } + } +} From 34f6f4871bfe8cfb07868dea2c2b264c9139ab5c Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 19 May 2025 12:12:13 +0000 Subject: [PATCH 2/2] Small refactor --- .../sdk/core/DefaultCredentialsProvider.java | 13 +- .../oauth/DatabricksOAuthTokenSource.java | 89 +--- .../oauth/DatabricksOAuthTokenSourceTest.java | 395 ++++++++---------- 3 files changed, 195 insertions(+), 302 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index 0e4723f36..f72aa435b 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -34,14 +34,6 @@ public NamedIDTokenSource(String name, IDTokenSource idTokenSource) { this.name = name; this.idTokenSource = idTokenSource; } - - public String getName() { - return name; - } - - public IDTokenSource getIdTokenSource() { - return idTokenSource; - } } public DefaultCredentialsProvider() {} @@ -143,14 +135,13 @@ private void addOIDCCredentialsProviders(DatabricksConfig config) { config.getClientId(), config.getHost(), endpoints, - namedIdTokenSource.getIdTokenSource(), + namedIdTokenSource.idTokenSource, config.getHttpClient()) .audience(config.getTokenAudience()) .accountId(config.isAccountClient() ? config.getAccountId() : null) .build(); - providers.add( - new TokenSourceCredentialsProvider(oauthTokenSource, namedIdTokenSource.getName())); + providers.add(new TokenSourceCredentialsProvider(oauthTokenSource, namedIdTokenSource.name)); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index e642159c0..f16ae2aed 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -1,15 +1,12 @@ package com.databricks.sdk.core.oauth; import com.databricks.sdk.core.DatabricksException; -import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; -import com.databricks.sdk.core.http.Response; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Strings; -import java.io.IOException; import java.time.LocalDateTime; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,8 +41,6 @@ public class DatabricksOAuthTokenSource extends RefreshableTokenSource { private static final String SCOPE_PARAM = "scope"; private static final String CLIENT_ID_PARAM = "client_id"; - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - private DatabricksOAuthTokenSource(Builder builder) { this.clientId = builder.clientId; this.host = builder.host; @@ -123,44 +118,29 @@ public DatabricksOAuthTokenSource build() { } } - /** - * Validates that a value is non-null for required fields. If the value is a string, it also - * checks that it is non-empty. - * - * @param value The value to validate. - * @param fieldName The name of the field being validated. - * @throws IllegalArgumentException when the value is null or an empty string. - */ - private static void validate(Object value, String fieldName) { - if (value == null) { - LOG.error("Required parameter '{}' is null", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be null", fieldName)); - } - if (value instanceof String && ((String) value).isEmpty()) { - LOG.error("Required parameter '{}' is empty", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be empty", fieldName)); - } - } - /** * Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to * obtain an access token. * * @return A Token containing the access token and related information. * @throws DatabricksException when the token exchange fails. - * @throws IllegalArgumentException when there is an error code in the response or when required - * parameters are missing. + * @throws IllegalArgumentException when the required string parameters are empty. + * @throws NullPointerException when any of the required parameters are null. */ @Override public Token refresh() { - // Validate all required parameters - validate(clientId, "ClientID"); - validate(host, "Host"); - validate(endpoints, "Endpoints"); - validate(idTokenSource, "IDTokenSource"); - validate(httpClient, "HttpClient"); + Objects.requireNonNull(clientId, "ClientID cannot be null"); + Objects.requireNonNull(host, "Host cannot be null"); + Objects.requireNonNull(endpoints, "Endpoints cannot be null"); + Objects.requireNonNull(idTokenSource, "IDTokenSource cannot be null"); + Objects.requireNonNull(httpClient, "HttpClient cannot be null"); + + if (clientId.isEmpty()) { + throw new IllegalArgumentException("ClientID cannot be empty"); + } + if (host.isEmpty()) { + throw new IllegalArgumentException("Host cannot be empty"); + } String effectiveAudience = determineAudience(); IDToken idToken = idTokenSource.getIDToken(effectiveAudience); @@ -172,47 +152,20 @@ public Token refresh() { params.put(SCOPE_PARAM, SCOPE); params.put(CLIENT_ID_PARAM, clientId); - Response rawResponse; - try { - rawResponse = httpClient.execute(new FormRequest(endpoints.getTokenEndpoint(), params)); - } catch (IOException e) { - LOG.error( - "Failed to exchange ID token for access token at {}: {}", - endpoints.getTokenEndpoint(), - e.getMessage(), - e); - throw new DatabricksException( - String.format( - "Failed to exchange ID token for access token at %s: %s", - endpoints.getTokenEndpoint(), e.getMessage()), - e); - } - OAuthResponse response; try { - response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); - } catch (IOException e) { + response = + TokenEndpointClient.requestToken(this.httpClient, endpoints.getTokenEndpoint(), params); + } catch (DatabricksException e) { LOG.error( - "Failed to parse OAuth response from token endpoint {}: {}", + "OAuth token exchange failed for client ID '{}' at {}: {}", + this.clientId, endpoints.getTokenEndpoint(), e.getMessage(), e); - throw new DatabricksException( - String.format( - "Failed to parse OAuth response from token endpoint %s: %s", - endpoints.getTokenEndpoint(), e.getMessage())); + throw e; } - if (response.getErrorCode() != null) { - LOG.error( - "Token exchange failed with error: {} - {}", - response.getErrorCode(), - response.getErrorSummary()); - throw new IllegalArgumentException( - String.format( - "Token exchange failed with error: %s - %s", - response.getErrorCode(), response.getErrorSummary())); - } LocalDateTime expiry = LocalDateTime.now().plusSeconds(response.getExpiresIn()); return new Token( response.getAccessToken(), response.getTokenType(), response.getRefreshToken(), expiry); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java index 8d7da8d3a..8217179f2 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -15,7 +15,6 @@ import java.util.HashMap; import java.util.Map; import java.util.stream.Stream; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; @@ -35,45 +34,42 @@ class DatabricksOAuthTokenSourceTest { private static final String TEST_AUDIENCE = "test-audience"; private static final String TEST_ACCOUNT_ID = "test-account-id"; - // Error message constants - private static final String ERROR_NULL = "Required parameter '%s' cannot be null"; - private static final String ERROR_EMPTY = "Required parameter '%s' cannot be empty"; - - private IDTokenSource mockIdTokenSource; - - @BeforeEach - void setUp() { - mockIdTokenSource = Mockito.mock(IDTokenSource.class); - IDToken idToken = new IDToken(TEST_ID_TOKEN); - when(mockIdTokenSource.getIDToken(any())).thenReturn(idToken); - } - /** * Test case data for parameterized token source tests. Each case defines a specific OAuth token * exchange scenario. */ private static class TestCase { final String name; // Descriptive name of the test case + final String clientId; // Client ID to use + final String host; // Host to use + final OpenIDConnectEndpoints endpoints; // OIDC endpoints + final IDTokenSource idTokenSource; // ID token source + final HttpClient httpClient; // HTTP client final String audience; // Custom audience value if provided final String accountId; // Account ID if provided final String expectedAudience; // Expected audience used in token exchange - final HttpClient mockHttpClient; // Pre-configured mock HTTP client final Class expectedException; // Expected exception type if any TestCase( String name, + String clientId, + String host, + OpenIDConnectEndpoints endpoints, + IDTokenSource idTokenSource, + HttpClient httpClient, String audience, String accountId, String expectedAudience, - int statusCode, - Object responseBody, - HttpClient mockHttpClient, Class expectedException) { this.name = name; + this.clientId = clientId; + this.host = host; + this.endpoints = endpoints; + this.idTokenSource = idTokenSource; + this.httpClient = httpClient; this.audience = audience; this.accountId = accountId; this.expectedAudience = expectedAudience; - this.mockHttpClient = mockHttpClient; this.expectedException = expectedException; } @@ -87,20 +83,27 @@ public String toString() { * Provides test cases for OAuth token exchange scenarios. Includes success cases with different * audience configurations and various error cases. */ - private static Stream provideTestCases() { - try { - // Success response with valid token data - Map successResponse = new HashMap<>(); - successResponse.put("access_token", TOKEN); - successResponse.put("token_type", TOKEN_TYPE); - successResponse.put("refresh_token", REFRESH_TOKEN); - successResponse.put("expires_in", EXPIRES_IN); + private static Stream provideTestCases() throws MalformedURLException { + // Create valid components for reuse + OpenIDConnectEndpoints testEndpoints = + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + IDTokenSource testIdTokenSource = Mockito.mock(IDTokenSource.class); + IDToken idToken = new IDToken(TEST_ID_TOKEN); + when(testIdTokenSource.getIDToken(any())).thenReturn(idToken); + + // Create success response for token exchange tests + Map successResponse = new HashMap<>(); + successResponse.put("access_token", TOKEN); + successResponse.put("token_type", TOKEN_TYPE); + successResponse.put("refresh_token", REFRESH_TOKEN); + successResponse.put("expires_in", EXPIRES_IN); - // Error response for invalid requests - Map errorResponse = new HashMap<>(); - errorResponse.put("error", "invalid_request"); - errorResponse.put("error_description", "Invalid client ID"); + // Create error response for invalid requests + Map errorResponse = new HashMap<>(); + errorResponse.put("error", "invalid_request"); + errorResponse.put("error_description", "Invalid client ID"); + try { ObjectMapper mapper = new ObjectMapper(); final String errorJson = mapper.writeValueAsString(errorResponse); final String successJson = mapper.writeValueAsString(successResponse); @@ -115,71 +118,162 @@ private static Stream provideTestCases() { FormRequest expectedRequest = new FormRequest(TEST_TOKEN_ENDPOINT, formParams); return Stream.of( - // Success cases with different audience configurations + // Token exchange test cases new TestCase( "Default audience from token endpoint", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), null, null, TEST_TOKEN_ENDPOINT, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), new TestCase( "Custom audience provided", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), TEST_AUDIENCE, null, TEST_AUDIENCE, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), new TestCase( "Custom audience takes precedence over account ID", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), TEST_AUDIENCE, TEST_ACCOUNT_ID, TEST_AUDIENCE, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), new TestCase( "Account ID used as audience when no custom audience", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), null, TEST_ACCOUNT_ID, TEST_ACCOUNT_ID, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), - // Error cases new TestCase( "Invalid request returns 400", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 400, errorJson), null, null, TEST_TOKEN_ENDPOINT, - 400, - errorJson, - createMockHttpClient(expectedRequest, 400, errorJson), - IllegalArgumentException.class), + DatabricksException.class), new TestCase( "Network error during token exchange", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClientWithError(expectedRequest), null, null, TEST_TOKEN_ENDPOINT, - 0, - null, - createMockHttpClientWithError(expectedRequest), DatabricksException.class), new TestCase( "Invalid JSON response from server", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, "invalid json"), null, null, TEST_TOKEN_ENDPOINT, - 200, - "invalid json", - createMockHttpClient(expectedRequest, 200, "invalid json"), - DatabricksException.class)); + DatabricksException.class), + // Parameter validation test cases + new TestCase( + "Null client ID", + null, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Empty client ID", + "", + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + IllegalArgumentException.class), + new TestCase( + "Null host", + TEST_CLIENT_ID, + null, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Empty host", + TEST_CLIENT_ID, + "", + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + IllegalArgumentException.class), + new TestCase( + "Null endpoints", + TEST_CLIENT_ID, + TEST_HOST, + null, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Null IDTokenSource", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + null, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Null HttpClient", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + null, + null, + null, + null, + NullPointerException.class)); } catch (IOException e) { throw new RuntimeException("Failed to create test cases", e); } @@ -212,179 +306,34 @@ private static HttpClient createMockHttpClientWithError(FormRequest expectedRequ * Tests OAuth token exchange with various configurations and error scenarios. Verifies correct * audience selection, token exchange, and error handling. */ - @ParameterizedTest(name = "testTokenSource: {arguments}") + @ParameterizedTest(name = "{0}") @MethodSource("provideTestCases") void testTokenSource(TestCase testCase) { - try { - // Create token source with test configuration - OpenIDConnectEndpoints endpoints = - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); - - DatabricksOAuthTokenSource.Builder builder = - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, TEST_HOST, endpoints, mockIdTokenSource, testCase.mockHttpClient); - - builder.audience(testCase.audience).accountId(testCase.accountId); - - DatabricksOAuthTokenSource tokenSource = builder.build(); - - if (testCase.expectedException != null) { - assertThrows(testCase.expectedException, () -> tokenSource.getToken()); - } else { - // Verify successful token exchange - Token token = tokenSource.getToken(); - assertEquals(TOKEN, token.getAccessToken()); - assertEquals(TOKEN_TYPE, token.getTokenType()); - assertEquals(REFRESH_TOKEN, token.getRefreshToken()); - assertFalse(token.isExpired()); + DatabricksOAuthTokenSource.Builder builder = + new DatabricksOAuthTokenSource.Builder( + testCase.clientId, + testCase.host, + testCase.endpoints, + testCase.idTokenSource, + testCase.httpClient); - // Verify correct audience was used - verify(mockIdTokenSource).getIDToken(testCase.expectedAudience); - } - } catch (IOException e) { - throw new RuntimeException("Test failed", e); - } - } + builder.audience(testCase.audience); + builder.accountId(testCase.accountId); - /** - * Test case data for parameter validation tests. Each case defines a specific validation - * scenario. - */ - private static class ValidationTestCase { - final String name; - final String clientId; - final String host; - final OpenIDConnectEndpoints endpoints; - final IDTokenSource idTokenSource; - final HttpClient httpClient; - final String expectedFieldName; - final boolean isNullTest; + DatabricksOAuthTokenSource tokenSource = builder.build(); - ValidationTestCase( - String name, - String clientId, - String host, - OpenIDConnectEndpoints endpoints, - IDTokenSource idTokenSource, - HttpClient httpClient, - String expectedFieldName, - boolean isNullTest) { - this.name = name; - this.clientId = clientId; - this.host = host; - this.endpoints = endpoints; - this.idTokenSource = idTokenSource; - this.httpClient = httpClient; - this.expectedFieldName = expectedFieldName; - this.isNullTest = isNullTest; - } + if (testCase.expectedException != null) { + assertThrows(testCase.expectedException, () -> tokenSource.getToken()); + } else { + // Verify successful token exchange + Token token = tokenSource.getToken(); + assertEquals(TOKEN, token.getAccessToken()); + assertEquals(TOKEN_TYPE, token.getTokenType()); + assertEquals(REFRESH_TOKEN, token.getRefreshToken()); + assertFalse(token.isExpired()); - @Override - public String toString() { - return name; + // Verify correct audience was used + verify(testCase.idTokenSource, atLeastOnce()).getIDToken(testCase.expectedAudience); } } - - private static Stream provideValidationTestCases() - throws MalformedURLException { - OpenIDConnectEndpoints validEndpoints = - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); - HttpClient validHttpClient = Mockito.mock(HttpClient.class); - IDTokenSource validIdTokenSource = Mockito.mock(IDTokenSource.class); - - return Stream.of( - // Client ID validation - new ValidationTestCase( - "Null client ID", - null, - TEST_HOST, - validEndpoints, - validIdTokenSource, - validHttpClient, - "ClientID", - true), - new ValidationTestCase( - "Empty client ID", - "", - TEST_HOST, - validEndpoints, - validIdTokenSource, - validHttpClient, - "ClientID", - false), - // Host validation - new ValidationTestCase( - "Null host", - TEST_CLIENT_ID, - null, - validEndpoints, - validIdTokenSource, - validHttpClient, - "Host", - true), - new ValidationTestCase( - "Empty host", - TEST_CLIENT_ID, - "", - validEndpoints, - validIdTokenSource, - validHttpClient, - "Host", - false), - // Endpoints validation - new ValidationTestCase( - "Null endpoints", - TEST_CLIENT_ID, - TEST_HOST, - null, - validIdTokenSource, - validHttpClient, - "Endpoints", - true), - // IDTokenSource validation - new ValidationTestCase( - "Null IDTokenSource", - TEST_CLIENT_ID, - TEST_HOST, - validEndpoints, - null, - validHttpClient, - "IDTokenSource", - true), - // HttpClient validation - new ValidationTestCase( - "Null HttpClient", - TEST_CLIENT_ID, - TEST_HOST, - validEndpoints, - validIdTokenSource, - null, - "HttpClient", - true)); - } - - /** - * Tests validation of required fields in the token source using parameterized test cases. - * Verifies that null or empty values for required fields cause getToken() to throw - * IllegalArgumentException with specific error messages. - */ - @ParameterizedTest(name = "testParameterValidation: {0}") - @MethodSource("provideValidationTestCases") - void testParameterValidation(ValidationTestCase testCase) { - DatabricksOAuthTokenSource tokenSource = - new DatabricksOAuthTokenSource.Builder( - testCase.clientId, - testCase.host, - testCase.endpoints, - testCase.idTokenSource, - testCase.httpClient) - .build(); - - IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> tokenSource.getToken()); - - String expectedMessage = - String.format(testCase.isNullTest ? ERROR_NULL : ERROR_EMPTY, testCase.expectedFieldName); - assertEquals(expectedMessage, exception.getMessage()); - } }