diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index c1829661d..d7e5ff961 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,10 +3,12 @@ ## Release v0.77.0 ### New Features and Improvements +* Add `get_async_open_ai_client()` method to `ServingEndpointsAPI` for native `AsyncOpenAI` client support ([#847](https://github.com/databricks/databricks-sdk-py/issues/847)). ### Security ### Bug Fixes +* Fix `get_langchain_chat_open_ai_client()` returning 401 on async operations by adding `http_async_client` ([#1173](https://github.com/databricks/databricks-sdk-py/issues/1173)). ### Documentation diff --git a/databricks/sdk/mixins/open_ai_client.py b/databricks/sdk/mixins/open_ai_client.py index 4ab08ee5a..4ea267d5c 100644 --- a/databricks/sdk/mixins/open_ai_client.py +++ b/databricks/sdk/mixins/open_ai_client.py @@ -8,28 +8,59 @@ ServingEndpointsAPI) +def _get_bearer_auth(authenticate_func): + """Create an httpx Auth class that uses Databricks authentication. + + This auth class works with both httpx.Client and httpx.AsyncClient. + The auth_flow generator pattern is automatically wrapped by httpx for + both sync and async operations. + """ + import httpx + + class BearerAuth(httpx.Auth): + def __init__(self, get_headers_func): + self.get_headers_func = get_headers_func + + def auth_flow(self, request: httpx.Request): + auth_headers = self.get_headers_func() + request.headers["Authorization"] = auth_headers["Authorization"] + yield request + + return BearerAuth(authenticate_func) + + class ServingEndpointsExt(ServingEndpointsAPI): - # Using the HTTP Client to pass in the databricks authorization - # This method will be called on every invocation, so when using with model serving will always get the refreshed token - def _get_authorized_http_client(self): - import httpx + _OPENAI_RESERVED_PARAMS = {"base_url", "api_key", "http_client"} - class BearerAuth(httpx.Auth): + def _check_reserved_openai_params(self, kwargs): + conflicting_params = self._OPENAI_RESERVED_PARAMS.intersection(kwargs.keys()) + if conflicting_params: + raise ValueError( + f"Cannot override reserved Databricks parameters: {', '.join(sorted(conflicting_params))}. " + f"These parameters are automatically configured for Databricks Model Serving." + ) - def __init__(self, get_headers_func): - self.get_headers_func = get_headers_func + def _build_openai_client_params(self, http_client, kwargs): + client_params = { + "base_url": self._api._cfg.host + "/serving-endpoints", + "api_key": "no-token", + "http_client": http_client, + } + client_params.update(kwargs) + return client_params - def auth_flow(self, request: httpx.Request) -> httpx.Request: - auth_headers = self.get_headers_func() - request.headers["Authorization"] = auth_headers["Authorization"] - yield request + def _get_authorized_http_client(self): + import httpx - databricks_token_auth = BearerAuth(self._api._cfg.authenticate) + databricks_token_auth = _get_bearer_auth(self._api._cfg.authenticate) + return httpx.Client(auth=databricks_token_auth) - # Create an HTTP client with Bearer Token authentication - http_client = httpx.Client(auth=databricks_token_auth) - return http_client + def _get_authorized_async_http_client(self): + import httpx + + databricks_token_auth = _get_bearer_auth(self._api._cfg.authenticate) + return httpx.AsyncClient(auth=databricks_token_auth) def get_open_ai_client(self, **kwargs): """Create an OpenAI client configured for Databricks Model Serving. @@ -70,36 +101,87 @@ def get_open_ai_client(self, **kwargs): from openai import OpenAI except Exception: raise ImportError( - "Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`" + "OpenAI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`" ) - # Check for reserved parameters that should not be overridden - reserved_params = {"base_url", "api_key", "http_client"} - conflicting_params = reserved_params.intersection(kwargs.keys()) - if conflicting_params: - raise ValueError( - f"Cannot override reserved Databricks parameters: {', '.join(sorted(conflicting_params))}. " - f"These parameters are automatically configured for Databricks Model Serving." - ) + self._check_reserved_openai_params(kwargs) + client_params = self._build_openai_client_params(self._get_authorized_http_client(), kwargs) + return OpenAI(**client_params) - # Default parameters that are required for Databricks integration - client_params = { - "base_url": self._api._cfg.host + "/serving-endpoints", - "api_key": "no-token", # Passing in a placeholder to pass validations, this will not be used - "http_client": self._get_authorized_http_client(), - } + def get_async_open_ai_client(self, **kwargs): + """Create an AsyncOpenAI client configured for Databricks Model Serving. - # Update with any additional parameters passed by the user - client_params.update(kwargs) + Returns an AsyncOpenAI client instance that is pre-configured to send requests to + Databricks Model Serving endpoints. The client uses Databricks authentication + to query endpoints within the workspace associated with the current WorkspaceClient + instance. - return OpenAI(**client_params) + This client is suitable for async/await patterns and concurrent API calls. + + Args: + **kwargs: Additional parameters to pass to the AsyncOpenAI client constructor. + Common parameters include: + - timeout (float): Request timeout in seconds (e.g., 30.0) + - max_retries (int): Maximum number of retries for failed requests (e.g., 3) + - default_headers (dict): Additional headers to include with requests + - default_query (dict): Additional query parameters to include with requests + + Any parameter accepted by the AsyncOpenAI client constructor can be passed here, + except for the following parameters which are reserved for Databricks integration: + base_url, api_key, http_client + + Returns: + AsyncOpenAI: An AsyncOpenAI client instance configured for Databricks Model Serving. + + Raises: + ImportError: If the OpenAI library is not installed. + ValueError: If any reserved Databricks parameters are provided in kwargs. + + Example: + >>> client = workspace_client.serving_endpoints.get_async_open_ai_client() + >>> # With custom timeout and retries + >>> client = workspace_client.serving_endpoints.get_async_open_ai_client( + ... timeout=30.0, + ... max_retries=5 + ... ) + >>> # Use with async/await + >>> response = await client.chat.completions.create( + ... model="databricks-meta-llama-3-1-70b-instruct", + ... messages=[{"role": "user", "content": "Hello!"}] + ... ) + """ + try: + from openai import AsyncOpenAI + except Exception: + raise ImportError( + "OpenAI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`" + ) + + self._check_reserved_openai_params(kwargs) + client_params = self._build_openai_client_params(self._get_authorized_async_http_client(), kwargs) + return AsyncOpenAI(**client_params) def get_langchain_chat_open_ai_client(self, model): + """Create a LangChain ChatOpenAI client configured for Databricks Model Serving. + + Returns a ChatOpenAI instance that is pre-configured to send requests to + Databricks Model Serving endpoints. The client uses Databricks authentication + for both synchronous and asynchronous operations. + + Args: + model: The name of the model serving endpoint to use. + + Returns: + ChatOpenAI: A LangChain ChatOpenAI client instance configured for Databricks. + + Raises: + ImportError: If langchain-openai is not installed. + """ try: from langchain_openai import ChatOpenAI except Exception: raise ImportError( - "Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]` and ensure you are using python>3.7" + "Langchain OpenAI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]` and ensure you are using python>3.7" ) return ChatOpenAI( @@ -107,6 +189,7 @@ def get_langchain_chat_open_ai_client(self, model): openai_api_base=self._api._cfg.host + "/serving-endpoints", api_key="no-token", # Passing in a placeholder to pass validations, this will not be used http_client=self._get_authorized_http_client(), + http_async_client=self._get_authorized_async_http_client(), ) def http_request( diff --git a/tests/test_open_ai_mixin.py b/tests/test_open_ai_mixin.py index dfc248d0a..d43fd9a44 100644 --- a/tests/test_open_ai_mixin.py +++ b/tests/test_open_ai_mixin.py @@ -90,6 +90,8 @@ def test_langchain_open_ai_client(monkeypatch): assert client.openai_api_base == "https://test_host/serving-endpoints" assert client.model_name == "databricks-meta-llama-3-1-70b-instruct" + assert client.http_client.auth is not None + assert client.http_async_client.auth is not None def test_http_request(w, requests_mock): @@ -115,3 +117,50 @@ def test_http_request(w, requests_mock): assert requests_mock.called assert response.status_code == 200 # Verify the response status assert response.text == "The request was successful" # Ensure the response body matches the mocked data + + +def test_async_open_ai_client(monkeypatch): + from openai import AsyncOpenAI + + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv("DATABRICKS_HOST", "test_host") + monkeypatch.setenv("DATABRICKS_TOKEN", "test_token") + w = WorkspaceClient(config=Config()) + client = w.serving_endpoints.get_async_open_ai_client() + + assert isinstance(client, AsyncOpenAI) + assert client.base_url == "https://test_host/serving-endpoints/" + assert client.api_key == "no-token" + + +def test_async_open_ai_client_with_custom_params(monkeypatch): + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv("DATABRICKS_HOST", "test_host") + monkeypatch.setenv("DATABRICKS_TOKEN", "test_token") + w = WorkspaceClient(config=Config()) + + client = w.serving_endpoints.get_async_open_ai_client(timeout=30.0, max_retries=3) + + assert client.base_url == "https://test_host/serving-endpoints/" + assert client.api_key == "no-token" + assert client.timeout == 30.0 + assert client.max_retries == 3 + + +def test_async_open_ai_client_prevents_reserved_param_override(monkeypatch): + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv("DATABRICKS_HOST", "test_host") + monkeypatch.setenv("DATABRICKS_TOKEN", "test_token") + w = WorkspaceClient(config=Config()) + + with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: base_url"): + w.serving_endpoints.get_async_open_ai_client(base_url="https://custom-host") + + with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: api_key"): + w.serving_endpoints.get_async_open_ai_client(api_key="custom-key") + + with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: http_client"): + w.serving_endpoints.get_async_open_ai_client(http_client=None)