Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
## Release v0.77.0

### New Features and Improvements
* Add `get_async_open_ai_client()` method to `ServingEndpointsAPI` for native `AsyncOpenAI` client support ([#847](https://github.com/databricks/databricks-sdk-py/issues/847)).

### Security

### Bug Fixes
* Fix `get_langchain_chat_open_ai_client()` returning 401 on async operations by adding `http_async_client` ([#1173](https://github.com/databricks/databricks-sdk-py/issues/1173)).

### Documentation

Expand Down
151 changes: 117 additions & 34 deletions databricks/sdk/mixins/open_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,59 @@
ServingEndpointsAPI)


def _get_bearer_auth(authenticate_func):
"""Create an httpx Auth class that uses Databricks authentication.

This auth class works with both httpx.Client and httpx.AsyncClient.
The auth_flow generator pattern is automatically wrapped by httpx for
both sync and async operations.
"""
import httpx

class BearerAuth(httpx.Auth):
def __init__(self, get_headers_func):
self.get_headers_func = get_headers_func

def auth_flow(self, request: httpx.Request):
auth_headers = self.get_headers_func()
request.headers["Authorization"] = auth_headers["Authorization"]
yield request

return BearerAuth(authenticate_func)


class ServingEndpointsExt(ServingEndpointsAPI):

# Using the HTTP Client to pass in the databricks authorization
# This method will be called on every invocation, so when using with model serving will always get the refreshed token
def _get_authorized_http_client(self):
import httpx
_OPENAI_RESERVED_PARAMS = {"base_url", "api_key", "http_client"}

class BearerAuth(httpx.Auth):
def _check_reserved_openai_params(self, kwargs):
conflicting_params = self._OPENAI_RESERVED_PARAMS.intersection(kwargs.keys())
if conflicting_params:
raise ValueError(
f"Cannot override reserved Databricks parameters: {', '.join(sorted(conflicting_params))}. "
f"These parameters are automatically configured for Databricks Model Serving."
)

def __init__(self, get_headers_func):
self.get_headers_func = get_headers_func
def _build_openai_client_params(self, http_client, kwargs):
client_params = {
"base_url": self._api._cfg.host + "/serving-endpoints",
"api_key": "no-token",
"http_client": http_client,
}
client_params.update(kwargs)
return client_params

def auth_flow(self, request: httpx.Request) -> httpx.Request:
auth_headers = self.get_headers_func()
request.headers["Authorization"] = auth_headers["Authorization"]
yield request
def _get_authorized_http_client(self):
import httpx

databricks_token_auth = BearerAuth(self._api._cfg.authenticate)
databricks_token_auth = _get_bearer_auth(self._api._cfg.authenticate)
return httpx.Client(auth=databricks_token_auth)

# Create an HTTP client with Bearer Token authentication
http_client = httpx.Client(auth=databricks_token_auth)
return http_client
def _get_authorized_async_http_client(self):
import httpx

databricks_token_auth = _get_bearer_auth(self._api._cfg.authenticate)
return httpx.AsyncClient(auth=databricks_token_auth)

def get_open_ai_client(self, **kwargs):
"""Create an OpenAI client configured for Databricks Model Serving.
Expand Down Expand Up @@ -70,43 +101,95 @@ def get_open_ai_client(self, **kwargs):
from openai import OpenAI
except Exception:
raise ImportError(
"Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`"
"OpenAI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`"
)

# Check for reserved parameters that should not be overridden
reserved_params = {"base_url", "api_key", "http_client"}
conflicting_params = reserved_params.intersection(kwargs.keys())
if conflicting_params:
raise ValueError(
f"Cannot override reserved Databricks parameters: {', '.join(sorted(conflicting_params))}. "
f"These parameters are automatically configured for Databricks Model Serving."
)
self._check_reserved_openai_params(kwargs)
client_params = self._build_openai_client_params(self._get_authorized_http_client(), kwargs)
return OpenAI(**client_params)

# Default parameters that are required for Databricks integration
client_params = {
"base_url": self._api._cfg.host + "/serving-endpoints",
"api_key": "no-token", # Passing in a placeholder to pass validations, this will not be used
"http_client": self._get_authorized_http_client(),
}
def get_async_open_ai_client(self, **kwargs):
"""Create an AsyncOpenAI client configured for Databricks Model Serving.

# Update with any additional parameters passed by the user
client_params.update(kwargs)
Returns an AsyncOpenAI client instance that is pre-configured to send requests to
Databricks Model Serving endpoints. The client uses Databricks authentication
to query endpoints within the workspace associated with the current WorkspaceClient
instance.

return OpenAI(**client_params)
This client is suitable for async/await patterns and concurrent API calls.

Args:
**kwargs: Additional parameters to pass to the AsyncOpenAI client constructor.
Common parameters include:
- timeout (float): Request timeout in seconds (e.g., 30.0)
- max_retries (int): Maximum number of retries for failed requests (e.g., 3)
- default_headers (dict): Additional headers to include with requests
- default_query (dict): Additional query parameters to include with requests

Any parameter accepted by the AsyncOpenAI client constructor can be passed here,
except for the following parameters which are reserved for Databricks integration:
base_url, api_key, http_client

Returns:
AsyncOpenAI: An AsyncOpenAI client instance configured for Databricks Model Serving.

Raises:
ImportError: If the OpenAI library is not installed.
ValueError: If any reserved Databricks parameters are provided in kwargs.

Example:
>>> client = workspace_client.serving_endpoints.get_async_open_ai_client()
>>> # With custom timeout and retries
>>> client = workspace_client.serving_endpoints.get_async_open_ai_client(
... timeout=30.0,
... max_retries=5
... )
>>> # Use with async/await
>>> response = await client.chat.completions.create(
... model="databricks-meta-llama-3-1-70b-instruct",
... messages=[{"role": "user", "content": "Hello!"}]
... )
"""
try:
from openai import AsyncOpenAI
except Exception:
raise ImportError(
"OpenAI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`"
)

self._check_reserved_openai_params(kwargs)
client_params = self._build_openai_client_params(self._get_authorized_async_http_client(), kwargs)
return AsyncOpenAI(**client_params)

def get_langchain_chat_open_ai_client(self, model):
"""Create a LangChain ChatOpenAI client configured for Databricks Model Serving.

Returns a ChatOpenAI instance that is pre-configured to send requests to
Databricks Model Serving endpoints. The client uses Databricks authentication
for both synchronous and asynchronous operations.

Args:
model: The name of the model serving endpoint to use.

Returns:
ChatOpenAI: A LangChain ChatOpenAI client instance configured for Databricks.

Raises:
ImportError: If langchain-openai is not installed.
"""
try:
from langchain_openai import ChatOpenAI
except Exception:
raise ImportError(
"Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]` and ensure you are using python>3.7"
"Langchain OpenAI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]` and ensure you are using python>3.7"
)

return ChatOpenAI(
model=model,
openai_api_base=self._api._cfg.host + "/serving-endpoints",
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
http_client=self._get_authorized_http_client(),
http_async_client=self._get_authorized_async_http_client(),
)

def http_request(
Expand Down
49 changes: 49 additions & 0 deletions tests/test_open_ai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def test_langchain_open_ai_client(monkeypatch):

assert client.openai_api_base == "https://test_host/serving-endpoints"
assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"
assert client.http_client.auth is not None
assert client.http_async_client.auth is not None


def test_http_request(w, requests_mock):
Expand All @@ -115,3 +117,50 @@ def test_http_request(w, requests_mock):
assert requests_mock.called
assert response.status_code == 200 # Verify the response status
assert response.text == "The request was successful" # Ensure the response body matches the mocked data


def test_async_open_ai_client(monkeypatch):
from openai import AsyncOpenAI

from databricks.sdk import WorkspaceClient

monkeypatch.setenv("DATABRICKS_HOST", "test_host")
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
w = WorkspaceClient(config=Config())
client = w.serving_endpoints.get_async_open_ai_client()

assert isinstance(client, AsyncOpenAI)
assert client.base_url == "https://test_host/serving-endpoints/"
assert client.api_key == "no-token"


def test_async_open_ai_client_with_custom_params(monkeypatch):
from databricks.sdk import WorkspaceClient

monkeypatch.setenv("DATABRICKS_HOST", "test_host")
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
w = WorkspaceClient(config=Config())

client = w.serving_endpoints.get_async_open_ai_client(timeout=30.0, max_retries=3)

assert client.base_url == "https://test_host/serving-endpoints/"
assert client.api_key == "no-token"
assert client.timeout == 30.0
assert client.max_retries == 3


def test_async_open_ai_client_prevents_reserved_param_override(monkeypatch):
from databricks.sdk import WorkspaceClient

monkeypatch.setenv("DATABRICKS_HOST", "test_host")
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
w = WorkspaceClient(config=Config())

with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: base_url"):
w.serving_endpoints.get_async_open_ai_client(base_url="https://custom-host")

with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: api_key"):
w.serving_endpoints.get_async_open_ai_client(api_key="custom-key")

with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: http_client"):
w.serving_endpoints.get_async_open_ai_client(http_client=None)