Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features and Improvements
* Add new auth type (`runtime-oauth`) for notebooks: Introduce a new authentication mechanism that allows notebooks to authenticate using OAuth tokens
* Add support for SPOG hosts with experimental flag

### Security

Expand Down
1 change: 1 addition & 0 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 50 additions & 5 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
import sys
import urllib.parse
from enum import Enum
from typing import Dict, Iterable, List, Optional

import requests
Expand All @@ -19,11 +20,19 @@
DatabricksEnvironment, get_environment_for_hostname)
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
get_azure_entra_id_workspace_endpoints,
get_workspace_endpoints)
get_unified_endpoints, get_workspace_endpoints)

logger = logging.getLogger("databricks.sdk")


class HostType(Enum):
"""Enum representing the type of Databricks host."""

ACCOUNTS = "accounts"
WORKSPACE = "workspace"
UNIFIED = "unified"


class ConfigAttribute:
"""Configuration attribute metadata and descriptor protocols."""

Expand Down Expand Up @@ -61,6 +70,10 @@ def with_user_agent_extra(key: str, value: str):
class Config:
host: str = ConfigAttribute(env="DATABRICKS_HOST")
account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID")
workspace_id: str = ConfigAttribute(env="DATABRICKS_WORKSPACE_ID")

# Experimental flag to indicate if the host is a unified host (supports both workspace and account APIs)
experimental_is_unified_host: bool = ConfigAttribute(env="DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST")

# PAT token.
token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True)
Expand Down Expand Up @@ -339,10 +352,27 @@ def is_aws(self) -> bool:
return self.environment.cloud == Cloud.AWS

@property
def is_account_client(self) -> bool:
def host_type(self) -> HostType:
"""Determine the type of host based on the configuration.

Returns the HostType which can be ACCOUNTS, WORKSPACE, or UNIFIED.
"""
if not self.host:
return False
return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.")
logger.debug(f"Host type: {HostType.WORKSPACE.value} (no host configured)")
return HostType.WORKSPACE

# Check if explicitly marked as unified host
if self.experimental_is_unified_host is True:
logger.debug(f"Host type: {HostType.UNIFIED.value} (experimental flag set)")
return HostType.UNIFIED

# Check for accounts host pattern
if self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod."):
logger.debug(f"Host type: {HostType.ACCOUNTS.value} (accounts URL pattern)")
return HostType.ACCOUNTS

logger.debug(f"Host type: {HostType.WORKSPACE.value} (default)")
return HostType.WORKSPACE

@property
def arm_environment(self) -> AzureEnvironment:
Expand Down Expand Up @@ -394,8 +424,23 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
return None
if self.is_azure and self.azure_client_id:
return get_azure_entra_id_workspace_endpoints(self.host)
if self.is_account_client and self.account_id:

# Handle unified hosts
if self.host_type == HostType.UNIFIED:
if self.workspace_id:
return get_unified_endpoints(self.host, self.workspace_id)
elif self.account_id:
return get_account_endpoints(self.host, self.account_id)
else:
raise ValueError(
"Unified host requires either workspace_id (for workspace client) or account_id (for account client)"
)

# Handle traditional account hosts
if self.host_type == HostType.ACCOUNTS and self.account_id:
return get_account_endpoints(self.host, self.account_id)

# Default to workspace endpoints
return get_workspace_endpoints(self.host)

def debug_string(self) -> str:
Expand Down
15 changes: 10 additions & 5 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@ class ApiClient:

def __init__(self, cfg: Config):
self._cfg = cfg

# Create header factory that includes both auth and org ID headers
def combined_header_factory():
headers = cfg.authenticate()
# Add X-Databricks-Org-Id header for workspace clients on unified hosts
if cfg.workspace_id and cfg.host_type.value == "unified":
headers["X-Databricks-Org-Id"] = cfg.workspace_id
return headers

self._api_client = _BaseClient(
debug_truncate_bytes=cfg.debug_truncate_bytes,
retry_timeout_seconds=cfg.retry_timeout_seconds,
user_agent_base=cfg.user_agent,
header_factory=cfg.authenticate,
header_factory=combined_header_factory,
max_connection_pools=cfg.max_connection_pools,
max_connections_per_pool=cfg.max_connections_per_pool,
pool_block=True,
Expand All @@ -39,10 +48,6 @@ def __init__(self, cfg: Config):
def account_id(self) -> str:
return self._cfg.account_id

@property
def is_account_client(self) -> bool:
return self._cfg.is_account_client

def get_oauth_token(self, auth_details: str) -> Token:
if not self._cfg.auth_type:
self._cfg.authenticate()
Expand Down
11 changes: 6 additions & 5 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.oauth2 import service_account # type: ignore

from . import azure, oauth, oidc, oidc_token_supplier
from .config import HostType

CredentialsProvider = Callable[[], Dict[str, str]]

Expand Down Expand Up @@ -422,9 +423,9 @@ def _oidc_credentials_provider(

# Determine the audience for token exchange
audience = cfg.token_audience
if audience is None and cfg.is_account_client:
if audience is None and cfg.host_type == HostType.ACCOUNTS:
audience = cfg.account_id
if audience is None and not cfg.is_account_client:
if audience is None and cfg.host_type != HostType.ACCOUNTS:
audience = cfg.oidc_endpoints.token_endpoint

# Try to get an OIDC token. If no supplier returns a token, we cannot use this authentication mode.
Expand Down Expand Up @@ -581,7 +582,7 @@ def token() -> oauth.Token:
def refreshed_headers() -> Dict[str, str]:
credentials.refresh(request)
headers = {"Authorization": f"Bearer {credentials.token}"}
if cfg.is_account_client:
if cfg.host_type == HostType.ACCOUNTS:
gcp_credentials.refresh(request)
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
return headers
Expand Down Expand Up @@ -622,7 +623,7 @@ def token() -> oauth.Token:
def refreshed_headers() -> Dict[str, str]:
id_creds.refresh(request)
headers = {"Authorization": f"Bearer {id_creds.token}"}
if cfg.is_account_client:
if cfg.host_type == HostType.ACCOUNTS:
gcp_impersonated_credentials.refresh(request)
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
return headers
Expand Down Expand Up @@ -844,7 +845,7 @@ class DatabricksCliTokenSource(CliTokenSource):

def __init__(self, cfg: "Config"):
args = ["auth", "token", "--host", cfg.host]
if cfg.is_account_client:
if cfg.host_type == HostType.ACCOUNTS:
args += ["--account-id", cfg.account_id]

cli_path = cfg.databricks_cli_path
Expand Down
13 changes: 13 additions & 0 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,19 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O
return OidcEndpoints.from_dict(resp)


def get_unified_endpoints(host: str, workspace_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints:
"""
Get the OIDC endpoints for a unified host with a specific workspace.
:param host: The Databricks unified host.
:param workspace_id: The workspace ID.
:return: The OIDC endpoints for the workspace on the unified host.
"""
host = _fix_host_if_needed(host)
oidc = f"{host}/oidc/unified/{workspace_id}/.well-known/oauth-authorization-server"
resp = client.do("GET", oidc)
return OidcEndpoints.from_dict(resp)


def get_azure_entra_id_workspace_endpoints(
host: str,
) -> Optional[OidcEndpoints]:
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from databricks.sdk import AccountClient, FilesAPI, FilesExt, WorkspaceClient
from databricks.sdk.config import HostType
from databricks.sdk.service.catalog import VolumeType


Expand Down Expand Up @@ -63,7 +64,7 @@ def a(env_or_skip) -> AccountClient:
_load_debug_env_if_runs_from_ide("account")
env_or_skip("CLOUD_ENV")
account_client = AccountClient()
if not account_client.config.is_account_client:
if account_client.config.host_type != HostType.ACCOUNTS:
pytest.skip("not Databricks Account client")
return account_client

Expand All @@ -73,7 +74,7 @@ def ucacct(env_or_skip) -> AccountClient:
_load_debug_env_if_runs_from_ide("ucacct")
env_or_skip("CLOUD_ENV")
account_client = AccountClient()
if not account_client.config.is_account_client:
if account_client.config.host_type != HostType.ACCOUNTS:
pytest.skip("not Databricks Account client")
if "TEST_METASTORE_ID" not in os.environ:
pytest.skip("not in Unity Catalog Workspace test env")
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from databricks.sdk.core import DatabricksError
from databricks.sdk.service.compute import EventType
from databricks.sdk import WorkspaceClient


def test_smallest_node_type(w):
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/test_spog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from databricks.sdk import WorkspaceClient

def test_smallest_node_type_spog_with_profile():
w = WorkspaceClient(
profile="spog-test")
node_type_id = w.clusters.select_node_type(local_disk=True)
assert node_type_id is not None


def test_smallest_node_type_spog_without_profile():
w = WorkspaceClient(
host="https://db-deco-test.databricks.com",
is_unified_host=True,
)
node_type_id = w.clusters.select_node_type(local_disk=True)
assert node_type_id is not None
122 changes: 121 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

from databricks.sdk import oauth, useragent
from databricks.sdk.config import Config, with_product, with_user_agent_extra
from databricks.sdk.config import Config, HostType, with_product, with_user_agent_extra
from databricks.sdk.version import __version__

from .conftest import noop_credentials, set_az_path
Expand Down Expand Up @@ -260,3 +260,123 @@ def test_oauth_token_reuses_existing_provider(mocker):
# Both calls should work and use the same provider instance
assert token1 == token2 == mock_token
assert mock_oauth_provider.oauth_token.call_count == 2


def test_host_type_workspace():
"""Test that a regular workspace host is identified correctly."""
config = Config(host="https://test.databricks.com", token="test-token")
assert config.host_type == HostType.WORKSPACE


def test_host_type_accounts():
"""Test that an accounts host is identified correctly."""
config = Config(host="https://accounts.cloud.databricks.com", account_id="test-account", token="test-token")
assert config.host_type == HostType.ACCOUNTS


def test_host_type_accounts_dod():
"""Test that an accounts-dod host is identified correctly."""
config = Config(host="https://accounts-dod.cloud.databricks.us", account_id="test-account", token="test-token")
assert config.host_type == HostType.ACCOUNTS


def test_host_type_unified():
"""Test that a unified host is identified when experimental flag is set."""
config = Config(
host="https://unified.databricks.com",
workspace_id="test-workspace",
experimental_is_unified_host=True,
token="test-token",
)
assert config.host_type == HostType.UNIFIED


def test_oidc_endpoints_unified_workspace(mocker, requests_mock):
"""Test that oidc_endpoints returns unified endpoints for workspace on unified host."""
requests_mock.get(
"https://unified.databricks.com/oidc/unified/test-workspace/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://unified.databricks.com/oidc/unified/test-workspace/v1/authorize",
"token_endpoint": "https://unified.databricks.com/oidc/unified/test-workspace/v1/token",
},
)

config = Config(
host="https://unified.databricks.com",
workspace_id="test-workspace",
experimental_is_unified_host=True,
token="test-token",
)

endpoints = config.oidc_endpoints
assert endpoints is not None
assert "unified/test-workspace" in endpoints.authorization_endpoint
assert "unified/test-workspace" in endpoints.token_endpoint


def test_oidc_endpoints_unified_account(mocker, requests_mock):
"""Test that oidc_endpoints returns account endpoints for account on unified host."""
requests_mock.get(
"https://unified.databricks.com/oidc/accounts/test-account/.well-known/oauth-authorization-server",
json={
"authorization_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/authorize",
"token_endpoint": "https://unified.databricks.com/oidc/accounts/test-account/v1/token",
},
)

config = Config(
host="https://unified.databricks.com",
account_id="test-account",
experimental_is_unified_host=True,
token="test-token",
)

endpoints = config.oidc_endpoints
assert endpoints is not None
assert "accounts/test-account" in endpoints.authorization_endpoint
assert "accounts/test-account" in endpoints.token_endpoint


def test_oidc_endpoints_unified_missing_ids():
"""Test that oidc_endpoints raises error when unified host lacks required IDs."""
config = Config(host="https://unified.databricks.com", experimental_is_unified_host=True, token="test-token")

with pytest.raises(ValueError) as exc_info:
_ = config.oidc_endpoints

assert "Unified host requires either workspace_id" in str(exc_info.value)


def test_workspace_org_id_header_on_unified_host(requests_mock):
"""Test that X-Databricks-Org-Id header is added for workspace clients on unified hosts."""
from databricks.sdk.core import ApiClient

requests_mock.get("https://unified.databricks.com/api/2.0/test", json={"result": "success"})

config = Config(
host="https://unified.databricks.com",
workspace_id="test-workspace-123",
experimental_is_unified_host=True,
token="test-token",
)

api_client = ApiClient(config)
api_client.do("GET", "/api/2.0/test")

# Verify the request was made with the X-Databricks-Org-Id header
assert requests_mock.last_request.headers.get("X-Databricks-Org-Id") == "test-workspace-123"


def test_no_org_id_header_on_regular_workspace(requests_mock):
"""Test that X-Databricks-Org-Id header is NOT added for regular workspace hosts."""
from databricks.sdk.core import ApiClient

requests_mock.get("https://test.databricks.com/api/2.0/test", json={"result": "success"})

config = Config(host="https://test.databricks.com", token="test-token")

api_client = ApiClient(config)
api_client.do("GET", "/api/2.0/test")

# Verify the X-Databricks-Org-Id header was NOT added
assert "X-Databricks-Org-Id" not in requests_mock.last_request.headers
Loading
Loading