Skip to content

Commit a96dc34

Browse files
[DECO-2483] Handle Azure authentication when WorkspaceResourceID is provided (#145)
## Changes Handle Azure authentication when WorkspaceResourceID is provided Get token for the correct subscription ## Tests * Created Unit tests * Manually listed workspace cluster in the following scenarios: * User with wrong default tenant. No WorkspaceResourceID provided: Fail (expected). WARN log emitted. * User with wrong default tenant. WorkspaceResourceID provided: Succeed * User with no subscription. No WorkspaceResourceID provided: Succeed. WARN log emitted. * User with no subscription. WorkspaceResourceID provided: Succeed (fallback mode, expected). * Run integration tests https://github.com/databricks/eng-dev-ecosystem/actions/runs/6038942050
1 parent 4a40b53 commit a96dc34

File tree

4 files changed

+206
-6
lines changed

4 files changed

+206
-6
lines changed

databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,45 @@ public CliTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
2424
new ArrayList<>(
2525
Arrays.asList(
2626
"az", "account", "get-access-token", "--resource", resource, "--output", "json"));
27-
return new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv);
27+
Optional<String> subscription = getSubscription(config);
28+
if (subscription.isPresent()) {
29+
// This will fail if the user has access to the workspace, but not to the subscription
30+
// itself.
31+
// In such case, we fall back to not using the subscription.
32+
List<String> extendedCmd = new ArrayList<>(cmd);
33+
extendedCmd.addAll(Arrays.asList("--subscription", subscription.get()));
34+
try {
35+
return getToken(config, extendedCmd);
36+
} catch (DatabricksException ex) {
37+
LOG.warn("Failed to get token for subscription. Using resource only token.");
38+
}
39+
} else {
40+
LOG.warn(
41+
"azure_workspace_resource_id field not provided. "
42+
+ "It is recommended to specify this field in the Databricks configuration to avoid authentication errors.");
43+
}
44+
45+
return getToken(config, cmd);
46+
}
47+
48+
protected CliTokenSource getToken(DatabricksConfig config, List<String> cmd) {
49+
CliTokenSource token =
50+
new CliTokenSource(cmd, "tokenType", "accessToken", "expiresOn", config::getAllEnv);
51+
token.getToken(); // We need this to check if the CLI is installed and to validate the config.
52+
return token;
53+
}
54+
55+
private Optional<String> getSubscription(DatabricksConfig config) {
56+
String resourceId = config.getAzureWorkspaceResourceId();
57+
if (resourceId == null || resourceId.equals("")) {
58+
return Optional.empty();
59+
}
60+
String[] components = resourceId.split("/");
61+
if (components.length < 3) {
62+
LOG.warn("Invalid azure workspace resource ID");
63+
return Optional.empty();
64+
}
65+
return Optional.of(components[2]);
2866
}
2967

3068
@Override
@@ -37,11 +75,10 @@ public HeaderFactory configure(DatabricksConfig config) {
3775
ensureHostPresent(config, mapper);
3876
String resource = config.getEffectiveAzureLoginAppId();
3977
CliTokenSource tokenSource = tokenSourceFor(config, resource);
40-
CliTokenSource mgmtTokenSource =
41-
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
42-
tokenSource.getToken(); // We need this for checking if Azure CLI is installed.
78+
CliTokenSource mgmtTokenSource;
4379
try {
44-
mgmtTokenSource.getToken();
80+
mgmtTokenSource =
81+
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
4582
} catch (Exception e) {
4683
LOG.debug("Not including service management token in headers", e);
4784
mgmtTokenSource = null;

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public class DatabricksConfig {
7474
sensitive = true)
7575
private String googleCredentials;
7676

77-
/** Azure Resource Manager ID for Azure Databricks workspace, which is exhanged for a Host */
77+
/** Azure Resource Manager ID for Azure Databricks workspace, which is exchanged for a Host */
7878
@ConfigAttribute(
7979
value = "azure_workspace_resource_id",
8080
env = "DATABRICKS_AZURE_RESOURCE_ID",
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package com.databricks.sdk.core;
2+
3+
import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID;
4+
import static org.junit.jupiter.api.Assertions.*;
5+
import static org.mockito.ArgumentMatchers.*;
6+
import static org.mockito.Mockito.times;
7+
8+
import com.databricks.sdk.core.oauth.Token;
9+
import com.databricks.sdk.core.oauth.TokenSource;
10+
import java.time.LocalDateTime;
11+
import java.util.Arrays;
12+
import java.util.List;
13+
import org.junit.jupiter.api.Test;
14+
import org.mockito.ArgumentCaptor;
15+
import org.mockito.Mockito;
16+
17+
class AzureCliCredentialsProviderTest {
18+
19+
private static final String WORKSPACE_RESOURCE_ID =
20+
"/subscriptions/2a2345f8/resourceGroups/deco-rg/providers/Microsoft.Databricks/workspaces/deco-ws";
21+
private static final String SUBSCRIPTION = "2a2345f8";
22+
private static final String TOKEN = "t-123";
23+
private static final String TOKEN_TYPE = "token-type";
24+
25+
private static CliTokenSource mockTokenSource() {
26+
CliTokenSource tokenSource = Mockito.mock(CliTokenSource.class);
27+
Mockito.when(tokenSource.getToken())
28+
.thenReturn(new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now()));
29+
return tokenSource;
30+
}
31+
32+
private static AzureCliCredentialsProvider getAzureCliCredentialsProvider(
33+
TokenSource tokenSource) {
34+
35+
AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider());
36+
Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList());
37+
38+
return provider;
39+
}
40+
41+
@Test
42+
void testWorkSpaceIDUsage() {
43+
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource());
44+
DatabricksConfig config =
45+
new DatabricksConfig()
46+
.setHost(".azuredatabricks.")
47+
.setCredentialsProvider(provider)
48+
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID);
49+
ArgumentCaptor<List<String>> argument = ArgumentCaptor.forClass(List.class);
50+
51+
HeaderFactory header = provider.configure(config);
52+
53+
String token = header.headers().get("Authorization");
54+
assertEquals(token, TOKEN_TYPE + " " + TOKEN);
55+
Mockito.verify(provider, times(2)).getToken(any(), argument.capture());
56+
57+
List<String> value = argument.getValue();
58+
value = value.subList(value.size() - 2, value.size());
59+
List<String> expected = Arrays.asList("--subscription", SUBSCRIPTION);
60+
assertEquals(expected, value);
61+
}
62+
63+
@Test
64+
void testFallbackWhenTailsToGetTokenForSubscription() {
65+
CliTokenSource tokenSource = mockTokenSource();
66+
67+
AzureCliCredentialsProvider provider = Mockito.spy(new AzureCliCredentialsProvider());
68+
Mockito.doThrow(new DatabricksException("error")).when(provider).getToken(any(), anyList());
69+
Mockito.doReturn(tokenSource).when(provider).getToken(any(), anyList());
70+
71+
DatabricksConfig config =
72+
new DatabricksConfig()
73+
.setHost(".azuredatabricks.")
74+
.setCredentialsProvider(provider)
75+
.setAzureWorkspaceResourceId(WORKSPACE_RESOURCE_ID);
76+
77+
HeaderFactory header = provider.configure(config);
78+
79+
String token = header.headers().get("Authorization");
80+
assertEquals(token, TOKEN_TYPE + " " + TOKEN);
81+
82+
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
83+
}
84+
85+
@Test
86+
void testGetTokenWithoutWorkspaceResourceID() {
87+
AzureCliCredentialsProvider provider = getAzureCliCredentialsProvider(mockTokenSource());
88+
DatabricksConfig config =
89+
new DatabricksConfig().setHost(".azuredatabricks.").setCredentialsProvider(provider);
90+
91+
ArgumentCaptor<List<String>> argument = ArgumentCaptor.forClass(List.class);
92+
93+
HeaderFactory header = provider.configure(config);
94+
95+
String token = header.headers().get("Authorization");
96+
assertEquals(token, TOKEN_TYPE + " " + TOKEN);
97+
Mockito.verify(provider, times(2)).getToken(any(), argument.capture());
98+
99+
List<String> value = argument.getValue();
100+
assertFalse(value.contains("--subscription"));
101+
assertFalse(value.contains(SUBSCRIPTION));
102+
}
103+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import static com.databricks.sdk.core.AzureEnvironment.ARM_DATABRICKS_RESOURCE_ID;
4+
import static org.junit.jupiter.api.Assertions.*;
5+
import static org.mockito.ArgumentMatchers.any;
6+
import static org.mockito.ArgumentMatchers.eq;
7+
import static org.mockito.Mockito.times;
8+
9+
import com.databricks.sdk.core.*;
10+
import java.time.LocalDateTime;
11+
import java.time.temporal.IsoFields;
12+
import org.junit.jupiter.api.Test;
13+
import org.mockito.Mockito;
14+
15+
class AzureServicePrincipalCredentialsProviderTest {
16+
private static final String TOKEN = "t-123";
17+
private static final String TOKEN_TYPE = "token-type";
18+
public static final String PUBLIC_MANAGEMENT_ENDPOINT = "https://management.core.windows.net/";
19+
20+
private static RefreshableTokenSource mockTokenSource() {
21+
RefreshableTokenSource tokenSource = Mockito.mock(RefreshableTokenSource.class);
22+
Mockito.when(tokenSource.getToken())
23+
.thenReturn(
24+
new Token(TOKEN, TOKEN_TYPE, LocalDateTime.now().plus(1, IsoFields.WEEK_BASED_YEARS)));
25+
return tokenSource;
26+
}
27+
28+
private static AzureServicePrincipalCredentialsProvider
29+
getAzureServicePrincipalCredentialsProvider(RefreshableTokenSource tokenSource) {
30+
AzureServicePrincipalCredentialsProvider provider =
31+
Mockito.spy(new AzureServicePrincipalCredentialsProvider());
32+
Mockito.doReturn(tokenSource)
33+
.when(provider)
34+
.tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
35+
Mockito.doReturn(tokenSource)
36+
.when(provider)
37+
.tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT));
38+
return provider;
39+
}
40+
41+
@Test
42+
void testGetToken() {
43+
AzureServicePrincipalCredentialsProvider provider =
44+
getAzureServicePrincipalCredentialsProvider(mockTokenSource());
45+
DatabricksConfig config =
46+
new DatabricksConfig()
47+
.setHost(".azuredatabricks.")
48+
.setCredentialsProvider(provider)
49+
.setAzureClientId("clientID")
50+
.setAzureClientSecret("clientSecret")
51+
.setAzureTenantId("tenantID");
52+
53+
HeaderFactory header = provider.configure(config);
54+
55+
String token = header.headers().get("Authorization");
56+
assertEquals(token, "Bearer " + TOKEN);
57+
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(ARM_DATABRICKS_RESOURCE_ID));
58+
Mockito.verify(provider, times(1)).tokenSourceFor(any(), eq(PUBLIC_MANAGEMENT_ENDPOINT));
59+
}
60+
}

0 commit comments

Comments
 (0)