From 09796a28feee31ff1bea2b336a4cc46c4063ce24 Mon Sep 17 00:00:00 2001 From: Hector Castejon Diaz Date: Tue, 25 Mar 2025 11:16:17 +0100 Subject: [PATCH] Implement WIF support --- NEXT_CHANGELOG.md | 4 + .../databricks/sdk/core/DatabricksConfig.java | 16 +++ .../sdk/core/DefaultCredentialsProvider.java | 6 +- ...reServicePrincipalCredentialsProvider.java | 2 +- .../sdk/core/oauth/ClientCredentials.java | 28 +++-- .../DatabricksWifCredentialsProvider.java | 62 ++++++++++ .../core/oauth/GitHubOidcTokenSupplier.java | 72 +++++++++++ .../sdk/integration/DatabricksWifIT.java | 112 ++++++++++++++++++ 8 files changed, 285 insertions(+), 17 deletions(-) create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksWifCredentialsProvider.java create mode 100644 databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GitHubOidcTokenSupplier.java create mode 100644 databricks-sdk-java/src/test/java/com/databricks/sdk/integration/DatabricksWifIT.java diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 3637e7375..1c5eae287 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,10 @@ ## Release v0.43.0 ### New Features and Improvements +* Introduce support for Databricks Workload Identity Federation in GitHub workflows ([423](https://github.com/databricks/databricks-sdk-java/pull/423)). + See README.md for instructions. +* [Breaking] Users running their workflows in GitHub Actions, which use Cloud native authentication and also have a `DATABRICKS_CLIENT_ID` and `DATABRICKS_HOST` + environment variables set may see their authentication start failing due to the order in which the SDK tries different authentication methods. ### Bug Fixes diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index 2943ce82e..cfe6a104f 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -141,6 +141,13 @@ public class DatabricksConfig { private DatabricksEnvironment databricksEnvironment; + /** + * When using Workload Identity Federation, the audience to specify when fetching an ID token from + * the ID token supplier. + */ + @ConfigAttribute(env = "TOKEN_AUDIENCE") + private String tokenAudience; + public Environment getEnv() { return env; } @@ -512,6 +519,15 @@ public DatabricksConfig setHttpClient(HttpClient httpClient) { return this; } + public String getTokenAudience() { + return tokenAudience; + } + + public DatabricksConfig setTokenAudience(String tokenAudience) { + this.tokenAudience = tokenAudience; + return this; + } + public boolean isAzure() { if (azureWorkspaceResourceId != null) { return true; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index f16cded39..58ad5c21e 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -1,9 +1,6 @@ package com.databricks.sdk.core; -import com.databricks.sdk.core.oauth.AzureGithubOidcCredentialsProvider; -import com.databricks.sdk.core.oauth.AzureServicePrincipalCredentialsProvider; -import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider; -import com.databricks.sdk.core.oauth.OAuthM2MServicePrincipalCredentialsProvider; +import com.databricks.sdk.core.oauth.*; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -18,6 +15,7 @@ public class DefaultCredentialsProvider implements CredentialsProvider { PatCredentialsProvider.class, BasicCredentialsProvider.class, OAuthM2MServicePrincipalCredentialsProvider.class, + DatabricksWifCredentialsProvider.class, AzureGithubOidcCredentialsProvider.class, AzureServicePrincipalCredentialsProvider.class, AzureCliCredentialsProvider.class, diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index ca85c2031..432046777 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -63,7 +63,7 @@ private static RefreshableTokenSource tokenSourceFor(DatabricksConfig config, St .withClientId(config.getAzureClientId()) .withClientSecret(config.getAzureClientSecret()) .withTokenUrl(tokenUrl) - .withEndpointParameters(endpointParams) + .withEndpointParametersSupplier(() -> endpointParams) .withAuthParameterPosition(AuthParameterPosition.BODY) .build(); } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java index 7709a7b10..327d8fe18 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java @@ -3,6 +3,7 @@ import com.databricks.sdk.core.commons.CommonsHttpClient; import com.databricks.sdk.core.http.HttpClient; import java.util.*; +import java.util.function.Supplier; /** * An implementation of RefreshableTokenSource implementing the client_credentials OAuth grant type. @@ -18,7 +19,10 @@ public static class Builder { private String clientSecret; private String tokenUrl; private HttpClient hc = new CommonsHttpClient.Builder().withTimeoutSeconds(30).build(); - private Map endpointParams = Collections.emptyMap(); + // Endpoint parameters can include tokens with expiration which + // may need to be refreshed. This supplier will be called each time + // that the credentials are refreshed. + private Supplier> endpointParamsSupplier = null; private List scopes = Collections.emptyList(); private AuthParameterPosition position = AuthParameterPosition.BODY; @@ -32,13 +36,14 @@ public Builder withClientSecret(String clientSecret) { return this; } - public Builder withTokenUrl(String tokenUrl) { - this.tokenUrl = tokenUrl; + public Builder withEndpointParametersSupplier( + Supplier> endpointParamsSupplier) { + this.endpointParamsSupplier = endpointParamsSupplier; return this; } - public Builder withEndpointParameters(Map params) { - this.endpointParams = params; + public Builder withTokenUrl(String tokenUrl) { + this.tokenUrl = tokenUrl; return this; } @@ -59,10 +64,9 @@ public Builder withHttpClient(HttpClient hc) { public ClientCredentials build() { Objects.requireNonNull(this.clientId, "clientId must be specified"); - Objects.requireNonNull(this.clientSecret, "clientSecret must be specified"); Objects.requireNonNull(this.tokenUrl, "tokenUrl must be specified"); return new ClientCredentials( - hc, clientId, clientSecret, tokenUrl, endpointParams, scopes, position); + hc, clientId, clientSecret, tokenUrl, endpointParamsSupplier, scopes, position); } } @@ -70,23 +74,23 @@ public ClientCredentials build() { private String clientId; private String clientSecret; private String tokenUrl; - private Map endpointParams; private List scopes; private AuthParameterPosition position; + private Supplier> endpointParamsSupplier; private ClientCredentials( HttpClient hc, String clientId, String clientSecret, String tokenUrl, - Map endpointParams, + Supplier> endpointParamsSupplier, List scopes, AuthParameterPosition position) { this.hc = hc; this.clientId = clientId; this.clientSecret = clientSecret; this.tokenUrl = tokenUrl; - this.endpointParams = endpointParams; + this.endpointParamsSupplier = endpointParamsSupplier; this.scopes = scopes; this.position = position; } @@ -98,8 +102,8 @@ protected Token refresh() { if (scopes != null) { params.put("scope", String.join(" ", scopes)); } - if (endpointParams != null) { - params.putAll(endpointParams); + if (endpointParamsSupplier != null) { + params.putAll(endpointParamsSupplier.get()); } return retrieveToken(hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position); } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksWifCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksWifCredentialsProvider.java new file mode 100644 index 000000000..f5381cde3 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksWifCredentialsProvider.java @@ -0,0 +1,62 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.CredentialsProvider; +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.HeaderFactory; +import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * DatabricksWIFCredentials uses a Token Supplier to get a JWT Token and exchanges it for a + * Databricks Token. Supported suppliers: - GitHub OIDC + */ +public class DatabricksWifCredentialsProvider implements CredentialsProvider { + + @Override + public String authType() { + return "databricks-wif"; + } + + @Override + public HeaderFactory configure(DatabricksConfig config) throws DatabricksException { + GitHubOidcTokenSupplier idTokenProvider = new GitHubOidcTokenSupplier(config); + + if (!idTokenProvider.enabled() || config.getHost() == null || config.getClientId() == null) { + return null; + } + + String endpointUrl; + + try { + endpointUrl = config.getOidcEndpoints().getTokenEndpoint(); + } catch (IOException e) { + throw new DatabricksException("Unable to fetch OIDC endpoint: " + e.getMessage(), e); + } + + ClientCredentials clientCredentials = + new ClientCredentials.Builder() + .withHttpClient(config.getHttpClient()) + .withClientId(config.getClientId()) + .withTokenUrl(endpointUrl) + .withScopes(Collections.singletonList("all-apis")) + .withAuthParameterPosition(AuthParameterPosition.HEADER) + .withEndpointParametersSupplier( + () -> + new ImmutableMap.Builder() + .put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt") + .put("subject_token", idTokenProvider.getOidcToken()) + .put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + .build()) + .build(); + + return () -> { + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + clientCredentials.getToken().getAccessToken()); + return headers; + }; + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GitHubOidcTokenSupplier.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GitHubOidcTokenSupplier.java new file mode 100644 index 000000000..b463bfaee --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GitHubOidcTokenSupplier.java @@ -0,0 +1,72 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.io.IOException; + +public class GitHubOidcTokenSupplier { + + private final ObjectMapper mapper = new ObjectMapper(); + + private final DatabricksConfig config; + + public GitHubOidcTokenSupplier(DatabricksConfig config) { + this.config = config; + } + + /** Checks if the required parameters are present to request a GitHub's OIDC token. */ + public Boolean enabled() { + return config.getActionsIdTokenRequestUrl() != null + && config.getActionsIdTokenRequestToken() != null; + } + + /** + * Requests a GitHub's OIDC token. + * + * @return A GitHub OIDC token. + */ + public String getOidcToken() { + if (!enabled()) { + throw new DatabricksException("Failed to request ID token: missing required parameters"); + } + + String requestUrl = config.getActionsIdTokenRequestUrl(); + if (config.getTokenAudience() != null) { + requestUrl += "&audience=" + config.getTokenAudience(); + } + + Request req = + new Request("GET", requestUrl) + .withHeader("Authorization", "Bearer " + config.getActionsIdTokenRequestToken()); + + Response resp; + try { + resp = config.getHttpClient().execute(req); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request ID token from " + requestUrl + ":" + e.getMessage(), e); + } + + if (resp.getStatusCode() != 200) { + throw new DatabricksException( + "Failed to request ID token: status code " + + resp.getStatusCode() + + ", response body: " + + resp.getBody().toString()); + } + + ObjectNode jsonResp; + try { + jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request ID token: corrupted token: " + e.getMessage()); + } + + return jsonResp.get("value").textValue(); + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/integration/DatabricksWifIT.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/integration/DatabricksWifIT.java new file mode 100644 index 000000000..05bbf185f --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/integration/DatabricksWifIT.java @@ -0,0 +1,112 @@ +package com.databricks.sdk.integration; + +import com.databricks.sdk.AccountClient; +import com.databricks.sdk.WorkspaceClient; +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.integration.framework.EnvContext; +import com.databricks.sdk.integration.framework.EnvOrSkip; +import com.databricks.sdk.integration.framework.EnvTest; +import com.databricks.sdk.service.iam.*; +import com.databricks.sdk.service.oauth2.CreateServicePrincipalFederationPolicyRequest; +import com.databricks.sdk.service.oauth2.FederationPolicy; +import com.databricks.sdk.service.oauth2.OidcFederationPolicy; +import java.util.Collections; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@EnvContext("ucacct") +@ExtendWith(EnvTest.class) +public class DatabricksWifIT { + // This test cannot run on local machines. We use ACTIONS_ID_TOKEN_REQUEST_URL + // to determine whether we are running in the GitHub Actions, + // and we skip the test if we are not. + @Test + void workspace( + AccountClient a, + @EnvOrSkip("TEST_WORKSPACE_ID") String workspaceId, + @EnvOrSkip("TEST_WORKSPACE_URL") String workspaceUrl, + @EnvOrSkip("ACTIONS_ID_TOKEN_REQUEST_URL") String userId) { + String spName = "java-sdk-sp" + UUID.randomUUID(); + + // Create SP with access to the workspace + ServicePrincipal sp = + a.servicePrincipals().create(new ServicePrincipal().setActive(true).setDisplayName(spName)); + + a.workspaceAssignment() + .update( + new UpdateWorkspaceAssignments() + .setWorkspaceId(Long.valueOf(workspaceId)) + .setPrincipalId(Long.valueOf(sp.getId())) + .setPermissions(Collections.singleton(WorkspacePermission.ADMIN))); + + // Setup Federation Policy + OidcFederationPolicy policy = + new OidcFederationPolicy() + .setIssuer("https://token.actions.githubusercontent.com") + .setSubject("repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests") + .setAudiences(Collections.singleton("https://github.com/databricks-eng")); + + a.servicePrincipalFederationPolicy() + .create( + new CreateServicePrincipalFederationPolicyRequest() + .setServicePrincipalId(Long.valueOf(sp.getId())) + .setPolicy(new FederationPolicy().setOidcPolicy(policy))); + + // Test WIF login + DatabricksConfig config = + new DatabricksConfig() + .setHost(workspaceUrl) + .setClientId(sp.getApplicationId()) + .setAuthType("databricks-wif") + .setTokenAudience("https://github.com/databricks-eng"); + + WorkspaceClient ws = new WorkspaceClient(config); + + ws.currentUser().me(); + } + + // This test cannot run on local machines. We use ACTIONS_ID_TOKEN_REQUEST_URL + // to determine whether we are running in the GitHub Actions, + // and we skip the test if we are not. + @Test + void account(AccountClient a, @EnvOrSkip("ACTIONS_ID_TOKEN_REQUEST_URL") String userId) { + String spName = "java-sdk-sp" + UUID.randomUUID(); + + // Create SP + ServicePrincipal sp = + a.servicePrincipals() + .create( + new ServicePrincipal() + .setActive(true) + .setDisplayName(spName) + .setRoles(Collections.singleton(new ComplexValue().setValue("account_admin")))); + + // Setup Federation Policy + OidcFederationPolicy policy = + new OidcFederationPolicy() + .setIssuer("https://token.actions.githubusercontent.com") + .setSubject("repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests") + .setAudiences(Collections.singleton("https://github.com/databricks-eng")); + + a.servicePrincipalFederationPolicy() + .create( + new CreateServicePrincipalFederationPolicyRequest() + .setServicePrincipalId(Long.valueOf(sp.getId())) + .setPolicy(new FederationPolicy().setOidcPolicy(policy))); + + // Test WIF login + DatabricksConfig config = + new DatabricksConfig() + .setHost(a.config().getHost()) + .setAccountId(a.config().getAccountId()) + .setClientId(sp.getApplicationId()) + .setAuthType("databricks-wif") + .setTokenAudience("https://github.com/databricks-eng"); + + AccountClient ac = new AccountClient(config); + + Iterable groups = ac.groups().list(new ListAccountGroupsRequest()); + groups.iterator().next(); + } +}