From 6f8da14508bfefeea6df445e305773304a1572a9 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Wed, 30 Apr 2025 15:34:24 +0000 Subject: [PATCH 1/6] Add DatabricksOAuthTokenSource --- .../oauth/DatabricksOAuthTokenSource.java | 200 +++++++++++ .../oauth/DatabricksOAuthTokenSourceTest.java | 333 ++++++++++++++++++ 2 files changed, 533 insertions(+) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java new file mode 100644 index 000000000..307fc2cb2 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -0,0 +1,200 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Strings; +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.Map; + +/** + * Implementation of TokenSource that handles OAuth token exchange for Databricks authentication. + * This class manages the OAuth token exchange flow using ID tokens to obtain access tokens. + */ +public class DatabricksOAuthTokenSource implements TokenSource { + /** OAuth client ID used for token exchange */ + private final String clientId; + /** Databricks account ID, used as audience if provided */ + private final String accountId; + /** OpenID Connect endpoints configuration */ + private final OpenIDConnectEndpoints endpoints; + /** Custom audience value for token exchange */ + private final String audience; + /** Source of ID tokens used in token exchange */ + private final IDTokenSource idTokenSource; + /** HTTP client for making token exchange requests */ + private final HttpClient httpClient; + + private DatabricksOAuthTokenSource(Builder builder) { + this.clientId = builder.clientId; + this.accountId = builder.accountId; + this.endpoints = builder.endpoints; + this.audience = builder.audience; + this.idTokenSource = builder.idTokenSource; + this.httpClient = builder.httpClient; + } + + /** + * Builder class for constructing DatabricksOAuthTokenSource instances. Provides a fluent + * interface for setting required and optional parameters. + */ + public static class Builder { + private final String clientId; + private final String host; + private final OpenIDConnectEndpoints endpoints; + private final IDTokenSource idTokenSource; + private final HttpClient httpClient; + private String accountId; + private String audience; + + /** + * Validates that a value is non-empty and non-null for required fields. + * + * @param value The value to validate + * @param fieldName The name of the field being validated + * @throws IllegalArgumentException if validation fails + */ + private static void validate(Object value, String fieldName) { + if (value == null) { + throw new IllegalArgumentException(fieldName + " must be non-null"); + } + if (value instanceof String && ((String) value).isEmpty()) { + throw new IllegalArgumentException(fieldName + " must be non-empty"); + } + } + + /** + * Creates a new Builder with required parameters. + * + * @param clientId OAuth client ID + * @param host Databricks host URL + * @param endpoints OpenID Connect endpoints configuration + * @param idTokenSource Source of ID tokens + * @param httpClient HTTP client for making requests + */ + public Builder( + String clientId, + String host, + OpenIDConnectEndpoints endpoints, + IDTokenSource idTokenSource, + HttpClient httpClient) { + validate(clientId, "ClientID"); + validate(host, "Host"); + validate(endpoints, "Endpoints"); + validate(idTokenSource, "IDTokenSource"); + validate(httpClient, "HttpClient"); + + this.clientId = clientId; + this.host = host; + this.endpoints = endpoints; + this.idTokenSource = idTokenSource; + this.httpClient = httpClient; + } + + /** + * Sets the Databricks account ID. + * + * @param accountId The account ID + * @return This builder instance + */ + public Builder accountId(String accountId) { + validate(accountId, "AccountID"); + this.accountId = accountId; + return this; + } + + /** + * Sets a custom audience value for token exchange. + * + * @param audience The audience value + * @return This builder instance + */ + public Builder audience(String audience) { + validate(audience, "Audience"); + this.audience = audience; + return this; + } + + /** + * Builds a new DatabricksOAuthTokenSource instance. + * + * @return A new DatabricksOAuthTokenSource + */ + public DatabricksOAuthTokenSource build() { + return new DatabricksOAuthTokenSource(this); + } + } + + /** + * Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to + * obtain an access token. + * + * @return A Token containing the access token and related information + * @throws DatabricksException if token exchange fails + */ + @Override + public Token getToken() { + String effectiveAudience = determineAudience(); + IDToken idToken = idTokenSource.getIDToken(effectiveAudience); + + Map params = new HashMap<>(); + params.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange"); + params.put("subject_token", idToken.getValue()); + params.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt"); + params.put("scope", "all-apis"); + params.put("client_id", clientId); + + Response rawResponse; + try { + rawResponse = httpClient.execute(new FormRequest(endpoints.getTokenEndpoint(), params)); + } catch (IOException e) { + throw new DatabricksException( + String.format( + "Failed to exchange ID token for access token at %s: %s", + endpoints.getTokenEndpoint(), e.getMessage()), + e); + } + + OAuthResponse response; + try { + response = new ObjectMapper().readValue(rawResponse.getBody(), OAuthResponse.class); + } catch (IOException e) { + throw new DatabricksException( + String.format( + "Failed to parse OAuth response from token endpoint %s: %s", + endpoints.getTokenEndpoint(), e.getMessage())); + } + + if (response.getErrorCode() != null) { + throw new IllegalArgumentException( + String.format( + "Token exchange failed with error: %s - %s", + response.getErrorCode(), response.getErrorSummary())); + } + LocalDateTime expiry = LocalDateTime.now().plusSeconds(response.getExpiresIn()); + return new Token( + response.getAccessToken(), response.getTokenType(), response.getRefreshToken(), expiry); + } + + /** + * Determines the appropriate audience value for token exchange. Uses the following precedence: 1. + * Custom audience if provided 2. Account ID if provided 3. Token endpoint URL as fallback + * + * @return The determined audience value + */ + private String determineAudience() { + if (!Strings.isNullOrEmpty(audience)) { + return audience; + } + + if (!Strings.isNullOrEmpty(accountId)) { + return accountId; + } + + return endpoints.getTokenEndpoint(); + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java new file mode 100644 index 000000000..3ffc9e71a --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -0,0 +1,333 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +class DatabricksOAuthTokenSourceTest { + // Test constants + private static final String TOKEN = "test-access-token"; + private static final String TOKEN_TYPE = "Bearer"; + private static final String REFRESH_TOKEN = "test-refresh-token"; + private static final int EXPIRES_IN = 3600; + + private static final String TEST_HOST = "https://test.databricks.com"; + private static final String TEST_TOKEN_ENDPOINT = TEST_HOST + "/oidc/v1/token"; + private static final String TEST_AUTHORIZATION_ENDPOINT = TEST_HOST + "/authorize"; + private static final String TEST_CLIENT_ID = "test-client-id"; + private static final String TEST_ID_TOKEN = "test-id-token"; + private static final String TEST_AUDIENCE = "test-audience"; + private static final String TEST_ACCOUNT_ID = "test-account-id"; + + private IDTokenSource mockIdTokenSource; + + @BeforeEach + void setUp() { + mockIdTokenSource = Mockito.mock(IDTokenSource.class); + IDToken idToken = new IDToken(TEST_ID_TOKEN); + when(mockIdTokenSource.getIDToken(any())).thenReturn(idToken); + } + + /** + * Test case data for parameterized token source tests. Each case defines a specific OAuth token + * exchange scenario. + */ + private static class TestCase { + final String name; // Descriptive name of the test case + final String audience; // Custom audience value if provided + final String accountId; // Account ID if provided + final String expectedAudience; // Expected audience used in token exchange + final boolean expectError; // Whether this case should result in an error + final int statusCode; // HTTP status code for the response + final String responseBody; // Response body from the token endpoint + + TestCase( + String name, + String audience, + String accountId, + String expectedAudience, + boolean expectError, + int statusCode, + String responseBody) { + this.name = name; + this.audience = audience; + this.accountId = accountId; + this.expectedAudience = expectedAudience; + this.expectError = expectError; + this.statusCode = statusCode; + this.responseBody = responseBody; + } + + @Override + public String toString() { + return name; + } + } + + /** + * Provides test cases for OAuth token exchange scenarios. Includes success cases with different + * audience configurations and various error cases. + */ + private static Stream provideTestCases() { + // Success response with valid token data + Map successResponse = new HashMap<>(); + successResponse.put("access_token", TOKEN); + successResponse.put("token_type", TOKEN_TYPE); + successResponse.put("refresh_token", REFRESH_TOKEN); + successResponse.put("expires_in", EXPIRES_IN); + + // Error response for invalid requests + Map errorResponse = new HashMap<>(); + errorResponse.put("error", "invalid_request"); + errorResponse.put("error_description", "Invalid client ID"); + + ObjectMapper mapper = new ObjectMapper(); + String successJson; + String errorJson; + try { + successJson = mapper.writeValueAsString(successResponse); + errorJson = mapper.writeValueAsString(errorResponse); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return Stream.of( + // Success cases with different audience configurations + new TestCase( + "Default audience from token endpoint", + null, + null, + TEST_TOKEN_ENDPOINT, + false, + 200, + successJson), + new TestCase( + "Custom audience provided", + TEST_AUDIENCE, + null, + TEST_AUDIENCE, + false, + 200, + successJson), + new TestCase( + "Custom audience takes precedence over account ID", + TEST_AUDIENCE, + TEST_ACCOUNT_ID, + TEST_AUDIENCE, + false, + 200, + successJson), + new TestCase( + "Account ID used as audience when no custom audience", + null, + TEST_ACCOUNT_ID, + TEST_ACCOUNT_ID, + false, + 200, + successJson), + // Error cases + new TestCase( + "Invalid request returns 400", null, null, TEST_TOKEN_ENDPOINT, true, 400, errorJson), + new TestCase( + "Network error during token exchange", null, null, TEST_TOKEN_ENDPOINT, true, 0, null), + new TestCase( + "Invalid JSON response from server", + null, + null, + TEST_TOKEN_ENDPOINT, + true, + 200, + "invalid json")); + } + + /** + * Tests OAuth token exchange with various configurations and error scenarios. Verifies correct + * audience selection, token exchange, and error handling. + */ + @ParameterizedTest(name = "testTokenSource: {arguments}") + @MethodSource("provideTestCases") + void testTokenSource(TestCase testCase) throws IOException { + // Mock HTTP client with test case specific behavior + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + if (testCase.expectError) { + if (testCase.statusCode == 0) { + when(mockHttpClient.execute(any())).thenThrow(new IOException("Network error")); + } else { + when(mockHttpClient.execute(any())) + .thenReturn( + new Response( + testCase.responseBody, testCase.statusCode, "Bad Request", new URL(TEST_HOST))); + } + } else { + when(mockHttpClient.execute(any())) + .thenReturn(new Response(testCase.responseBody, new URL(TEST_HOST))); + } + + // Create token source with test configuration + OpenIDConnectEndpoints endpoints = + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + + DatabricksOAuthTokenSource.Builder builder = + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, TEST_HOST, endpoints, mockIdTokenSource, mockHttpClient); + + if (testCase.audience != null) { + builder.audience(testCase.audience); + } + if (testCase.accountId != null) { + builder.accountId(testCase.accountId); + } + + DatabricksOAuthTokenSource tokenSource = builder.build(); + + if (testCase.expectError) { + if (testCase.statusCode == 400) { + assertThrows(IllegalArgumentException.class, () -> tokenSource.getToken()); + } else { + assertThrows(DatabricksException.class, () -> tokenSource.getToken()); + } + } else { + // Verify successful token exchange + Token token = tokenSource.getToken(); + assertEquals(TOKEN, token.getAccessToken()); + assertEquals(TOKEN_TYPE, token.getTokenType()); + assertEquals(REFRESH_TOKEN, token.getRefreshToken()); + assertFalse(token.isExpired()); + + // Verify correct audience was used + verify(mockIdTokenSource).getIDToken(testCase.expectedAudience); + + // Verify token exchange request + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FormRequest.class); + verify(mockHttpClient).execute(requestCaptor.capture()); + + FormRequest capturedRequest = requestCaptor.getValue(); + assertEquals(TEST_TOKEN_ENDPOINT, capturedRequest.getUrl()); + + // Verify request parameters + String body = capturedRequest.getBodyString(); + assertTrue(body.contains("client_id=" + TEST_CLIENT_ID)); + assertTrue(body.contains("subject_token=" + TEST_ID_TOKEN)); + assertTrue( + body.contains("subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt")); + assertTrue( + body.contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange")); + assertTrue(body.contains("scope=all-apis")); + } + } + + /** + * Tests validation of required fields in the token source builder. Verifies that null or empty + * values for required fields throw IllegalArgumentException. + */ + @Test + void testConstructorValidation() { + // Test null client ID + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + null, + TEST_HOST, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test empty client ID + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + "", + TEST_HOST, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test null host + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + null, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test empty host + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + "", + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test null endpoints + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + TEST_HOST, + null, + mockIdTokenSource, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test null IDTokenSource + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + TEST_HOST, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + null, + Mockito.mock(HttpClient.class)) + .build(); + }); + + // Test null HttpClient + assertThrows( + IllegalArgumentException.class, + () -> { + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, + TEST_HOST, + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), + mockIdTokenSource, + null) + .build(); + }); + } +} From 66490d41893874516e2593b420b4d8cbe63e4428 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Thu, 1 May 2025 15:19:58 +0000 Subject: [PATCH 2/6] Fix comments and tests --- .../oauth/DatabricksOAuthTokenSource.java | 68 ++++++++++++------- .../oauth/DatabricksOAuthTokenSourceTest.java | 24 ++++--- 2 files changed, 56 insertions(+), 36 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index 307fc2cb2..b29e5fc0c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -16,19 +16,30 @@ * This class manages the OAuth token exchange flow using ID tokens to obtain access tokens. */ public class DatabricksOAuthTokenSource implements TokenSource { - /** OAuth client ID used for token exchange */ + /** OAuth client ID used for token exchange. */ private final String clientId; - /** Databricks account ID, used as audience if provided */ + /** Databricks account ID, used as audience if provided. */ private final String accountId; - /** OpenID Connect endpoints configuration */ + /** OpenID Connect endpoints configuration. */ private final OpenIDConnectEndpoints endpoints; - /** Custom audience value for token exchange */ + /** Custom audience value for token exchange. */ private final String audience; - /** Source of ID tokens used in token exchange */ + /** Source of ID tokens used in token exchange. */ private final IDTokenSource idTokenSource; - /** HTTP client for making token exchange requests */ + /** HTTP client for making token exchange requests. */ private final HttpClient httpClient; + private static final String GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"; + private static final String SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt"; + private static final String SCOPE = "all-apis"; + private static final String GRANT_TYPE_PARAM = "grant_type"; + private static final String SUBJECT_TOKEN_PARAM = "subject_token"; + private static final String SUBJECT_TOKEN_TYPE_PARAM = "subject_token_type"; + private static final String SCOPE_PARAM = "scope"; + private static final String CLIENT_ID_PARAM = "client_id"; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private DatabricksOAuthTokenSource(Builder builder) { this.clientId = builder.clientId; this.accountId = builder.accountId; @@ -52,11 +63,12 @@ public static class Builder { private String audience; /** - * Validates that a value is non-empty and non-null for required fields. + * Validates that a value is non-null for required fields. If the value is a string, it also + * checks that it is non-empty. * - * @param value The value to validate - * @param fieldName The name of the field being validated - * @throws IllegalArgumentException if validation fails + * @param value The value to validate. + * @param fieldName The name of the field being validated. + * @throws IllegalArgumentException if validation fails. */ private static void validate(Object value, String fieldName) { if (value == null) { @@ -70,11 +82,12 @@ private static void validate(Object value, String fieldName) { /** * Creates a new Builder with required parameters. * - * @param clientId OAuth client ID - * @param host Databricks host URL - * @param endpoints OpenID Connect endpoints configuration - * @param idTokenSource Source of ID tokens - * @param httpClient HTTP client for making requests + * @param clientId OAuth client ID. + * @param host Databricks host URL. + * @param endpoints OpenID Connect endpoints configuration. + * @param idTokenSource Source of ID tokens. + * @param httpClient HTTP client for making requests. + * @throws IllegalArgumentException if any required parameter is null or empty. */ public Builder( String clientId, @@ -98,8 +111,9 @@ public Builder( /** * Sets the Databricks account ID. * - * @param accountId The account ID - * @return This builder instance + * @param accountId The account ID. + * @return This builder instance. + * @throws IllegalArgumentException if the account ID is null or empty. */ public Builder accountId(String accountId) { validate(accountId, "AccountID"); @@ -122,7 +136,8 @@ public Builder audience(String audience) { /** * Builds a new DatabricksOAuthTokenSource instance. * - * @return A new DatabricksOAuthTokenSource + * @return A new DatabricksOAuthTokenSource. + * @throws IllegalArgumentException if any required parameters are null or empty. */ public DatabricksOAuthTokenSource build() { return new DatabricksOAuthTokenSource(this); @@ -133,8 +148,9 @@ public DatabricksOAuthTokenSource build() { * Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to * obtain an access token. * - * @return A Token containing the access token and related information - * @throws DatabricksException if token exchange fails + * @return A Token containing the access token and related information. + * @throws DatabricksException when the token exchange fails. + * @throws IllegalArgumentException when there is an error code in the response. */ @Override public Token getToken() { @@ -142,11 +158,11 @@ public Token getToken() { IDToken idToken = idTokenSource.getIDToken(effectiveAudience); Map params = new HashMap<>(); - params.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange"); - params.put("subject_token", idToken.getValue()); - params.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt"); - params.put("scope", "all-apis"); - params.put("client_id", clientId); + params.put(GRANT_TYPE_PARAM, GRANT_TYPE); + params.put(SUBJECT_TOKEN_PARAM, idToken.getValue()); + params.put(SUBJECT_TOKEN_TYPE_PARAM, SUBJECT_TOKEN_TYPE); + params.put(SCOPE_PARAM, SCOPE); + params.put(CLIENT_ID_PARAM, clientId); Response rawResponse; try { @@ -161,7 +177,7 @@ public Token getToken() { OAuthResponse response; try { - response = new ObjectMapper().readValue(rawResponse.getBody(), OAuthResponse.class); + response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); } catch (IOException e) { throw new DatabricksException( String.format( diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java index 3ffc9e71a..599ed4a6b 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -56,7 +56,7 @@ private static class TestCase { final String expectedAudience; // Expected audience used in token exchange final boolean expectError; // Whether this case should result in an error final int statusCode; // HTTP status code for the response - final String responseBody; // Response body from the token endpoint + final Object responseBody; // Response body from the token endpoint TestCase( String name, @@ -65,7 +65,7 @@ private static class TestCase { String expectedAudience, boolean expectError, int statusCode, - String responseBody) { + Object responseBody) { this.name = name; this.audience = audience; this.accountId = accountId; @@ -99,10 +99,8 @@ private static Stream provideTestCases() { errorResponse.put("error_description", "Invalid client ID"); ObjectMapper mapper = new ObjectMapper(); - String successJson; String errorJson; try { - successJson = mapper.writeValueAsString(successResponse); errorJson = mapper.writeValueAsString(errorResponse); } catch (IOException e) { throw new RuntimeException(e); @@ -117,7 +115,7 @@ private static Stream provideTestCases() { TEST_TOKEN_ENDPOINT, false, 200, - successJson), + successResponse), new TestCase( "Custom audience provided", TEST_AUDIENCE, @@ -125,7 +123,7 @@ private static Stream provideTestCases() { TEST_AUDIENCE, false, 200, - successJson), + successResponse), new TestCase( "Custom audience takes precedence over account ID", TEST_AUDIENCE, @@ -133,7 +131,7 @@ private static Stream provideTestCases() { TEST_AUDIENCE, false, 200, - successJson), + successResponse), new TestCase( "Account ID used as audience when no custom audience", null, @@ -141,7 +139,7 @@ private static Stream provideTestCases() { TEST_ACCOUNT_ID, false, 200, - successJson), + successResponse), // Error cases new TestCase( "Invalid request returns 400", null, null, TEST_TOKEN_ENDPOINT, true, 400, errorJson), @@ -166,6 +164,8 @@ private static Stream provideTestCases() { void testTokenSource(TestCase testCase) throws IOException { // Mock HTTP client with test case specific behavior HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + ObjectMapper mapper = new ObjectMapper(); + if (testCase.expectError) { if (testCase.statusCode == 0) { when(mockHttpClient.execute(any())).thenThrow(new IOException("Network error")); @@ -173,11 +173,15 @@ void testTokenSource(TestCase testCase) throws IOException { when(mockHttpClient.execute(any())) .thenReturn( new Response( - testCase.responseBody, testCase.statusCode, "Bad Request", new URL(TEST_HOST))); + testCase.responseBody.toString(), + testCase.statusCode, + "Bad Request", + new URL(TEST_HOST))); } } else { + String responseJson = mapper.writeValueAsString(testCase.responseBody); when(mockHttpClient.execute(any())) - .thenReturn(new Response(testCase.responseBody, new URL(TEST_HOST))); + .thenReturn(new Response(responseJson, new URL(TEST_HOST))); } // Create token source with test configuration From de918d3cb7026606ba62514240581411249d4938 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Fri, 2 May 2025 10:13:00 +0000 Subject: [PATCH 3/6] Fix comments --- .../databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java | 1 + 1 file changed, 1 insertion(+) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index b29e5fc0c..75184b411 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -126,6 +126,7 @@ public Builder accountId(String accountId) { * * @param audience The audience value * @return This builder instance + * @throws IllegalArgumentException if the audience is null or empty. */ public Builder audience(String audience) { validate(audience, "Audience"); From 1680580a841c6ef629c213dfebbbe17e4680a84f Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Mon, 5 May 2025 16:56:47 +0000 Subject: [PATCH 4/6] Fix argument validation and add logging --- .../oauth/DatabricksOAuthTokenSource.java | 81 ++-- .../oauth/DatabricksOAuthTokenSourceTest.java | 401 ++++++++---------- 2 files changed, 231 insertions(+), 251 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index 75184b411..20fda32f5 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -10,14 +10,20 @@ import java.time.LocalDateTime; import java.util.HashMap; import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Implementation of TokenSource that handles OAuth token exchange for Databricks authentication. * This class manages the OAuth token exchange flow using ID tokens to obtain access tokens. */ public class DatabricksOAuthTokenSource implements TokenSource { + private static final Logger LOG = LoggerFactory.getLogger(DatabricksOAuthTokenSource.class); + /** OAuth client ID used for token exchange. */ private final String clientId; + /** Databricks host URL. */ + private final String host; /** Databricks account ID, used as audience if provided. */ private final String accountId; /** OpenID Connect endpoints configuration. */ @@ -42,6 +48,7 @@ public class DatabricksOAuthTokenSource implements TokenSource { private DatabricksOAuthTokenSource(Builder builder) { this.clientId = builder.clientId; + this.host = builder.host; this.accountId = builder.accountId; this.endpoints = builder.endpoints; this.audience = builder.audience; @@ -62,23 +69,6 @@ public static class Builder { private String accountId; private String audience; - /** - * Validates that a value is non-null for required fields. If the value is a string, it also - * checks that it is non-empty. - * - * @param value The value to validate. - * @param fieldName The name of the field being validated. - * @throws IllegalArgumentException if validation fails. - */ - private static void validate(Object value, String fieldName) { - if (value == null) { - throw new IllegalArgumentException(fieldName + " must be non-null"); - } - if (value instanceof String && ((String) value).isEmpty()) { - throw new IllegalArgumentException(fieldName + " must be non-empty"); - } - } - /** * Creates a new Builder with required parameters. * @@ -87,7 +77,6 @@ private static void validate(Object value, String fieldName) { * @param endpoints OpenID Connect endpoints configuration. * @param idTokenSource Source of ID tokens. * @param httpClient HTTP client for making requests. - * @throws IllegalArgumentException if any required parameter is null or empty. */ public Builder( String clientId, @@ -95,12 +84,6 @@ public Builder( OpenIDConnectEndpoints endpoints, IDTokenSource idTokenSource, HttpClient httpClient) { - validate(clientId, "ClientID"); - validate(host, "Host"); - validate(endpoints, "Endpoints"); - validate(idTokenSource, "IDTokenSource"); - validate(httpClient, "HttpClient"); - this.clientId = clientId; this.host = host; this.endpoints = endpoints; @@ -113,10 +96,8 @@ public Builder( * * @param accountId The account ID. * @return This builder instance. - * @throws IllegalArgumentException if the account ID is null or empty. */ public Builder accountId(String accountId) { - validate(accountId, "AccountID"); this.accountId = accountId; return this; } @@ -126,10 +107,8 @@ public Builder accountId(String accountId) { * * @param audience The audience value * @return This builder instance - * @throws IllegalArgumentException if the audience is null or empty. */ public Builder audience(String audience) { - validate(audience, "Audience"); this.audience = audience; return this; } @@ -138,23 +117,51 @@ public Builder audience(String audience) { * Builds a new DatabricksOAuthTokenSource instance. * * @return A new DatabricksOAuthTokenSource. - * @throws IllegalArgumentException if any required parameters are null or empty. */ public DatabricksOAuthTokenSource build() { return new DatabricksOAuthTokenSource(this); } } + /** + * Validates that a value is non-null for required fields. If the value is a string, it also + * checks that it is non-empty. + * + * @param value The value to validate. + * @param fieldName The name of the field being validated. + * @return true if validation passes, false otherwise + */ + private static boolean validate(Object value, String fieldName) { + if (value == null) { + return false; + } + if (value instanceof String && ((String) value).isEmpty()) { + return false; + } + return true; + } + /** * Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to * obtain an access token. * * @return A Token containing the access token and related information. * @throws DatabricksException when the token exchange fails. - * @throws IllegalArgumentException when there is an error code in the response. + * @throws IllegalArgumentException when there is an error code in the response or when required + * parameters are missing. */ @Override public Token getToken() { + // Validate all required parameters + if (!validate(clientId, "ClientID") + || !validate(host, "Host") + || !validate(endpoints, "Endpoints") + || !validate(idTokenSource, "IDTokenSource") + || !validate(httpClient, "HttpClient")) { + LOG.error("Missing required parameters for token exchange"); + throw new IllegalArgumentException("Missing required parameters for token exchange."); + } + String effectiveAudience = determineAudience(); IDToken idToken = idTokenSource.getIDToken(effectiveAudience); @@ -169,6 +176,11 @@ public Token getToken() { try { rawResponse = httpClient.execute(new FormRequest(endpoints.getTokenEndpoint(), params)); } catch (IOException e) { + LOG.error( + "Failed to exchange ID token for access token at {}: {}", + endpoints.getTokenEndpoint(), + e.getMessage(), + e); throw new DatabricksException( String.format( "Failed to exchange ID token for access token at %s: %s", @@ -180,6 +192,11 @@ public Token getToken() { try { response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); } catch (IOException e) { + LOG.error( + "Failed to parse OAuth response from token endpoint {}: {}", + endpoints.getTokenEndpoint(), + e.getMessage(), + e); throw new DatabricksException( String.format( "Failed to parse OAuth response from token endpoint %s: %s", @@ -187,6 +204,10 @@ public Token getToken() { } if (response.getErrorCode() != null) { + LOG.error( + "Token exchange failed with error: {} - {}", + response.getErrorCode(), + response.getErrorSummary()); throw new IllegalArgumentException( String.format( "Token exchange failed with error: %s - %s", diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java index 599ed4a6b..973504a16 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -10,6 +10,7 @@ import com.databricks.sdk.core.http.Response; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import java.net.MalformedURLException; import java.net.URL; import java.util.HashMap; import java.util.Map; @@ -18,7 +19,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import org.mockito.ArgumentCaptor; import org.mockito.Mockito; class DatabricksOAuthTokenSourceTest { @@ -54,25 +54,24 @@ private static class TestCase { final String audience; // Custom audience value if provided final String accountId; // Account ID if provided final String expectedAudience; // Expected audience used in token exchange - final boolean expectError; // Whether this case should result in an error - final int statusCode; // HTTP status code for the response - final Object responseBody; // Response body from the token endpoint + final HttpClient mockHttpClient; // Pre-configured mock HTTP client + final Class expectedException; // Expected exception type if any TestCase( String name, String audience, String accountId, String expectedAudience, - boolean expectError, int statusCode, - Object responseBody) { + Object responseBody, + HttpClient mockHttpClient, + Class expectedException) { this.name = name; this.audience = audience; this.accountId = accountId; this.expectedAudience = expectedAudience; - this.expectError = expectError; - this.statusCode = statusCode; - this.responseBody = responseBody; + this.mockHttpClient = mockHttpClient; + this.expectedException = expectedException; } @Override @@ -86,73 +85,124 @@ public String toString() { * audience configurations and various error cases. */ private static Stream provideTestCases() { - // Success response with valid token data - Map successResponse = new HashMap<>(); - successResponse.put("access_token", TOKEN); - successResponse.put("token_type", TOKEN_TYPE); - successResponse.put("refresh_token", REFRESH_TOKEN); - successResponse.put("expires_in", EXPIRES_IN); + try { + // Success response with valid token data + Map successResponse = new HashMap<>(); + successResponse.put("access_token", TOKEN); + successResponse.put("token_type", TOKEN_TYPE); + successResponse.put("refresh_token", REFRESH_TOKEN); + successResponse.put("expires_in", EXPIRES_IN); + + // Error response for invalid requests + Map errorResponse = new HashMap<>(); + errorResponse.put("error", "invalid_request"); + errorResponse.put("error_description", "Invalid client ID"); + + ObjectMapper mapper = new ObjectMapper(); + final String errorJson = mapper.writeValueAsString(errorResponse); + final String successJson = mapper.writeValueAsString(successResponse); - // Error response for invalid requests - Map errorResponse = new HashMap<>(); - errorResponse.put("error", "invalid_request"); - errorResponse.put("error_description", "Invalid client ID"); + // Create the expected request that will be used in all test cases + Map formParams = new HashMap<>(); + formParams.put("client_id", TEST_CLIENT_ID); + formParams.put("subject_token", TEST_ID_TOKEN); + formParams.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt"); + formParams.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange"); + formParams.put("scope", "all-apis"); + FormRequest expectedRequest = new FormRequest(TEST_TOKEN_ENDPOINT, formParams); - ObjectMapper mapper = new ObjectMapper(); - String errorJson; + return Stream.of( + // Success cases with different audience configurations + new TestCase( + "Default audience from token endpoint", + null, + null, + TEST_TOKEN_ENDPOINT, + 200, + successResponse, + createMockHttpClient(expectedRequest, 200, successJson), + null), + new TestCase( + "Custom audience provided", + TEST_AUDIENCE, + null, + TEST_AUDIENCE, + 200, + successResponse, + createMockHttpClient(expectedRequest, 200, successJson), + null), + new TestCase( + "Custom audience takes precedence over account ID", + TEST_AUDIENCE, + TEST_ACCOUNT_ID, + TEST_AUDIENCE, + 200, + successResponse, + createMockHttpClient(expectedRequest, 200, successJson), + null), + new TestCase( + "Account ID used as audience when no custom audience", + null, + TEST_ACCOUNT_ID, + TEST_ACCOUNT_ID, + 200, + successResponse, + createMockHttpClient(expectedRequest, 200, successJson), + null), + // Error cases + new TestCase( + "Invalid request returns 400", + null, + null, + TEST_TOKEN_ENDPOINT, + 400, + errorJson, + createMockHttpClient(expectedRequest, 400, errorJson), + IllegalArgumentException.class), + new TestCase( + "Network error during token exchange", + null, + null, + TEST_TOKEN_ENDPOINT, + 0, + null, + createMockHttpClientWithError(expectedRequest), + DatabricksException.class), + new TestCase( + "Invalid JSON response from server", + null, + null, + TEST_TOKEN_ENDPOINT, + 200, + "invalid json", + createMockHttpClient(expectedRequest, 200, "invalid json"), + DatabricksException.class)); + } catch (IOException e) { + throw new RuntimeException("Failed to create test cases", e); + } + } + + private static HttpClient createMockHttpClient( + FormRequest expectedRequest, int statusCode, String responseBody) { try { - errorJson = mapper.writeValueAsString(errorResponse); + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + String statusMessage = statusCode == 200 ? "OK" : "Bad Request"; + when(mockHttpClient.execute(expectedRequest)) + .thenReturn(new Response(responseBody, statusCode, statusMessage, new URL(TEST_HOST))); + return mockHttpClient; } catch (IOException e) { - throw new RuntimeException(e); + throw new RuntimeException("Failed to create mock HTTP client", e); } + } - return Stream.of( - // Success cases with different audience configurations - new TestCase( - "Default audience from token endpoint", - null, - null, - TEST_TOKEN_ENDPOINT, - false, - 200, - successResponse), - new TestCase( - "Custom audience provided", - TEST_AUDIENCE, - null, - TEST_AUDIENCE, - false, - 200, - successResponse), - new TestCase( - "Custom audience takes precedence over account ID", - TEST_AUDIENCE, - TEST_ACCOUNT_ID, - TEST_AUDIENCE, - false, - 200, - successResponse), - new TestCase( - "Account ID used as audience when no custom audience", - null, - TEST_ACCOUNT_ID, - TEST_ACCOUNT_ID, - false, - 200, - successResponse), - // Error cases - new TestCase( - "Invalid request returns 400", null, null, TEST_TOKEN_ENDPOINT, true, 400, errorJson), - new TestCase( - "Network error during token exchange", null, null, TEST_TOKEN_ENDPOINT, true, 0, null), - new TestCase( - "Invalid JSON response from server", - null, - null, - TEST_TOKEN_ENDPOINT, - true, - 200, - "invalid json")); + private static HttpClient createMockHttpClientWithError(FormRequest expectedRequest) { + try { + HttpClient mockHttpClient = Mockito.mock(HttpClient.class); + when(mockHttpClient.execute(expectedRequest)).thenThrow(new IOException("Network error")); + return mockHttpClient; + } catch (IOException e) { + throw new RuntimeException("Failed to create mock HTTP client with error", e); + } } /** @@ -161,177 +211,86 @@ private static Stream provideTestCases() { */ @ParameterizedTest(name = "testTokenSource: {arguments}") @MethodSource("provideTestCases") - void testTokenSource(TestCase testCase) throws IOException { - // Mock HTTP client with test case specific behavior - HttpClient mockHttpClient = Mockito.mock(HttpClient.class); - ObjectMapper mapper = new ObjectMapper(); - - if (testCase.expectError) { - if (testCase.statusCode == 0) { - when(mockHttpClient.execute(any())).thenThrow(new IOException("Network error")); - } else { - when(mockHttpClient.execute(any())) - .thenReturn( - new Response( - testCase.responseBody.toString(), - testCase.statusCode, - "Bad Request", - new URL(TEST_HOST))); - } - } else { - String responseJson = mapper.writeValueAsString(testCase.responseBody); - when(mockHttpClient.execute(any())) - .thenReturn(new Response(responseJson, new URL(TEST_HOST))); - } - - // Create token source with test configuration - OpenIDConnectEndpoints endpoints = - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + void testTokenSource(TestCase testCase) { + try { + // Create token source with test configuration + OpenIDConnectEndpoints endpoints = + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); - DatabricksOAuthTokenSource.Builder builder = - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, TEST_HOST, endpoints, mockIdTokenSource, mockHttpClient); + DatabricksOAuthTokenSource.Builder builder = + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, TEST_HOST, endpoints, mockIdTokenSource, testCase.mockHttpClient); - if (testCase.audience != null) { - builder.audience(testCase.audience); - } - if (testCase.accountId != null) { - builder.accountId(testCase.accountId); - } + builder.audience(testCase.audience).accountId(testCase.accountId); - DatabricksOAuthTokenSource tokenSource = builder.build(); + DatabricksOAuthTokenSource tokenSource = builder.build(); - if (testCase.expectError) { - if (testCase.statusCode == 400) { - assertThrows(IllegalArgumentException.class, () -> tokenSource.getToken()); + if (testCase.expectedException != null) { + assertThrows(testCase.expectedException, () -> tokenSource.getToken()); } else { - assertThrows(DatabricksException.class, () -> tokenSource.getToken()); - } - } else { - // Verify successful token exchange - Token token = tokenSource.getToken(); - assertEquals(TOKEN, token.getAccessToken()); - assertEquals(TOKEN_TYPE, token.getTokenType()); - assertEquals(REFRESH_TOKEN, token.getRefreshToken()); - assertFalse(token.isExpired()); + // Verify successful token exchange + Token token = tokenSource.getToken(); + assertEquals(TOKEN, token.getAccessToken()); + assertEquals(TOKEN_TYPE, token.getTokenType()); + assertEquals(REFRESH_TOKEN, token.getRefreshToken()); + assertFalse(token.isExpired()); - // Verify correct audience was used - verify(mockIdTokenSource).getIDToken(testCase.expectedAudience); - - // Verify token exchange request - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FormRequest.class); - verify(mockHttpClient).execute(requestCaptor.capture()); - - FormRequest capturedRequest = requestCaptor.getValue(); - assertEquals(TEST_TOKEN_ENDPOINT, capturedRequest.getUrl()); - - // Verify request parameters - String body = capturedRequest.getBodyString(); - assertTrue(body.contains("client_id=" + TEST_CLIENT_ID)); - assertTrue(body.contains("subject_token=" + TEST_ID_TOKEN)); - assertTrue( - body.contains("subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt")); - assertTrue( - body.contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange")); - assertTrue(body.contains("scope=all-apis")); + // Verify correct audience was used + verify(mockIdTokenSource).getIDToken(testCase.expectedAudience); + } + } catch (IOException e) { + throw new RuntimeException("Test failed", e); } } /** - * Tests validation of required fields in the token source builder. Verifies that null or empty - * values for required fields throw IllegalArgumentException. + * Tests validation of required fields in the token source. Verifies that null or empty values for + * required fields cause getToken() to throw IllegalArgumentException. */ @Test - void testConstructorValidation() { + void testParameterValidation() { + OpenIDConnectEndpoints validEndpoints; + try { + validEndpoints = new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + } catch (MalformedURLException e) { + fail("Failed to create valid endpoints: " + e.getMessage()); + return; + } + HttpClient validHttpClient = Mockito.mock(HttpClient.class); + // Test null client ID - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - null, - TEST_HOST, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); + final DatabricksOAuthTokenSource tokenSource1 = + new DatabricksOAuthTokenSource.Builder( + null, TEST_HOST, validEndpoints, mockIdTokenSource, validHttpClient) + .build(); + assertThrows(IllegalArgumentException.class, () -> tokenSource1.getToken()); // Test empty client ID - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - "", - TEST_HOST, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); - - // Test null host - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - null, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); - - // Test empty host - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - "", - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); + final DatabricksOAuthTokenSource tokenSource2 = + new DatabricksOAuthTokenSource.Builder( + "", TEST_HOST, validEndpoints, mockIdTokenSource, validHttpClient) + .build(); + assertThrows(IllegalArgumentException.class, () -> tokenSource2.getToken()); // Test null endpoints - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - TEST_HOST, - null, - mockIdTokenSource, - Mockito.mock(HttpClient.class)) - .build(); - }); + final DatabricksOAuthTokenSource tokenSource3 = + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, TEST_HOST, null, mockIdTokenSource, validHttpClient) + .build(); + assertThrows(IllegalArgumentException.class, () -> tokenSource3.getToken()); // Test null IDTokenSource - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - TEST_HOST, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - null, - Mockito.mock(HttpClient.class)) - .build(); - }); + final DatabricksOAuthTokenSource tokenSource4 = + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, TEST_HOST, validEndpoints, null, validHttpClient) + .build(); + assertThrows(IllegalArgumentException.class, () -> tokenSource4.getToken()); // Test null HttpClient - assertThrows( - IllegalArgumentException.class, - () -> { - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, - TEST_HOST, - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT), - mockIdTokenSource, - null) - .build(); - }); + final DatabricksOAuthTokenSource tokenSource5 = + new DatabricksOAuthTokenSource.Builder( + TEST_CLIENT_ID, TEST_HOST, validEndpoints, mockIdTokenSource, null) + .build(); + assertThrows(IllegalArgumentException.class, () -> tokenSource5.getToken()); } } From 73ffd88d5ad2e69094a2944ff0c9e5360948523a Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 6 May 2025 03:54:37 +0000 Subject: [PATCH 5/6] Minor comment change --- .../databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index 20fda32f5..041026198 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -57,8 +57,8 @@ private DatabricksOAuthTokenSource(Builder builder) { } /** - * Builder class for constructing DatabricksOAuthTokenSource instances. Provides a fluent - * interface for setting required and optional parameters. + * Builder class for constructing DatabricksOAuthTokenSource instances. Provides a flexible way to + * set required and optional parameters. */ public static class Builder { private final String clientId; From d8438944b3e1fecdf544a5be514e10dc07e4a5a2 Mon Sep 17 00:00:00 2001 From: emmyzhou-db Date: Tue, 6 May 2025 08:53:51 +0000 Subject: [PATCH 6/6] Fix validation and tests --- .../oauth/DatabricksOAuthTokenSource.java | 26 +-- .../oauth/DatabricksOAuthTokenSourceTest.java | 174 ++++++++++++++---- 2 files changed, 147 insertions(+), 53 deletions(-) diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index 041026198..c8ac65ba2 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -129,16 +129,19 @@ public DatabricksOAuthTokenSource build() { * * @param value The value to validate. * @param fieldName The name of the field being validated. - * @return true if validation passes, false otherwise + * @throws IllegalArgumentException when the value is null or an empty string. */ - private static boolean validate(Object value, String fieldName) { + private static void validate(Object value, String fieldName) { if (value == null) { - return false; + LOG.error("Required parameter '{}' is null", fieldName); + throw new IllegalArgumentException( + String.format("Required parameter '%s' cannot be null", fieldName)); } if (value instanceof String && ((String) value).isEmpty()) { - return false; + LOG.error("Required parameter '{}' is empty", fieldName); + throw new IllegalArgumentException( + String.format("Required parameter '%s' cannot be empty", fieldName)); } - return true; } /** @@ -153,14 +156,11 @@ private static boolean validate(Object value, String fieldName) { @Override public Token getToken() { // Validate all required parameters - if (!validate(clientId, "ClientID") - || !validate(host, "Host") - || !validate(endpoints, "Endpoints") - || !validate(idTokenSource, "IDTokenSource") - || !validate(httpClient, "HttpClient")) { - LOG.error("Missing required parameters for token exchange"); - throw new IllegalArgumentException("Missing required parameters for token exchange."); - } + validate(clientId, "ClientID"); + validate(host, "Host"); + validate(endpoints, "Endpoints"); + validate(idTokenSource, "IDTokenSource"); + validate(httpClient, "HttpClient"); String effectiveAudience = determineAudience(); IDToken idToken = idTokenSource.getIDToken(effectiveAudience); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java index 973504a16..8d7da8d3a 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -16,7 +16,6 @@ import java.util.Map; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; @@ -36,6 +35,10 @@ class DatabricksOAuthTokenSourceTest { private static final String TEST_AUDIENCE = "test-audience"; private static final String TEST_ACCOUNT_ID = "test-account-id"; + // Error message constants + private static final String ERROR_NULL = "Required parameter '%s' cannot be null"; + private static final String ERROR_EMPTY = "Required parameter '%s' cannot be empty"; + private IDTokenSource mockIdTokenSource; @BeforeEach @@ -244,53 +247,144 @@ void testTokenSource(TestCase testCase) { } /** - * Tests validation of required fields in the token source. Verifies that null or empty values for - * required fields cause getToken() to throw IllegalArgumentException. + * Test case data for parameter validation tests. Each case defines a specific validation + * scenario. */ - @Test - void testParameterValidation() { - OpenIDConnectEndpoints validEndpoints; - try { - validEndpoints = new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); - } catch (MalformedURLException e) { - fail("Failed to create valid endpoints: " + e.getMessage()); - return; + private static class ValidationTestCase { + final String name; + final String clientId; + final String host; + final OpenIDConnectEndpoints endpoints; + final IDTokenSource idTokenSource; + final HttpClient httpClient; + final String expectedFieldName; + final boolean isNullTest; + + ValidationTestCase( + String name, + String clientId, + String host, + OpenIDConnectEndpoints endpoints, + IDTokenSource idTokenSource, + HttpClient httpClient, + String expectedFieldName, + boolean isNullTest) { + this.name = name; + this.clientId = clientId; + this.host = host; + this.endpoints = endpoints; + this.idTokenSource = idTokenSource; + this.httpClient = httpClient; + this.expectedFieldName = expectedFieldName; + this.isNullTest = isNullTest; } - HttpClient validHttpClient = Mockito.mock(HttpClient.class); - // Test null client ID - final DatabricksOAuthTokenSource tokenSource1 = - new DatabricksOAuthTokenSource.Builder( - null, TEST_HOST, validEndpoints, mockIdTokenSource, validHttpClient) - .build(); - assertThrows(IllegalArgumentException.class, () -> tokenSource1.getToken()); + @Override + public String toString() { + return name; + } + } - // Test empty client ID - final DatabricksOAuthTokenSource tokenSource2 = - new DatabricksOAuthTokenSource.Builder( - "", TEST_HOST, validEndpoints, mockIdTokenSource, validHttpClient) - .build(); - assertThrows(IllegalArgumentException.class, () -> tokenSource2.getToken()); + private static Stream provideValidationTestCases() + throws MalformedURLException { + OpenIDConnectEndpoints validEndpoints = + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + HttpClient validHttpClient = Mockito.mock(HttpClient.class); + IDTokenSource validIdTokenSource = Mockito.mock(IDTokenSource.class); - // Test null endpoints - final DatabricksOAuthTokenSource tokenSource3 = - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, TEST_HOST, null, mockIdTokenSource, validHttpClient) - .build(); - assertThrows(IllegalArgumentException.class, () -> tokenSource3.getToken()); + return Stream.of( + // Client ID validation + new ValidationTestCase( + "Null client ID", + null, + TEST_HOST, + validEndpoints, + validIdTokenSource, + validHttpClient, + "ClientID", + true), + new ValidationTestCase( + "Empty client ID", + "", + TEST_HOST, + validEndpoints, + validIdTokenSource, + validHttpClient, + "ClientID", + false), + // Host validation + new ValidationTestCase( + "Null host", + TEST_CLIENT_ID, + null, + validEndpoints, + validIdTokenSource, + validHttpClient, + "Host", + true), + new ValidationTestCase( + "Empty host", + TEST_CLIENT_ID, + "", + validEndpoints, + validIdTokenSource, + validHttpClient, + "Host", + false), + // Endpoints validation + new ValidationTestCase( + "Null endpoints", + TEST_CLIENT_ID, + TEST_HOST, + null, + validIdTokenSource, + validHttpClient, + "Endpoints", + true), + // IDTokenSource validation + new ValidationTestCase( + "Null IDTokenSource", + TEST_CLIENT_ID, + TEST_HOST, + validEndpoints, + null, + validHttpClient, + "IDTokenSource", + true), + // HttpClient validation + new ValidationTestCase( + "Null HttpClient", + TEST_CLIENT_ID, + TEST_HOST, + validEndpoints, + validIdTokenSource, + null, + "HttpClient", + true)); + } - // Test null IDTokenSource - final DatabricksOAuthTokenSource tokenSource4 = + /** + * Tests validation of required fields in the token source using parameterized test cases. + * Verifies that null or empty values for required fields cause getToken() to throw + * IllegalArgumentException with specific error messages. + */ + @ParameterizedTest(name = "testParameterValidation: {0}") + @MethodSource("provideValidationTestCases") + void testParameterValidation(ValidationTestCase testCase) { + DatabricksOAuthTokenSource tokenSource = new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, TEST_HOST, validEndpoints, null, validHttpClient) + testCase.clientId, + testCase.host, + testCase.endpoints, + testCase.idTokenSource, + testCase.httpClient) .build(); - assertThrows(IllegalArgumentException.class, () -> tokenSource4.getToken()); - // Test null HttpClient - final DatabricksOAuthTokenSource tokenSource5 = - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, TEST_HOST, validEndpoints, mockIdTokenSource, null) - .build(); - assertThrows(IllegalArgumentException.class, () -> tokenSource5.getToken()); + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> tokenSource.getToken()); + + String expectedMessage = + String.format(testCase.isNullTest ? ERROR_NULL : ERROR_EMPTY, testCase.expectedFieldName); + assertEquals(expectedMessage, exception.getMessage()); } }