Skip to content

Commit efe5aa4

Browse files
Add DatabricksOAuthTokenSource (#439)
## What changes are proposed in this pull request? - Implemented DatabricksOAuthTokenSource class that handles the OAuth token exchange flow - This class manages the OAuth token exchange flow using ID tokens to obtain access tokens - The token exchange mechanism is essential for implementing OIDC-based authentication in Databricks ## How is this tested? The implementation is tested through unit tests that: - Mock the ID Token Source to simulate the token exchange flow - Mock the HTTP client to verify token exchange requests and responses - Test the complete token exchange flow including audience selection and error handling - Validate proper parameter handling and error responses - All tests are automated and part of the unit test suite. No manual testing was required. NO_CHANGELOG=true --------- Co-authored-by: Parth Bansal <parth.bansal@databricks.com>
1 parent 91fe8a5 commit efe5aa4

File tree

2 files changed

+628
-0
lines changed

2 files changed

+628
-0
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.DatabricksException;
4+
import com.databricks.sdk.core.http.FormRequest;
5+
import com.databricks.sdk.core.http.HttpClient;
6+
import com.databricks.sdk.core.http.Response;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.google.common.base.Strings;
9+
import java.io.IOException;
10+
import java.time.LocalDateTime;
11+
import java.util.HashMap;
12+
import java.util.Map;
13+
import org.slf4j.Logger;
14+
import org.slf4j.LoggerFactory;
15+
16+
/**
17+
* Implementation of TokenSource that handles OAuth token exchange for Databricks authentication.
18+
* This class manages the OAuth token exchange flow using ID tokens to obtain access tokens.
19+
*/
20+
public class DatabricksOAuthTokenSource implements TokenSource {
21+
private static final Logger LOG = LoggerFactory.getLogger(DatabricksOAuthTokenSource.class);
22+
23+
/** OAuth client ID used for token exchange. */
24+
private final String clientId;
25+
/** Databricks host URL. */
26+
private final String host;
27+
/** Databricks account ID, used as audience if provided. */
28+
private final String accountId;
29+
/** OpenID Connect endpoints configuration. */
30+
private final OpenIDConnectEndpoints endpoints;
31+
/** Custom audience value for token exchange. */
32+
private final String audience;
33+
/** Source of ID tokens used in token exchange. */
34+
private final IDTokenSource idTokenSource;
35+
/** HTTP client for making token exchange requests. */
36+
private final HttpClient httpClient;
37+
38+
private static final String GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange";
39+
private static final String SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt";
40+
private static final String SCOPE = "all-apis";
41+
private static final String GRANT_TYPE_PARAM = "grant_type";
42+
private static final String SUBJECT_TOKEN_PARAM = "subject_token";
43+
private static final String SUBJECT_TOKEN_TYPE_PARAM = "subject_token_type";
44+
private static final String SCOPE_PARAM = "scope";
45+
private static final String CLIENT_ID_PARAM = "client_id";
46+
47+
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
48+
49+
private DatabricksOAuthTokenSource(Builder builder) {
50+
this.clientId = builder.clientId;
51+
this.host = builder.host;
52+
this.accountId = builder.accountId;
53+
this.endpoints = builder.endpoints;
54+
this.audience = builder.audience;
55+
this.idTokenSource = builder.idTokenSource;
56+
this.httpClient = builder.httpClient;
57+
}
58+
59+
/**
60+
* Builder class for constructing DatabricksOAuthTokenSource instances. Provides a flexible way to
61+
* set required and optional parameters.
62+
*/
63+
public static class Builder {
64+
private final String clientId;
65+
private final String host;
66+
private final OpenIDConnectEndpoints endpoints;
67+
private final IDTokenSource idTokenSource;
68+
private final HttpClient httpClient;
69+
private String accountId;
70+
private String audience;
71+
72+
/**
73+
* Creates a new Builder with required parameters.
74+
*
75+
* @param clientId OAuth client ID.
76+
* @param host Databricks host URL.
77+
* @param endpoints OpenID Connect endpoints configuration.
78+
* @param idTokenSource Source of ID tokens.
79+
* @param httpClient HTTP client for making requests.
80+
*/
81+
public Builder(
82+
String clientId,
83+
String host,
84+
OpenIDConnectEndpoints endpoints,
85+
IDTokenSource idTokenSource,
86+
HttpClient httpClient) {
87+
this.clientId = clientId;
88+
this.host = host;
89+
this.endpoints = endpoints;
90+
this.idTokenSource = idTokenSource;
91+
this.httpClient = httpClient;
92+
}
93+
94+
/**
95+
* Sets the Databricks account ID.
96+
*
97+
* @param accountId The account ID.
98+
* @return This builder instance.
99+
*/
100+
public Builder accountId(String accountId) {
101+
this.accountId = accountId;
102+
return this;
103+
}
104+
105+
/**
106+
* Sets a custom audience value for token exchange.
107+
*
108+
* @param audience The audience value
109+
* @return This builder instance
110+
*/
111+
public Builder audience(String audience) {
112+
this.audience = audience;
113+
return this;
114+
}
115+
116+
/**
117+
* Builds a new DatabricksOAuthTokenSource instance.
118+
*
119+
* @return A new DatabricksOAuthTokenSource.
120+
*/
121+
public DatabricksOAuthTokenSource build() {
122+
return new DatabricksOAuthTokenSource(this);
123+
}
124+
}
125+
126+
/**
127+
* Validates that a value is non-null for required fields. If the value is a string, it also
128+
* checks that it is non-empty.
129+
*
130+
* @param value The value to validate.
131+
* @param fieldName The name of the field being validated.
132+
* @throws IllegalArgumentException when the value is null or an empty string.
133+
*/
134+
private static void validate(Object value, String fieldName) {
135+
if (value == null) {
136+
LOG.error("Required parameter '{}' is null", fieldName);
137+
throw new IllegalArgumentException(
138+
String.format("Required parameter '%s' cannot be null", fieldName));
139+
}
140+
if (value instanceof String && ((String) value).isEmpty()) {
141+
LOG.error("Required parameter '{}' is empty", fieldName);
142+
throw new IllegalArgumentException(
143+
String.format("Required parameter '%s' cannot be empty", fieldName));
144+
}
145+
}
146+
147+
/**
148+
* Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to
149+
* obtain an access token.
150+
*
151+
* @return A Token containing the access token and related information.
152+
* @throws DatabricksException when the token exchange fails.
153+
* @throws IllegalArgumentException when there is an error code in the response or when required
154+
* parameters are missing.
155+
*/
156+
@Override
157+
public Token getToken() {
158+
// Validate all required parameters
159+
validate(clientId, "ClientID");
160+
validate(host, "Host");
161+
validate(endpoints, "Endpoints");
162+
validate(idTokenSource, "IDTokenSource");
163+
validate(httpClient, "HttpClient");
164+
165+
String effectiveAudience = determineAudience();
166+
IDToken idToken = idTokenSource.getIDToken(effectiveAudience);
167+
168+
Map<String, String> params = new HashMap<>();
169+
params.put(GRANT_TYPE_PARAM, GRANT_TYPE);
170+
params.put(SUBJECT_TOKEN_PARAM, idToken.getValue());
171+
params.put(SUBJECT_TOKEN_TYPE_PARAM, SUBJECT_TOKEN_TYPE);
172+
params.put(SCOPE_PARAM, SCOPE);
173+
params.put(CLIENT_ID_PARAM, clientId);
174+
175+
Response rawResponse;
176+
try {
177+
rawResponse = httpClient.execute(new FormRequest(endpoints.getTokenEndpoint(), params));
178+
} catch (IOException e) {
179+
LOG.error(
180+
"Failed to exchange ID token for access token at {}: {}",
181+
endpoints.getTokenEndpoint(),
182+
e.getMessage(),
183+
e);
184+
throw new DatabricksException(
185+
String.format(
186+
"Failed to exchange ID token for access token at %s: %s",
187+
endpoints.getTokenEndpoint(), e.getMessage()),
188+
e);
189+
}
190+
191+
OAuthResponse response;
192+
try {
193+
response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class);
194+
} catch (IOException e) {
195+
LOG.error(
196+
"Failed to parse OAuth response from token endpoint {}: {}",
197+
endpoints.getTokenEndpoint(),
198+
e.getMessage(),
199+
e);
200+
throw new DatabricksException(
201+
String.format(
202+
"Failed to parse OAuth response from token endpoint %s: %s",
203+
endpoints.getTokenEndpoint(), e.getMessage()));
204+
}
205+
206+
if (response.getErrorCode() != null) {
207+
LOG.error(
208+
"Token exchange failed with error: {} - {}",
209+
response.getErrorCode(),
210+
response.getErrorSummary());
211+
throw new IllegalArgumentException(
212+
String.format(
213+
"Token exchange failed with error: %s - %s",
214+
response.getErrorCode(), response.getErrorSummary()));
215+
}
216+
LocalDateTime expiry = LocalDateTime.now().plusSeconds(response.getExpiresIn());
217+
return new Token(
218+
response.getAccessToken(), response.getTokenType(), response.getRefreshToken(), expiry);
219+
}
220+
221+
/**
222+
* Determines the appropriate audience value for token exchange. Uses the following precedence: 1.
223+
* Custom audience if provided 2. Account ID if provided 3. Token endpoint URL as fallback
224+
*
225+
* @return The determined audience value
226+
*/
227+
private String determineAudience() {
228+
if (!Strings.isNullOrEmpty(audience)) {
229+
return audience;
230+
}
231+
232+
if (!Strings.isNullOrEmpty(accountId)) {
233+
return accountId;
234+
}
235+
236+
return endpoints.getTokenEndpoint();
237+
}
238+
}

0 commit comments

Comments
 (0)