Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,45 @@
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Implementation of TokenSource that handles OAuth token exchange for Databricks authentication.
* This class manages the OAuth token exchange flow using ID tokens to obtain access tokens.
*/
public class DatabricksOAuthTokenSource implements TokenSource {
/** OAuth client ID used for token exchange */
private static final Logger LOG = LoggerFactory.getLogger(DatabricksOAuthTokenSource.class);

/** OAuth client ID used for token exchange. */
private final String clientId;
/** Databricks account ID, used as audience if provided */
/** Databricks host URL. */
private final String host;
/** Databricks account ID, used as audience if provided. */
private final String accountId;
/** OpenID Connect endpoints configuration */
/** OpenID Connect endpoints configuration. */
private final OpenIDConnectEndpoints endpoints;
/** Custom audience value for token exchange */
/** Custom audience value for token exchange. */
private final String audience;
/** Source of ID tokens used in token exchange */
/** Source of ID tokens used in token exchange. */
private final IDTokenSource idTokenSource;
/** HTTP client for making token exchange requests */
/** HTTP client for making token exchange requests. */
private final HttpClient httpClient;

private static final String GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange";
private static final String SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt";
private static final String SCOPE = "all-apis";
private static final String GRANT_TYPE_PARAM = "grant_type";
private static final String SUBJECT_TOKEN_PARAM = "subject_token";
private static final String SUBJECT_TOKEN_TYPE_PARAM = "subject_token_type";
private static final String SCOPE_PARAM = "scope";
private static final String CLIENT_ID_PARAM = "client_id";

private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

private DatabricksOAuthTokenSource(Builder builder) {
this.clientId = builder.clientId;
this.host = builder.host;
this.accountId = builder.accountId;
this.endpoints = builder.endpoints;
this.audience = builder.audience;
Expand All @@ -39,8 +57,8 @@ private DatabricksOAuthTokenSource(Builder builder) {
}

/**
* Builder class for constructing DatabricksOAuthTokenSource instances. Provides a fluent
* interface for setting required and optional parameters.
* Builder class for constructing DatabricksOAuthTokenSource instances. Provides a flexible way to
* set required and optional parameters.
*/
public static class Builder {
private final String clientId;
Expand All @@ -51,43 +69,21 @@ public static class Builder {
private String accountId;
private String audience;

/**
* Validates that a value is non-empty and non-null for required fields.
*
* @param value The value to validate
* @param fieldName The name of the field being validated
* @throws IllegalArgumentException if validation fails
*/
private static void validate(Object value, String fieldName) {
if (value == null) {
throw new IllegalArgumentException(fieldName + " must be non-null");
}
if (value instanceof String && ((String) value).isEmpty()) {
throw new IllegalArgumentException(fieldName + " must be non-empty");
}
}

/**
* Creates a new Builder with required parameters.
*
* @param clientId OAuth client ID
* @param host Databricks host URL
* @param endpoints OpenID Connect endpoints configuration
* @param idTokenSource Source of ID tokens
* @param httpClient HTTP client for making requests
* @param clientId OAuth client ID.
* @param host Databricks host URL.
* @param endpoints OpenID Connect endpoints configuration.
* @param idTokenSource Source of ID tokens.
* @param httpClient HTTP client for making requests.
*/
public Builder(
String clientId,
String host,
OpenIDConnectEndpoints endpoints,
IDTokenSource idTokenSource,
HttpClient httpClient) {
validate(clientId, "ClientID");
validate(host, "Host");
validate(endpoints, "Endpoints");
validate(idTokenSource, "IDTokenSource");
validate(httpClient, "HttpClient");

this.clientId = clientId;
this.host = host;
this.endpoints = endpoints;
Expand All @@ -98,11 +94,10 @@ public Builder(
/**
* Sets the Databricks account ID.
*
* @param accountId The account ID
* @return This builder instance
* @param accountId The account ID.
* @return This builder instance.
*/
public Builder accountId(String accountId) {
validate(accountId, "AccountID");
this.accountId = accountId;
return this;
}
Expand All @@ -114,44 +109,78 @@ public Builder accountId(String accountId) {
* @return This builder instance
*/
public Builder audience(String audience) {
validate(audience, "Audience");
this.audience = audience;
return this;
}

/**
* Builds a new DatabricksOAuthTokenSource instance.
*
* @return A new DatabricksOAuthTokenSource
* @return A new DatabricksOAuthTokenSource.
*/
public DatabricksOAuthTokenSource build() {
return new DatabricksOAuthTokenSource(this);
}
}

/**
* Validates that a value is non-null for required fields. If the value is a string, it also
* checks that it is non-empty.
*
* @param value The value to validate.
* @param fieldName The name of the field being validated.
* @throws IllegalArgumentException when the value is null or an empty string.
*/
private static void validate(Object value, String fieldName) {
if (value == null) {
LOG.error("Required parameter '{}' is null", fieldName);
throw new IllegalArgumentException(
String.format("Required parameter '%s' cannot be null", fieldName));
}
if (value instanceof String && ((String) value).isEmpty()) {
LOG.error("Required parameter '{}' is empty", fieldName);
throw new IllegalArgumentException(
String.format("Required parameter '%s' cannot be empty", fieldName));
}
}

/**
* Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to
* obtain an access token.
*
* @return A Token containing the access token and related information
* @throws DatabricksException if token exchange fails
* @return A Token containing the access token and related information.
* @throws DatabricksException when the token exchange fails.
* @throws IllegalArgumentException when there is an error code in the response or when required
* parameters are missing.
*/
@Override
public Token getToken() {
// Validate all required parameters
validate(clientId, "ClientID");
validate(host, "Host");
validate(endpoints, "Endpoints");
validate(idTokenSource, "IDTokenSource");
validate(httpClient, "HttpClient");

String effectiveAudience = determineAudience();
IDToken idToken = idTokenSource.getIDToken(effectiveAudience);

Map<String, String> params = new HashMap<>();
params.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange");
params.put("subject_token", idToken.getValue());
params.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt");
params.put("scope", "all-apis");
params.put("client_id", clientId);
params.put(GRANT_TYPE_PARAM, GRANT_TYPE);
params.put(SUBJECT_TOKEN_PARAM, idToken.getValue());
params.put(SUBJECT_TOKEN_TYPE_PARAM, SUBJECT_TOKEN_TYPE);
params.put(SCOPE_PARAM, SCOPE);
params.put(CLIENT_ID_PARAM, clientId);

Response rawResponse;
try {
rawResponse = httpClient.execute(new FormRequest(endpoints.getTokenEndpoint(), params));
} catch (IOException e) {
LOG.error(
"Failed to exchange ID token for access token at {}: {}",
endpoints.getTokenEndpoint(),
e.getMessage(),
e);
throw new DatabricksException(
String.format(
"Failed to exchange ID token for access token at %s: %s",
Expand All @@ -161,15 +190,24 @@ public Token getToken() {

OAuthResponse response;
try {
response = new ObjectMapper().readValue(rawResponse.getBody(), OAuthResponse.class);
response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class);
} catch (IOException e) {
LOG.error(
"Failed to parse OAuth response from token endpoint {}: {}",
endpoints.getTokenEndpoint(),
e.getMessage(),
e);
throw new DatabricksException(
String.format(
"Failed to parse OAuth response from token endpoint %s: %s",
endpoints.getTokenEndpoint(), e.getMessage()));
}

if (response.getErrorCode() != null) {
LOG.error(
"Token exchange failed with error: {} - {}",
response.getErrorCode(),
response.getErrorSummary());
throw new IllegalArgumentException(
String.format(
"Token exchange failed with error: %s - %s",
Expand Down
Loading
Loading