diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 395a1a32c..f51d9f9b5 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,10 @@ ## Release v0.48.0 ### New Features and Improvements + * Introduce support for Databricks Workload Identity Federation in GitHub workflows ([423](https://github.com/databricks/databricks-sdk-java/pull/423)). + See README.md for instructions. + * [Breaking] Users running their workflows in GitHub Actions, which use Cloud native authentication and also have a `DATABRICKS_CLIENT_ID` and `DATABRICKS_HOST` + environment variables set may see their authentication start failing due to the order in which the SDK tries different authentication methods. ### Bug Fixes diff --git a/README.md b/README.md index 904a4d545..8a8d2858a 100644 --- a/README.md +++ b/README.md @@ -116,18 +116,18 @@ Depending on the Databricks authentication method, the SDK uses the following in ### Databricks native authentication -By default, the Databricks SDK for Java initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks basic (username/password) authentication (`auth_type="basic"` argument). +By default, the Databricks SDK for Java initially tries [Databricks token authentication](https://docs.databricks.com/dev-tools/api/latest/authentication.html) (`auth_type='pat'` argument). If the SDK is unsuccessful, it then tries Databricks Workload Identity Federation (WIF) authentication using OIDC (`auth_type="github-oidc"` argument). - For Databricks token authentication, you must provide `host` and `token`; or their environment variable or `.databrickscfg` file field equivalents. -- For Databricks basic authentication, you must provide `host`, `username`, and `password` _(for AWS workspace-level operations)_; or `host`, `account_id`, `username`, and `password` _(for AWS, Azure, or GCP account-level operations)_; or their environment variable or `.databrickscfg` file field equivalents. +- For Databricks OIDC authentication, you must provide the `host`, `client_id` and `token_audience` _(optional)_ either directly, through the corresponding environment variables, or in your `.databrickscfg` configuration file. | Argument | Description | Environment variable | |--------------|-------------|-------------------| | `host` | _(String)_ The Databricks host URL for either the Databricks workspace endpoint or the Databricks accounts endpoint. | `DATABRICKS_HOST` | | `account_id` | _(String)_ The Databricks account ID for the Databricks accounts endpoint. Only has effect when `Host` is either `https://accounts.cloud.databricks.com/` _(AWS)_, `https://accounts.azuredatabricks.net/` _(Azure)_, or `https://accounts.gcp.databricks.com/` _(GCP)_. | `DATABRICKS_ACCOUNT_ID` | | `token` | _(String)_ The Databricks personal access token (PAT) _(AWS, Azure, and GCP)_ or Azure Active Directory (Azure AD) token _(Azure)_. | `DATABRICKS_TOKEN` | -| `username` | _(String)_ The Databricks username part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_USERNAME` | -| `password` | _(String)_ The Databricks password part of basic authentication. Only possible when `Host` is `*.cloud.databricks.com` _(AWS)_. | `DATABRICKS_PASSWORD` | +| `client_id` | _(String)_ The Databricks Service Principal Application ID. | `DATABRICKS_CLIENT_ID` | +| `token_audience` | _(String)_ When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier. | `TOKEN_AUDIENCE` | For example, to use Databricks token authentication: diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index fa89f5041..fcb79c87b 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -141,6 +141,13 @@ public class DatabricksConfig { private DatabricksEnvironment databricksEnvironment; + /** + * When using Workload Identity Federation, the audience to specify when fetching an ID token from + * the ID token supplier. + */ + @ConfigAttribute(env = "TOKEN_AUDIENCE") + private String tokenAudience; + public Environment getEnv() { return env; } @@ -512,6 +519,15 @@ public DatabricksConfig setHttpClient(HttpClient httpClient) { return this; } + public String getTokenAudience() { + return tokenAudience; + } + + public DatabricksConfig setTokenAudience(String tokenAudience) { + this.tokenAudience = tokenAudience; + return this; + } + public boolean isAzure() { if (azureWorkspaceResourceId != null) { return true; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index f16cded39..b8f4d7867 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -1,9 +1,6 @@ package com.databricks.sdk.core; -import com.databricks.sdk.core.oauth.AzureGithubOidcCredentialsProvider; -import com.databricks.sdk.core.oauth.AzureServicePrincipalCredentialsProvider; -import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider; -import com.databricks.sdk.core.oauth.OAuthM2MServicePrincipalCredentialsProvider; +import com.databricks.sdk.core.oauth.*; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -18,6 +15,7 @@ public class DefaultCredentialsProvider implements CredentialsProvider { PatCredentialsProvider.class, BasicCredentialsProvider.class, OAuthM2MServicePrincipalCredentialsProvider.class, + GithubOidcCredentialsProvider.class, AzureGithubOidcCredentialsProvider.class, AzureServicePrincipalCredentialsProvider.class, AzureCliCredentialsProvider.class, diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index ca85c2031..432046777 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -63,7 +63,7 @@ private static RefreshableTokenSource tokenSourceFor(DatabricksConfig config, St .withClientId(config.getAzureClientId()) .withClientSecret(config.getAzureClientSecret()) .withTokenUrl(tokenUrl) - .withEndpointParameters(endpointParams) + .withEndpointParametersSupplier(() -> endpointParams) .withAuthParameterPosition(AuthParameterPosition.BODY) .build(); } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java index 7709a7b10..1c4b7d6de 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java @@ -3,6 +3,7 @@ import com.databricks.sdk.core.commons.CommonsHttpClient; import com.databricks.sdk.core.http.HttpClient; import java.util.*; +import java.util.function.Supplier; /** * An implementation of RefreshableTokenSource implementing the client_credentials OAuth grant type. @@ -18,7 +19,11 @@ public static class Builder { private String clientSecret; private String tokenUrl; private HttpClient hc = new CommonsHttpClient.Builder().withTimeoutSeconds(30).build(); - private Map endpointParams = Collections.emptyMap(); + + // Endpoint parameters can include tokens with expiration which + // may need to be refreshed. This supplier will be called each time + // the credentials are refreshed. + private Supplier> endpointParamsSupplier = null; private List scopes = Collections.emptyList(); private AuthParameterPosition position = AuthParameterPosition.BODY; @@ -32,13 +37,14 @@ public Builder withClientSecret(String clientSecret) { return this; } - public Builder withTokenUrl(String tokenUrl) { - this.tokenUrl = tokenUrl; + public Builder withEndpointParametersSupplier( + Supplier> endpointParamsSupplier) { + this.endpointParamsSupplier = endpointParamsSupplier; return this; } - public Builder withEndpointParameters(Map params) { - this.endpointParams = params; + public Builder withTokenUrl(String tokenUrl) { + this.tokenUrl = tokenUrl; return this; } @@ -59,10 +65,9 @@ public Builder withHttpClient(HttpClient hc) { public ClientCredentials build() { Objects.requireNonNull(this.clientId, "clientId must be specified"); - Objects.requireNonNull(this.clientSecret, "clientSecret must be specified"); Objects.requireNonNull(this.tokenUrl, "tokenUrl must be specified"); return new ClientCredentials( - hc, clientId, clientSecret, tokenUrl, endpointParams, scopes, position); + hc, clientId, clientSecret, tokenUrl, endpointParamsSupplier, scopes, position); } } @@ -70,23 +75,23 @@ public ClientCredentials build() { private String clientId; private String clientSecret; private String tokenUrl; - private Map endpointParams; private List scopes; private AuthParameterPosition position; + private Supplier> endpointParamsSupplier; private ClientCredentials( HttpClient hc, String clientId, String clientSecret, String tokenUrl, - Map endpointParams, + Supplier> endpointParamsSupplier, List scopes, AuthParameterPosition position) { this.hc = hc; this.clientId = clientId; this.clientSecret = clientSecret; this.tokenUrl = tokenUrl; - this.endpointParams = endpointParams; + this.endpointParamsSupplier = endpointParamsSupplier; this.scopes = scopes; this.position = position; } @@ -98,8 +103,8 @@ protected Token refresh() { if (scopes != null) { params.put("scope", String.join(" ", scopes)); } - if (endpointParams != null) { - params.putAll(endpointParams); + if (endpointParamsSupplier != null) { + params.putAll(endpointParamsSupplier.get()); } return retrieveToken(hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position); } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GitHubOidcTokenSupplier.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GitHubOidcTokenSupplier.java new file mode 100644 index 000000000..523c0df1f --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GitHubOidcTokenSupplier.java @@ -0,0 +1,79 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.io.IOException; + +public class GitHubOidcTokenSupplier { + + private final ObjectMapper mapper = new ObjectMapper(); + private final HttpClient httpClient; + private final String idTokenRequestUrl; + private final String idTokenRequestToken; + private final String tokenAudience; + + public GitHubOidcTokenSupplier( + HttpClient httpClient, + String idTokenRequestUrl, + String idTokenRequestToken, + String tokenAudience) { + this.httpClient = httpClient; + this.idTokenRequestUrl = idTokenRequestUrl; + this.idTokenRequestToken = idTokenRequestToken; + this.tokenAudience = tokenAudience; + } + + /** Checks if the required parameters are present to request a GitHub's OIDC token. */ + public Boolean enabled() { + return idTokenRequestUrl != null && idTokenRequestToken != null; + } + + /** + * Requests a GitHub's OIDC token. + * + * @return A GitHub OIDC token. + */ + public String getOidcToken() { + if (!enabled()) { + throw new DatabricksException("Failed to request ID token: missing required parameters"); + } + + String requestUrl = idTokenRequestUrl; + if (tokenAudience != null) { + requestUrl += "&audience=" + tokenAudience; + } + + Request req = + new Request("GET", requestUrl).withHeader("Authorization", "Bearer " + idTokenRequestToken); + + Response resp; + try { + resp = httpClient.execute(req); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request ID token from " + requestUrl + ":" + e.getMessage(), e); + } + + if (resp.getStatusCode() != 200) { + throw new DatabricksException( + "Failed to request ID token: status code " + + resp.getStatusCode() + + ", response body: " + + resp.getBody().toString()); + } + + ObjectNode jsonResp; + try { + jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request ID token: corrupted token: " + e.getMessage()); + } + + return jsonResp.get("value").textValue(); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubOidcCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubOidcCredentialsProvider.java new file mode 100644 index 000000000..eeb70797e --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubOidcCredentialsProvider.java @@ -0,0 +1,67 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.CredentialsProvider; +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.HeaderFactory; +import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * GithubOidcCredentialsProvider uses a Token Supplier to get a GitHub OIDC JWT Token and exchanges + * it for a Databricks Token. + */ +public class GithubOidcCredentialsProvider implements CredentialsProvider { + + @Override + public String authType() { + return "github-oidc"; + } + + @Override + public HeaderFactory configure(DatabricksConfig config) throws DatabricksException { + GitHubOidcTokenSupplier idTokenProvider = + new GitHubOidcTokenSupplier( + config.getHttpClient(), + config.getActionsIdTokenRequestUrl(), + config.getActionsIdTokenRequestToken(), + config.getTokenAudience()); + + if (!idTokenProvider.enabled() || config.getHost() == null || config.getClientId() == null) { + return null; + } + + String endpointUrl; + + try { + endpointUrl = config.getOidcEndpoints().getTokenEndpoint(); + } catch (IOException e) { + throw new DatabricksException("Unable to fetch OIDC endpoint: " + e.getMessage(), e); + } + + ClientCredentials clientCredentials = + new ClientCredentials.Builder() + .withHttpClient(config.getHttpClient()) + .withClientId(config.getClientId()) + .withTokenUrl(endpointUrl) + .withScopes(Collections.singletonList("all-apis")) + .withAuthParameterPosition(AuthParameterPosition.HEADER) + .withEndpointParametersSupplier( + () -> + new ImmutableMap.Builder() + .put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt") + .put("subject_token", idTokenProvider.getOidcToken()) + .put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + .build()) + .build(); + + return () -> { + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + clientCredentials.getToken().getAccessToken()); + return headers; + }; + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/integration/DatabricksWifIT.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/integration/DatabricksWifIT.java new file mode 100644 index 000000000..b50f2267d --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/integration/DatabricksWifIT.java @@ -0,0 +1,112 @@ +package com.databricks.sdk.integration; + +import com.databricks.sdk.AccountClient; +import com.databricks.sdk.WorkspaceClient; +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.integration.framework.EnvContext; +import com.databricks.sdk.integration.framework.EnvOrSkip; +import com.databricks.sdk.integration.framework.EnvTest; +import com.databricks.sdk.service.iam.*; +import com.databricks.sdk.service.oauth2.CreateServicePrincipalFederationPolicyRequest; +import com.databricks.sdk.service.oauth2.FederationPolicy; +import com.databricks.sdk.service.oauth2.OidcFederationPolicy; +import java.util.Collections; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@EnvContext("ucacct") +@ExtendWith(EnvTest.class) +public class DatabricksWifIT { + // This test cannot run on local machines. We use ACTIONS_ID_TOKEN_REQUEST_URL + // to determine whether we are running in the GitHub Actions, + // and we skip the test if we are not. + @Test + void workspace( + AccountClient a, + @EnvOrSkip("TEST_WORKSPACE_ID") String workspaceId, + @EnvOrSkip("TEST_WORKSPACE_URL") String workspaceUrl, + @EnvOrSkip("ACTIONS_ID_TOKEN_REQUEST_URL") String userId) { + String spName = "java-sdk-sp" + UUID.randomUUID(); + + // Create SP with access to the workspace + ServicePrincipal sp = + a.servicePrincipals().create(new ServicePrincipal().setActive(true).setDisplayName(spName)); + + a.workspaceAssignment() + .update( + new UpdateWorkspaceAssignments() + .setWorkspaceId(Long.valueOf(workspaceId)) + .setPrincipalId(Long.valueOf(sp.getId())) + .setPermissions(Collections.singleton(WorkspacePermission.ADMIN))); + + // Setup Federation Policy + OidcFederationPolicy policy = + new OidcFederationPolicy() + .setIssuer("https://token.actions.githubusercontent.com") + .setSubject("repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests") + .setAudiences(Collections.singleton("https://github.com/databricks-eng")); + + a.servicePrincipalFederationPolicy() + .create( + new CreateServicePrincipalFederationPolicyRequest() + .setServicePrincipalId(Long.valueOf(sp.getId())) + .setPolicy(new FederationPolicy().setOidcPolicy(policy))); + + // Test WIF login + DatabricksConfig config = + new DatabricksConfig() + .setHost(workspaceUrl) + .setClientId(sp.getApplicationId()) + .setAuthType("github-oidc") + .setTokenAudience("https://github.com/databricks-eng"); + + WorkspaceClient ws = new WorkspaceClient(config); + + ws.currentUser().me(); + } + + // This test cannot run on local machines. We use ACTIONS_ID_TOKEN_REQUEST_URL + // to determine whether we are running in the GitHub Actions, + // and we skip the test if we are not. + @Test + void account(AccountClient a, @EnvOrSkip("ACTIONS_ID_TOKEN_REQUEST_URL") String userId) { + String spName = "java-sdk-sp" + UUID.randomUUID(); + + // Create SP + ServicePrincipal sp = + a.servicePrincipals() + .create( + new ServicePrincipal() + .setActive(true) + .setDisplayName(spName) + .setRoles(Collections.singleton(new ComplexValue().setValue("account_admin")))); + + // Setup Federation Policy + OidcFederationPolicy policy = + new OidcFederationPolicy() + .setIssuer("https://token.actions.githubusercontent.com") + .setSubject("repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests") + .setAudiences(Collections.singleton("https://github.com/databricks-eng")); + + a.servicePrincipalFederationPolicy() + .create( + new CreateServicePrincipalFederationPolicyRequest() + .setServicePrincipalId(Long.valueOf(sp.getId())) + .setPolicy(new FederationPolicy().setOidcPolicy(policy))); + + // Test WIF login + DatabricksConfig config = + new DatabricksConfig() + .setHost(a.config().getHost()) + .setAccountId(a.config().getAccountId()) + .setClientId(sp.getApplicationId()) + .setAuthType("github-oidc") + .setTokenAudience("https://github.com/databricks-eng"); + + AccountClient ac = new AccountClient(config); + + Iterable groups = ac.groups().list(new ListAccountGroupsRequest()); + groups.iterator().next(); + } +}