Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions python/packages/core/agent_framework/azure/_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import json
import logging
import sys
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from typing import Any, Generic, TypedDict

from azure.core.credentials import TokenCredential
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from pydantic import ValidationError
Expand Down Expand Up @@ -152,13 +152,12 @@ def __init__(
deployment_name: str | None = None,
endpoint: str | None = None,
base_url: str | None = None,
api_version: str | None = None,
ad_token: str | None = None,
ad_token_provider: AsyncAzureADTokenProvider | None = None,
ad_token_provider: Callable[[], str | Awaitable[str]] | None = None,
token_endpoint: str | None = None,
credential: TokenCredential | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | None = None,
async_client: AsyncOpenAI | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
instruction_role: str | None = None,
Expand All @@ -176,11 +175,9 @@ def __init__(
in the env vars or .env file.
Can also be set via environment variable AZURE_OPENAI_ENDPOINT.
base_url: The deployment base URL. If provided will override the value
in the env vars or .env file.
in the env vars or .env file. For standard Azure endpoints, /openai/v1/ is
appended automatically.
Can also be set via environment variable AZURE_OPENAI_BASE_URL.
api_version: The deployment API version. If provided will override the value
in the env vars or .env file.
Can also be set via environment variable AZURE_OPENAI_API_VERSION.
ad_token: The Azure Active Directory token.
ad_token_provider: The Azure Active Directory token provider.
token_endpoint: The token endpoint to request an Azure token.
Expand Down Expand Up @@ -236,7 +233,6 @@ class MyOptions(AzureOpenAIChatOptions, total=False):
base_url=base_url, # type: ignore
endpoint=endpoint, # type: ignore
chat_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
Expand All @@ -254,7 +250,6 @@ class MyOptions(AzureOpenAIChatOptions, total=False):
deployment_name=azure_openai_settings.chat_deployment_name,
endpoint=azure_openai_settings.endpoint,
base_url=azure_openai_settings.base_url,
api_version=azure_openai_settings.api_version, # type: ignore
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
Expand Down
35 changes: 9 additions & 26 deletions python/packages/core/agent_framework/azure/_responses_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) Microsoft. All rights reserved.

import sys
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from typing import TYPE_CHECKING, Any, Generic, TypedDict
from urllib.parse import urljoin

from azure.core.credentials import TokenCredential
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from openai import AsyncOpenAI
from pydantic import ValidationError

from agent_framework import use_chat_middleware, use_function_invocation
Expand Down Expand Up @@ -59,13 +58,12 @@ def __init__(
deployment_name: str | None = None,
endpoint: str | None = None,
base_url: str | None = None,
api_version: str | None = None,
ad_token: str | None = None,
ad_token_provider: AsyncAzureADTokenProvider | None = None,
ad_token_provider: Callable[[], str | Awaitable[str]] | None = None,
token_endpoint: str | None = None,
credential: TokenCredential | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | None = None,
async_client: AsyncOpenAI | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
instruction_role: str | None = None,
Expand All @@ -83,11 +81,9 @@ def __init__(
in the env vars or .env file.
Can also be set via environment variable AZURE_OPENAI_ENDPOINT.
base_url: The deployment base URL. If provided will override the value
in the env vars or .env file. Currently, the base_url must end with "/openai/v1/".
in the env vars or .env file. For standard Azure endpoints, /openai/v1/ is
appended automatically.
Can also be set via environment variable AZURE_OPENAI_BASE_URL.
api_version: The deployment API version. If provided will override the value
in the env vars or .env file. Currently, the api_version must be "preview".
Can also be set via environment variable AZURE_OPENAI_API_VERSION.
ad_token: The Azure Active Directory token.
ad_token_provider: The Azure Active Directory token provider.
token_endpoint: The token endpoint to request an Azure token.
Expand Down Expand Up @@ -142,22 +138,10 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False):
base_url=base_url, # type: ignore
endpoint=endpoint, # type: ignore
responses_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
default_api_version="preview",
)
# TODO(peterychang): This is a temporary hack to ensure that the base_url is set correctly
# while this feature is in preview.
# But we should only do this if we're on azure. Private deployments may not need this.
if (
not azure_openai_settings.base_url
and azure_openai_settings.endpoint
and azure_openai_settings.endpoint.host
and azure_openai_settings.endpoint.host.endswith(".openai.azure.com")
):
azure_openai_settings.base_url = urljoin(str(azure_openai_settings.endpoint), "/openai/v1/") # type: ignore
except ValidationError as exc:
raise ServiceInitializationError(f"Failed to validate settings: {exc}") from exc

Expand All @@ -171,7 +155,6 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False):
deployment_name=azure_openai_settings.responses_deployment_name,
endpoint=azure_openai_settings.endpoint,
base_url=azure_openai_settings.base_url,
api_version=azure_openai_settings.api_version, # type: ignore
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
Expand All @@ -183,8 +166,8 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False):
)

@override
def _check_model_presence(self, run_options: dict[str, Any]) -> None:
if not run_options.get("model"):
def _check_model_presence(self, options: dict[str, Any]) -> None:
if not options.get("model"):
if not self.model_id:
raise ValueError("deployment_name must be a non-empty string")
run_options["model"] = self.model_id
options["model"] = self.model_id
85 changes: 55 additions & 30 deletions python/packages/core/agent_framework/azure/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from collections.abc import Awaitable, Callable, Mapping
from copy import copy
from typing import Any, ClassVar, Final
from urllib.parse import urljoin

from azure.core.credentials import TokenCredential
from openai.lib.azure import AsyncAzureOpenAI
from openai import AsyncOpenAI
from pydantic import SecretStr, model_validator

from .._pydantic import AFBaseSettings, HTTPsUrl
Expand Down Expand Up @@ -142,6 +143,29 @@ def _validate_fields(self) -> Self:
return self


def _construct_v1_base_url(endpoint: HTTPsUrl | None, base_url: HTTPsUrl | None) -> str | None:
"""Construct the v1 API base URL from endpoint if not explicitly provided.

For standard Azure OpenAI endpoints, automatically appends /openai/v1/ path.
Custom/private deployments can provide their own base_url.

Args:
endpoint: The Azure OpenAI endpoint URL.
base_url: Explicit base URL if provided by user.

Returns:
The base URL to use, or None if neither endpoint nor base_url is valid.
"""
if base_url:
return str(base_url)

# Standard Azure OpenAI endpoints
if endpoint and endpoint.host and endpoint.host.endswith((".openai.azure.com", ".services.ai.azure.com")):
return urljoin(str(endpoint), "/openai/v1/")

return None


class AzureOpenAIConfigMixin(OpenAIBase):
"""Internal class for configuring a connection to an Azure OpenAI service."""

Expand All @@ -153,31 +177,27 @@ def __init__(
deployment_name: str,
endpoint: HTTPsUrl | None = None,
base_url: HTTPsUrl | None = None,
api_version: str = DEFAULT_AZURE_API_VERSION,
api_key: str | None = None,
ad_token: str | None = None,
ad_token_provider: Callable[[], str | Awaitable[str]] | None = None,
token_endpoint: str | None = None,
credential: TokenCredential | None = None,
default_headers: Mapping[str, str] | None = None,
client: AsyncAzureOpenAI | None = None,
client: AsyncOpenAI | None = None,
instruction_role: str | None = None,
**kwargs: Any,
) -> None:
"""Internal class for configuring a connection to an Azure OpenAI service.

The `validate_call` decorator is used with a configuration that allows arbitrary types.
This is necessary for types like `HTTPsUrl` and `OpenAIModelTypes`.

Args:
deployment_name: Name of the deployment.
endpoint: The specific endpoint URL for the deployment.
base_url: The base URL for Azure services.
api_version: Azure API version. Defaults to the defined DEFAULT_AZURE_API_VERSION.
api_key: API key for Azure services.
base_url: The base URL for Azure services. If not provided and endpoint is a
standard Azure OpenAI endpoint, /openai/v1/ will be appended automatically.
api_key: API key for Azure services. Can also be a token provider callable.
ad_token: Azure AD token for authentication.
ad_token_provider: A callable or coroutine function providing Azure AD tokens.
token_endpoint: Azure AD token endpoint use to get the token.
token_endpoint: Azure AD token endpoint used to get the token.
credential: Azure credential for authentication.
default_headers: Default headers for HTTP requests.
client: An existing client to use.
Expand All @@ -191,7 +211,11 @@ def __init__(
if APP_INFO:
merged_headers.update(APP_INFO)
merged_headers = prepend_agent_framework_to_user_agent(merged_headers)

if not client:
# Construct v1 base URL from endpoint if not explicitly provided
v1_base_url = _construct_v1_base_url(endpoint, base_url)

# If the client is None, the api_key is none, the ad_token is none, and the ad_token_provider is none,
# then we will attempt to get the ad_token using the default endpoint specified in the Azure OpenAI
# settings.
Expand All @@ -203,35 +227,36 @@ def __init__(
"Please provide either api_key, ad_token or ad_token_provider or a client."
)

if not endpoint and not base_url:
raise ServiceInitializationError("Please provide an endpoint or a base_url")
if not v1_base_url:
raise ServiceInitializationError(
"Please provide an endpoint or a base_url. "
"For standard Azure OpenAI endpoints (*.openai.azure.com and *.services.ai.azure.com), "
"the v1 API path will be appended automatically; for non-standard or private deployments, "
"you must provide a base_url that already includes the desired API path."
)

# Determine the effective api_key for AsyncOpenAI
effective_api_key: str | Callable[[], str | Awaitable[str]] | None = None
if api_key:
effective_api_key = api_key
elif ad_token_provider:
effective_api_key = ad_token_provider
elif ad_token:
effective_api_key = ad_token

args: dict[str, Any] = {
"base_url": v1_base_url,
"api_key": effective_api_key,
"default_headers": merged_headers,
}
if api_version:
args["api_version"] = api_version
if ad_token:
args["azure_ad_token"] = ad_token
if ad_token_provider:
args["azure_ad_token_provider"] = ad_token_provider
if api_key:
args["api_key"] = api_key
if base_url:
args["base_url"] = str(base_url)
if endpoint and not base_url:
args["azure_endpoint"] = str(endpoint)
if deployment_name:
args["azure_deployment"] = deployment_name
if "websocket_base_url" in kwargs:
args["websocket_base_url"] = kwargs.pop("websocket_base_url")

client = AsyncAzureOpenAI(**args)
client = AsyncOpenAI(**args)

# Store configuration as instance attributes for serialization
self.endpoint = str(endpoint)
self.base_url = str(base_url)
self.api_version = api_version
self.endpoint = str(endpoint) if endpoint else None
self.base_url = str(base_url) if base_url else None
self.deployment_name = deployment_name
self.instruction_role = instruction_role
# Store default_headers but filter out USER_AGENT_KEY for serialization
Expand Down
2 changes: 1 addition & 1 deletion python/packages/core/tests/azure/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def azure_openai_unit_test_env(monkeypatch, exclude_list, override_env_param_dic
override_env_param_dict = {}

env_vars = {
"AZURE_OPENAI_ENDPOINT": "https://test-endpoint.com",
"AZURE_OPENAI_ENDPOINT": "https://test-endpoint.openai.azure.com",
"AZURE_OPENAI_CHAT_DEPLOYMENT_NAME": "test_chat_deployment",
"AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME": "test_chat_deployment",
"AZURE_OPENAI_TEXT_DEPLOYMENT_NAME": "test_text_deployment",
Expand Down
Loading
Loading