From 6f8da14508bfefeea6df445e305773304a1572a9 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 30 Apr 2025 15:34:24 +0000 Subject: [PATCH 1/5] Add DatabricksOAuthTokenSource --- .../oauth/DatabricksOAuthTokenSource.java | 200 +++++++++++ .../oauth/DatabricksOAuthTokenSourceTest.java | 333 ++++++++++++++++++ 2 files changed, 533 insertions(+) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java new file mode 100644 index 000000000..307fc2cb2 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -0,0 +1,200 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Strings; +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.Map; + +/** + * Implementation of TokenSource that handles OAuth token exchange for Databricks authentication. + * This class manages the OAuth token exchange flow using ID tokens to obtain access tokens. + */ +public class DatabricksOAuthTokenSource implements TokenSource { + /** OAuth client ID used for token exchange */ + private final String clientId; + /** Databricks account ID, used as audience if provided */ + private final String accountId; + /** OpenID Connect endpoints configuration */ + private final OpenIDConnectEndpoints endpoints; + /** Custom audience value for token exchange */ + private final String audience; + /** Source of ID tokens used in token exchange */ + private final IDTokenSource idTokenSource; + /** HTTP client for making token exchange requests */ + private final HttpClient httpClient; + + private DatabricksOAuthTokenSource(Builder builder) { + this.clientId = builder.clientId; + this.accountId = builder.accountId; + this.endpoints = builder.endpoints; + this.audience = builder.audience; + this.idTokenSource = builder.idTokenSource; + this.httpClient = builder.httpClient; + } + + /** + * Builder class for constructing DatabricksOAuthTokenSource instances. Provides a fluent + * interface for setting required and optional parameters. + */ + public static class Builder { + private final String clientId; + private final String host; + private final OpenIDConnectEndpoints endpoints; + private final IDTokenSource idTokenSource; + private final HttpClient httpClient; + private String accountId; + private String audience; + + /** + * Validates that a value is non-empty and non-null for required fields. + * + * @param value The value to validate + * @param fieldName The name of the field being validated + * @throws IllegalArgumentException if validation fails + */ + private static void validate(Object value, String fieldName) { + if (value == null) { + throw new IllegalArgumentException(fieldName + " must be non-null"); + } + if (value instanceof String && ((String) value).isEmpty()) { + throw new IllegalArgumentException(fieldName + " must be non-empty"); + } + } + + /** + * Creates a new Builder with required parameters. + * + * @param clientId OAuth client ID + * @param host Databricks host URL + * @param endpoints OpenID Connect endpoints configuration + * @param idTokenSource Source of ID tokens + * @param httpClient HTTP client for making requests + */ + public Builder( + String clientId, + String host, + OpenIDConnectEndpoints endpoints, + IDTokenSource idTokenSource, + HttpClient httpClient) { + validate(clientId, "ClientID"); + validate(host, "Host"); + validate(endpoints, "Endpoints"); + validate(idTokenSource, "IDTokenSource"); + validate(httpClient, "HttpClient"); + + this.clientId = clientId; + this.host = host; + this.endpoints = endpoints; + this.idTokenSource = idTokenSource; + this.httpClient = httpClient; + } + + /** + * Sets the Databricks account ID. + * + * @param accountId The account ID + * @return This builder instance + */ + public Builder accountId(String accountId) { + validate(accountId, "AccountID"); + this.accountId = accountId; + return this; + } + + /** + * Sets a custom audience value for token exchange. + * + * @param audience The audience value + * @return This builder instance + */ + public Builder audience(String audience) { + validate(audience, "Audience"); + this.audience = audience; + return this; + } + + /** + * Builds a new DatabricksOAuthTokenSource instance. + * + * @return A new DatabricksOAuthTokenSource + */ + public DatabricksOAuthTokenSource build() { + return new DatabricksOAuthTokenSource(this); + } + } + + /** + * Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to + * obtain an access token. + * + * @return A Token containing the access token and related information + * @throws DatabricksException if token exchange fails + */ + @Override + public Token getToken() { + String effectiveAudience = determineAudience(); + IDToken idToken = idTokenSource.getIDToken(effectiveAudience); + + Map params = new HashMap<>(); + params.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange"); + params.put("subject_token", idToken.getValue()); + params.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt"); + params.put("scope", "all-apis"); + params.put("client_id", clientId); + + Response rawResponse; + try { + rawResponse = httpClient.execute(new FormRequest(endpoints.getTokenEndpoint(), params)); + } catch (IOException e) { + throw new DatabricksException( + String.format( + "Failed to exchange ID token for access token at %s: %s", + endpoints.getTokenEndpoint(), e.getMessage()), + e); + } + + OAuthResponse response; + try { + response = new ObjectMapper().readValue(rawResponse.getBody(), OAuthResponse.class); + } catch (IOException e) { + throw new DatabricksException( + String.format( + "Failed to parse OAuth response from token endpoint %s: %s", + endpoints.getTokenEndpoint(), e.getMessage())); + } + + if (response.getErrorCode() != null) { + throw new IllegalArgumentException( + String.format( + "Token exchange failed with error: %s - %s", + response.getErrorCode(), response.getErrorSummary())); + } + LocalDateTime expiry = LocalDateTime.now().plusSeconds(response.getExpiresIn()); + return new Token( + response.getAccessToken(), response.getTokenType(), response.getRefreshToken(), expiry); + } + + /** + * Determines the appropriate audience value for token exchange. Uses the following precedence: 1. + * Custom audience if provided 2. Account ID if provided 3. Token endpoint URL as fallback + * + * @return The determined audience value + */ + private String determineAudience() { + if (!Strings.isNullOrEmpty(audience)) { + return audience; + } + + if (!Strings.isNullOrEmpty(accountId)) { + return accountId; + } + + return endpoints.getTokenEndpoint(); + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java new file mode 100644 index 000000000..3ffc9e71a --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -0,0 +1,333 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +class DatabricksOAuthTokenSourceTest { + // Test constants + private static final String TOKEN = "test-access-token"; + private static final String TOKEN_TYPE = "Bearer"; + private static final String REFRESH_TOKEN = "test-refresh-token"; + private static final int EXPIRES_IN = 3600; + + private static final String TEST_HOST = "https://test.databricks.com"; + private static final String TEST_TOKEN_ENDPOINT = TEST_HOST + "/oidc/v1/token"; + private static final String TEST_AUTHORIZATION_ENDPOINT = TEST_HOST + "/authorize"; + private static final String TEST_CLIENT_ID = "test-client-id"; + private static final String TEST_ID_TOKEN = "test-id-token"; + private static final String TEST_AUDIENCE = "test-audience"; + private static final String TEST_ACCOUNT_ID = "test-account-id"; + + private IDTokenSource mockIdTokenSource; + + @BeforeEach + void setUp() { + mockIdTokenSource = Mockito.mock(IDTokenSource.class); + IDToken idToken = new IDToken(TEST_ID_TOKEN); + when(mockIdTokenSource.getIDToken(any())).thenReturn(idToken); + } + + /** + * Test case data for parameterized token source tests. Each case defines a specific OAuth token + * exchange scenario. + */ + private static class TestCase { + final String name; // Descriptive name of the test case + final String audience; // Custom audience value if provided + final String accountId; // Account ID if provided + final String expectedAudience; // Expected audience used in token exchange + final boolean expectError; // Whether this case should result in an error + final int statusCode; // HTTP status code for the response + final String responseBody; // Response body from the token endpoint + + TestCase( + String name, + String audience, + String accountId, + String expectedAudience, + boolean expectError, + int statusCode, + String responseBody) { + this.name = name; + this.audience = audience; + this.accountId = accountId; + this.expectedAudience = expectedAudience; + this.expectError = expectError; + this.statusCode = statusCode; + this.responseBody = responseBody; + } + + @Override + public String toString() { + return name; + } + } + + /** + * Provides test cases for OAuth token exchange scenarios. Includes success cases with different + * audience configurations and various error cases. + */ + private static Stream provideTestCases() { + // Success response with valid token data + Map successResponse = new HashMap<>(); + successResponse.put("access_token", TOKEN); + successResponse.put("token_type", TOKEN_TYPE); + successResponse.put("refresh_token", REFRESH_TOKEN); + successResponse.put("expires_in", EXPIRES_IN); + + // Error response for invalid requests + Map errorResponse = new HashMap<>(); + errorResponse.put("error", "invalid_request"); + errorResponse.put("error_description", "Invalid client ID"); + + ObjectMapper mapper = new ObjectMapper(); + String successJson; + String errorJson; + try { + successJson = mapper.writeValueAsString(successResponse); + errorJson = mapper.writeValueAsString(errorResponse); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return Stream.of( + // Success cases with different audience configurations + new TestCase( + "Default audience from token endpoint", + null, + null, + TEST_TOKEN_ENDPOINT, + false, + 200, + successJson), + new TestCase( + "Custom audience provided", + TEST_AUDIENCE, + null, + TEST_AUDIENCE, + false, + 200, + successJson), + new TestCase( + "Custom audience takes precedence over account ID", + TEST_AUDIENCE, + TEST_ACCOUNT_ID, + TEST_AUDIENCE, + false, + 200, + successJson), + new TestCase( + "Account ID used as audience when no custom audience", + null, + TEST_ACCOUNT_ID, + TEST_ACCOUNT_ID, + false, + 200, + successJson), + // Error cases + new TestCase( + "Invalid request returns 400", null, null, TEST_TOKEN_ENDPOINT, true, 400, errorJson), + new TestCase( + "Network error during token exchange", null, null, TEST_TOKEN_ENDPOINT, true, 0, null), + new TestCase( + "Invalid JSON response from server", + null, + null, + TEST_TOKEN_ENDPOINT, + true, + 200, + "invalid json")); + } + + /** + * Tests OAuth token exchange with various configurations and error scenarios. Verifies correct + * audience selection, token exchange, and error handling. + */ + @ParameterizedTest(name = "testTokenSource: {arguments}") + @MethodSource("provideTestCases") + void testTokenSource(TestCase testCase) throws IOException { + // Mock HTTP client with test case specific behavior + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + if (testCase.expectError) { + if (testCase.statusCode == 0) { + when(mockHttpClient.execute(any())).thenThrow(new IOException("Network error")); + } else { + when(mockHttpClient.execute(any())) + .thenReturn( + new Response( + testCase.responseBody, testCase.statusCode, "Bad Request", new URL(TEST_HOST))); + } + } else { + when(mockHttpClient.execute(any())) + .thenReturn(new Response(testCase.responseBody, new URL(TEST_HOST))); + } + + // Create token source with test configuration + OpenIDConnectEndpoints endpoints = + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + + DatabricksOAuthTokenSource.Builder builder = + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, TEST_HOST, endpoints, mockIdTokenSource, mockHttpClient); + + if (testCase.audience != null) { + builder.audience(testCase.audience); + } + if (testCase.accountId != null) { + builder.accountId(testCase.accountId); + } + + DatabricksOAuthTokenSource tokenSource = builder.build(); + + if (testCase.expectError) { + if (testCase.statusCode == 400) { + assertThrows(IllegalArgumentException.class, () -> tokenSource.getToken()); + } else { + assertThrows(DatabricksException.class, () -> tokenSource.getToken()); + } + } else { + // Verify successful token exchange + Token token = tokenSource.getToken(); + assertEquals(TOKEN, token.getAccessToken()); + assertEquals(TOKEN_TYPE, token.getTokenType()); + assertEquals(REFRESH_TOKEN, token.getRefreshToken()); + assertFalse(token.isExpired()); + + // Verify correct audience was used + verify(mockIdTokenSource).getIDToken(testCase.expectedAudience); + + // Verify token exchange request + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FormRequest.class); + verify(mockHttpClient).execute(requestCaptor.capture()); + + FormRequest capturedRequest = requestCaptor.getValue(); + assertEquals(TEST_TOKEN_ENDPOINT, capturedRequest.getUrl()); + + // Verify request parameters + String body = capturedRequest.getBodyString(); + assertTrue(body.contains("client_id=" + TEST_CLIENT_ID)); + assertTrue(body.contains("subject_token=" + TEST_ID_TOKEN)); + assertTrue( + body.contains("subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt")); + assertTrue( + body.contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange")); + assertTrue(body.contains("scope=all-apis")); + } + } + + /** + * Tests validation of required fields in the token source builder. Verifies that null or empty + * values for required fields throw IllegalArgumentException. + */ + @Test + void testConstructorValidation() { + // Test null client ID + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + null, + TEST_HOST, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test empty client ID + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + "", + TEST_HOST, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test null host + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + null, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test empty host + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + "", + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test null endpoints + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + TEST_HOST, + null, + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test null IDTokenSource + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + TEST_HOST, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + null, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test null HttpClient + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + TEST_HOST, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + null) + .build(); + }); + } +} From d210a605526d2ebf76288e0b833fe358921bc6f6 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 1 May 2025 12:53:49 +0000 Subject: [PATCH 2/5] Add GithubIDTokenSource --- .../sdk/core/oauth/GithubIDTokenSource.java | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java new file mode 100644 index 000000000..9f0d3349d --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java @@ -0,0 +1,99 @@ +package com.databricks.sdk.core.oauth; + + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.base.Strings; +import java.io.IOException; + + +/** +* GithubIDTokenSource retrieves JWT Tokens from GitHub Actions. +*/ +public class GithubIDTokenSource implements IDTokenSource { + private final String actionsIDTokenRequestURL; + private final String actionsIDTokenRequestToken; + private final HttpClient httpClient; + private final ObjectMapper mapper = new ObjectMapper(); + + + /** + * Constructs a new GithubIDTokenSource. + * + * @param actionsIDTokenRequestURL The URL to request the ID token from GitHub Actions. + * @param actionsIDTokenRequestToken The token used to authenticate the request. + * @param httpClient The HTTP client to use for making requests. + */ + public GithubIDTokenSource( + String actionsIDTokenRequestURL, + String actionsIDTokenRequestToken, + HttpClient httpClient) { + this.actionsIDTokenRequestURL = actionsIDTokenRequestURL; + this.actionsIDTokenRequestToken = actionsIDTokenRequestToken; + this.httpClient = httpClient; + } + + + @Override + public IDToken getIDToken(String audience) { + if (Strings.isNullOrEmpty(actionsIDTokenRequestURL)) { + throw new DatabricksException("missing ActionsIDTokenRequestURL"); + } + if (Strings.isNullOrEmpty(actionsIDTokenRequestToken)) { + throw new DatabricksException("missing ActionsIDTokenRequestToken"); + } + + + String requestUrl = actionsIDTokenRequestURL; + if (!Strings.isNullOrEmpty(audience)) { + requestUrl = String.format("%s&audience=%s", requestUrl, audience); + } + + + Request req = + new Request("GET", requestUrl) + .withHeader("Authorization", "Bearer " + actionsIDTokenRequestToken); + + + Response resp; + try { + resp = httpClient.execute(req); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request ID token from " + requestUrl + ": " + e.getMessage(), e); + } + + + if (resp.getStatusCode() != 200) { + throw new DatabricksException( + "Failed to request ID token: status code " + + resp.getStatusCode() + + ", response body: " + + resp.getBody().toString()); + } + + + ObjectNode jsonResp; + try { + jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class); + } catch (IOException e) { + throw new DatabricksException("Failed to request ID token: corrupted token: " + e.getMessage()); + } + + + String tokenValue = jsonResp.get("value").textValue(); + if (Strings.isNullOrEmpty(tokenValue)) { + throw new DatabricksException("Received empty ID token from GitHub Actions"); + } + + + return new IDToken(tokenValue); + } +} + + + From c66199dd44735d29a99995b2c11109f4f59bea5e Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 2 May 2025 10:11:30 +0000 Subject: [PATCH 3/5] Add unit tests --- .../sdk/core/oauth/GithubIDTokenSource.java | 20 ++- .../core/oauth/GithubIDTokenSourceTest.java | 144 ++++++++++++++++++ 2 files changed, 156 insertions(+), 8 deletions(-) create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java index 9f0d3349d..7bd7ff4a1 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java @@ -32,6 +32,15 @@ public GithubIDTokenSource( String actionsIDTokenRequestURL, String actionsIDTokenRequestToken, HttpClient httpClient) { + if (Strings.isNullOrEmpty(actionsIDTokenRequestURL)) { + throw new DatabricksException("Missing ActionsIDTokenRequestURL"); + } + if (Strings.isNullOrEmpty(actionsIDTokenRequestToken)) { + throw new DatabricksException("Missing ActionsIDTokenRequestToken"); + } + if (httpClient == null) { + throw new DatabricksException("HttpClient cannot be null"); + } this.actionsIDTokenRequestURL = actionsIDTokenRequestURL; this.actionsIDTokenRequestToken = actionsIDTokenRequestToken; this.httpClient = httpClient; @@ -40,14 +49,6 @@ public GithubIDTokenSource( @Override public IDToken getIDToken(String audience) { - if (Strings.isNullOrEmpty(actionsIDTokenRequestURL)) { - throw new DatabricksException("missing ActionsIDTokenRequestURL"); - } - if (Strings.isNullOrEmpty(actionsIDTokenRequestToken)) { - throw new DatabricksException("missing ActionsIDTokenRequestToken"); - } - - String requestUrl = actionsIDTokenRequestURL; if (!Strings.isNullOrEmpty(audience)) { requestUrl = String.format("%s&audience=%s", requestUrl, audience); @@ -84,6 +85,9 @@ public IDToken getIDToken(String audience) { throw new DatabricksException("Failed to request ID token: corrupted token: " + e.getMessage()); } + if (!jsonResp.has("value")) { + throw new DatabricksException("ID token response missing 'value' field"); + } String tokenValue = jsonResp.get("value").textValue(); if (Strings.isNullOrEmpty(tokenValue)) { diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java new file mode 100644 index 000000000..d2a56185d --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java @@ -0,0 +1,144 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +public class GithubIDTokenSourceTest { + private static final String TEST_REQUEST_URL = "https://github.com/token"; + private static final String TEST_REQUEST_TOKEN = "test-request-token"; + private static final String TEST_ID_TOKEN = "test-id-token"; + private static final String TEST_AUDIENCE = "test-audience"; + + @Mock + private HttpClient mockHttpClient; + + private GithubIDTokenSource tokenSource; + private ObjectMapper mapper; + + @BeforeEach + void setUp() throws IOException { + MockitoAnnotations.openMocks(this); + mapper = new ObjectMapper(); + tokenSource = new GithubIDTokenSource(TEST_REQUEST_URL, TEST_REQUEST_TOKEN, mockHttpClient); + } + + @Test + void testSuccessfulTokenRetrieval() throws IOException { + // Prepare mock response + ObjectNode responseJson = mapper.createObjectNode(); + responseJson.put("value", TEST_ID_TOKEN); + Response mockResponse = makeResponse(responseJson.toString(), 200); + when(mockHttpClient.execute(any(Request.class))).thenReturn(mockResponse); + + // Test token retrieval + IDToken token = tokenSource.getIDToken(TEST_AUDIENCE); + + assertNotNull(token); + assertEquals(TEST_ID_TOKEN, token.getValue()); + + // Verify the request was made with correct parameters + verify(mockHttpClient).execute(argThat(request -> { + return request.getMethod().equals("GET") && + request.getUrl().startsWith(TEST_REQUEST_URL) && + request.getUrl().contains("audience=" + TEST_AUDIENCE) && + request.getHeaders().get("Authorization").equals("Bearer " + TEST_REQUEST_TOKEN); + })); + } + + @Test + void testSuccessfulTokenRetrievalWithoutAudience() throws IOException { + // Prepare mock response + ObjectNode responseJson = mapper.createObjectNode(); + responseJson.put("value", TEST_ID_TOKEN); + Response mockResponse = makeResponse(responseJson.toString(), 200); + when(mockHttpClient.execute(any(Request.class))).thenReturn(mockResponse); + + // Test token retrieval without audience + IDToken token = tokenSource.getIDToken(""); + + assertNotNull(token); + assertEquals(TEST_ID_TOKEN, token.getValue()); + + // Verify the request was made with correct parameters + verify(mockHttpClient).execute(argThat(request -> { + return request.getMethod().equals("GET") && + request.getUrl().equals(TEST_REQUEST_URL) && + request.getHeaders().get("Authorization").equals("Bearer " + TEST_REQUEST_TOKEN); + })); + } + + @Test + void testMissingRequestURL() { + assertThrows(DatabricksException.class, + () -> new GithubIDTokenSource(null, TEST_REQUEST_TOKEN, mockHttpClient)); + } + + @Test + void testMissingRequestToken() { + assertThrows(DatabricksException.class, + () -> new GithubIDTokenSource(TEST_REQUEST_URL, null, mockHttpClient)); + } + + @Test + void testHttpClientError() throws IOException { + when(mockHttpClient.execute(any(Request.class))) + .thenThrow(new IOException("Network error")); + + assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); + } + + @Test + void testNonSuccessStatusCode() throws IOException { + Response errorResponse = makeResponse("Error response", 400); + when(mockHttpClient.execute(any(Request.class))).thenReturn(errorResponse); + + assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); + } + + @Test + void testInvalidJsonResponse() throws IOException { + Response invalidResponse = makeResponse("invalid json", 200); + when(mockHttpClient.execute(any(Request.class))).thenReturn(invalidResponse); + + assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); + } + + @Test + void testMissingTokenValue() throws IOException { + ObjectNode responseJson = mapper.createObjectNode(); + // Deliberately omit the 'value' field + Response invalidResponse = makeResponse(responseJson.toString(), 200); + when(mockHttpClient.execute(any(Request.class))).thenReturn(invalidResponse); + + assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); + } + + @Test + void testEmptyTokenValue() throws IOException { + ObjectNode responseJson = mapper.createObjectNode(); + responseJson.put("value", ""); + Response invalidResponse = makeResponse(responseJson.toString(), 200); + when(mockHttpClient.execute(any(Request.class))).thenReturn(invalidResponse); + + assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); + } + + private static Response makeResponse(String body, int status) throws MalformedURLException { + return new Response(body, status, "status", new URL("https://databricks.com/")); + } +} From 10f1a70cc12d9ebdb3fb19f006dbc854533b3adb Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 6 May 2025 13:35:07 +0000 Subject: [PATCH 4/5] Add TokenSourceCredentialsProvider --- .../databricks/sdk/core/DatabricksConfig.java | 20 +- .../sdk/core/DefaultCredentialsProvider.java | 99 ++++--- .../sdk/core/oauth/GithubIDTokenSource.java | 166 ++++++------ .../oauth/TokenSourceCredentialsProvider.java | 44 +++ .../core/oauth/GithubIDTokenSourceTest.java | 254 +++++++++--------- 5 files changed, 336 insertions(+), 247 deletions(-) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index fa89f5041..7aa475258 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -35,6 +35,9 @@ public class DatabricksConfig { @ConfigAttribute(env = "DATABRICKS_SCOPES", auth = "oauth") private List scopes; + @ConfigAttribute(env = "DATABRICKS_TOKEN_AUDIENCE", auth = "oauth") + private String audience; + @ConfigAttribute(env = "DATABRICKS_REDIRECT_URL", auth = "oauth") private String redirectUrl; @@ -302,6 +305,15 @@ public DatabricksConfig setClientSecret(String clientSecret) { return this; } + public String getAudience() { + return audience; + } + + public DatabricksConfig setAudience(String audience) { + this.audience = audience; + return this; + } + public String getOAuthRedirectUrl() { return redirectUrl; } @@ -374,13 +386,17 @@ public DatabricksConfig setAzureUseMsi(boolean azureUseMsi) { return this; } - /** @deprecated Use {@link #getAzureUseMsi()} instead. */ + /** + * @deprecated Use {@link #getAzureUseMsi()} instead. + */ @Deprecated() public boolean getAzureUseMSI() { return azureUseMsi; } - /** @deprecated Use {@link #setAzureUseMsi(boolean)} instead. */ + /** + * @deprecated Use {@link #setAzureUseMsi(boolean)} instead. + */ @Deprecated public DatabricksConfig setAzureUseMSI(boolean azureUseMsi) { this.azureUseMsi = azureUseMsi; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index f16cded39..19ebad345 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -2,58 +2,38 @@ import com.databricks.sdk.core.oauth.AzureGithubOidcCredentialsProvider; import com.databricks.sdk.core.oauth.AzureServicePrincipalCredentialsProvider; +import com.databricks.sdk.core.oauth.DatabricksOAuthTokenSource; import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider; +import com.databricks.sdk.core.oauth.GithubIDTokenSource; +import com.databricks.sdk.core.oauth.IDTokenSource; import com.databricks.sdk.core.oauth.OAuthM2MServicePrincipalCredentialsProvider; +import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints; +import com.databricks.sdk.core.oauth.TokenSourceCredentialsProvider; import java.util.ArrayList; -import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class DefaultCredentialsProvider implements CredentialsProvider { private static final Logger LOG = LoggerFactory.getLogger(DefaultCredentialsProvider.class); - private static final List> providerClasses = - Arrays.asList( - PatCredentialsProvider.class, - BasicCredentialsProvider.class, - OAuthM2MServicePrincipalCredentialsProvider.class, - AzureGithubOidcCredentialsProvider.class, - AzureServicePrincipalCredentialsProvider.class, - AzureCliCredentialsProvider.class, - ExternalBrowserCredentialsProvider.class, - DatabricksCliCredentialsProvider.class, - NotebookNativeCredentialsProvider.class, - GoogleCredentialsCredentialsProvider.class, - GoogleIdCredentialsProvider.class); - - private final List providers; + private List providers = new ArrayList<>(); private String authType = "default"; + public DefaultCredentialsProvider() {} + + @Override public String authType() { return authType; } - public DefaultCredentialsProvider() { - providers = new ArrayList<>(); - for (Class clazz : providerClasses) { - try { - providers.add((CredentialsProvider) clazz.newInstance()); - } catch (NoClassDefFoundError | InstantiationException | IllegalAccessException e) { - LOG.warn( - "Failed to instantiate credentials provider: " - + clazz.getName() - + ", skipping. Cause: " - + e.getClass().getCanonicalName() - + ": " - + e.getMessage()); - } - } - } - @Override public synchronized HeaderFactory configure(DatabricksConfig config) { + addDefaultCredentialsProviders(config); + for (CredentialsProvider provider : providers) { if (config.getAuthType() != null && !config.getAuthType().isEmpty() @@ -82,4 +62,57 @@ public synchronized HeaderFactory configure(DatabricksConfig config) { + authFlowUrl + " to configure credentials for your preferred authentication method"); } + + private void addDefaultCredentialsProviders(DatabricksConfig config) { + providers.add(new PatCredentialsProvider()); + providers.add(new BasicCredentialsProvider()); + providers.add(new OAuthM2MServicePrincipalCredentialsProvider()); + + addOIDCTokenCredentialsProviders(config); + + providers.add(new AzureGithubOidcCredentialsProvider()); + providers.add(new AzureServicePrincipalCredentialsProvider()); + providers.add(new AzureCliCredentialsProvider()); + providers.add(new ExternalBrowserCredentialsProvider()); + providers.add(new DatabricksCliCredentialsProvider()); + providers.add(new NotebookNativeCredentialsProvider()); + providers.add(new GoogleCredentialsCredentialsProvider()); + providers.add(new GoogleIdCredentialsProvider()); + } + + private void addOIDCTokenCredentialsProviders(DatabricksConfig config) { + OpenIDConnectEndpoints endpoints = null; + try { + endpoints = config.getOidcEndpoints(); + } catch (Exception e) { + LOG.error("Error getting OIDC endpoints", e); + } + + Map namedIdTokenSources = new HashMap<>(); + namedIdTokenSources.put( + "github-oidc", + new GithubIDTokenSource( + config.getActionsIdTokenRequestUrl(), + config.getActionsIdTokenRequestToken(), + config.getHttpClient())); + // Add new providers to the map as needed + + for (Map.Entry entry : namedIdTokenSources.entrySet()) { + String name = entry.getKey(); + IDTokenSource idTokenSource = entry.getValue(); + + DatabricksOAuthTokenSource oauthTokenSource = + new DatabricksOAuthTokenSource.Builder( + config.getClientId(), + config.getHost(), + endpoints, + idTokenSource, + config.getHttpClient()) + .audience(config.getAudience()) + .accountId(config.isAccountClient() ? config.getAccountId() : null) + .build(); + + providers.add(new TokenSourceCredentialsProvider(oauthTokenSource, name)); + } + } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java index 7bd7ff4a1..362719bda 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java @@ -1,6 +1,5 @@ package com.databricks.sdk.core.oauth; - import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Request; @@ -10,94 +9,81 @@ import com.google.common.base.Strings; import java.io.IOException; - -/** -* GithubIDTokenSource retrieves JWT Tokens from GitHub Actions. -*/ +/** GithubIDTokenSource retrieves JWT Tokens from GitHub Actions. */ public class GithubIDTokenSource implements IDTokenSource { - private final String actionsIDTokenRequestURL; - private final String actionsIDTokenRequestToken; - private final HttpClient httpClient; - private final ObjectMapper mapper = new ObjectMapper(); - - - /** - * Constructs a new GithubIDTokenSource. - * - * @param actionsIDTokenRequestURL The URL to request the ID token from GitHub Actions. - * @param actionsIDTokenRequestToken The token used to authenticate the request. - * @param httpClient The HTTP client to use for making requests. - */ - public GithubIDTokenSource( - String actionsIDTokenRequestURL, - String actionsIDTokenRequestToken, - HttpClient httpClient) { - if (Strings.isNullOrEmpty(actionsIDTokenRequestURL)) { - throw new DatabricksException("Missing ActionsIDTokenRequestURL"); - } - if (Strings.isNullOrEmpty(actionsIDTokenRequestToken)) { - throw new DatabricksException("Missing ActionsIDTokenRequestToken"); - } - if (httpClient == null) { - throw new DatabricksException("HttpClient cannot be null"); - } - this.actionsIDTokenRequestURL = actionsIDTokenRequestURL; - this.actionsIDTokenRequestToken = actionsIDTokenRequestToken; - this.httpClient = httpClient; - } - - - @Override - public IDToken getIDToken(String audience) { - String requestUrl = actionsIDTokenRequestURL; - if (!Strings.isNullOrEmpty(audience)) { - requestUrl = String.format("%s&audience=%s", requestUrl, audience); - } - - - Request req = - new Request("GET", requestUrl) - .withHeader("Authorization", "Bearer " + actionsIDTokenRequestToken); - - - Response resp; - try { - resp = httpClient.execute(req); - } catch (IOException e) { - throw new DatabricksException( - "Failed to request ID token from " + requestUrl + ": " + e.getMessage(), e); - } - - - if (resp.getStatusCode() != 200) { - throw new DatabricksException( - "Failed to request ID token: status code " - + resp.getStatusCode() - + ", response body: " - + resp.getBody().toString()); - } - - - ObjectNode jsonResp; - try { - jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class); - } catch (IOException e) { - throw new DatabricksException("Failed to request ID token: corrupted token: " + e.getMessage()); - } - - if (!jsonResp.has("value")) { - throw new DatabricksException("ID token response missing 'value' field"); - } - - String tokenValue = jsonResp.get("value").textValue(); - if (Strings.isNullOrEmpty(tokenValue)) { - throw new DatabricksException("Received empty ID token from GitHub Actions"); - } - - - return new IDToken(tokenValue); - } + private final String actionsIDTokenRequestURL; + private final String actionsIDTokenRequestToken; + private final HttpClient httpClient; + private final ObjectMapper mapper = new ObjectMapper(); + + /** + * Constructs a new GithubIDTokenSource. + * + * @param actionsIDTokenRequestURL The URL to request the ID token from GitHub Actions. + * @param actionsIDTokenRequestToken The token used to authenticate the request. + * @param httpClient The HTTP client to use for making requests. + */ + public GithubIDTokenSource( + String actionsIDTokenRequestURL, String actionsIDTokenRequestToken, HttpClient httpClient) { + this.actionsIDTokenRequestURL = actionsIDTokenRequestURL; + this.actionsIDTokenRequestToken = actionsIDTokenRequestToken; + this.httpClient = httpClient; + } + + @Override + public IDToken getIDToken(String audience) { + if (Strings.isNullOrEmpty(actionsIDTokenRequestURL)) { + throw new DatabricksException("Missing ActionsIDTokenRequestURL"); + } + if (Strings.isNullOrEmpty(actionsIDTokenRequestToken)) { + throw new DatabricksException("Missing ActionsIDTokenRequestToken"); + } + if (httpClient == null) { + throw new DatabricksException("HttpClient cannot be null"); + } + + String requestUrl = actionsIDTokenRequestURL; + if (!Strings.isNullOrEmpty(audience)) { + requestUrl = String.format("%s&audience=%s", requestUrl, audience); + } + + Request req = + new Request("GET", requestUrl) + .withHeader("Authorization", "Bearer " + actionsIDTokenRequestToken); + + Response resp; + try { + resp = httpClient.execute(req); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request ID token from " + requestUrl + ": " + e.getMessage(), e); + } + + if (resp.getStatusCode() != 200) { + throw new DatabricksException( + "Failed to request ID token: status code " + + resp.getStatusCode() + + ", response body: " + + resp.getBody().toString()); + } + + ObjectNode jsonResp; + try { + jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request ID token: corrupted token: " + e.getMessage()); + } + + if (!jsonResp.has("value")) { + throw new DatabricksException("ID token response missing 'value' field"); + } + + String tokenValue = jsonResp.get("value").textValue(); + if (Strings.isNullOrEmpty(tokenValue)) { + throw new DatabricksException("Received empty ID token from GitHub Actions"); + } + + return new IDToken(tokenValue); + } } - - - diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java new file mode 100644 index 000000000..233d839db --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java @@ -0,0 +1,44 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.CredentialsProvider; +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.HeaderFactory; +import java.util.HashMap; +import java.util.Map; + +/** Base class for token-based credentials providers. */ +public class TokenSourceCredentialsProvider implements CredentialsProvider { + private final TokenSource tokenSource; + private final String authType; + + /** + * Creates a new TokenSourceCredentialsProvider with the specified token source and auth type. + * + * @param tokenSource The token source to use for token exchange + * @param authType The authentication type string + */ + public TokenSourceCredentialsProvider(TokenSource tokenSource, String authType) { + this.tokenSource = tokenSource; + this.authType = authType; + } + + @Override + public HeaderFactory configure(DatabricksConfig config) { + + return () -> { + Map headers = new HashMap<>(); + try { + String accessToken = tokenSource.getToken().getAccessToken(); + headers.put("Authorization", "Bearer " + accessToken); + return headers; + } catch (Exception e) { + return null; + } + }; + } + + @Override + public String authType() { + return authType; + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java index d2a56185d..e53aaefa8 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java @@ -13,132 +13,142 @@ import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.MockitoAnnotations; public class GithubIDTokenSourceTest { - private static final String TEST_REQUEST_URL = "https://github.com/token"; - private static final String TEST_REQUEST_TOKEN = "test-request-token"; - private static final String TEST_ID_TOKEN = "test-id-token"; - private static final String TEST_AUDIENCE = "test-audience"; - - @Mock - private HttpClient mockHttpClient; - - private GithubIDTokenSource tokenSource; - private ObjectMapper mapper; - - @BeforeEach - void setUp() throws IOException { - MockitoAnnotations.openMocks(this); - mapper = new ObjectMapper(); - tokenSource = new GithubIDTokenSource(TEST_REQUEST_URL, TEST_REQUEST_TOKEN, mockHttpClient); - } - - @Test - void testSuccessfulTokenRetrieval() throws IOException { - // Prepare mock response - ObjectNode responseJson = mapper.createObjectNode(); - responseJson.put("value", TEST_ID_TOKEN); - Response mockResponse = makeResponse(responseJson.toString(), 200); - when(mockHttpClient.execute(any(Request.class))).thenReturn(mockResponse); - - // Test token retrieval - IDToken token = tokenSource.getIDToken(TEST_AUDIENCE); - - assertNotNull(token); - assertEquals(TEST_ID_TOKEN, token.getValue()); - - // Verify the request was made with correct parameters - verify(mockHttpClient).execute(argThat(request -> { - return request.getMethod().equals("GET") && - request.getUrl().startsWith(TEST_REQUEST_URL) && - request.getUrl().contains("audience=" + TEST_AUDIENCE) && - request.getHeaders().get("Authorization").equals("Bearer " + TEST_REQUEST_TOKEN); - })); - } - - @Test - void testSuccessfulTokenRetrievalWithoutAudience() throws IOException { - // Prepare mock response - ObjectNode responseJson = mapper.createObjectNode(); - responseJson.put("value", TEST_ID_TOKEN); - Response mockResponse = makeResponse(responseJson.toString(), 200); - when(mockHttpClient.execute(any(Request.class))).thenReturn(mockResponse); - - // Test token retrieval without audience - IDToken token = tokenSource.getIDToken(""); - - assertNotNull(token); - assertEquals(TEST_ID_TOKEN, token.getValue()); - - // Verify the request was made with correct parameters - verify(mockHttpClient).execute(argThat(request -> { - return request.getMethod().equals("GET") && - request.getUrl().equals(TEST_REQUEST_URL) && - request.getHeaders().get("Authorization").equals("Bearer " + TEST_REQUEST_TOKEN); - })); - } - - @Test - void testMissingRequestURL() { - assertThrows(DatabricksException.class, - () -> new GithubIDTokenSource(null, TEST_REQUEST_TOKEN, mockHttpClient)); - } - - @Test - void testMissingRequestToken() { - assertThrows(DatabricksException.class, - () -> new GithubIDTokenSource(TEST_REQUEST_URL, null, mockHttpClient)); - } - - @Test - void testHttpClientError() throws IOException { - when(mockHttpClient.execute(any(Request.class))) - .thenThrow(new IOException("Network error")); - - assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); - } - - @Test - void testNonSuccessStatusCode() throws IOException { - Response errorResponse = makeResponse("Error response", 400); - when(mockHttpClient.execute(any(Request.class))).thenReturn(errorResponse); - - assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); - } - - @Test - void testInvalidJsonResponse() throws IOException { - Response invalidResponse = makeResponse("invalid json", 200); - when(mockHttpClient.execute(any(Request.class))).thenReturn(invalidResponse); - - assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); - } - - @Test - void testMissingTokenValue() throws IOException { - ObjectNode responseJson = mapper.createObjectNode(); - // Deliberately omit the 'value' field - Response invalidResponse = makeResponse(responseJson.toString(), 200); - when(mockHttpClient.execute(any(Request.class))).thenReturn(invalidResponse); - - assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); - } - - @Test - void testEmptyTokenValue() throws IOException { - ObjectNode responseJson = mapper.createObjectNode(); - responseJson.put("value", ""); - Response invalidResponse = makeResponse(responseJson.toString(), 200); - when(mockHttpClient.execute(any(Request.class))).thenReturn(invalidResponse); - - assertThrows(DatabricksException.class, () -> tokenSource.getIDToken(TEST_AUDIENCE)); - } - - private static Response makeResponse(String body, int status) throws MalformedURLException { - return new Response(body, status, "status", new URL("https://databricks.com/")); - } + private static final String TEST_REQUEST_URL = "https://github.com/token"; + private static final String TEST_REQUEST_TOKEN = "test-request-token"; + private static final String TEST_ID_TOKEN = "test-id-token"; + private static final String TEST_AUDIENCE = "test-audience"; + + @Mock private static HttpClient mockHttpClient; + + private GithubIDTokenSource tokenSource; + private ObjectMapper mapper; + + @BeforeEach + void setUp() throws IOException { + MockitoAnnotations.openMocks(this); + mapper = new ObjectMapper(); + tokenSource = new GithubIDTokenSource(TEST_REQUEST_URL, TEST_REQUEST_TOKEN, mockHttpClient); + } + + @Test + void testSuccessfulTokenRetrieval() throws IOException { + // Prepare mock response + ObjectNode responseJson = mapper.createObjectNode(); + responseJson.put("value", TEST_ID_TOKEN); + Response mockResponse = makeResponse(responseJson.toString(), 200); + when(mockHttpClient.execute(any(Request.class))).thenReturn(mockResponse); + + // Test token retrieval + IDToken token = tokenSource.getIDToken(TEST_AUDIENCE); + + assertNotNull(token); + assertEquals(TEST_ID_TOKEN, token.getValue()); + + // Verify the request was made with correct parameters + verify(mockHttpClient) + .execute( + argThat( + request -> { + return request.getMethod().equals("GET") + && request.getUrl().startsWith(TEST_REQUEST_URL) + && request.getUrl().contains("audience=" + TEST_AUDIENCE) + && request + .getHeaders() + .get("Authorization") + .equals("Bearer " + TEST_REQUEST_TOKEN); + })); + } + + @Test + void testSuccessfulTokenRetrievalWithoutAudience() throws IOException { + // Prepare mock response + ObjectNode responseJson = mapper.createObjectNode(); + responseJson.put("value", TEST_ID_TOKEN); + Response mockResponse = makeResponse(responseJson.toString(), 200); + when(mockHttpClient.execute(any(Request.class))).thenReturn(mockResponse); + + // Test token retrieval without audience + IDToken token = tokenSource.getIDToken(""); + + assertNotNull(token); + assertEquals(TEST_ID_TOKEN, token.getValue()); + + // Verify the request was made with correct parameters + verify(mockHttpClient) + .execute( + argThat( + request -> { + return request.getMethod().equals("GET") + && request.getUrl().equals(TEST_REQUEST_URL) + && request + .getHeaders() + .get("Authorization") + .equals("Bearer " + TEST_REQUEST_TOKEN); + })); + } + + private static Stream provideInvalidConstructorParameters() { + return Stream.of( + Arguments.of("Missing Request URL", null, TEST_REQUEST_TOKEN, mockHttpClient), + Arguments.of("Missing Request Token", TEST_REQUEST_URL, null, mockHttpClient), + Arguments.of("Null HttpClient", TEST_REQUEST_URL, TEST_REQUEST_TOKEN, null)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideInvalidConstructorParameters") + void testInvalidConstructorParameters( + String testName, String requestUrl, String requestToken, HttpClient httpClient) { + GithubIDTokenSource invalidSource = + new GithubIDTokenSource(requestUrl, requestToken, httpClient); + assertThrows(DatabricksException.class, () -> invalidSource.getIDToken(TEST_AUDIENCE)); + } + + private static Stream provideHttpErrorScenarios() throws IOException { + HttpClient httpClientError = mock(HttpClient.class); + when(httpClientError.execute(any(Request.class))).thenThrow(new IOException("Network error")); + + HttpClient nonSuccessClient = mock(HttpClient.class); + when(nonSuccessClient.execute(any(Request.class))) + .thenReturn(makeResponse("Error response", 400)); + + HttpClient invalidJsonClient = mock(HttpClient.class); + when(invalidJsonClient.execute(any(Request.class))) + .thenReturn(makeResponse("Invalid json", 200)); + + HttpClient missingTokenClient = mock(HttpClient.class); + when(missingTokenClient.execute(any(Request.class))).thenReturn(makeResponse("{}", 200)); + + HttpClient emptyTokenClient = mock(HttpClient.class); + when(emptyTokenClient.execute(any(Request.class))) + .thenReturn(makeResponse("{\"value\":\"\"}", 200)); + + return Stream.of( + Arguments.of("HTTP Client Error", httpClientError), + Arguments.of("Non-Success Status Code", nonSuccessClient), + Arguments.of("Invalid JSON Response", invalidJsonClient), + Arguments.of("Missing Token Value", missingTokenClient), + Arguments.of("Empty Token Value", emptyTokenClient)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideHttpErrorScenarios") + void testHttpErrorScenarios(String testName, HttpClient httpClient) { + GithubIDTokenSource source = + new GithubIDTokenSource(TEST_REQUEST_URL, TEST_REQUEST_TOKEN, httpClient); + assertThrows(DatabricksException.class, () -> source.getIDToken(TEST_AUDIENCE)); + } + + private static Response makeResponse(String body, int status) throws MalformedURLException { + return new Response(body, status, "status", new URL("https://databricks.com/")); + } } From 6ed6754308b234f28793da1f8e2dc288f58a03b1 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 6 May 2025 16:59:09 +0200 Subject: [PATCH 5/5] Emmyzhou db/db oauth token source (#443) ## What changes are proposed in this pull request? Merging emmyzhou-db/db-oauth-token-source into emmy-zhou/github-oidc-token-source. NO_CHANGELOG=true --- .../oauth/DatabricksOAuthTokenSource.java | 134 +++-- .../oauth/DatabricksOAuthTokenSourceTest.java | 501 ++++++++++-------- 2 files changed, 365 insertions(+), 270 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index 307fc2cb2..c8ac65ba2 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -10,27 +10,45 @@ import java.time.LocalDateTime; import java.util.HashMap; import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Implementation of TokenSource that handles OAuth token exchange for Databricks authentication. * This class manages the OAuth token exchange flow using ID tokens to obtain access tokens. */ public class DatabricksOAuthTokenSource implements TokenSource { - /** OAuth client ID used for token exchange */ + private static final Logger LOG = LoggerFactory.getLogger(DatabricksOAuthTokenSource.class); + + /** OAuth client ID used for token exchange. */ private final String clientId; - /** Databricks account ID, used as audience if provided */ + /** Databricks host URL. */ + private final String host; + /** Databricks account ID, used as audience if provided. */ private final String accountId; - /** OpenID Connect endpoints configuration */ + /** OpenID Connect endpoints configuration. */ private final OpenIDConnectEndpoints endpoints; - /** Custom audience value for token exchange */ + /** Custom audience value for token exchange. */ private final String audience; - /** Source of ID tokens used in token exchange */ + /** Source of ID tokens used in token exchange. */ private final IDTokenSource idTokenSource; - /** HTTP client for making token exchange requests */ + /** HTTP client for making token exchange requests. */ private final HttpClient httpClient; + private static final String GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"; + private static final String SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt"; + private static final String SCOPE = "all-apis"; + private static final String GRANT_TYPE_PARAM = "grant_type"; + private static final String SUBJECT_TOKEN_PARAM = "subject_token"; + private static final String SUBJECT_TOKEN_TYPE_PARAM = "subject_token_type"; + private static final String SCOPE_PARAM = "scope"; + private static final String CLIENT_ID_PARAM = "client_id"; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private DatabricksOAuthTokenSource(Builder builder) { this.clientId = builder.clientId; + this.host = builder.host; this.accountId = builder.accountId; this.endpoints = builder.endpoints; this.audience = builder.audience; @@ -39,8 +57,8 @@ private DatabricksOAuthTokenSource(Builder builder) { } /** - * Builder class for constructing DatabricksOAuthTokenSource instances. Provides a fluent - * interface for setting required and optional parameters. + * Builder class for constructing DatabricksOAuthTokenSource instances. Provides a flexible way to + * set required and optional parameters. */ public static class Builder { private final String clientId; @@ -51,30 +69,14 @@ public static class Builder { private String accountId; private String audience; - /** - * Validates that a value is non-empty and non-null for required fields. - * - * @param value The value to validate - * @param fieldName The name of the field being validated - * @throws IllegalArgumentException if validation fails - */ - private static void validate(Object value, String fieldName) { - if (value == null) { - throw new IllegalArgumentException(fieldName + " must be non-null"); - } - if (value instanceof String && ((String) value).isEmpty()) { - throw new IllegalArgumentException(fieldName + " must be non-empty"); - } - } - /** * Creates a new Builder with required parameters. * - * @param clientId OAuth client ID - * @param host Databricks host URL - * @param endpoints OpenID Connect endpoints configuration - * @param idTokenSource Source of ID tokens - * @param httpClient HTTP client for making requests + * @param clientId OAuth client ID. + * @param host Databricks host URL. + * @param endpoints OpenID Connect endpoints configuration. + * @param idTokenSource Source of ID tokens. + * @param httpClient HTTP client for making requests. */ public Builder( String clientId, @@ -82,12 +84,6 @@ public Builder( OpenIDConnectEndpoints endpoints, IDTokenSource idTokenSource, HttpClient httpClient) { - validate(clientId, "ClientID"); - validate(host, "Host"); - validate(endpoints, "Endpoints"); - validate(idTokenSource, "IDTokenSource"); - validate(httpClient, "HttpClient"); - this.clientId = clientId; this.host = host; this.endpoints = endpoints; @@ -98,11 +94,10 @@ public Builder( /** * Sets the Databricks account ID. * - * @param accountId The account ID - * @return This builder instance + * @param accountId The account ID. + * @return This builder instance. */ public Builder accountId(String accountId) { - validate(accountId, "AccountID"); this.accountId = accountId; return this; } @@ -114,7 +109,6 @@ public Builder accountId(String accountId) { * @return This builder instance */ public Builder audience(String audience) { - validate(audience, "Audience"); this.audience = audience; return this; } @@ -122,36 +116,71 @@ public Builder audience(String audience) { /** * Builds a new DatabricksOAuthTokenSource instance. * - * @return A new DatabricksOAuthTokenSource + * @return A new DatabricksOAuthTokenSource. */ public DatabricksOAuthTokenSource build() { return new DatabricksOAuthTokenSource(this); } } + /** + * Validates that a value is non-null for required fields. If the value is a string, it also + * checks that it is non-empty. + * + * @param value The value to validate. + * @param fieldName The name of the field being validated. + * @throws IllegalArgumentException when the value is null or an empty string. + */ + private static void validate(Object value, String fieldName) { + if (value == null) { + LOG.error("Required parameter '{}' is null", fieldName); + throw new IllegalArgumentException( + String.format("Required parameter '%s' cannot be null", fieldName)); + } + if (value instanceof String && ((String) value).isEmpty()) { + LOG.error("Required parameter '{}' is empty", fieldName); + throw new IllegalArgumentException( + String.format("Required parameter '%s' cannot be empty", fieldName)); + } + } + /** * Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to * obtain an access token. * - * @return A Token containing the access token and related information - * @throws DatabricksException if token exchange fails + * @return A Token containing the access token and related information. + * @throws DatabricksException when the token exchange fails. + * @throws IllegalArgumentException when there is an error code in the response or when required + * parameters are missing. */ @Override public Token getToken() { + // Validate all required parameters + validate(clientId, "ClientID"); + validate(host, "Host"); + validate(endpoints, "Endpoints"); + validate(idTokenSource, "IDTokenSource"); + validate(httpClient, "HttpClient"); + String effectiveAudience = determineAudience(); IDToken idToken = idTokenSource.getIDToken(effectiveAudience); Map params = new HashMap<>(); - params.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange"); - params.put("subject_token", idToken.getValue()); - params.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt"); - params.put("scope", "all-apis"); - params.put("client_id", clientId); + params.put(GRANT_TYPE_PARAM, GRANT_TYPE); + params.put(SUBJECT_TOKEN_PARAM, idToken.getValue()); + params.put(SUBJECT_TOKEN_TYPE_PARAM, SUBJECT_TOKEN_TYPE); + params.put(SCOPE_PARAM, SCOPE); + params.put(CLIENT_ID_PARAM, clientId); Response rawResponse; try { rawResponse = httpClient.execute(new FormRequest(endpoints.getTokenEndpoint(), params)); } catch (IOException e) { + LOG.error( + "Failed to exchange ID token for access token at {}: {}", + endpoints.getTokenEndpoint(), + e.getMessage(), + e); throw new DatabricksException( String.format( "Failed to exchange ID token for access token at %s: %s", @@ -161,8 +190,13 @@ public Token getToken() { OAuthResponse response; try { - response = new ObjectMapper().readValue(rawResponse.getBody(), OAuthResponse.class); + response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); } catch (IOException e) { + LOG.error( + "Failed to parse OAuth response from token endpoint {}: {}", + endpoints.getTokenEndpoint(), + e.getMessage(), + e); throw new DatabricksException( String.format( "Failed to parse OAuth response from token endpoint %s: %s", @@ -170,6 +204,10 @@ public Token getToken() { } if (response.getErrorCode() != null) { + LOG.error( + "Token exchange failed with error: {} - {}", + response.getErrorCode(), + response.getErrorSummary()); throw new IllegalArgumentException( String.format( "Token exchange failed with error: %s - %s", diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java index 3ffc9e71a..8d7da8d3a 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -10,15 +10,14 @@ import com.databricks.sdk.core.http.Response; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import java.net.MalformedURLException; import java.net.URL; import java.util.HashMap; import java.util.Map; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.ArgumentCaptor; import org.mockito.Mockito; class DatabricksOAuthTokenSourceTest { @@ -36,6 +35,10 @@ class DatabricksOAuthTokenSourceTest { private static final String TEST_AUDIENCE = "test-audience"; private static final String TEST_ACCOUNT_ID = "test-account-id"; + // Error message constants + private static final String ERROR_NULL = "Required parameter '%s' cannot be null"; + private static final String ERROR_EMPTY = "Required parameter '%s' cannot be empty"; + private IDTokenSource mockIdTokenSource; @BeforeEach @@ -54,25 +57,24 @@ private static class TestCase { final String audience; // Custom audience value if provided final String accountId; // Account ID if provided final String expectedAudience; // Expected audience used in token exchange - final boolean expectError; // Whether this case should result in an error - final int statusCode; // HTTP status code for the response - final String responseBody; // Response body from the token endpoint + final HttpClient mockHttpClient; // Pre-configured mock HTTP client + final Class expectedException; // Expected exception type if any TestCase( String name, String audience, String accountId, String expectedAudience, - boolean expectError, int statusCode, - String responseBody) { + Object responseBody, + HttpClient mockHttpClient, + Class expectedException) { this.name = name; this.audience = audience; this.accountId = accountId; this.expectedAudience = expectedAudience; - this.expectError = expectError; - this.statusCode = statusCode; - this.responseBody = responseBody; + this.mockHttpClient = mockHttpClient; + this.expectedException = expectedException; } @Override @@ -86,75 +88,124 @@ public String toString() { * audience configurations and various error cases. */ private static Stream provideTestCases() { - // Success response with valid token data - Map successResponse = new HashMap<>(); - successResponse.put("access_token", TOKEN); - successResponse.put("token_type", TOKEN_TYPE); - successResponse.put("refresh_token", REFRESH_TOKEN); - successResponse.put("expires_in", EXPIRES_IN); + try { + // Success response with valid token data + Map successResponse = new HashMap<>(); + successResponse.put("access_token", TOKEN); + successResponse.put("token_type", TOKEN_TYPE); + successResponse.put("refresh_token", REFRESH_TOKEN); + successResponse.put("expires_in", EXPIRES_IN); + + // Error response for invalid requests + Map errorResponse = new HashMap<>(); + errorResponse.put("error", "invalid_request"); + errorResponse.put("error_description", "Invalid client ID"); + + ObjectMapper mapper = new ObjectMapper(); + final String errorJson = mapper.writeValueAsString(errorResponse); + final String successJson = mapper.writeValueAsString(successResponse); - // Error response for invalid requests - Map errorResponse = new HashMap<>(); - errorResponse.put("error", "invalid_request"); - errorResponse.put("error_description", "Invalid client ID"); + // Create the expected request that will be used in all test cases + Map formParams = new HashMap<>(); + formParams.put("client_id", TEST_CLIENT_ID); + formParams.put("subject_token", TEST_ID_TOKEN); + formParams.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt"); + formParams.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange"); + formParams.put("scope", "all-apis"); + FormRequest expectedRequest = new FormRequest(TEST_TOKEN_ENDPOINT, formParams); + + return Stream.of( + // Success cases with different audience configurations + new TestCase( + "Default audience from token endpoint", + null, + null, + TEST_TOKEN_ENDPOINT, + 200, + successResponse, + createMockHttpClient(expectedRequest, 200, successJson), + null), + new TestCase( + "Custom audience provided", + TEST_AUDIENCE, + null, + TEST_AUDIENCE, + 200, + successResponse, + createMockHttpClient(expectedRequest, 200, successJson), + null), + new TestCase( + "Custom audience takes precedence over account ID", + TEST_AUDIENCE, + TEST_ACCOUNT_ID, + TEST_AUDIENCE, + 200, + successResponse, + createMockHttpClient(expectedRequest, 200, successJson), + null), + new TestCase( + "Account ID used as audience when no custom audience", + null, + TEST_ACCOUNT_ID, + TEST_ACCOUNT_ID, + 200, + successResponse, + createMockHttpClient(expectedRequest, 200, successJson), + null), + // Error cases + new TestCase( + "Invalid request returns 400", + null, + null, + TEST_TOKEN_ENDPOINT, + 400, + errorJson, + createMockHttpClient(expectedRequest, 400, errorJson), + IllegalArgumentException.class), + new TestCase( + "Network error during token exchange", + null, + null, + TEST_TOKEN_ENDPOINT, + 0, + null, + createMockHttpClientWithError(expectedRequest), + DatabricksException.class), + new TestCase( + "Invalid JSON response from server", + null, + null, + TEST_TOKEN_ENDPOINT, + 200, + "invalid json", + createMockHttpClient(expectedRequest, 200, "invalid json"), + DatabricksException.class)); + } catch (IOException e) { + throw new RuntimeException("Failed to create test cases", e); + } + } - ObjectMapper mapper = new ObjectMapper(); - String successJson; - String errorJson; + private static HttpClient createMockHttpClient( + FormRequest expectedRequest, int statusCode, String responseBody) { try { - successJson = mapper.writeValueAsString(successResponse); - errorJson = mapper.writeValueAsString(errorResponse); + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + String statusMessage = statusCode == 200 ? "OK" : "Bad Request"; + when(mockHttpClient.execute(expectedRequest)) + .thenReturn(new Response(responseBody, statusCode, statusMessage, new URL(TEST_HOST))); + return mockHttpClient; } catch (IOException e) { - throw new RuntimeException(e); + throw new RuntimeException("Failed to create mock HTTP client", e); } + } - return Stream.of( - // Success cases with different audience configurations - new TestCase( - "Default audience from token endpoint", - null, - null, - TEST_TOKEN_ENDPOINT, - false, - 200, - successJson), - new TestCase( - "Custom audience provided", - TEST_AUDIENCE, - null, - TEST_AUDIENCE, - false, - 200, - successJson), - new TestCase( - "Custom audience takes precedence over account ID", - TEST_AUDIENCE, - TEST_ACCOUNT_ID, - TEST_AUDIENCE, - false, - 200, - successJson), - new TestCase( - "Account ID used as audience when no custom audience", - null, - TEST_ACCOUNT_ID, - TEST_ACCOUNT_ID, - false, - 200, - successJson), - // Error cases - new TestCase( - "Invalid request returns 400", null, null, TEST_TOKEN_ENDPOINT, true, 400, errorJson), - new TestCase( - "Network error during token exchange", null, null, TEST_TOKEN_ENDPOINT, true, 0, null), - new TestCase( - "Invalid JSON response from server", - null, - null, - TEST_TOKEN_ENDPOINT, - true, - 200, - "invalid json")); + private static HttpClient createMockHttpClientWithError(FormRequest expectedRequest) { + try { + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + when(mockHttpClient.execute(expectedRequest)).thenThrow(new IOException("Network error")); + return mockHttpClient; + } catch (IOException e) { + throw new RuntimeException("Failed to create mock HTTP client with error", e); + } } /** @@ -163,171 +214,177 @@ private static Stream provideTestCases() { */ @ParameterizedTest(name = "testTokenSource: {arguments}") @MethodSource("provideTestCases") - void testTokenSource(TestCase testCase) throws IOException { - // Mock HTTP client with test case specific behavior - HttpClient mockHttpClient = Mockito.mock(HttpClient.class); - if (testCase.expectError) { - if (testCase.statusCode == 0) { - when(mockHttpClient.execute(any())).thenThrow(new IOException("Network error")); - } else { - when(mockHttpClient.execute(any())) - .thenReturn( - new Response( - testCase.responseBody, testCase.statusCode, "Bad Request", new URL(TEST_HOST))); - } - } else { - when(mockHttpClient.execute(any())) - .thenReturn(new Response(testCase.responseBody, new URL(TEST_HOST))); - } - - // Create token source with test configuration - OpenIDConnectEndpoints endpoints = - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + void testTokenSource(TestCase testCase) { + try { + // Create token source with test configuration + OpenIDConnectEndpoints endpoints = + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); - DatabricksOAuthTokenSource.Builder builder = - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, TEST_HOST, endpoints, mockIdTokenSource, mockHttpClient); + DatabricksOAuthTokenSource.Builder builder = + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, TEST_HOST, endpoints, mockIdTokenSource, testCase.mockHttpClient); - if (testCase.audience != null) { - builder.audience(testCase.audience); - } - if (testCase.accountId != null) { - builder.accountId(testCase.accountId); - } + builder.audience(testCase.audience).accountId(testCase.accountId); - DatabricksOAuthTokenSource tokenSource = builder.build(); + DatabricksOAuthTokenSource tokenSource = builder.build(); - if (testCase.expectError) { - if (testCase.statusCode == 400) { - assertThrows(IllegalArgumentException.class, () -> tokenSource.getToken()); + if (testCase.expectedException != null) { + assertThrows(testCase.expectedException, () -> tokenSource.getToken()); } else { - assertThrows(DatabricksException.class, () -> tokenSource.getToken()); - } - } else { - // Verify successful token exchange - Token token = tokenSource.getToken(); - assertEquals(TOKEN, token.getAccessToken()); - assertEquals(TOKEN_TYPE, token.getTokenType()); - assertEquals(REFRESH_TOKEN, token.getRefreshToken()); - assertFalse(token.isExpired()); - - // Verify correct audience was used - verify(mockIdTokenSource).getIDToken(testCase.expectedAudience); + // Verify successful token exchange + Token token = tokenSource.getToken(); + assertEquals(TOKEN, token.getAccessToken()); + assertEquals(TOKEN_TYPE, token.getTokenType()); + assertEquals(REFRESH_TOKEN, token.getRefreshToken()); + assertFalse(token.isExpired()); - // Verify token exchange request - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FormRequest.class); - verify(mockHttpClient).execute(requestCaptor.capture()); - - FormRequest capturedRequest = requestCaptor.getValue(); - assertEquals(TEST_TOKEN_ENDPOINT, capturedRequest.getUrl()); - - // Verify request parameters - String body = capturedRequest.getBodyString(); - assertTrue(body.contains("client_id=" + TEST_CLIENT_ID)); - assertTrue(body.contains("subject_token=" + TEST_ID_TOKEN)); - assertTrue( - body.contains("subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt")); - assertTrue( - body.contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange")); - assertTrue(body.contains("scope=all-apis")); + // Verify correct audience was used + verify(mockIdTokenSource).getIDToken(testCase.expectedAudience); + } + } catch (IOException e) { + throw new RuntimeException("Test failed", e); } } /** - * Tests validation of required fields in the token source builder. Verifies that null or empty - * values for required fields throw IllegalArgumentException. + * Test case data for parameter validation tests. Each case defines a specific validation + * scenario. */ - @Test - void testConstructorValidation() { - // Test null client ID - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - null, - TEST_HOST, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); + private static class ValidationTestCase { + final String name; + final String clientId; + final String host; + final OpenIDConnectEndpoints endpoints; + final IDTokenSource idTokenSource; + final HttpClient httpClient; + final String expectedFieldName; + final boolean isNullTest; - // Test empty client ID - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - "", - TEST_HOST, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); + ValidationTestCase( + String name, + String clientId, + String host, + OpenIDConnectEndpoints endpoints, + IDTokenSource idTokenSource, + HttpClient httpClient, + String expectedFieldName, + boolean isNullTest) { + this.name = name; + this.clientId = clientId; + this.host = host; + this.endpoints = endpoints; + this.idTokenSource = idTokenSource; + this.httpClient = httpClient; + this.expectedFieldName = expectedFieldName; + this.isNullTest = isNullTest; + } - // Test null host - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - null, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); + @Override + public String toString() { + return name; + } + } - // Test empty host - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - "", - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); + private static Stream provideValidationTestCases() + throws MalformedURLException { + OpenIDConnectEndpoints validEndpoints = + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + HttpClient validHttpClient = Mockito.mock(HttpClient.class); + IDTokenSource validIdTokenSource = Mockito.mock(IDTokenSource.class); - // Test null endpoints - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - TEST_HOST, - null, - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); + return Stream.of( + // Client ID validation + new ValidationTestCase( + "Null client ID", + null, + TEST_HOST, + validEndpoints, + validIdTokenSource, + validHttpClient, + "ClientID", + true), + new ValidationTestCase( + "Empty client ID", + "", + TEST_HOST, + validEndpoints, + validIdTokenSource, + validHttpClient, + "ClientID", + false), + // Host validation + new ValidationTestCase( + "Null host", + TEST_CLIENT_ID, + null, + validEndpoints, + validIdTokenSource, + validHttpClient, + "Host", + true), + new ValidationTestCase( + "Empty host", + TEST_CLIENT_ID, + "", + validEndpoints, + validIdTokenSource, + validHttpClient, + "Host", + false), + // Endpoints validation + new ValidationTestCase( + "Null endpoints", + TEST_CLIENT_ID, + TEST_HOST, + null, + validIdTokenSource, + validHttpClient, + "Endpoints", + true), + // IDTokenSource validation + new ValidationTestCase( + "Null IDTokenSource", + TEST_CLIENT_ID, + TEST_HOST, + validEndpoints, + null, + validHttpClient, + "IDTokenSource", + true), + // HttpClient validation + new ValidationTestCase( + "Null HttpClient", + TEST_CLIENT_ID, + TEST_HOST, + validEndpoints, + validIdTokenSource, + null, + "HttpClient", + true)); + } - // Test null IDTokenSource - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - TEST_HOST, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - null, - Mockito.mock(HttpClient.class)) - .build(); - }); + /** + * Tests validation of required fields in the token source using parameterized test cases. + * Verifies that null or empty values for required fields cause getToken() to throw + * IllegalArgumentException with specific error messages. + */ + @ParameterizedTest(name = "testParameterValidation: {0}") + @MethodSource("provideValidationTestCases") + void testParameterValidation(ValidationTestCase testCase) { + DatabricksOAuthTokenSource tokenSource = + new DatabricksOAuthTokenSource.Builder( + testCase.clientId, + testCase.host, + testCase.endpoints, + testCase.idTokenSource, + testCase.httpClient) + .build(); - // Test null HttpClient - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - TEST_HOST, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - null) - .build(); - }); + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> tokenSource.getToken()); + + String expectedMessage = + String.format(testCase.isNullTest ? ERROR_NULL : ERROR_EMPTY, testCase.expectedFieldName); + assertEquals(expectedMessage, exception.getMessage()); } }