Skip to content

Commit 11d23fd

Browse files
committed
Implement WIF support
1 parent 289fd2a commit 11d23fd

File tree

7 files changed

+275
-18
lines changed

7 files changed

+275
-18
lines changed

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ public class DatabricksConfig {
141141

142142
private DatabricksEnvironment databricksEnvironment;
143143

144+
/**
145+
* When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier.
146+
*/
147+
@ConfigAttribute(env = "TOKEN_AUDIENCE")
148+
private String tokenAudience;
149+
144150
public Environment getEnv() {
145151
return env;
146152
}
@@ -512,6 +518,15 @@ public DatabricksConfig setHttpClient(HttpClient httpClient) {
512518
return this;
513519
}
514520

521+
public String getTokenAudience() {
522+
return tokenAudience;
523+
}
524+
525+
public DatabricksConfig setTokenAudience(String tokenAudience) {
526+
this.tokenAudience = tokenAudience;
527+
return this;
528+
}
529+
515530
public boolean isAzure() {
516531
if (azureWorkspaceResourceId != null) {
517532
return true;

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package com.databricks.sdk.core;
22

3-
import com.databricks.sdk.core.oauth.AzureGithubOidcCredentialsProvider;
4-
import com.databricks.sdk.core.oauth.AzureServicePrincipalCredentialsProvider;
5-
import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider;
6-
import com.databricks.sdk.core.oauth.OAuthM2MServicePrincipalCredentialsProvider;
3+
import com.databricks.sdk.core.oauth.*;
4+
75
import java.util.ArrayList;
86
import java.util.Arrays;
97
import java.util.List;
@@ -18,6 +16,7 @@ public class DefaultCredentialsProvider implements CredentialsProvider {
1816
PatCredentialsProvider.class,
1917
BasicCredentialsProvider.class,
2018
OAuthM2MServicePrincipalCredentialsProvider.class,
19+
DatabricksWIFCredentialsProvider.class,
2120
AzureGithubOidcCredentialsProvider.class,
2221
AzureServicePrincipalCredentialsProvider.class,
2322
AzureCliCredentialsProvider.class,

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private static RefreshableTokenSource tokenSourceFor(DatabricksConfig config, St
6363
.withClientId(config.getAzureClientId())
6464
.withClientSecret(config.getAzureClientSecret())
6565
.withTokenUrl(tokenUrl)
66-
.withEndpointParameters(endpointParams)
66+
.withEndpointParametersSupplier(() -> endpointParams)
6767
.withAuthParameterPosition(AuthParameterPosition.BODY)
6868
.build();
6969
}

databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ClientCredentials.java

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import com.databricks.sdk.core.commons.CommonsHttpClient;
44
import com.databricks.sdk.core.http.HttpClient;
55
import java.util.*;
6+
import java.util.function.Function;
7+
import java.util.function.Supplier;
68

79
/**
810
* An implementation of RefreshableTokenSource implementing the client_credentials OAuth grant type.
@@ -18,7 +20,7 @@ public static class Builder {
1820
private String clientSecret;
1921
private String tokenUrl;
2022
private HttpClient hc = new CommonsHttpClient.Builder().withTimeoutSeconds(30).build();
21-
private Map<String, String> endpointParams = Collections.emptyMap();
23+
private Supplier<Map<String, String>> endpointParamsSupplier = null;
2224
private List<String> scopes = Collections.emptyList();
2325
private AuthParameterPosition position = AuthParameterPosition.BODY;
2426

@@ -32,13 +34,13 @@ public Builder withClientSecret(String clientSecret) {
3234
return this;
3335
}
3436

35-
public Builder withTokenUrl(String tokenUrl) {
36-
this.tokenUrl = tokenUrl;
37-
return this;
37+
public Builder withEndpointParametersSupplier(Supplier<Map<String, String>> endpointParamsSupplier) {
38+
this.endpointParamsSupplier = endpointParamsSupplier;
39+
return this;
3840
}
3941

40-
public Builder withEndpointParameters(Map<String, String> params) {
41-
this.endpointParams = params;
42+
public Builder withTokenUrl(String tokenUrl) {
43+
this.tokenUrl = tokenUrl;
4244
return this;
4345
}
4446

@@ -59,34 +61,33 @@ public Builder withHttpClient(HttpClient hc) {
5961

6062
public ClientCredentials build() {
6163
Objects.requireNonNull(this.clientId, "clientId must be specified");
62-
Objects.requireNonNull(this.clientSecret, "clientSecret must be specified");
6364
Objects.requireNonNull(this.tokenUrl, "tokenUrl must be specified");
6465
return new ClientCredentials(
65-
hc, clientId, clientSecret, tokenUrl, endpointParams, scopes, position);
66+
hc, clientId, clientSecret, tokenUrl, endpointParamsSupplier, scopes, position);
6667
}
6768
}
6869

6970
private HttpClient hc;
7071
private String clientId;
7172
private String clientSecret;
7273
private String tokenUrl;
73-
private Map<String, String> endpointParams;
7474
private List<String> scopes;
7575
private AuthParameterPosition position;
76+
private Supplier<Map<String, String>> endpointParamsSupplier;
7677

7778
private ClientCredentials(
7879
HttpClient hc,
7980
String clientId,
8081
String clientSecret,
8182
String tokenUrl,
82-
Map<String, String> endpointParams,
83+
Supplier<Map<String, String>> endpointParamsSupplier,
8384
List<String> scopes,
8485
AuthParameterPosition position) {
8586
this.hc = hc;
8687
this.clientId = clientId;
8788
this.clientSecret = clientSecret;
8889
this.tokenUrl = tokenUrl;
89-
this.endpointParams = endpointParams;
90+
this.endpointParamsSupplier = endpointParamsSupplier;
9091
this.scopes = scopes;
9192
this.position = position;
9293
}
@@ -98,8 +99,8 @@ protected Token refresh() {
9899
if (scopes != null) {
99100
params.put("scope", String.join(" ", scopes));
100101
}
101-
if (endpointParams != null) {
102-
params.putAll(endpointParams);
102+
if (endpointParamsSupplier != null) {
103+
params.putAll(endpointParamsSupplier.get());
103104
}
104105
return retrieveToken(hc, clientId, clientSecret, tokenUrl, params, new HashMap<>(), position);
105106
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.CredentialsProvider;
4+
import com.databricks.sdk.core.DatabricksConfig;
5+
import com.databricks.sdk.core.DatabricksException;
6+
import com.databricks.sdk.core.HeaderFactory;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.google.common.collect.ImmutableMap;
9+
10+
import java.io.IOException;
11+
import java.util.Collections;
12+
import java.util.HashMap;
13+
import java.util.Map;
14+
15+
/**
16+
* TODO: Add description
17+
*
18+
*/
19+
public class DatabricksWIFCredentialsProvider implements CredentialsProvider {
20+
private final ObjectMapper mapper = new ObjectMapper();
21+
22+
@Override
23+
public String authType() {
24+
return "databricks-wif";
25+
}
26+
27+
@Override
28+
public HeaderFactory configure(DatabricksConfig config) throws DatabricksException{
29+
GitHubOidcTokenSupplier idTokenProvider = new GitHubOidcTokenSupplier(config);
30+
31+
if (!idTokenProvider.enabled() || config.getHost() == null || config.getClientId() == null) {
32+
return null;
33+
}
34+
35+
String endpointUrl;
36+
37+
try {
38+
endpointUrl = config.getOidcEndpoints().getTokenEndpoint();
39+
} catch (IOException e) {
40+
throw new DatabricksException("Unable to fetch OIDC endpoint: " + e.getMessage(), e);
41+
}
42+
43+
ClientCredentials clientCredentials = new ClientCredentials.Builder()
44+
.withHttpClient(config.getHttpClient())
45+
.withClientId(config.getClientId())
46+
.withTokenUrl(endpointUrl)
47+
.withScopes(Collections.singletonList("all-apis"))
48+
.withAuthParameterPosition(AuthParameterPosition.HEADER)
49+
.withEndpointParametersSupplier(
50+
() -> new ImmutableMap.Builder<String, String>()
51+
.put("subject_token_type", "urn:ietf:params:oauth:token-type:jwt")
52+
.put("subject_token", idTokenProvider.getOidcToken())
53+
.put("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
54+
.build()
55+
)
56+
.build();
57+
58+
return () -> {
59+
Map<String, String> headers = new HashMap<>();
60+
headers.put("Authorization", "Bearer " + clientCredentials.getToken().getAccessToken());
61+
return headers;
62+
};
63+
}
64+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.DatabricksConfig;
4+
import com.databricks.sdk.core.DatabricksException;
5+
import com.databricks.sdk.core.http.Request;
6+
import com.databricks.sdk.core.http.Response;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.fasterxml.jackson.databind.node.ObjectNode;
9+
10+
import java.io.IOException;
11+
12+
public class GitHubOidcTokenSupplier {
13+
14+
private final ObjectMapper mapper = new ObjectMapper();
15+
16+
private final DatabricksConfig config;
17+
18+
public GitHubOidcTokenSupplier(DatabricksConfig config) {
19+
this.config = config;
20+
}
21+
22+
23+
/**
24+
* Checks if the required parameters are present to request a GitHub's OIDC token.
25+
*/
26+
public Boolean enabled() {
27+
return config.getActionsIdTokenRequestUrl() != null
28+
&& config.getActionsIdTokenRequestToken() != null;
29+
}
30+
31+
/**
32+
* Requests a GitHub's OIDC token.
33+
*
34+
* @return A GitHub OIDC token.
35+
*/
36+
public String getOidcToken() {
37+
if (!enabled()) {
38+
throw new DatabricksException("Failed to request ID token: missing required parameters");
39+
}
40+
41+
String requestUrl =
42+
config.getActionsIdTokenRequestUrl();
43+
if (config.getTokenAudience() != null) {
44+
requestUrl += "&audience=" + config.getTokenAudience();
45+
}
46+
47+
Request req =
48+
new Request("GET", requestUrl)
49+
.withHeader("Authorization", "Bearer " + config.getActionsIdTokenRequestToken());
50+
51+
Response resp;
52+
try {
53+
resp = config.getHttpClient().execute(req);
54+
} catch (IOException e) {
55+
throw new DatabricksException(
56+
"Failed to request ID token from " + requestUrl + ":" + e.getMessage(), e);
57+
}
58+
59+
if (resp.getStatusCode() != 200) {
60+
throw new DatabricksException(
61+
"Failed to request ID token: status code "
62+
+ resp.getStatusCode()
63+
+ ", response body: "
64+
+ resp.getBody().toString());
65+
}
66+
67+
ObjectNode jsonResp;
68+
try {
69+
jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class);
70+
} catch (IOException e) {
71+
throw new DatabricksException(
72+
"Failed to request ID token: corrupted token: " + e.getMessage());
73+
}
74+
75+
return jsonResp.get("value").textValue();
76+
}
77+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package com.databricks.sdk.integration;
2+
3+
import com.databricks.sdk.AccountClient;
4+
import com.databricks.sdk.WorkspaceClient;
5+
import com.databricks.sdk.core.DatabricksConfig;
6+
import com.databricks.sdk.integration.framework.CollectionUtils;
7+
import com.databricks.sdk.integration.framework.EnvContext;
8+
import com.databricks.sdk.integration.framework.EnvOrSkip;
9+
import com.databricks.sdk.integration.framework.EnvTest;
10+
import com.databricks.sdk.service.iam.*;
11+
import com.databricks.sdk.service.oauth2.CreateAccountFederationPolicyRequest;
12+
import com.databricks.sdk.service.oauth2.CreateServicePrincipalFederationPolicyRequest;
13+
import com.databricks.sdk.service.oauth2.FederationPolicy;
14+
import com.databricks.sdk.service.oauth2.OidcFederationPolicy;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.ExtendWith;
17+
18+
import java.util.Collections;
19+
import java.util.UUID;
20+
21+
@EnvContext("ucacct")
22+
@ExtendWith(EnvTest.class)
23+
public class DatabricksWifIT {
24+
// This test cannot run on local machines. We use ACTIONS_ID_TOKEN_REQUEST_URL
25+
// to determine whether we are running in the GitHub Actions,
26+
// and we skip the test if we are not.
27+
@Test
28+
void workspace(AccountClient a,
29+
@EnvOrSkip("TEST_WORKSPACE_ID") String workspaceId,
30+
@EnvOrSkip("TEST_WORKSPACE_URL") String workspaceUrl,
31+
@EnvOrSkip("ACTIONS_ID_TOKEN_REQUEST_URL") String userId) {
32+
String spName = "java-sdk-sp" + UUID.randomUUID();
33+
34+
// Create SP with access to the workspace
35+
ServicePrincipal sp = a.servicePrincipals().create(new ServicePrincipal()
36+
.setActive(true).setDisplayName(spName));
37+
38+
a.workspaceAssignment().update(new UpdateWorkspaceAssignments()
39+
.setWorkspaceId(Long.valueOf(workspaceId))
40+
.setPrincipalId(Long.valueOf(sp.getId()))
41+
.setPermissions(Collections.singleton(WorkspacePermission.ADMIN)));
42+
43+
// Setup Federation Policy
44+
OidcFederationPolicy policy = new OidcFederationPolicy()
45+
.setIssuer("https://token.actions.githubusercontent.com")
46+
.setSubject("repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests")
47+
.setAudiences(Collections.singleton("https://github.com/databricks-eng"));
48+
49+
a.servicePrincipalFederationPolicy().create(new CreateServicePrincipalFederationPolicyRequest()
50+
.setServicePrincipalId(Long.valueOf(sp.getId()))
51+
.setPolicy(new FederationPolicy().setOidcPolicy(policy)));
52+
53+
// Test WIF login
54+
DatabricksConfig config = new DatabricksConfig().setHost(workspaceUrl)
55+
.setClientId(sp.getApplicationId())
56+
.setAuthType("databricks-wif")
57+
.setTokenAudience("https://github.com/databricks-eng");
58+
59+
WorkspaceClient ws = new WorkspaceClient(config);
60+
61+
ws.currentUser().me();
62+
}
63+
64+
// This test cannot run on local machines. We use ACTIONS_ID_TOKEN_REQUEST_URL
65+
// to determine whether we are running in the GitHub Actions,
66+
// and we skip the test if we are not.
67+
@Test
68+
void account(AccountClient a,
69+
@EnvOrSkip("ACTIONS_ID_TOKEN_REQUEST_URL") String userId) {
70+
String spName = "java-sdk-sp" + UUID.randomUUID();
71+
72+
// Create SP
73+
ServicePrincipal sp = a.servicePrincipals().create(new ServicePrincipal()
74+
.setActive(true).setDisplayName(spName))
75+
.setRoles(Collections.singleton(new ComplexValue().setValue("account_admin")));
76+
77+
78+
// Setup Federation Policy
79+
OidcFederationPolicy policy = new OidcFederationPolicy()
80+
.setIssuer("https://token.actions.githubusercontent.com")
81+
.setSubject("repo:databricks-eng/eng-dev-ecosystem:environment:integration-tests")
82+
.setAudiences(Collections.singleton("https://github.com/databricks-eng"));
83+
84+
a.servicePrincipalFederationPolicy().create(new CreateServicePrincipalFederationPolicyRequest()
85+
.setServicePrincipalId(Long.valueOf(sp.getId()))
86+
.setPolicy(new FederationPolicy().setOidcPolicy(policy)));
87+
88+
// Test WIF login
89+
DatabricksConfig config = new DatabricksConfig().setHost(a.config().getHost())
90+
.setAccountId(a.config().getAccountId())
91+
.setClientId(sp.getApplicationId())
92+
.setAuthType("databricks-wif")
93+
.setTokenAudience("https://github.com/databricks-eng");
94+
95+
AccountClient ac = new AccountClient(config);
96+
97+
Iterable<Group> groups = ac.groups().list(new ListAccountGroupsRequest());
98+
groups.iterator().next();
99+
}
100+
}
101+

0 commit comments

Comments
 (0)