Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import pathlib
import re
import sys
import urllib.parse
from enum import Enum
Expand Down Expand Up @@ -137,6 +138,10 @@ class Config:
scopes: str = ConfigAttribute()
authorization_details: str = ConfigAttribute()

# Controls whether the offline_access scope is requested during U2M OAuth authentication.
# offline_access is requested by default, causing a refresh token to be included in the OAuth token.
disable_oauth_refresh_token: bool = ConfigAttribute(env="DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN")

files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB

# When downloading a file, the maximum number of attempts to retry downloading the whole file. Default is no limit.
Expand Down Expand Up @@ -265,6 +270,7 @@ def __init__(
self._known_file_config_loader()
self._fix_host_if_needed()
self._validate()
self._sort_scopes()
self.init_auth()
self._init_product(product, product_version)
except ValueError as e:
Expand Down Expand Up @@ -666,6 +672,16 @@ def _validate(self):
names = " and ".join(sorted(auths_used))
raise ValueError(f"validate: more than one authorization method configured: {names}")

def _sort_scopes(self):
"""Sort scopes in-place for better de-duplication in the refresh token cache.
Delimiter is set to a single whitespace after sorting."""
if self.scopes and isinstance(self.scopes, str):
# Split on whitespaces and commas, sort, and rejoin
parsed = [s for s in re.split(r"[\s,]+", self.scopes) if s]
if parsed:
parsed.sort()
self.scopes = " ".join(parsed)

def init_auth(self):
try:
self._header_factory = self._credentials_strategy(self)
Expand All @@ -685,6 +701,33 @@ def _init_product(self, product, product_version):
else:
self._product_info = None

def get_scopes(self) -> List[str]:
"""Get OAuth scopes with proper defaulting.

Returns ["all-apis"] if no scopes configured.
This is the single source of truth for scope defaulting across all OAuth methods.

Parses string scopes by splitting on whitespaces and commas.

Returns:
List of scope strings.
"""
if self.scopes and isinstance(self.scopes, str):
parsed = [s for s in re.split(r"[\s,]+", self.scopes) if s]
if not parsed: # Empty string case
return ["all-apis"]
return parsed
return ["all-apis"]

def get_scopes_as_string(self) -> str:
"""Get OAuth scopes as a space-separated string.

Returns "all-apis" if no scopes configured.
"""
if self.scopes and isinstance(self.scopes, str):
return self.scopes
return " ".join(self.get_scopes())

def __repr__(self):
return f"<{self.debug_string()}>"

Expand Down
5 changes: 3 additions & 2 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
client_id=cfg.client_id,
client_secret=cfg.client_secret,
token_url=oidc.token_endpoint,
scopes=cfg.scopes or "all-apis",
scopes=cfg.get_scopes_as_string(),
use_header=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
Expand Down Expand Up @@ -387,6 +387,7 @@ def oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optio
account_id=cfg.account_id,
id_token_source=id_token_source,
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.get_scopes_as_string(),
)

def refreshed_headers() -> Dict[str, str]:
Expand Down Expand Up @@ -450,7 +451,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
"subject_token": id_token,
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
},
scopes=cfg.scopes or "all-apis",
scopes=cfg.get_scopes_as_string(),
use_params=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
Expand Down
4 changes: 3 additions & 1 deletion databricks/sdk/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
account_id: Optional[str] = None,
audience: Optional[str] = None,
disable_async: bool = False,
scopes: Optional[str] = None,
):
self._host = host
self._id_token_source = id_token_source
Expand All @@ -164,6 +165,7 @@ def __init__(
self._account_id = account_id
self._audience = audience
self._disable_async = disable_async
self._scopes = scopes

def token(self) -> oauth.Token:
"""Get a token by exchanging the ID token.
Expand Down Expand Up @@ -202,7 +204,7 @@ def _exchange_id_token(self, id_token: IdToken) -> oauth.Token:
"subject_token": id_token.jwt,
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
},
scopes="all-apis",
scopes=self._scopes,
use_params=True,
disable_async=self._disable_async,
)
Expand Down
41 changes: 40 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
with_user_agent_extra)
from databricks.sdk.version import __version__

from .conftest import noop_credentials, set_az_path
from .conftest import noop_credentials, set_az_path, set_home

__tests__ = os.path.dirname(__file__)

Expand Down Expand Up @@ -453,3 +453,42 @@ def test_no_org_id_header_on_regular_workspace(requests_mock):

# Verify the X-Databricks-Org-Id header was NOT added
assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers


def test_disable_oauth_refresh_token_from_env(monkeypatch, mocker):
mocker.patch("databricks.sdk.config.Config.init_auth")
monkeypatch.setenv("DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN", "true")
config = Config(host="https://test.databricks.com")
assert config.disable_oauth_refresh_token is True


def test_disable_oauth_refresh_token_defaults_to_false(mocker):
mocker.patch("databricks.sdk.config.Config.init_auth")
config = Config(host="https://test.databricks.com")
assert config.disable_oauth_refresh_token is None # ConfigAttribute returns None when not set


def test_config_file_scopes_empty_defaults_to_all_apis(monkeypatch, mocker):
"""Test that empty scopes in config file defaults to all-apis."""
mocker.patch("databricks.sdk.config.Config.init_auth")
set_home(monkeypatch, "/testdata")
config = Config(profile="scope-empty")
assert config.get_scopes() == ["all-apis"]


def test_config_file_scopes_single(monkeypatch, mocker):
"""Test single scope from config file."""
mocker.patch("databricks.sdk.config.Config.init_auth")
set_home(monkeypatch, "/testdata")
config = Config(profile="scope-single")
assert config.get_scopes() == ["clusters"]


def test_config_file_scopes_multiple_sorted(monkeypatch, mocker):
"""Test multiple scopes from config file are sorted."""
mocker.patch("databricks.sdk.config.Config.init_auth")
set_home(monkeypatch, "/testdata")
config = Config(profile="scope-multiple")
# Should be sorted alphabetically
expected = ["clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving:read", "pipelines"]
assert config.get_scopes() == expected
2 changes: 1 addition & 1 deletion tests/test_notebook_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_config_authenticate_integration(

@pytest.mark.parametrize(
"scopes_input,expected_scopes",
[(["sql", "offline_access"], "sql offline_access")],
[(["sql", "offline_access"], "offline_access sql")],
)
def test_workspace_client_integration(
mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, scopes_input, expected_scopes
Expand Down
136 changes: 136 additions & 0 deletions tests/test_oauth_scopes_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Integration tests for OAuth scopes support.

These tests verify that scopes correctly flow through to token endpoints
across all OAuth authentication methods (M2M, U2M, WIF/OIDC).
"""

from typing import Optional
from urllib.parse import parse_qs

import pytest

from databricks.sdk.config import Config

# --- Helper Functions ---


def get_scope_from_request(request_text: str) -> Optional[str]:
"""Extract and return the scope value from a URL-encoded request body."""
params = parse_qs(request_text)
scope_list = params.get("scope")
return scope_list[0] if scope_list else None


def get_grant_type_from_request(request_text: str) -> Optional[str]:
"""Extract and return the grant_type value from a URL-encoded request body."""
params = parse_qs(request_text)
grant_type_list = params.get("grant_type")
return grant_type_list[0] if grant_type_list else None


# --- M2M (Machine-to-Machine) Integration Tests ---


@pytest.mark.parametrize(
"scopes_input,expected_scope",
[
(None, "all-apis"),
("unity-catalog:read", "unity-catalog:read"),
("jobs:read, clusters, mlflow:read", "clusters jobs:read mlflow:read"),
],
ids=[
"default_scope",
"single_custom_scope",
"multiple_scopes_sorted",
],
)
def test_m2m_scopes(requests_mock, scopes_input, expected_scope):
"""Test M2M authentication sends correct scopes to token endpoint."""
# Mock the well-known endpoint
requests_mock.get(
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
},
)

# Mock the token endpoint
token_mock = requests_mock.post(
"https://test.databricks.com/oidc/v1/token",
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
)

# Create config with M2M auth
config = Config(
host="https://test.databricks.com",
client_id="test-client-id",
client_secret="test-client-secret",
auth_type="oauth-m2m",
scopes=scopes_input,
)

# Authenticate (triggers token request)
headers = config.authenticate()

# Verify scope was sent correctly
assert token_mock.called
assert get_scope_from_request(token_mock.last_request.text) == expected_scope
assert headers["Authorization"] == "Bearer test-token"


# --- WIF/OIDC Integration Tests ---


@pytest.mark.parametrize(
"scopes_input,expected_scope",
[
(None, "all-apis"),
("unity-catalog:read, clusters", "clusters unity-catalog:read"),
("jobs:read", "jobs:read"),
],
ids=[
"default_scope",
"multiple_scopes",
"single_scope",
],
)
def test_oidc_scopes(requests_mock, tmp_path, scopes_input, expected_scope):
"""Test OIDC token exchange sends correct scopes to token endpoint."""
# Create a temporary OIDC token file
oidc_token_file = tmp_path / "oidc_token"
oidc_token_file.write_text("mock-id-token")

# Mock the well-known endpoint
requests_mock.get(
"https://test.databricks.com/oidc/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize",
"token_endpoint": "https://test.databricks.com/oidc/v1/token",
},
)

# Mock the token exchange endpoint
token_mock = requests_mock.post(
"https://test.databricks.com/oidc/v1/token",
json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600},
)

# Create config with OIDC auth
config = Config(
host="https://test.databricks.com",
oidc_token_filepath=str(oidc_token_file),
auth_type="file-oidc",
scopes=scopes_input,
)

# Authenticate (triggers token exchange)
headers = config.authenticate()

# Verify scope and grant_type were sent correctly
assert token_mock.called
assert get_scope_from_request(token_mock.last_request.text) == expected_scope
assert (
get_grant_type_from_request(token_mock.last_request.text) == "urn:ietf:params:oauth:grant-type:token-exchange"
)
assert headers["Authorization"] == "Bearer test-token"
13 changes: 12 additions & 1 deletion tests/testdata/.databrickscfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,15 @@ google_credentials = paw48590aw8e09t8apu

[pat.with.dot]
host = https://dbc-XXXXXXXX-YYYY.cloud.databricks.com/
token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ
token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ

[scope-empty]
host = https://example.cloud.databricks.com

[scope-single]
host = https://example.cloud.databricks.com
scopes = clusters

[scope-multiple]
host = https://example.cloud.databricks.com
scopes = clusters, jobs, pipelines, iam:read, files:read, mlflow, model-serving:read
Loading