diff --git a/sdk/ai/azure-ai-projects/azure/ai/projects/_patch.py b/sdk/ai/azure-ai-projects/azure/ai/projects/_patch.py index abe4ae47086d..8bbe0f6d1e3e 100644 --- a/sdk/ai/azure-ai-projects/azure/ai/projects/_patch.py +++ b/sdk/ai/azure-ai-projects/azure/ai/projects/_patch.py @@ -9,6 +9,7 @@ """ import os import logging +import httpx from typing import List, Any from openai import OpenAI from azure.core.tracing.decorator import distributed_trace @@ -106,8 +107,6 @@ def get_openai_client(self, **kwargs: Any) -> "OpenAI": # type: ignore[name-def :return: An authenticated OpenAI client :rtype: ~openai.OpenAI - :raises ~azure.core.exceptions.ModuleNotFoundError: if the ``openai`` package - is not installed. :raises ~azure.core.exceptions.HttpResponseError: """ @@ -118,107 +117,14 @@ def get_openai_client(self, **kwargs: Any) -> "OpenAI": # type: ignore[name-def base_url, ) - http_client = None - kwargs = kwargs.copy() if kwargs else {} - if self._console_logging_enabled: - try: - import httpx - except ModuleNotFoundError as e: - raise ModuleNotFoundError("Failed to import httpx. Please install it using 'pip install httpx'") from e - - class OpenAILoggingTransport(httpx.HTTPTransport): - - def _sanitize_auth_header(self, headers) -> None: - """Sanitize authorization header by redacting sensitive information. - - :param headers: Dictionary of HTTP headers to sanitize - :type headers: dict - """ - - if "authorization" in headers: - auth_value = headers["authorization"] - if len(auth_value) >= 7: - headers["authorization"] = auth_value[:7] + "" - else: - headers["authorization"] = "" - - def handle_request(self, request: httpx.Request) -> httpx.Response: - """ - Log HTTP request and response details to console, in a nicely formatted way, - for OpenAI / Azure OpenAI clients. - - :param request: The HTTP request to handle and log - :type request: httpx.Request - - :return: The HTTP response received - :rtype: httpx.Response - """ - - print(f"\n==> Request:\n{request.method} {request.url}") - headers = dict(request.headers) - self._sanitize_auth_header(headers) - print("Headers:") - for key, value in sorted(headers.items()): - print(f" {key}: {value}") - - self._log_request_body(request) - - response = super().handle_request(request) - - print(f"\n<== Response:\n{response.status_code} {response.reason_phrase}") - print("Headers:") - for key, value in sorted(dict(response.headers).items()): - print(f" {key}: {value}") - - content = response.read() - if content is None or content == b"": - print("Body: [No content]") - else: - try: - print(f"Body:\n {content.decode('utf-8')}") - except Exception: # pylint: disable=broad-exception-caught - print(f"Body (raw):\n {content!r}") - print("\n") - - return response - - def _log_request_body(self, request: httpx.Request) -> None: - """Log request body content safely, handling binary data and streaming content. - - :param request: The HTTP request object containing the body to log - :type request: httpx.Request - """ - - # Check content-type header to identify file uploads - content_type = request.headers.get("content-type", "").lower() - if "multipart/form-data" in content_type: - print("Body: [Multipart form data - file upload, not logged]") - return - - # Safely check if content exists without accessing it - if not hasattr(request, "content"): - print("Body: [No content attribute]") - return - - # Very careful content access - wrap in try-catch immediately - try: - content = request.content - except Exception as access_error: # pylint: disable=broad-exception-caught - print(f"Body: [Cannot access content: {access_error}]") - return - - if content is None or content == b"": - print("Body: [No content]") - return - - try: - print(f"Body:\n {content.decode('utf-8')}") - except Exception: # pylint: disable=broad-exception-caught - print(f"Body (raw):\n {content!r}") - + if "http_client" in kwargs: + http_client = kwargs.pop("http_client") + elif self._console_logging_enabled: http_client = httpx.Client(transport=OpenAILoggingTransport()) + else: + http_client = None default_headers = dict[str, str](kwargs.pop("default_headers", None) or {}) @@ -256,6 +162,107 @@ def _create_openai_client(**kwargs) -> OpenAI: return client +class OpenAILoggingTransport(httpx.HTTPTransport): + """Custom HTTP transport that logs OpenAI API requests and responses to the console. + + This transport wraps httpx.HTTPTransport to intercept all HTTP traffic and print + detailed request/response information for debugging purposes. It automatically + redacts sensitive authorization headers and handles various content types including + multipart form data (file uploads). + + Used internally by AIProjectClient when console logging is enabled via the + AZURE_AI_PROJECTS_CONSOLE_LOGGING environment variable. + """ + + def _sanitize_auth_header(self, headers) -> None: + """Sanitize authorization header by redacting sensitive information. + + :param headers: Dictionary of HTTP headers to sanitize + :type headers: dict + """ + + if "authorization" in headers: + auth_value = headers["authorization"] + if len(auth_value) >= 7: + headers["authorization"] = auth_value[:7] + "" + else: + headers["authorization"] = "" + + def handle_request(self, request: httpx.Request) -> httpx.Response: + """ + Log HTTP request and response details to console, in a nicely formatted way, + for OpenAI / Azure OpenAI clients. + + :param request: The HTTP request to handle and log + :type request: httpx.Request + + :return: The HTTP response received + :rtype: httpx.Response + """ + + print(f"\n==> Request:\n{request.method} {request.url}") + headers = dict(request.headers) + self._sanitize_auth_header(headers) + print("Headers:") + for key, value in sorted(headers.items()): + print(f" {key}: {value}") + + self._log_request_body(request) + + response = super().handle_request(request) + + print(f"\n<== Response:\n{response.status_code} {response.reason_phrase}") + print("Headers:") + for key, value in sorted(dict(response.headers).items()): + print(f" {key}: {value}") + + content = response.read() + if content is None or content == b"": + print("Body: [No content]") + else: + try: + print(f"Body:\n {content.decode('utf-8')}") + except Exception: # pylint: disable=broad-exception-caught + print(f"Body (raw):\n {content!r}") + print("\n") + + return response + + def _log_request_body(self, request: httpx.Request) -> None: + """Log request body content safely, handling binary data and streaming content. + + :param request: The HTTP request object containing the body to log + :type request: httpx.Request + """ + + # Check content-type header to identify file uploads + content_type = request.headers.get("content-type", "").lower() + if "multipart/form-data" in content_type: + print("Body: [Multipart form data - file upload, not logged]") + return + + # Safely check if content exists without accessing it + if not hasattr(request, "content"): + print("Body: [No content attribute]") + return + + # Very careful content access - wrap in try-catch immediately + try: + content = request.content + except Exception as access_error: # pylint: disable=broad-exception-caught + print(f"Body: [Cannot access content: {access_error}]") + return + + if content is None or content == b"": + print("Body: [No content]") + return + + try: + print(f"Body:\n {content.decode('utf-8')}") + except Exception: # pylint: disable=broad-exception-caught + print(f"Body (raw):\n {content!r}") + + __all__: List[str] = [ "AIProjectClient", ] # Add all objects you want publicly available to users at this package level diff --git a/sdk/ai/azure-ai-projects/azure/ai/projects/aio/_patch.py b/sdk/ai/azure-ai-projects/azure/ai/projects/aio/_patch.py index e9d163cff7e4..b7608b41e380 100644 --- a/sdk/ai/azure-ai-projects/azure/ai/projects/aio/_patch.py +++ b/sdk/ai/azure-ai-projects/azure/ai/projects/aio/_patch.py @@ -9,6 +9,7 @@ """ import os import logging +import httpx from typing import List, Any from openai import AsyncOpenAI from azure.core.tracing.decorator import distributed_trace @@ -106,8 +107,6 @@ def get_openai_client(self, **kwargs: Any) -> "AsyncOpenAI": # type: ignore[nam :return: An authenticated AsyncOpenAI client :rtype: ~openai.AsyncOpenAI - :raises ~azure.core.exceptions.ModuleNotFoundError: if the ``openai`` package - is not installed. :raises ~azure.core.exceptions.HttpResponseError: """ @@ -118,107 +117,14 @@ def get_openai_client(self, **kwargs: Any) -> "AsyncOpenAI": # type: ignore[nam base_url, ) - http_client = None - kwargs = kwargs.copy() if kwargs else {} - if self._console_logging_enabled: - try: - import httpx - except ModuleNotFoundError as e: - raise ModuleNotFoundError("Failed to import httpx. Please install it using 'pip install httpx'") from e - - class OpenAILoggingTransport(httpx.AsyncHTTPTransport): - - def _sanitize_auth_header(self, headers): - """Sanitize authorization header by redacting sensitive information. - - :param headers: Dictionary of HTTP headers to sanitize - :type headers: dict - """ - - if "authorization" in headers: - auth_value = headers["authorization"] - if len(auth_value) >= 7: - headers["authorization"] = auth_value[:7] + "" - else: - headers["authorization"] = "" - - async def handle_async_request(self, request: httpx.Request) -> httpx.Response: - """ - Log HTTP request and response details to console, in a nicely formatted way, - for OpenAI / Azure OpenAI clients. - - :param request: The HTTP request to handle and log - :type request: httpx.Request - - :return: The HTTP response received - :rtype: httpx.Response - """ - - print(f"\n==> Request:\n{request.method} {request.url}") - headers = dict(request.headers) - self._sanitize_auth_header(headers) - print("Headers:") - for key, value in sorted(headers.items()): - print(f" {key}: {value}") - - self._log_request_body(request) - - response = await super().handle_async_request(request) - - print(f"\n<== Response:\n{response.status_code} {response.reason_phrase}") - print("Headers:") - for key, value in sorted(dict(response.headers).items()): - print(f" {key}: {value}") - - content = await response.aread() - if content is None or content == b"": - print("Body: [No content]") - else: - try: - print(f"Body:\n {content.decode('utf-8')}") - except Exception: # pylint: disable=broad-exception-caught - print(f"Body (raw):\n {content!r}") - print("\n") - - return response - - def _log_request_body(self, request: httpx.Request) -> None: - """Log request body content safely, handling binary data and streaming content. - - :param request: The HTTP request object containing the body to log - :type request: httpx.Request - """ - - # Check content-type header to identify file uploads - content_type = request.headers.get("content-type", "").lower() - if "multipart/form-data" in content_type: - print("Body: [Multipart form data - file upload, not logged]") - return - - # Safely check if content exists without accessing it - if not hasattr(request, "content"): - print("Body: [No content attribute]") - return - - # Very careful content access - wrap in try-catch immediately - try: - content = request.content - except Exception as access_error: # pylint: disable=broad-exception-caught - print(f"Body: [Cannot access content: {access_error}]") - return - - if content is None or content == b"": - print("Body: [No content]") - return - - try: - print(f"Body:\n {content.decode('utf-8')}") - except Exception: # pylint: disable=broad-exception-caught - print(f"Body (raw):\n {content!r}") - + if "http_client" in kwargs: + http_client = kwargs.pop("http_client") + elif self._console_logging_enabled: http_client = httpx.AsyncClient(transport=OpenAILoggingTransport()) + else: + http_client = None default_headers = dict[str, str](kwargs.pop("default_headers", None) or {}) @@ -256,6 +162,107 @@ def _create_openai_client(**kwargs) -> AsyncOpenAI: return client +class OpenAILoggingTransport(httpx.AsyncHTTPTransport): + """Custom HTTP async transport that logs OpenAI API requests and responses to the console. + + This transport wraps httpx.AsyncHTTPTransport to intercept all HTTP traffic and print + detailed request/response information for debugging purposes. It automatically + redacts sensitive authorization headers and handles various content types including + multipart form data (file uploads). + + Used internally by AIProjectClient when console logging is enabled via the + AZURE_AI_PROJECTS_CONSOLE_LOGGING environment variable. + """ + + def _sanitize_auth_header(self, headers): + """Sanitize authorization header by redacting sensitive information. + + :param headers: Dictionary of HTTP headers to sanitize + :type headers: dict + """ + + if "authorization" in headers: + auth_value = headers["authorization"] + if len(auth_value) >= 7: + headers["authorization"] = auth_value[:7] + "" + else: + headers["authorization"] = "" + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + """ + Log HTTP request and response details to console, in a nicely formatted way, + for OpenAI / Azure OpenAI clients. + + :param request: The HTTP request to handle and log + :type request: httpx.Request + + :return: The HTTP response received + :rtype: httpx.Response + """ + + print(f"\n==> Request:\n{request.method} {request.url}") + headers = dict(request.headers) + self._sanitize_auth_header(headers) + print("Headers:") + for key, value in sorted(headers.items()): + print(f" {key}: {value}") + + self._log_request_body(request) + + response = await super().handle_async_request(request) + + print(f"\n<== Response:\n{response.status_code} {response.reason_phrase}") + print("Headers:") + for key, value in sorted(dict(response.headers).items()): + print(f" {key}: {value}") + + content = await response.aread() + if content is None or content == b"": + print("Body: [No content]") + else: + try: + print(f"Body:\n {content.decode('utf-8')}") + except Exception: # pylint: disable=broad-exception-caught + print(f"Body (raw):\n {content!r}") + print("\n") + + return response + + def _log_request_body(self, request: httpx.Request) -> None: + """Log request body content safely, handling binary data and streaming content. + + :param request: The HTTP request object containing the body to log + :type request: httpx.Request + """ + + # Check content-type header to identify file uploads + content_type = request.headers.get("content-type", "").lower() + if "multipart/form-data" in content_type: + print("Body: [Multipart form data - file upload, not logged]") + return + + # Safely check if content exists without accessing it + if not hasattr(request, "content"): + print("Body: [No content attribute]") + return + + # Very careful content access - wrap in try-catch immediately + try: + content = request.content + except Exception as access_error: # pylint: disable=broad-exception-caught + print(f"Body: [Cannot access content: {access_error}]") + return + + if content is None or content == b"": + print("Body: [No content]") + return + + try: + print(f"Body:\n {content.decode('utf-8')}") + except Exception: # pylint: disable=broad-exception-caught + print(f"Body (raw):\n {content!r}") + + __all__: List[str] = ["AIProjectClient"] # Add all objects you want publicly available to users at this package level diff --git a/sdk/ai/azure-ai-projects/tests/responses/test_responses_with_http_client_override.py b/sdk/ai/azure-ai-projects/tests/responses/test_responses_with_http_client_override.py new file mode 100644 index 000000000000..46da53d9994f --- /dev/null +++ b/sdk/ai/azure-ai-projects/tests/responses/test_responses_with_http_client_override.py @@ -0,0 +1,104 @@ +# pylint: disable=line-too-long,useless-suppression +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +""" +Tests to verify that a custom http_client can be passed to get_openai_client() +and that the returned OpenAI client uses it instead of the default one. +""" + +import os +import pytest +import httpx +from typing import Any +from azure.core.credentials import TokenCredential +from azure.ai.projects import AIProjectClient + + +class DummyTokenCredential(TokenCredential): + """A dummy credential that returns None for testing purposes.""" + + def get_token(self, *scopes: str, **kwargs: Any): # type: ignore[override] + return None + + +@pytest.fixture(autouse=True) +def patch_openai(monkeypatch): + """Ensure no real network/token calls are made during the test.""" + monkeypatch.setattr("azure.ai.projects._patch.get_bearer_token_provider", lambda *_, **__: "token-provider") + + +@pytest.mark.skipif( + os.environ.get("AZURE_AI_PROJECTS_CONSOLE_LOGGING", "false").lower() == "true", + reason="Test skipped because AZURE_AI_PROJECTS_CONSOLE_LOGGING is set to 'true'", +) +class TestResponsesWithHttpClientOverride: + """Tests for custom http_client override in get_openai_client().""" + + def test_custom_http_client_is_used(self): + """ + Test that a custom http_client passed to get_openai_client() is actually used + by the returned OpenAI client when making API calls. + """ + # Track whether our custom http_client was invoked + request_intercepted = {"called": False, "request": None} + + class TrackingTransport(httpx.BaseTransport): + """Custom transport that tracks requests and returns mock responses.""" + + def handle_request(self, request: httpx.Request) -> httpx.Response: + # Mark that our custom transport was called + request_intercepted["called"] = True + request_intercepted["request"] = request + + # Return a mock response for the OpenAI responses.create() call + return httpx.Response( + 200, + request=request, + json={ + "id": "resp_test_123", + "output": [ + { + "type": "message", + "id": "msg_test_123", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "This is a test response from the mock.", + } + ], + } + ], + }, + ) + + # Create a custom http_client with our tracking transport + custom_http_client = httpx.Client(transport=TrackingTransport()) + + # Create the AIProjectClient + project_client = AIProjectClient( + endpoint="https://example.com/api/projects/test", + credential=DummyTokenCredential(), + ) + + # Get an OpenAI client with our custom http_client + openai_client = project_client.get_openai_client(http_client=custom_http_client) + + # Make an API call + response = openai_client.responses.create( + model="gpt-4o", + input="Test input", + ) + + # Verify the custom http_client was used + assert request_intercepted["called"], "Custom http_client was not used for the request" + assert request_intercepted["request"] is not None, "Request was not captured" + + # Verify the request was made to the expected endpoint + assert "/openai/v1/responses" in str(request_intercepted["request"].url) + + # Verify we got a valid response + assert response.id == "resp_test_123" + assert response.output_text == "This is a test response from the mock." diff --git a/sdk/ai/azure-ai-projects/tests/responses/test_responses_with_http_client_override_async.py b/sdk/ai/azure-ai-projects/tests/responses/test_responses_with_http_client_override_async.py new file mode 100644 index 000000000000..ea6ab052613b --- /dev/null +++ b/sdk/ai/azure-ai-projects/tests/responses/test_responses_with_http_client_override_async.py @@ -0,0 +1,108 @@ +# pylint: disable=line-too-long,useless-suppression +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +""" +Tests to verify that a custom http_client can be passed to get_openai_client() +and that the returned AsyncOpenAI client uses it instead of the default one. +""" + +import os +import pytest +import httpx +from typing import Any +from azure.core.credentials_async import AsyncTokenCredential +from azure.ai.projects.aio import AIProjectClient + + +class DummyAsyncTokenCredential(AsyncTokenCredential): + """A dummy async credential that returns None for testing purposes.""" + + async def get_token(self, *scopes: str, **kwargs: Any): # type: ignore[override] + return None + + async def close(self) -> None: + pass + + +@pytest.fixture(autouse=True) +def patch_openai(monkeypatch): + """Ensure no real network/token calls are made during the test.""" + monkeypatch.setattr("azure.ai.projects.aio._patch.get_bearer_token_provider", lambda *_, **__: "token-provider") + + +@pytest.mark.skipif( + os.environ.get("AZURE_AI_PROJECTS_CONSOLE_LOGGING", "false").lower() == "true", + reason="Test skipped because AZURE_AI_PROJECTS_CONSOLE_LOGGING is set to 'true'", +) +class TestResponsesWithHttpClientOverrideAsync: + """Tests for custom http_client override in async get_openai_client().""" + + @pytest.mark.asyncio + async def test_custom_http_client_is_used(self): + """ + Test that a custom http_client passed to get_openai_client() is actually used + by the returned AsyncOpenAI client when making API calls. + """ + # Track whether our custom http_client was invoked + request_intercepted = {"called": False, "request": None} + + class TrackingTransport(httpx.AsyncBaseTransport): + """Custom async transport that tracks requests and returns mock responses.""" + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + # Mark that our custom transport was called + request_intercepted["called"] = True + request_intercepted["request"] = request + + # Return a mock response for the OpenAI responses.create() call + return httpx.Response( + 200, + request=request, + json={ + "id": "resp_test_123", + "output": [ + { + "type": "message", + "id": "msg_test_123", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "This is a test response from the mock.", + } + ], + } + ], + }, + ) + + # Create a custom http_client with our tracking transport + custom_http_client = httpx.AsyncClient(transport=TrackingTransport()) + + # Create the AIProjectClient + project_client = AIProjectClient( + endpoint="https://example.com/api/projects/test", + credential=DummyAsyncTokenCredential(), + ) + + # Get an AsyncOpenAI client with our custom http_client + openai_client = project_client.get_openai_client(http_client=custom_http_client) + + # Make an API call + response = await openai_client.responses.create( + model="gpt-4o", + input="Test input", + ) + + # Verify the custom http_client was used + assert request_intercepted["called"], "Custom http_client was not used for the request" + assert request_intercepted["request"] is not None, "Request was not captured" + + # Verify the request was made to the expected endpoint + assert "/openai/v1/responses" in str(request_intercepted["request"].url) + + # Verify we got a valid response + assert response.id == "resp_test_123" + assert response.output_text == "This is a test response from the mock."