Skip to content

Commit 3c6a6a9

Browse files
Merge branch 'main' into api-check
2 parents 548caba + c1c24ef commit 3c6a6a9

File tree

9 files changed

+267
-20
lines changed

9 files changed

+267
-20
lines changed

.github/workflows/integration-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
aws-region: us-east-1
5353
mask-aws-account-id: true
5454
- name: Checkout head commit
55-
uses: actions/checkout@v5
55+
uses: actions/checkout@v6
5656
with:
5757
ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo
5858
persist-credentials: false # Don't persist credentials for subsequent actions

.github/workflows/pr-and-push.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ jobs:
1717
contents: read
1818
with:
1919
ref: ${{ github.event.pull_request.head.sha }}
20+
secrets:
21+
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
2022

2123
check-api:
2224
runs-on: ubuntu-latest
@@ -35,3 +37,4 @@ jobs:
3537
echo "Breaking API changes detected"
3638
exit 1
3739
fi
40+

.github/workflows/pypi-publish-on-release.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ jobs:
1212
contents: read
1313
with:
1414
ref: ${{ github.event.release.target_commitish }}
15+
secrets:
16+
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
1517

1618
build:
1719
name: Build distribution 📦
@@ -22,7 +24,7 @@ jobs:
2224
runs-on: ubuntu-latest
2325

2426
steps:
25-
- uses: actions/checkout@v5
27+
- uses: actions/checkout@v6
2628
with:
2729
persist-credentials: false
2830

@@ -52,7 +54,7 @@ jobs:
5254
hatch build
5355
5456
- name: Store the distribution packages
55-
uses: actions/upload-artifact@v4
57+
uses: actions/upload-artifact@v6
5658
with:
5759
name: python-package-distributions
5860
path: dist/
@@ -74,7 +76,7 @@ jobs:
7476

7577
steps:
7678
- name: Download all the dists
77-
uses: actions/download-artifact@v5
79+
uses: actions/download-artifact@v7
7880
with:
7981
name: python-package-distributions
8082
path: dist/

.github/workflows/test-lint.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ on:
66
ref:
77
required: true
88
type: string
9+
secrets:
10+
CODECOV_TOKEN:
11+
required: false
912

1013
jobs:
1114
unit-test:
@@ -51,7 +54,7 @@ jobs:
5154
LOG_LEVEL: DEBUG
5255
steps:
5356
- name: Checkout code
54-
uses: actions/checkout@v5
57+
uses: actions/checkout@v6
5558
with:
5659
ref: ${{ inputs.ref }} # Explicitly define which commit to check out
5760
persist-credentials: false # Don't persist credentials for subsequent actions
@@ -92,7 +95,7 @@ jobs:
9295
contents: read
9396
steps:
9497
- name: Checkout code
95-
uses: actions/checkout@v5
98+
uses: actions/checkout@v6
9699
with:
97100
ref: ${{ inputs.ref }}
98101
persist-credentials: false

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ dev = [
8787
"hatch>=1.0.0,<2.0.0",
8888
"moto>=5.1.0,<6.0.0",
8989
"mypy>=1.15.0,<2.0.0",
90-
"pre-commit>=3.2.0,<4.4.0",
90+
"pre-commit>=3.2.0,<4.6.0",
9191
"pytest>=8.0.0,<9.0.0",
9292
"pytest-cov>=7.0.0,<8.0.0",
93-
"pytest-asyncio>=1.0.0,<1.3.0",
93+
"pytest-asyncio>=1.0.0,<1.4.0",
9494
"pytest-xdist>=3.0.0,<4.0.0",
9595
"ruff>=0.13.0,<0.14.0",
9696
]
@@ -144,7 +144,7 @@ extra-args = ["-n", "auto", "-vv"]
144144
dependencies = [
145145
"pytest>=8.0.0,<9.0.0",
146146
"pytest-cov>=7.0.0,<8.0.0",
147-
"pytest-asyncio>=1.0.0,<1.3.0",
147+
"pytest-asyncio>=1.0.0,<1.4.0",
148148
"pytest-xdist>=3.0.0,<4.0.0",
149149
"moto>=5.1.0,<6.0.0",
150150
]
@@ -166,7 +166,7 @@ features = ["all"]
166166
dependencies = [
167167
"commitizen>=4.4.0,<5.0.0",
168168
"hatch>=1.0.0,<2.0.0",
169-
"pre-commit>=3.2.0,<4.4.0",
169+
"pre-commit>=3.2.0,<4.6.0",
170170
]
171171

172172

src/strands/models/gemini.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,44 @@ class GeminiConfig(TypedDict, total=False):
5454
def __init__(
5555
self,
5656
*,
57+
client: Optional[genai.Client] = None,
5758
client_args: Optional[dict[str, Any]] = None,
5859
**model_config: Unpack[GeminiConfig],
5960
) -> None:
6061
"""Initialize provider instance.
6162
6263
Args:
64+
client: Pre-configured Gemini client to reuse across requests.
65+
When provided, this client will be reused for all requests and will NOT be closed
66+
by the model. The caller is responsible for managing the client lifecycle.
67+
This is useful for:
68+
- Injecting custom client wrappers
69+
- Reusing connection pools within a single event loop/worker
70+
- Centralizing observability, retries, and networking policy
71+
Note: The client should not be shared across different asyncio event loops.
6372
client_args: Arguments for the underlying Gemini client (e.g., api_key).
6473
For a complete list of supported arguments, see https://googleapis.github.io/python-genai/.
6574
**model_config: Configuration options for the Gemini model.
75+
76+
Raises:
77+
ValueError: If both `client` and `client_args` are provided.
6678
"""
6779
validate_config_keys(model_config, GeminiModel.GeminiConfig)
6880
self.config = GeminiModel.GeminiConfig(**model_config)
6981

82+
# Validate that only one client configuration method is provided
83+
if client is not None and client_args is not None and len(client_args) > 0:
84+
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
85+
86+
self._custom_client = client
87+
self.client_args = client_args or {}
88+
7089
# Validate gemini_tools if provided
7190
if "gemini_tools" in self.config:
7291
self._validate_gemini_tools(self.config["gemini_tools"])
7392

7493
logger.debug("config=<%s> | initializing", self.config)
7594

76-
self.client_args = client_args or {}
77-
7895
@override
7996
def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override]
8097
"""Update the Gemini model configuration with the provided arguments.
@@ -97,6 +114,24 @@ def get_config(self) -> GeminiConfig:
97114
"""
98115
return self.config
99116

117+
def _get_client(self) -> genai.Client:
118+
"""Get a Gemini client for making requests.
119+
120+
This method handles client lifecycle management:
121+
- If an injected client was provided during initialization, it returns that client
122+
without managing its lifecycle (caller is responsible for cleanup).
123+
- Otherwise, creates a new genai.Client from client_args.
124+
125+
Returns:
126+
genai.Client: A Gemini client instance.
127+
"""
128+
if self._custom_client is not None:
129+
# Use the injected client (caller manages lifecycle)
130+
return self._custom_client
131+
else:
132+
# Create a new client from client_args
133+
return genai.Client(**self.client_args)
134+
100135
def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part:
101136
"""Format content block into a Gemini part instance.
102137
@@ -382,7 +417,8 @@ async def stream(
382417
"""
383418
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))
384419

385-
client = genai.Client(**self.client_args).aio
420+
client = self._get_client().aio
421+
386422
try:
387423
response = await client.models.generate_content_stream(**request)
388424

@@ -465,7 +501,7 @@ async def structured_output(
465501
"response_schema": output_model.model_json_schema(),
466502
}
467503
request = self._format_request(prompt, None, system_prompt, params)
468-
client = genai.Client(**self.client_args).aio
504+
client = self._get_client().aio
469505
response = await client.models.generate_content(**request)
470506
yield {"output": output_model.model_validate(response.parsed)}
471507

src/strands/models/openai.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import json
88
import logging
99
import mimetypes
10-
from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
10+
from contextlib import asynccontextmanager
11+
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
1112

1213
import openai
1314
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
@@ -55,16 +56,39 @@ class OpenAIConfig(TypedDict, total=False):
5556
model_id: str
5657
params: Optional[dict[str, Any]]
5758

58-
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
59+
def __init__(
60+
self,
61+
client: Optional[Client] = None,
62+
client_args: Optional[dict[str, Any]] = None,
63+
**model_config: Unpack[OpenAIConfig],
64+
) -> None:
5965
"""Initialize provider instance.
6066
6167
Args:
62-
client_args: Arguments for the OpenAI client.
68+
client: Pre-configured OpenAI-compatible client to reuse across requests.
69+
When provided, this client will be reused for all requests and will NOT be closed
70+
by the model. The caller is responsible for managing the client lifecycle.
71+
This is useful for:
72+
- Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI)
73+
- Reusing connection pools within a single event loop/worker
74+
- Centralizing observability, retries, and networking policy
75+
- Pointing to custom model gateways
76+
Note: The client should not be shared across different asyncio event loops.
77+
client_args: Arguments for the OpenAI client (legacy approach).
6378
For a complete list of supported arguments, see https://pypi.org/project/openai/.
6479
**model_config: Configuration options for the OpenAI model.
80+
81+
Raises:
82+
ValueError: If both `client` and `client_args` are provided.
6583
"""
6684
validate_config_keys(model_config, self.OpenAIConfig)
6785
self.config = dict(model_config)
86+
87+
# Validate that only one client configuration method is provided
88+
if client is not None and client_args is not None and len(client_args) > 0:
89+
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
90+
91+
self._custom_client = client
6892
self.client_args = client_args or {}
6993

7094
logger.debug("config=<%s> | initializing", self.config)
@@ -422,6 +446,34 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
422446
case _:
423447
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
424448

449+
@asynccontextmanager
450+
async def _get_client(self) -> AsyncIterator[Any]:
451+
"""Get an OpenAI client for making requests.
452+
453+
This context manager handles client lifecycle management:
454+
- If an injected client was provided during initialization, it yields that client
455+
without closing it (caller manages lifecycle).
456+
- Otherwise, creates a new AsyncOpenAI client from client_args and automatically
457+
closes it when the context exits.
458+
459+
Note: We create a new client per request to avoid connection sharing in the underlying
460+
httpx client, as the asyncio event loop does not allow connections to be shared.
461+
For more details, see https://github.com/encode/httpx/discussions/2959.
462+
463+
Yields:
464+
Client: An OpenAI-compatible client instance.
465+
"""
466+
if self._custom_client is not None:
467+
# Use the injected client (caller manages lifecycle)
468+
yield self._custom_client
469+
else:
470+
# Create a new client from client_args
471+
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
472+
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
473+
# refer to https://github.com/encode/httpx/discussions/2959.
474+
async with openai.AsyncOpenAI(**self.client_args) as client:
475+
yield client
476+
425477
@override
426478
async def stream(
427479
self,
@@ -457,7 +509,7 @@ async def stream(
457509
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
458510
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
459511
# https://github.com/encode/httpx/discussions/2959.
460-
async with openai.AsyncOpenAI(**self.client_args) as client:
512+
async with self._get_client() as client:
461513
try:
462514
response = await client.chat.completions.create(**request)
463515
except openai.BadRequestError as e:
@@ -576,7 +628,7 @@ async def structured_output(
576628
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
577629
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
578630
# https://github.com/encode/httpx/discussions/2959.
579-
async with openai.AsyncOpenAI(**self.client_args) as client:
631+
async with self._get_client() as client:
580632
try:
581633
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
582634
model=self.get_config()["model_id"],

tests/strands/models/test_gemini.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,3 +720,77 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap
720720

721721
assert "Gemini API returned non-JSON error" in caplog.text
722722
assert f"error_message=<{error_message}>" in caplog.text
723+
724+
725+
@pytest.mark.asyncio
726+
async def test_stream_with_injected_client(model_id, agenerator, alist):
727+
"""Test that stream works with an injected client and doesn't close it."""
728+
# Create a mock injected client
729+
mock_injected_client = unittest.mock.Mock()
730+
mock_injected_client.aio = unittest.mock.AsyncMock()
731+
732+
mock_injected_client.aio.models.generate_content_stream.return_value = agenerator(
733+
[
734+
genai.types.GenerateContentResponse(
735+
candidates=[
736+
genai.types.Candidate(
737+
content=genai.types.Content(
738+
parts=[genai.types.Part(text="Hello")],
739+
),
740+
finish_reason="STOP",
741+
),
742+
],
743+
usage_metadata=genai.types.GenerateContentResponseUsageMetadata(
744+
prompt_token_count=1,
745+
total_token_count=3,
746+
),
747+
),
748+
]
749+
)
750+
751+
# Create model with injected client
752+
model = GeminiModel(client=mock_injected_client, model_id=model_id)
753+
754+
messages = [{"role": "user", "content": [{"text": "test"}]}]
755+
response = model.stream(messages)
756+
tru_events = await alist(response)
757+
758+
# Verify events were generated
759+
assert len(tru_events) > 0
760+
761+
# Verify the injected client was used
762+
mock_injected_client.aio.models.generate_content_stream.assert_called_once()
763+
764+
765+
@pytest.mark.asyncio
766+
async def test_structured_output_with_injected_client(model_id, weather_output, alist):
767+
"""Test that structured_output works with an injected client and doesn't close it."""
768+
# Create a mock injected client
769+
mock_injected_client = unittest.mock.Mock()
770+
mock_injected_client.aio = unittest.mock.AsyncMock()
771+
772+
mock_injected_client.aio.models.generate_content.return_value = unittest.mock.Mock(
773+
parsed=weather_output.model_dump()
774+
)
775+
776+
# Create model with injected client
777+
model = GeminiModel(client=mock_injected_client, model_id=model_id)
778+
779+
messages = [{"role": "user", "content": [{"text": "Generate weather"}]}]
780+
stream = model.structured_output(type(weather_output), messages)
781+
events = await alist(stream)
782+
783+
# Verify output was generated
784+
assert len(events) == 1
785+
assert events[0] == {"output": weather_output}
786+
787+
# Verify the injected client was used
788+
mock_injected_client.aio.models.generate_content.assert_called_once()
789+
790+
791+
def test_init_with_both_client_and_client_args_raises_error():
792+
"""Test that providing both client and client_args raises ValueError."""
793+
mock_client = unittest.mock.Mock()
794+
795+
with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"):
796+
GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")

0 commit comments

Comments
 (0)