diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index fa61bdbb9..369613a41 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -4,6 +4,7 @@ import logging import os import pathlib +import re import sys import urllib.parse from enum import Enum @@ -137,6 +138,10 @@ class Config: scopes: str = ConfigAttribute() authorization_details: str = ConfigAttribute() + # Controls whether the offline_access scope is requested during U2M OAuth authentication. + # offline_access is requested by default, causing a refresh token to be included in the OAuth token. + disable_oauth_refresh_token: bool = ConfigAttribute(env="DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN") + files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB # When downloading a file, the maximum number of attempts to retry downloading the whole file. Default is no limit. @@ -265,6 +270,7 @@ def __init__( self._known_file_config_loader() self._fix_host_if_needed() self._validate() + self._sort_scopes() self.init_auth() self._init_product(product, product_version) except ValueError as e: @@ -666,6 +672,16 @@ def _validate(self): names = " and ".join(sorted(auths_used)) raise ValueError(f"validate: more than one authorization method configured: {names}") + def _sort_scopes(self): + """Sort scopes in-place for better de-duplication in the refresh token cache. + Delimiter is set to a single whitespace after sorting.""" + if self.scopes and isinstance(self.scopes, str): + # Split on whitespaces and commas, sort, and rejoin + parsed = [s for s in re.split(r"[\s,]+", self.scopes) if s] + if parsed: + parsed.sort() + self.scopes = " ".join(parsed) + def init_auth(self): try: self._header_factory = self._credentials_strategy(self) @@ -685,6 +701,33 @@ def _init_product(self, product, product_version): else: self._product_info = None + def get_scopes(self) -> List[str]: + """Get OAuth scopes with proper defaulting. + + Returns ["all-apis"] if no scopes configured. + This is the single source of truth for scope defaulting across all OAuth methods. + + Parses string scopes by splitting on whitespaces and commas. + + Returns: + List of scope strings. + """ + if self.scopes and isinstance(self.scopes, str): + parsed = [s for s in re.split(r"[\s,]+", self.scopes) if s] + if not parsed: # Empty string case + return ["all-apis"] + return parsed + return ["all-apis"] + + def get_scopes_as_string(self) -> str: + """Get OAuth scopes as a space-separated string. + + Returns "all-apis" if no scopes configured. + """ + if self.scopes and isinstance(self.scopes, str): + return self.scopes + return " ".join(self.get_scopes()) + def __repr__(self): return f"<{self.debug_string()}>" diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 926c50a05..23b691203 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -225,7 +225,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]: client_id=cfg.client_id, client_secret=cfg.client_secret, token_url=oidc.token_endpoint, - scopes=cfg.scopes or "all-apis", + scopes=cfg.get_scopes_as_string(), use_header=True, disable_async=cfg.disable_async_token_refresh, authorization_details=cfg.authorization_details, @@ -387,6 +387,7 @@ def oidc_credentials_provider(cfg, id_token_source: oidc.IdTokenSource) -> Optio account_id=cfg.account_id, id_token_source=id_token_source, disable_async=cfg.disable_async_token_refresh, + scopes=cfg.get_scopes_as_string(), ) def refreshed_headers() -> Dict[str, str]: @@ -450,7 +451,7 @@ def token_source_for(audience: str) -> oauth.TokenSource: "subject_token": id_token, "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", }, - scopes=cfg.scopes or "all-apis", + scopes=cfg.get_scopes_as_string(), use_params=True, disable_async=cfg.disable_async_token_refresh, authorization_details=cfg.authorization_details, diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py index b8641a45d..6fd273f2a 100644 --- a/databricks/sdk/oidc.py +++ b/databricks/sdk/oidc.py @@ -156,6 +156,7 @@ def __init__( account_id: Optional[str] = None, audience: Optional[str] = None, disable_async: bool = False, + scopes: Optional[str] = None, ): self._host = host self._id_token_source = id_token_source @@ -164,6 +165,7 @@ def __init__( self._account_id = account_id self._audience = audience self._disable_async = disable_async + self._scopes = scopes def token(self) -> oauth.Token: """Get a token by exchanging the ID token. @@ -202,7 +204,7 @@ def _exchange_id_token(self, id_token: IdToken) -> oauth.Token: "subject_token": id_token.jwt, "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", }, - scopes="all-apis", + scopes=self._scopes, use_params=True, disable_async=self._disable_async, ) diff --git a/tests/test_config.py b/tests/test_config.py index 00e7540d9..99dd15505 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,7 +12,7 @@ with_user_agent_extra) from databricks.sdk.version import __version__ -from .conftest import noop_credentials, set_az_path +from .conftest import noop_credentials, set_az_path, set_home __tests__ = os.path.dirname(__file__) @@ -453,3 +453,42 @@ def test_no_org_id_header_on_regular_workspace(requests_mock): # Verify the X-Databricks-Org-Id header was NOT added assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers + + +def test_disable_oauth_refresh_token_from_env(monkeypatch, mocker): + mocker.patch("databricks.sdk.config.Config.init_auth") + monkeypatch.setenv("DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN", "true") + config = Config(host="https://test.databricks.com") + assert config.disable_oauth_refresh_token is True + + +def test_disable_oauth_refresh_token_defaults_to_false(mocker): + mocker.patch("databricks.sdk.config.Config.init_auth") + config = Config(host="https://test.databricks.com") + assert config.disable_oauth_refresh_token is None # ConfigAttribute returns None when not set + + +def test_config_file_scopes_empty_defaults_to_all_apis(monkeypatch, mocker): + """Test that empty scopes in config file defaults to all-apis.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile="scope-empty") + assert config.get_scopes() == ["all-apis"] + + +def test_config_file_scopes_single(monkeypatch, mocker): + """Test single scope from config file.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile="scope-single") + assert config.get_scopes() == ["clusters"] + + +def test_config_file_scopes_multiple_sorted(monkeypatch, mocker): + """Test multiple scopes from config file are sorted.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile="scope-multiple") + # Should be sorted alphabetically + expected = ["clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving:read", "pipelines"] + assert config.get_scopes() == expected diff --git a/tests/test_notebook_oauth.py b/tests/test_notebook_oauth.py index 55e5237d6..9e70272ca 100644 --- a/tests/test_notebook_oauth.py +++ b/tests/test_notebook_oauth.py @@ -174,7 +174,7 @@ def test_config_authenticate_integration( @pytest.mark.parametrize( "scopes_input,expected_scopes", - [(["sql", "offline_access"], "sql offline_access")], + [(["sql", "offline_access"], "offline_access sql")], ) def test_workspace_client_integration( mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, scopes_input, expected_scopes diff --git a/tests/test_oauth_scopes_integration.py b/tests/test_oauth_scopes_integration.py new file mode 100644 index 000000000..88818c27b --- /dev/null +++ b/tests/test_oauth_scopes_integration.py @@ -0,0 +1,136 @@ +"""Integration tests for OAuth scopes support. + +These tests verify that scopes correctly flow through to token endpoints +across all OAuth authentication methods (M2M, U2M, WIF/OIDC). +""" + +from typing import Optional +from urllib.parse import parse_qs + +import pytest + +from databricks.sdk.config import Config + +# --- Helper Functions --- + + +def get_scope_from_request(request_text: str) -> Optional[str]: + """Extract and return the scope value from a URL-encoded request body.""" + params = parse_qs(request_text) + scope_list = params.get("scope") + return scope_list[0] if scope_list else None + + +def get_grant_type_from_request(request_text: str) -> Optional[str]: + """Extract and return the grant_type value from a URL-encoded request body.""" + params = parse_qs(request_text) + grant_type_list = params.get("grant_type") + return grant_type_list[0] if grant_type_list else None + + +# --- M2M (Machine-to-Machine) Integration Tests --- + + +@pytest.mark.parametrize( + "scopes_input,expected_scope", + [ + (None, "all-apis"), + ("unity-catalog:read", "unity-catalog:read"), + ("jobs:read, clusters, mlflow:read", "clusters jobs:read mlflow:read"), + ], + ids=[ + "default_scope", + "single_custom_scope", + "multiple_scopes_sorted", + ], +) +def test_m2m_scopes(requests_mock, scopes_input, expected_scope): + """Test M2M authentication sends correct scopes to token endpoint.""" + # Mock the well-known endpoint + requests_mock.get( + "https://test.databricks.com/oidc/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize", + "token_endpoint": "https://test.databricks.com/oidc/v1/token", + }, + ) + + # Mock the token endpoint + token_mock = requests_mock.post( + "https://test.databricks.com/oidc/v1/token", + json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600}, + ) + + # Create config with M2M auth + config = Config( + host="https://test.databricks.com", + client_id="test-client-id", + client_secret="test-client-secret", + auth_type="oauth-m2m", + scopes=scopes_input, + ) + + # Authenticate (triggers token request) + headers = config.authenticate() + + # Verify scope was sent correctly + assert token_mock.called + assert get_scope_from_request(token_mock.last_request.text) == expected_scope + assert headers["Authorization"] == "Bearer test-token" + + +# --- WIF/OIDC Integration Tests --- + + +@pytest.mark.parametrize( + "scopes_input,expected_scope", + [ + (None, "all-apis"), + ("unity-catalog:read, clusters", "clusters unity-catalog:read"), + ("jobs:read", "jobs:read"), + ], + ids=[ + "default_scope", + "multiple_scopes", + "single_scope", + ], +) +def test_oidc_scopes(requests_mock, tmp_path, scopes_input, expected_scope): + """Test OIDC token exchange sends correct scopes to token endpoint.""" + # Create a temporary OIDC token file + oidc_token_file = tmp_path / "oidc_token" + oidc_token_file.write_text("mock-id-token") + + # Mock the well-known endpoint + requests_mock.get( + "https://test.databricks.com/oidc/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://test.databricks.com/oidc/v1/authorize", + "token_endpoint": "https://test.databricks.com/oidc/v1/token", + }, + ) + + # Mock the token exchange endpoint + token_mock = requests_mock.post( + "https://test.databricks.com/oidc/v1/token", + json={"access_token": "test-token", "token_type": "Bearer", "expires_in": 3600}, + ) + + # Create config with OIDC auth + config = Config( + host="https://test.databricks.com", + oidc_token_filepath=str(oidc_token_file), + auth_type="file-oidc", + scopes=scopes_input, + ) + + # Authenticate (triggers token exchange) + headers = config.authenticate() + + # Verify scope and grant_type were sent correctly + assert token_mock.called + assert get_scope_from_request(token_mock.last_request.text) == expected_scope + assert ( + get_grant_type_from_request(token_mock.last_request.text) == "urn:ietf:params:oauth:grant-type:token-exchange" + ) + assert headers["Authorization"] == "Bearer test-token" diff --git a/tests/testdata/.databrickscfg b/tests/testdata/.databrickscfg index 2759b6c1b..2ffb627ae 100644 --- a/tests/testdata/.databrickscfg +++ b/tests/testdata/.databrickscfg @@ -38,4 +38,15 @@ google_credentials = paw48590aw8e09t8apu [pat.with.dot] host = https://dbc-XXXXXXXX-YYYY.cloud.databricks.com/ -token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ \ No newline at end of file +token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ + +[scope-empty] +host = https://example.cloud.databricks.com + +[scope-single] +host = https://example.cloud.databricks.com +scopes = clusters + +[scope-multiple] +host = https://example.cloud.databricks.com +scopes = clusters, jobs, pipelines, iam:read, files:read, mlflow, model-serving:read