Skip to content

Commit 18b702b

Browse files
committed
custom scopes support in m2m and wif
1 parent 1742e6a commit 18b702b

File tree

3 files changed

+142
-3
lines changed

3 files changed

+142
-3
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
225225
client_id=cfg.client_id,
226226
client_secret=cfg.client_secret,
227227
token_url=oidc.token_endpoint,
228-
scopes=cfg.scopes or "all-apis",
228+
scopes=cfg.get_scopes_as_string(),
229229
use_header=True,
230230
disable_async=cfg.disable_async_token_refresh,
231231
authorization_details=cfg.authorization_details,
@@ -387,6 +387,7 @@ def oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optio
387387
account_id=cfg.account_id,
388388
id_token_source=id_token_source,
389389
disable_async=cfg.disable_async_token_refresh,
390+
scopes=cfg.get_scopes_as_string(),
390391
)
391392

392393
def refreshed_headers() -> Dict[str, str]:
@@ -450,7 +451,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
450451
"subject_token": id_token,
451452
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
452453
},
453-
scopes=cfg.scopes or "all-apis",
454+
scopes=cfg.get_scopes_as_string(),
454455
use_params=True,
455456
disable_async=cfg.disable_async_token_refresh,
456457
authorization_details=cfg.authorization_details,

databricks/sdk/oidc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
account_id: Optional[str] = None,
157157
audience: Optional[str] = None,
158158
disable_async: bool = False,
159+
scopes: Optional[str] = None,
159160
):
160161
self._host = host
161162
self._id_token_source = id_token_source
@@ -164,6 +165,7 @@ def __init__(
164165
self._account_id = account_id
165166
self._audience = audience
166167
self._disable_async = disable_async
168+
self._scopes = scopes
167169

168170
def token(self) -> oauth.Token:
169171
"""Get a token by exchanging the ID token.
@@ -202,7 +204,7 @@ def _exchange_id_token(self, id_token: IdToken) -> oauth.Token:
202204
"subject_token": id_token.jwt,
203205
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
204206
},
205-
scopes="all-apis",
207+
scopes=self._scopes,
206208
use_params=True,
207209
disable_async=self._disable_async,
208210
)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Integration tests for OAuth scopes support.
2+
3+
These tests verify that scopes correctly flow through to token endpoints
4+
across all OAuth authentication methods (M2M, U2M, WIF/OIDC).
5+
"""
6+
7+
from typing import Optional
8+
from urllib.parse import parse_qs
9+
10+
import pytest
11+
12+
from databricks.sdk.config import Config
13+
14+
# --- Helper Functions ---
15+
16+
17+
def get_scope_from_request(request_text: str) -> Optional[str]:
18+
"""Extract and return the scope value from a URL-encoded request body."""
19+
params = parse_qs(request_text)
20+
scope_list = params.get("scope")
21+
return scope_list[0] if scope_list else None
22+
23+
24+
def get_grant_type_from_request(request_text: str) -> Optional[str]:
25+
"""Extract and return the grant_type value from a URL-encoded request body."""
26+
params = parse_qs(request_text)
27+
grant_type_list = params.get("grant_type")
28+
return grant_type_list[0] if grant_type_list else None
29+
30+
31+
# --- M2M (Machine-to-Machine) Integration Tests ---
32+
33+
34+
@pytest.mark.parametrize(
35+
"scopes_input,expected_scope",
36+
[
37+
(None, "all-apis"),
38+
("unity-catalog:read", "unity-catalog:read"),
39+
("jobs:read, clusters, mlflow:read", "clusters jobs:read mlflow:read"),
40+
],
41+
ids=[
42+
"default_scope",
43+
"single_custom_scope",
44+
"multiple_scopes_sorted",
45+
],
46+
)
47+
def test_m2m_scopes(requests_mock, scopes_input, expected_scope):
48+
"""Test M2M authentication sends correct scopes to token endpoint."""
49+
# Mock the well-known endpoint
50+
requests_mock.get(
51+
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
52+
json={
53+
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
54+
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
55+
},
56+
)
57+
58+
# Mock the token endpoint
59+
token_mock = requests_mock.post(
60+
"https://test.databricks.com/oidc/v1/token",
61+
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
62+
)
63+
64+
# Create config with M2M auth
65+
config = Config(
66+
host="https://test.databricks.com",
67+
client_id="test-client-id",
68+
client_secret="test-client-secret",
69+
auth_type="oauth-m2m",
70+
scopes=scopes_input,
71+
)
72+
73+
# Authenticate (triggers token request)
74+
headers = config.authenticate()
75+
76+
# Verify scope was sent correctly
77+
assert token_mock.called
78+
assert get_scope_from_request(token_mock.last_request.text) == expected_scope
79+
assert headers["Authorization"] == "Bearer test-token"
80+
81+
82+
# --- WIF/OIDC Integration Tests ---
83+
84+
85+
@pytest.mark.parametrize(
86+
"scopes_input,expected_scope",
87+
[
88+
(None, "all-apis"),
89+
("unity-catalog:read, clusters", "clusters unity-catalog:read"),
90+
("jobs:read", "jobs:read"),
91+
],
92+
ids=[
93+
"default_scope",
94+
"multiple_scopes",
95+
"single_scope",
96+
],
97+
)
98+
def test_oidc_scopes(requests_mock, tmp_path, scopes_input, expected_scope):
99+
"""Test OIDC token exchange sends correct scopes to token endpoint."""
100+
# Create a temporary OIDC token file
101+
oidc_token_file = tmp_path / "oidc_token"
102+
oidc_token_file.write_text("mock-id-token")
103+
104+
# Mock the well-known endpoint
105+
requests_mock.get(
106+
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
107+
json={
108+
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
109+
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
110+
},
111+
)
112+
113+
# Mock the token exchange endpoint
114+
token_mock = requests_mock.post(
115+
"https://test.databricks.com/oidc/v1/token",
116+
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
117+
)
118+
119+
# Create config with OIDC auth
120+
config = Config(
121+
host="https://test.databricks.com",
122+
oidc_token_filepath=str(oidc_token_file),
123+
auth_type="file-oidc",
124+
scopes=scopes_input,
125+
)
126+
127+
# Authenticate (triggers token exchange)
128+
headers = config.authenticate()
129+
130+
# Verify scope and grant_type were sent correctly
131+
assert token_mock.called
132+
assert get_scope_from_request(token_mock.last_request.text) == expected_scope
133+
assert (
134+
get_grant_type_from_request(token_mock.last_request.text) == "urn:ietf:params:oauth:grant-type:token-exchange"
135+
)
136+
assert headers["Authorization"] == "Bearer test-token"

0 commit comments

Comments
 (0)