Skip to content

Commit 09796a2

Browse files
committed
Implement WIF support
1 parent 289fd2a commit 09796a2

File tree

8 files changed

+285
-17
lines changed

8 files changed

+285
-17
lines changed

NEXT_CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
## Release v0.43.0
44

55
### New Features and Improvements
6+
* Introduce support for Databricks Workload Identity Federation in GitHub workflows ([423](https://github.com/databricks/databricks-sdk-java/pull/423)).
7+
See README.md for instructions.
8+
* [Breaking] Users running their workflows in GitHub Actions, which use Cloud native authentication and also have a `DATABRICKS_CLIENT_ID` and `DATABRICKS_HOST`
9+
environment variables set may see their authentication start failing due to the order in which the SDK tries different authentication methods.
610

711
### Bug Fixes
812

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ public class DatabricksConfig {
141141

142142
private DatabricksEnvironment databricksEnvironment;
143143

144+
/**
145+
* When using Workload Identity Federation, the audience to specify when fetching an ID token from
146+
* the ID token supplier.
147+
*/
148+
@ConfigAttribute(env = "TOKEN_AUDIENCE")
149+
private String tokenAudience;
150+
144151
public Environment getEnv() {
145152
return env;
146153
}
@@ -512,6 +519,15 @@ public DatabricksConfig setHttpClient(HttpClient httpClient) {
512519
return this;
513520
}
514521

522+
public String getTokenAudience() {
523+
return tokenAudience;
524+
}
525+
526+
public DatabricksConfig setTokenAudience(String tokenAudience) {
527+
this.tokenAudience = tokenAudience;
528+
return this;
529+
}
530+
515531
public boolean isAzure() {
516532
if (azureWorkspaceResourceId != null) {
517533
return true;

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package com.databricks.sdk.core;
22

3-
import com.databricks.sdk.core.oauth.AzureGithubOidcCredentialsProvider;
4-
import com.databricks.sdk.core.oauth.AzureServicePrincipalCredentialsProvider;
5-
import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider;
6-
import com.databricks.sdk.core.oauth.OAuthM2MServicePrincipalCredentialsProvider;
3+
import com.databricks.sdk.core.oauth.*;
74
import java.util.ArrayList;
85
import java.util.Arrays;
96
import java.util.List;
@@ -18,6 +15,7 @@ public class DefaultCredentialsProvider implements CredentialsProvider {
1815
PatCredentialsProvider.class,
1916
BasicCredentialsProvider.class,
2017
OAuthM2MServicePrincipalCredentialsProvider.class,
18+
DatabricksWifCredentialsProvider.class,
2119
AzureGithubOidcCredentialsProvider.class,
2220
AzureServicePrincipalCredentialsProvider.class,
2321
AzureCliCredentialsProvider.class,

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private static RefreshableTokenSource tokenSourceFor(DatabricksConfig config, St
6363
.withClientId(config.getAzureClientId())
6464
.withClientSecret(config.getAzureClientSecret())
6565
.withTokenUrl(tokenUrl)
66-
.withEndpointParameters(endpointParams)
66+
.withEndpointParametersSupplier(() -> endpointParams)
6767
.withAuthParameterPosition(AuthParameterPosition.BODY)
6868
.build();
6969
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.databricks.sdk.core.commons.CommonsHttpClient;
44
import com.databricks.sdk.core.http.HttpClient;
55
import java.util.*;
6+
import java.util.function.Supplier;
67

78
/**
89
* An implementation of RefreshableTokenSource implementing the client_credentials OAuth grant type.
@@ -18,7 +19,10 @@ public static class Builder {
1819
private String clientSecret;
1920
private String tokenUrl;
2021
private HttpClient hc = new CommonsHttpClient.Builder().withTimeoutSeconds(30).build();
21-
private Map<String, String> endpointParams = Collections.emptyMap();
22+
// Endpoint parameters can include tokens with expiration which
23+
// may need to be refreshed. This supplier will be called each time
24+
// that the credentials are refreshed.
25+
private Supplier<Map<String, String>> endpointParamsSupplier = null;
2226
private List<String> scopes = Collections.emptyList();
2327
private AuthParameterPosition position = AuthParameterPosition.BODY;
2428

@@ -32,13 +36,14 @@ public Builder withClientSecret(String clientSecret) {
3236
return this;
3337
}
3438

35-
public Builder withTokenUrl(String tokenUrl) {
36-
this.tokenUrl = tokenUrl;
39+
public Builder withEndpointParametersSupplier(
40+
Supplier<Map<String, String>> endpointParamsSupplier) {
41+
this.endpointParamsSupplier = endpointParamsSupplier;
3742
return this;
3843
}
3944

40-
public Builder withEndpointParameters(Map<String, String> params) {
41-
this.endpointParams = params;
45+
public Builder withTokenUrl(String tokenUrl) {
46+
this.tokenUrl = tokenUrl;
4247
return this;
4348
}
4449

@@ -59,34 +64,33 @@ public Builder withHttpClient(HttpClient hc) {
5964

6065
public ClientCredentials build() {
6166
Objects.requireNonNull(this.clientId, "clientId must be specified");
62-
Objects.requireNonNull(this.clientSecret, "clientSecret must be specified");
6367
Objects.requireNonNull(this.tokenUrl, "tokenUrl must be specified");
6468
return new ClientCredentials(
65-
hc, clientId, clientSecret, tokenUrl, endpointParams, scopes, position);
69+
hc, clientId, clientSecret, tokenUrl, endpointParamsSupplier, scopes, position);
6670
}
6771
}
6872

6973
private HttpClient hc;
7074
private String clientId;
7175
private String clientSecret;
7276
private String tokenUrl;
73-
private Map<String, String> endpointParams;
7477
private List<String> scopes;
7578
private AuthParameterPosition position;
79+
private Supplier<Map<String, String>> endpointParamsSupplier;
7680

7781
private ClientCredentials(
7882
HttpClient hc,
7983
String clientId,
8084
String clientSecret,
8185
String tokenUrl,
82-
Map<String, String> endpointParams,
86+
Supplier<Map<String, String>> endpointParamsSupplier,
8387
List<String> scopes,
8488
AuthParameterPosition position) {
8589
this.hc = hc;
8690
this.clientId = clientId;
8791
this.clientSecret = clientSecret;
8892
this.tokenUrl = tokenUrl;
89-
this.endpointParams = endpointParams;
93+
this.endpointParamsSupplier = endpointParamsSupplier;
9094
this.scopes = scopes;
9195
this.position = position;
9296
}
@@ -98,8 +102,8 @@ protected Token refresh() {
98102
if (scopes != null) {
99103
params.put("scope", String.join(" ", scopes));
100104
}
101-
if (endpointParams != null) {
102-
params.putAll(endpointParams);
105+
if (endpointParamsSupplier != null) {
106+
params.putAll(endpointParamsSupplier.get());
103107
}
104108
return retrieveToken(hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position);
105109
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.CredentialsProvider;
4+
import com.databricks.sdk.core.DatabricksConfig;
5+
import com.databricks.sdk.core.DatabricksException;
6+
import com.databricks.sdk.core.HeaderFactory;
7+
import com.google.common.collect.ImmutableMap;
8+
import java.io.IOException;
9+
import java.util.Collections;
10+
import java.util.HashMap;
11+
import java.util.Map;
12+
13+
/**
14+
* DatabricksWIFCredentials uses a Token Supplier to get a JWT Token and exchanges it for a
15+
* Databricks Token. Supported suppliers: - GitHub OIDC
16+
*/
17+
public class DatabricksWifCredentialsProvider implements CredentialsProvider {
18+
19+
@Override
20+
public String authType() {
21+
return "databricks-wif";
22+
}
23+
24+
@Override
25+
public HeaderFactory configure(DatabricksConfig config) throws DatabricksException {
26+
GitHubOidcTokenSupplier idTokenProvider = new GitHubOidcTokenSupplier(config);
27+
28+
if (!idTokenProvider.enabled() || config.getHost() == null || config.getClientId() == null) {
29+
return null;
30+
}
31+
32+
String endpointUrl;
33+
34+
try {
35+
endpointUrl = config.getOidcEndpoints().getTokenEndpoint();
36+
} catch (IOException e) {
37+
throw new DatabricksException("Unable to fetch OIDC endpoint: " + e.getMessage(), e);
38+
}
39+
40+
ClientCredentials clientCredentials =
41+
new ClientCredentials.Builder()
42+
.withHttpClient(config.getHttpClient())
43+
.withClientId(config.getClientId())
44+
.withTokenUrl(endpointUrl)
45+
.withScopes(Collections.singletonList("all-apis"))
46+
.withAuthParameterPosition(AuthParameterPosition.HEADER)
47+
.withEndpointParametersSupplier(
48+
() ->
49+
new ImmutableMap.Builder<String, String>()
50+
.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt")
51+
.put("subject_token", idTokenProvider.getOidcToken())
52+
.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
53+
.build())
54+
.build();
55+
56+
return () -> {
57+
Map<String, String> headers = new HashMap<>();
58+
headers.put("Authorization", "Bearer " + clientCredentials.getToken().getAccessToken());
59+
return headers;
60+
};
61+
}
62+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.DatabricksConfig;
4+
import com.databricks.sdk.core.DatabricksException;
5+
import com.databricks.sdk.core.http.Request;
6+
import com.databricks.sdk.core.http.Response;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.fasterxml.jackson.databind.node.ObjectNode;
9+
import java.io.IOException;
10+
11+
public class GitHubOidcTokenSupplier {
12+
13+
private final ObjectMapper mapper = new ObjectMapper();
14+
15+
private final DatabricksConfig config;
16+
17+
public GitHubOidcTokenSupplier(DatabricksConfig config) {
18+
this.config = config;
19+
}
20+
21+
/** Checks if the required parameters are present to request a GitHub's OIDC token. */
22+
public Boolean enabled() {
23+
return config.getActionsIdTokenRequestUrl() != null
24+
&& config.getActionsIdTokenRequestToken() != null;
25+
}
26+
27+
/**
28+
* Requests a GitHub's OIDC token.
29+
*
30+
* @return A GitHub OIDC token.
31+
*/
32+
public String getOidcToken() {
33+
if (!enabled()) {
34+
throw new DatabricksException("Failed to request ID token: missing required parameters");
35+
}
36+
37+
String requestUrl = config.getActionsIdTokenRequestUrl();
38+
if (config.getTokenAudience() != null) {
39+
requestUrl += "&audience=" + config.getTokenAudience();
40+
}
41+
42+
Request req =
43+
new Request("GET", requestUrl)
44+
.withHeader("Authorization", "Bearer " + config.getActionsIdTokenRequestToken());
45+
46+
Response resp;
47+
try {
48+
resp = config.getHttpClient().execute(req);
49+
} catch (IOException e) {
50+
throw new DatabricksException(
51+
"Failed to request ID token from " + requestUrl + ":" + e.getMessage(), e);
52+
}
53+
54+
if (resp.getStatusCode() != 200) {
55+
throw new DatabricksException(
56+
"Failed to request ID token: status code "
57+
+ resp.getStatusCode()
58+
+ ", response body: "
59+
+ resp.getBody().toString());
60+
}
61+
62+
ObjectNode jsonResp;
63+
try {
64+
jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class);
65+
} catch (IOException e) {
66+
throw new DatabricksException(
67+
"Failed to request ID token: corrupted token: " + e.getMessage());
68+
}
69+
70+
return jsonResp.get("value").textValue();
71+
}
72+
}

0 commit comments

Comments
 (0)