diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index c6b4db8a5..bf926b054 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,7 @@ ### New Features and Improvements * Add new auth type (`runtime-oauth`) for notebooks: Introduce a new authentication mechanism that allows notebooks to authenticate using OAuth tokens +* Add support for SPOG hosts with experimental flag ### Security diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 8067265a3..d8637be72 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -222,6 +222,7 @@ def __init__( config: Optional[client.Config] = None, scopes: Optional[List[str]] = None, authorization_details: Optional[List[AuthorizationDetail]] = None, + is_unified_host: Optional[bool] = False, ): if not config: config = client.Config( diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index bbb490ac7..70cc421f7 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -6,6 +6,7 @@ import pathlib import sys import urllib.parse +from enum import Enum from typing import Dict, Iterable, List, Optional import requests @@ -19,11 +20,19 @@ DatabricksEnvironment, get_environment_for_hostname) from .oauth import (OidcEndpoints, Token, get_account_endpoints, get_azure_entra_id_workspace_endpoints, - get_workspace_endpoints) + get_unified_endpoints, get_workspace_endpoints) logger = logging.getLogger("databricks.sdk") +class HostType(Enum): + """Enum representing the type of Databricks host.""" + + ACCOUNTS = "accounts" + WORKSPACE = "workspace" + UNIFIED = "unified" + + class ConfigAttribute: """Configuration attribute metadata and descriptor protocols.""" @@ -61,6 +70,10 @@ def with_user_agent_extra(key: str, value: str): class Config: host: str = ConfigAttribute(env="DATABRICKS_HOST") account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") + workspace_id: str = ConfigAttribute(env="DATABRICKS_WORKSPACE_ID") + + # Experimental flag to indicate if the host is a unified host (supports both workspace and account APIs) + experimental_is_unified_host: bool = ConfigAttribute(env="DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST") # PAT token. token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) @@ -339,10 +352,27 @@ def is_aws(self) -> bool: return self.environment.cloud == Cloud.AWS @property - def is_account_client(self) -> bool: + def host_type(self) -> HostType: + """Determine the type of host based on the configuration. + + Returns the HostType which can be ACCOUNTS, WORKSPACE, or UNIFIED. + """ if not self.host: - return False - return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.") + logger.debug(f"Host type: {HostType.WORKSPACE.value} (no host configured)") + return HostType.WORKSPACE + + # Check if explicitly marked as unified host + if self.experimental_is_unified_host is True: + logger.debug(f"Host type: {HostType.UNIFIED.value} (experimental flag set)") + return HostType.UNIFIED + + # Check for accounts host pattern + if self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod."): + logger.debug(f"Host type: {HostType.ACCOUNTS.value} (accounts URL pattern)") + return HostType.ACCOUNTS + + logger.debug(f"Host type: {HostType.WORKSPACE.value} (default)") + return HostType.WORKSPACE @property def arm_environment(self) -> AzureEnvironment: @@ -394,8 +424,23 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]: return None if self.is_azure and self.azure_client_id: return get_azure_entra_id_workspace_endpoints(self.host) - if self.is_account_client and self.account_id: + + # Handle unified hosts + if self.host_type == HostType.UNIFIED: + if self.workspace_id: + return get_unified_endpoints(self.host, self.workspace_id) + elif self.account_id: + return get_account_endpoints(self.host, self.account_id) + else: + raise ValueError( + "Unified host requires either workspace_id (for workspace client) or account_id (for account client)" + ) + + # Handle traditional account hosts + if self.host_type == HostType.ACCOUNTS and self.account_id: return get_account_endpoints(self.host, self.account_id) + + # Default to workspace endpoints return get_workspace_endpoints(self.host) def debug_string(self) -> str: diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 92e3dbf89..876ecf5c6 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -22,11 +22,20 @@ class ApiClient: def __init__(self, cfg: Config): self._cfg = cfg + + # Create header factory that includes both auth and org ID headers + def combined_header_factory(): + headers = cfg.authenticate() + # Add X-Databricks-Org-Id header for workspace clients on unified hosts + if cfg.workspace_id and cfg.host_type.value == "unified": + headers["X-Databricks-Org-Id"] = cfg.workspace_id + return headers + self._api_client = _BaseClient( debug_truncate_bytes=cfg.debug_truncate_bytes, retry_timeout_seconds=cfg.retry_timeout_seconds, user_agent_base=cfg.user_agent, - header_factory=cfg.authenticate, + header_factory=combined_header_factory, max_connection_pools=cfg.max_connection_pools, max_connections_per_pool=cfg.max_connections_per_pool, pool_block=True, @@ -39,10 +48,6 @@ def __init__(self, cfg: Config): def account_id(self) -> str: return self._cfg.account_id - @property - def is_account_client(self) -> bool: - return self._cfg.is_account_client - def get_oauth_token(self, auth_details: str) -> Token: if not self._cfg.auth_type: self._cfg.authenticate() diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 926c50a05..c21fc46d5 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -21,6 +21,7 @@ from google.oauth2 import service_account # type: ignore from . import azure, oauth, oidc, oidc_token_supplier +from .config import HostType CredentialsProvider = Callable[[], Dict[str, str]] @@ -422,9 +423,9 @@ def _oidc_credentials_provider( # Determine the audience for token exchange audience = cfg.token_audience - if audience is None and cfg.is_account_client: + if audience is None and cfg.host_type == HostType.ACCOUNTS: audience = cfg.account_id - if audience is None and not cfg.is_account_client: + if audience is None and cfg.host_type != HostType.ACCOUNTS: audience = cfg.oidc_endpoints.token_endpoint # Try to get an OIDC token. If no supplier returns a token, we cannot use this authentication mode. @@ -581,7 +582,7 @@ def token() -> oauth.Token: def refreshed_headers() -> Dict[str, str]: credentials.refresh(request) headers = {"Authorization": f"Bearer {credentials.token}"} - if cfg.is_account_client: + if cfg.host_type == HostType.ACCOUNTS: gcp_credentials.refresh(request) headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token return headers @@ -622,7 +623,7 @@ def token() -> oauth.Token: def refreshed_headers() -> Dict[str, str]: id_creds.refresh(request) headers = {"Authorization": f"Bearer {id_creds.token}"} - if cfg.is_account_client: + if cfg.host_type == HostType.ACCOUNTS: gcp_impersonated_credentials.refresh(request) headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token return headers @@ -844,7 +845,7 @@ class DatabricksCliTokenSource(CliTokenSource): def __init__(self, cfg: "Config"): args = ["auth", "token", "--host", cfg.host] - if cfg.is_account_client: + if cfg.host_type == HostType.ACCOUNTS: args += ["--account-id", cfg.account_id] cli_path = cfg.databricks_cli_path diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 72681669f..b842c1a17 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -418,6 +418,19 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O return OidcEndpoints.from_dict(resp) +def get_unified_endpoints(host: str, workspace_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a unified host with a specific workspace. + :param host: The Databricks unified host. + :param workspace_id: The workspace ID. + :return: The OIDC endpoints for the workspace on the unified host. + """ + host = _fix_host_if_needed(host) + oidc = f"{host}/oidc/unified/{workspace_id}/.well-known/oauth-authorization-server" + resp = client.do("GET", oidc) + return OidcEndpoints.from_dict(resp) + + def get_azure_entra_id_workspace_endpoints( host: str, ) -> Optional[OidcEndpoints]: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 55114bd84..f532a57e3 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -8,6 +8,7 @@ import pytest from databricks.sdk import AccountClient, FilesAPI, FilesExt, WorkspaceClient +from databricks.sdk.config import HostType from databricks.sdk.service.catalog import VolumeType @@ -63,7 +64,7 @@ def a(env_or_skip) -> AccountClient: _load_debug_env_if_runs_from_ide("account") env_or_skip("CLOUD_ENV") account_client = AccountClient() - if not account_client.config.is_account_client: + if account_client.config.host_type != HostType.ACCOUNTS: pytest.skip("not Databricks Account client") return account_client @@ -73,7 +74,7 @@ def ucacct(env_or_skip) -> AccountClient: _load_debug_env_if_runs_from_ide("ucacct") env_or_skip("CLOUD_ENV") account_client = AccountClient() - if not account_client.config.is_account_client: + if account_client.config.host_type != HostType.ACCOUNTS: pytest.skip("not Databricks Account client") if "TEST_METASTORE_ID" not in os.environ: pytest.skip("not in Unity Catalog Workspace test env") diff --git a/tests/integration/test_clusters.py b/tests/integration/test_clusters.py index dd388d2ed..cddcd01b1 100644 --- a/tests/integration/test_clusters.py +++ b/tests/integration/test_clusters.py @@ -5,6 +5,7 @@ from databricks.sdk.core import DatabricksError from databricks.sdk.service.compute import EventType +from databricks.sdk import WorkspaceClient def test_smallest_node_type(w): diff --git a/tests/integration/test_spog.py b/tests/integration/test_spog.py new file mode 100644 index 000000000..d4215b051 --- /dev/null +++ b/tests/integration/test_spog.py @@ -0,0 +1,16 @@ +from databricks.sdk import WorkspaceClient + +def test_smallest_node_type_spog_with_profile(): + w = WorkspaceClient( + profile="spog-test") + node_type_id = w.clusters.select_node_type(local_disk=True) + assert node_type_id is not None + + +def test_smallest_node_type_spog_without_profile(): + w = WorkspaceClient( + host="https://db-deco-test.databricks.com", + is_unified_host=True, + ) + node_type_id = w.clusters.select_node_type(local_disk=True) + assert node_type_id is not None \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py index 59fbf8712..4001ac411 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,7 +8,7 @@ import pytest from databricks.sdk import oauth, useragent -from databricks.sdk.config import Config, with_product, with_user_agent_extra +from databricks.sdk.config import Config, HostType, with_product, with_user_agent_extra from databricks.sdk.version import __version__ from .conftest import noop_credentials, set_az_path @@ -260,3 +260,123 @@ def test_oauth_token_reuses_existing_provider(mocker): # Both calls should work and use the same provider instance assert token1 == token2 == mock_token assert mock_oauth_provider.oauth_token.call_count == 2 + + +def test_host_type_workspace(): + """Test that a regular workspace host is identified correctly.""" + config = Config(host="https://test.databricks.com", token="test-token") + assert config.host_type == HostType.WORKSPACE + + +def test_host_type_accounts(): + """Test that an accounts host is identified correctly.""" + config = Config(host="https://accounts.cloud.databricks.com", account_id="test-account", token="test-token") + assert config.host_type == HostType.ACCOUNTS + + +def test_host_type_accounts_dod(): + """Test that an accounts-dod host is identified correctly.""" + config = Config(host="https://accounts-dod.cloud.databricks.us", account_id="test-account", token="test-token") + assert config.host_type == HostType.ACCOUNTS + + +def test_host_type_unified(): + """Test that a unified host is identified when experimental flag is set.""" + config = Config( + host="https://unified.databricks.com", + workspace_id="test-workspace", + experimental_is_unified_host=True, + token="test-token", + ) + assert config.host_type == HostType.UNIFIED + + +def test_oidc_endpoints_unified_workspace(mocker, requests_mock): + """Test that oidc_endpoints returns unified endpoints for workspace on unified host.""" + requests_mock.get( + "https://unified.databricks.com/oidc/unified/test-workspace/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://unified.databricks.com/oidc/unified/test-workspace/v1/authorize", + "token_endpoint": "https://unified.databricks.com/oidc/unified/test-workspace/v1/token", + }, + ) + + config = Config( + host="https://unified.databricks.com", + workspace_id="test-workspace", + experimental_is_unified_host=True, + token="test-token", + ) + + endpoints = config.oidc_endpoints + assert endpoints is not None + assert "unified/test-workspace" in endpoints.authorization_endpoint + assert "unified/test-workspace" in endpoints.token_endpoint + + +def test_oidc_endpoints_unified_account(mocker, requests_mock): + """Test that oidc_endpoints returns account endpoints for account on unified host.""" + requests_mock.get( + "https://unified.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/authorize", + "token_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/token", + }, + ) + + config = Config( + host="https://unified.databricks.com", + account_id="test-account", + experimental_is_unified_host=True, + token="test-token", + ) + + endpoints = config.oidc_endpoints + assert endpoints is not None + assert "accounts/test-account" in endpoints.authorization_endpoint + assert "accounts/test-account" in endpoints.token_endpoint + + +def test_oidc_endpoints_unified_missing_ids(): + """Test that oidc_endpoints raises error when unified host lacks required IDs.""" + config = Config(host="https://unified.databricks.com", experimental_is_unified_host=True, token="test-token") + + with pytest.raises(ValueError) as exc_info: + _ = config.oidc_endpoints + + assert "Unified host requires either workspace_id" in str(exc_info.value) + + +def test_workspace_org_id_header_on_unified_host(requests_mock): + """Test that X-Databricks-Org-Id header is added for workspace clients on unified hosts.""" + from databricks.sdk.core import ApiClient + + requests_mock.get("https://unified.databricks.com/api/2.0/test", json={"result": "success"}) + + config = Config( + host="https://unified.databricks.com", + workspace_id="test-workspace-123", + experimental_is_unified_host=True, + token="test-token", + ) + + api_client = ApiClient(config) + api_client.do("GET", "/api/2.0/test") + + # Verify the request was made with the X-Databricks-Org-Id header + assert requests_mock.last_request.headers.get("X-Databricks-Org-Id") == "test-workspace-123" + + +def test_no_org_id_header_on_regular_workspace(requests_mock): + """Test that X-Databricks-Org-Id header is NOT added for regular workspace hosts.""" + from databricks.sdk.core import ApiClient + + requests_mock.get("https://test.databricks.com/api/2.0/test", json={"result": "success"}) + + config = Config(host="https://test.databricks.com", token="test-token") + + api_client = ApiClient(config) + api_client.do("GET", "/api/2.0/test") + + # Verify the X-Databricks-Org-Id header was NOT added + assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers diff --git a/tests/test_core.py b/tests/test_core.py index cc8ed921d..ac5f36927 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,6 +9,7 @@ import pytest from databricks.sdk import WorkspaceClient, errors, useragent +from databricks.sdk.config import HostType from databricks.sdk.core import ApiClient, Config, DatabricksError from databricks.sdk.credentials_provider import (CliTokenSource, CredentialsProvider, @@ -251,17 +252,17 @@ def refresh(self): def test_config_accounts_aws_is_accounts_host(config): config.host = "https://accounts.cloud.databricks.com" - assert config.is_account_client + assert config.host_type == HostType.ACCOUNTS def test_config_accounts_dod_is_accounts_host(config): config.host = "https://accounts-dod.cloud.databricks.us" - assert config.is_account_client + assert config.host_type == HostType.ACCOUNTS def test_config_workspace_is_not_accounts_host(config): config.host = "https://westeurope.azuredatabricks.net" - assert not config.is_account_client + assert config.host_type == HostType.WORKSPACE # This test uses the fake file system to avoid interference from local default profile.