|
7 | 7 | import json |
8 | 8 | import logging |
9 | 9 | import mimetypes |
10 | | -from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast |
| 10 | +from contextlib import asynccontextmanager |
| 11 | +from typing import Any, AsyncGenerator, AsyncIterator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast |
11 | 12 |
|
12 | 13 | import openai |
13 | 14 | from openai.types.chat.parsed_chat_completion import ParsedChatCompletion |
@@ -55,16 +56,39 @@ class OpenAIConfig(TypedDict, total=False): |
55 | 56 | model_id: str |
56 | 57 | params: Optional[dict[str, Any]] |
57 | 58 |
|
58 | | - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: |
| 59 | + def __init__( |
| 60 | + self, |
| 61 | + client: Optional[Client] = None, |
| 62 | + client_args: Optional[dict[str, Any]] = None, |
| 63 | + **model_config: Unpack[OpenAIConfig], |
| 64 | + ) -> None: |
59 | 65 | """Initialize provider instance. |
60 | 66 |
|
61 | 67 | Args: |
62 | | - client_args: Arguments for the OpenAI client. |
| 68 | + client: Pre-configured OpenAI-compatible client to reuse across requests. |
| 69 | + When provided, this client will be reused for all requests and will NOT be closed |
| 70 | + by the model. The caller is responsible for managing the client lifecycle. |
| 71 | + This is useful for: |
| 72 | + - Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI) |
| 73 | + - Reusing connection pools within a single event loop/worker |
| 74 | + - Centralizing observability, retries, and networking policy |
| 75 | + - Pointing to custom model gateways |
| 76 | + Note: The client should not be shared across different asyncio event loops. |
| 77 | + client_args: Arguments for the OpenAI client (legacy approach). |
63 | 78 | For a complete list of supported arguments, see https://pypi.org/project/openai/. |
64 | 79 | **model_config: Configuration options for the OpenAI model. |
| 80 | +
|
| 81 | + Raises: |
| 82 | + ValueError: If both `client` and `client_args` are provided. |
65 | 83 | """ |
66 | 84 | validate_config_keys(model_config, self.OpenAIConfig) |
67 | 85 | self.config = dict(model_config) |
| 86 | + |
| 87 | + # Validate that only one client configuration method is provided |
| 88 | + if client is not None and client_args is not None and len(client_args) > 0: |
| 89 | + raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") |
| 90 | + |
| 91 | + self._custom_client = client |
68 | 92 | self.client_args = client_args or {} |
69 | 93 |
|
70 | 94 | logger.debug("config=<%s> | initializing", self.config) |
@@ -422,6 +446,34 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: |
422 | 446 | case _: |
423 | 447 | raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") |
424 | 448 |
|
| 449 | + @asynccontextmanager |
| 450 | + async def _get_client(self) -> AsyncIterator[Any]: |
| 451 | + """Get an OpenAI client for making requests. |
| 452 | +
|
| 453 | + This context manager handles client lifecycle management: |
| 454 | + - If an injected client was provided during initialization, it yields that client |
| 455 | + without closing it (caller manages lifecycle). |
| 456 | + - Otherwise, creates a new AsyncOpenAI client from client_args and automatically |
| 457 | + closes it when the context exits. |
| 458 | +
|
| 459 | + Note: We create a new client per request to avoid connection sharing in the underlying |
| 460 | + httpx client, as the asyncio event loop does not allow connections to be shared. |
| 461 | + For more details, see https://github.com/encode/httpx/discussions/2959. |
| 462 | +
|
| 463 | + Yields: |
| 464 | + Client: An OpenAI-compatible client instance. |
| 465 | + """ |
| 466 | + if self._custom_client is not None: |
| 467 | + # Use the injected client (caller manages lifecycle) |
| 468 | + yield self._custom_client |
| 469 | + else: |
| 470 | + # Create a new client from client_args |
| 471 | + # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying |
| 472 | + # httpx client. The asyncio event loop does not allow connections to be shared. For more details, please |
| 473 | + # refer to https://github.com/encode/httpx/discussions/2959. |
| 474 | + async with openai.AsyncOpenAI(**self.client_args) as client: |
| 475 | + yield client |
| 476 | + |
425 | 477 | @override |
426 | 478 | async def stream( |
427 | 479 | self, |
@@ -457,7 +509,7 @@ async def stream( |
457 | 509 | # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx |
458 | 510 | # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to |
459 | 511 | # https://github.com/encode/httpx/discussions/2959. |
460 | | - async with openai.AsyncOpenAI(**self.client_args) as client: |
| 512 | + async with self._get_client() as client: |
461 | 513 | try: |
462 | 514 | response = await client.chat.completions.create(**request) |
463 | 515 | except openai.BadRequestError as e: |
@@ -576,7 +628,7 @@ async def structured_output( |
576 | 628 | # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx |
577 | 629 | # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to |
578 | 630 | # https://github.com/encode/httpx/discussions/2959. |
579 | | - async with openai.AsyncOpenAI(**self.client_args) as client: |
| 631 | + async with self._get_client() as client: |
580 | 632 | try: |
581 | 633 | response: ParsedChatCompletion = await client.beta.chat.completions.parse( |
582 | 634 | model=self.get_config()["model_id"], |
|
0 commit comments