Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
import copy
from functools import cached_property
import logging
from typing import Any
from typing import AsyncGenerator
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

from google.genai import Client
from google.genai import types
from google.genai.errors import ClientError
from pydantic import Field
from typing_extensions import override

from ..utils._client_labels_utils import get_client_labels
Expand All @@ -40,8 +41,6 @@
from .llm_response import LlmResponse

if TYPE_CHECKING:
from google.genai import Client

from .llm_request import LlmRequest

logger = logging.getLogger('google_adk.' + __name__)
Expand Down Expand Up @@ -86,6 +85,8 @@ class Gemini(BaseLlm):
model: The name of the Gemini model.
use_interactions_api: Whether to use the interactions API for model
invocation.
custom_api_client: Custom client for standard API calls.
custom_live_api_client: Custom client for Live API streaming.
"""

model: str = 'gemini-2.5-flash'
Expand Down Expand Up @@ -127,6 +128,49 @@ class Gemini(BaseLlm):
```
"""

custom_api_client: Optional[Client] = Field(
default=None, exclude=True, frozen=True, repr=False
)
"""Custom API client for generate_content operations.

Allows injecting a custom Google GenAI Client instance to override the
default api_client. Useful for testing, custom authentication, or using
different configurations. When set, this client is used for all
generate_content_async and interactions API calls.

Sample:
```python
from google.genai import Client

custom_client = Client(api_key="custom_key")
agent = Agent(
model=Gemini(custom_api_client=custom_client)
)
"""

custom_live_api_client: Optional[Client] = Field(
default=None, exclude=True, frozen=True, repr=False
)
"""Custom client for Live API (bi-directional streaming) operations.

Allows injecting a custom Google GenAI Client for ADK Live streaming. When
set, this client is used for all live.connect() calls. The client should be
configured with the appropriate API version for your backend (v1beta1 for
Vertex AI, v1alpha for Gemini API).

Sample:
```python
from google.genai import Client, types

live_client = Client(
http_options=types.HttpOptions(api_version="v1beta1")
)
agent = Agent(
model=Gemini(custom_live_api_client=live_client)
)
```
"""

@classmethod
@override
def supported_models(cls) -> list[str]:
Expand Down Expand Up @@ -298,7 +342,8 @@ def api_client(self) -> Client:
Returns:
The api client.
"""
from google.genai import Client
if self.custom_api_client:
return self.custom_api_client

return Client(
http_options=types.HttpOptions(
Expand Down Expand Up @@ -335,7 +380,8 @@ def _live_api_version(self) -> str:

@cached_property
def _live_api_client(self) -> Client:
from google.genai import Client
if self.custom_live_api_client:
return self.custom_live_api_client

return Client(
http_options=types.HttpOptions(
Expand Down
73 changes: 73 additions & 0 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2139,3 +2139,76 @@ async def __aexit__(self, *args):
# Verify the final speech_config is still None
assert config_arg.speech_config is None
assert isinstance(connection, GeminiLlmConnection)


@pytest.mark.asyncio
async def test_custom_api_client_is_used_for_generate_content(
llm_request, generate_content_response
):
"""Test that custom_api_client is used when provided."""
from google.genai import Client

# Create a mock custom client with proper spec
custom_client = mock.MagicMock(spec=Client)

# Create a mock coroutine that returns the generate_content_response
async def mock_coro():
return generate_content_response

custom_client.aio.models.generate_content.return_value = mock_coro()

# Create Gemini instance with custom_api_client
gemini_llm = Gemini(model="gemini-1.5-flash", custom_api_client=custom_client)

# Execute generate_content_async
responses = [
resp
async for resp in gemini_llm.generate_content_async(
llm_request, stream=False
)
]

# Verify that the custom client was used
assert len(responses) == 1
assert isinstance(responses[0], LlmResponse)
custom_client.aio.models.generate_content.assert_called_once()

# Verify that api_client property returns the custom client
assert gemini_llm.api_client is custom_client


@pytest.mark.asyncio
async def test_custom_live_api_client_is_used_for_connect(llm_request):
"""Test that custom_live_api_client is used when provided."""
from google.genai import Client

# Create a mock custom live client with proper spec
custom_live_client = mock.MagicMock(spec=Client)
mock_live_session = mock.AsyncMock()

class MockLiveConnect:

async def __aenter__(self):
return mock_live_session

async def __aexit__(self, *args):
pass

custom_live_client.aio.live.connect.return_value = MockLiveConnect()

# Setup live connect config
llm_request.live_connect_config = types.LiveConnectConfig()

# Create Gemini instance with custom_live_api_client
gemini_llm = Gemini(
model="gemini-1.5-flash", custom_live_api_client=custom_live_client
)

# Execute connect
async with gemini_llm.connect(llm_request) as connection:
# Verify that the custom live client was used
custom_live_client.aio.live.connect.assert_called_once()
assert isinstance(connection, GeminiLlmConnection)

# Verify that _live_api_client property returns the custom client
assert gemini_llm._live_api_client is custom_live_client