diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index dcf680a8c..e4485a10a 100644 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -253,7 +253,7 @@ def __init__( product=product, product_version=product_version, token_audience=token_audience, - scopes=" ".join(scopes) if scopes else None, + scopes=scopes, authorization_details=( json.dumps([detail.as_dict() for detail in authorization_details]) if authorization_details diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index fa61bdbb9..bfc241855 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -46,11 +46,13 @@ class ConfigAttribute: # name and transform are discovered from Config.__new__ name: str = None transform: type = str + _custom_transform = None - def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): + def __init__(self, env: str = None, auth: str = None, sensitive: bool = False, transform=None): self.env = env self.auth = auth self.sensitive = sensitive + self._custom_transform = transform def __get__(self, cfg: "Config", owner): if not cfg: @@ -64,6 +66,19 @@ def __repr__(self) -> str: return f"" +def _parse_scopes(value): + """Parse scopes into a deduplicated, sorted list.""" + if value is None: + return None + if isinstance(value, list): + result = sorted(set(s for s in value if s)) + return result if result else None + if isinstance(value, str): + parsed = sorted(set(s.strip() for s in value.split(",") if s.strip())) + return parsed if parsed else None + return None + + def with_product(product: str, product_version: str): """[INTERNAL API] Change the product name and version used in the User-Agent header.""" useragent.with_product(product, product_version) @@ -133,10 +148,14 @@ class Config: disable_experimental_files_api_client: bool = ConfigAttribute( env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT" ) - # TODO: Expose these via environment variables too. - scopes: str = ConfigAttribute() + + scopes: list = ConfigAttribute(transform=_parse_scopes) authorization_details: str = ConfigAttribute() + # disable_oauth_refresh_token controls whether a refresh token should be requested + # during the U2M authentication flow (default to false). + disable_oauth_refresh_token: bool = ConfigAttribute(env="DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN") + files_ext_client_download_streaming_chunk_size: int = 2 * 1024 * 1024 # 2 MiB # When downloading a file, the maximum number of attempts to retry downloading the whole file. Default is no limit. @@ -553,7 +572,7 @@ def attributes(cls) -> Iterable[ConfigAttribute]: if type(v) != ConfigAttribute: continue v.name = name - v.transform = anno.get(name, str) + v.transform = v._custom_transform if v._custom_transform else anno.get(name, str) attrs.append(v) cls._attributes = attrs return cls._attributes @@ -685,6 +704,21 @@ def _init_product(self, product, product_version): else: self._product_info = None + def get_scopes(self) -> list: + """Get OAuth scopes with proper defaulting. + + Returns ["all-apis"] if no scopes configured. + This is the single source of truth for scope defaulting across all OAuth methods. + """ + return self.scopes if self.scopes else ["all-apis"] + + def get_scopes_as_string(self) -> str: + """Get OAuth scopes as a space-separated string. + + Returns "all-apis" if no scopes configured. + """ + return " ".join(self.get_scopes()) + def __repr__(self): return f"<{self.debug_string()}>" diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 926c50a05..ebf6fa2bd 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -198,7 +198,7 @@ def get_notebook_pat_token() -> Optional[str]: token_source = oauth.PATOAuthTokenExchange( get_original_token=get_notebook_pat_token, host=cfg.host, - scopes=cfg.scopes, + scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, ) @@ -329,7 +329,7 @@ def token_source_for(resource: str) -> oauth.TokenSource: endpoint_params={"resource": resource}, use_params=True, disable_async=cfg.disable_async_token_refresh, - scopes=cfg.scopes, + scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, ) @@ -533,7 +533,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: }, use_params=True, disable_async=cfg.disable_async_token_refresh, - scopes=cfg.scopes, + scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, ) diff --git a/tests/test_config.py b/tests/test_config.py index 00e7540d9..99dd15505 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,7 +12,7 @@ with_user_agent_extra) from databricks.sdk.version import __version__ -from .conftest import noop_credentials, set_az_path +from .conftest import noop_credentials, set_az_path, set_home __tests__ = os.path.dirname(__file__) @@ -453,3 +453,42 @@ def test_no_org_id_header_on_regular_workspace(requests_mock): # Verify the X-Databricks-Org-Id header was NOT added assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers + + +def test_disable_oauth_refresh_token_from_env(monkeypatch, mocker): + mocker.patch("databricks.sdk.config.Config.init_auth") + monkeypatch.setenv("DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN", "true") + config = Config(host="https://test.databricks.com") + assert config.disable_oauth_refresh_token is True + + +def test_disable_oauth_refresh_token_defaults_to_false(mocker): + mocker.patch("databricks.sdk.config.Config.init_auth") + config = Config(host="https://test.databricks.com") + assert config.disable_oauth_refresh_token is None # ConfigAttribute returns None when not set + + +def test_config_file_scopes_empty_defaults_to_all_apis(monkeypatch, mocker): + """Test that empty scopes in config file defaults to all-apis.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile="scope-empty") + assert config.get_scopes() == ["all-apis"] + + +def test_config_file_scopes_single(monkeypatch, mocker): + """Test single scope from config file.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile="scope-single") + assert config.get_scopes() == ["clusters"] + + +def test_config_file_scopes_multiple_sorted(monkeypatch, mocker): + """Test multiple scopes from config file are sorted.""" + mocker.patch("databricks.sdk.config.Config.init_auth") + set_home(monkeypatch, "/testdata") + config = Config(profile="scope-multiple") + # Should be sorted alphabetically + expected = ["clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving:read", "pipelines"] + assert config.get_scopes() == expected diff --git a/tests/test_notebook_oauth.py b/tests/test_notebook_oauth.py index 55e5237d6..0338f7345 100644 --- a/tests/test_notebook_oauth.py +++ b/tests/test_notebook_oauth.py @@ -78,10 +78,10 @@ def credentials_provider() -> Dict[str, str]: @pytest.mark.parametrize( "scopes,auth_details", [ - ("sql offline_access", None), - ("sql offline_access", '{"type": "databricks_resource"}'), + ("sql, offline_access", None), + ("sql, offline_access", '{"type": "databricks_resource"}'), ("sql", None), - ("sql offline_access all-apis", None), + ("sql, offline_access, all-apis", None), ], ) def test_runtime_oauth_success_scenarios( @@ -117,7 +117,7 @@ def test_runtime_oauth_missing_scopes(mock_runtime_env, mock_runtime_native_auth def test_runtime_oauth_priority_over_native_auth(mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange): """Test that runtime-oauth is prioritized over runtime-native-auth.""" - cfg = Config(host="https://test.cloud.databricks.com", scopes="sql offline_access") + cfg = Config(host="https://test.cloud.databricks.com", scopes="sql, offline_access") default_creds = DefaultCredentials() creds_provider = default_creds(cfg) @@ -141,7 +141,7 @@ def test_fallback_to_native_auth_without_scopes(mock_runtime_env, mock_runtime_n def test_explicit_runtime_oauth_auth_type(mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange): """Test that runtime-oauth is used when explicitly specified as auth_type.""" - cfg = Config(host="https://test.cloud.databricks.com", scopes="sql offline_access", auth_type="runtime-oauth") + cfg = Config(host="https://test.cloud.databricks.com", scopes="sql, offline_access", auth_type="runtime-oauth") default_creds = DefaultCredentials() creds_provider = default_creds(cfg) @@ -164,7 +164,7 @@ def test_config_authenticate_integration( """Test Config.authenticate() integration with runtime-oauth and fallback.""" cfg_kwargs = {"host": "https://test.cloud.databricks.com"} if has_scopes: - cfg_kwargs["scopes"] = "sql offline_access" + cfg_kwargs["scopes"] = "sql, offline_access" cfg = Config(**cfg_kwargs) headers = cfg.authenticate() @@ -174,7 +174,7 @@ def test_config_authenticate_integration( @pytest.mark.parametrize( "scopes_input,expected_scopes", - [(["sql", "offline_access"], "sql offline_access")], + [(["sql", "offline_access"], ["offline_access", "sql"])], ) def test_workspace_client_integration( mock_runtime_env, mock_runtime_native_auth, mock_pat_exchange, scopes_input, expected_scopes diff --git a/tests/testdata/.databrickscfg b/tests/testdata/.databrickscfg index 2759b6c1b..2ffb627ae 100644 --- a/tests/testdata/.databrickscfg +++ b/tests/testdata/.databrickscfg @@ -38,4 +38,15 @@ google_credentials = paw48590aw8e09t8apu [pat.with.dot] host = https://dbc-XXXXXXXX-YYYY.cloud.databricks.com/ -token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ \ No newline at end of file +token = PT0+IC9kZXYvdXJhbmRvbSA8PT0KYFZ + +[scope-empty] +host = https://example.cloud.databricks.com + +[scope-single] +host = https://example.cloud.databricks.com +scopes = clusters + +[scope-multiple] +host = https://example.cloud.databricks.com +scopes = clusters, jobs, pipelines, iam:read, files:read, mlflow, model-serving:read