diff --git a/google/genai/_interactions/__init__.py b/google/genai/_interactions/__init__.py index 02d45e4d6..03f065d92 100644 --- a/google/genai/_interactions/__init__.py +++ b/google/genai/_interactions/__init__.py @@ -53,6 +53,7 @@ ) from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging +from ._client_adapter import GeminiNextGenAPIClientAdapter, AsyncGeminiNextGenAPIClientAdapter __all__ = [ "types", @@ -96,6 +97,8 @@ "DefaultHttpxClient", "DefaultAsyncHttpxClient", "DefaultAioHttpClient", + "AsyncGeminiNextGenAPIClientAdapter", + "GeminiNextGenAPIClientAdapter" ] if not _t.TYPE_CHECKING: diff --git a/google/genai/_interactions/_client.py b/google/genai/_interactions/_client.py index 4a67b877e..64bc99587 100644 --- a/google/genai/_interactions/_client.py +++ b/google/genai/_interactions/_client.py @@ -37,6 +37,7 @@ ) from ._utils import is_given, get_async_library from ._compat import cached_property +from ._models import FinalRequestOptions from ._version import __version__ from ._streaming import Stream as Stream, AsyncStream as AsyncStream from ._exceptions import APIStatusError @@ -45,6 +46,7 @@ SyncAPIClient, AsyncAPIClient, ) +from ._client_adapter import GeminiNextGenAPIClientAdapter, AsyncGeminiNextGenAPIClientAdapter if TYPE_CHECKING: from .resources import interactions @@ -66,6 +68,7 @@ class GeminiNextGenAPIClient(SyncAPIClient): # client options api_key: str | None api_version: str + client_adapter: GeminiNextGenAPIClientAdapter | None def __init__( self, @@ -81,6 +84,7 @@ def __init__( # We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: httpx.Client | None = None, + client_adapter: GeminiNextGenAPIClientAdapter | None = None, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -108,6 +112,8 @@ def __init__( if base_url is None: base_url = f"https://generativelanguage.googleapis.com" + self.client_adapter = client_adapter + super().__init__( version=__version__, base_url=base_url, @@ -159,13 +165,35 @@ def default_headers(self) -> dict[str, str | Omit]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: - if headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit): + if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit): + return + if self.api_key and headers.get("x-goog-api-key"): + return + if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit): return raise TypeError( '"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"' ) - + + @override + def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + if not self.client_adapter or not self.client_adapter.is_vertex_ai(): + return options + + headers = options.headers or {} + has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error + if has_auth: + return options + + adapted_headers = self.client_adapter.get_auth_headers() + if adapted_headers: + options.headers = { + **adapted_headers, + **headers + } + return options + def copy( self, *, @@ -179,6 +207,7 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, + client_adapter: GeminiNextGenAPIClientAdapter | None = None, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: """ @@ -212,6 +241,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + client_adapter=self.client_adapter or client_adapter, **_extra_kwargs, ) @@ -260,6 +290,7 @@ class AsyncGeminiNextGenAPIClient(AsyncAPIClient): # client options api_key: str | None api_version: str + client_adapter: AsyncGeminiNextGenAPIClientAdapter | None def __init__( self, @@ -275,6 +306,7 @@ def __init__( # We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details. http_client: httpx.AsyncClient | None = None, + client_adapter: AsyncGeminiNextGenAPIClientAdapter | None = None, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -302,6 +334,8 @@ def __init__( if base_url is None: base_url = f"https://generativelanguage.googleapis.com" + self.client_adapter = client_adapter + super().__init__( version=__version__, base_url=base_url, @@ -353,12 +387,34 @@ def default_headers(self) -> dict[str, str | Omit]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: - if headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit): + if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit): + return + if self.api_key and headers.get("x-goog-api-key"): + return + if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit): return raise TypeError( '"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"' ) + + @override + async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + if not self.client_adapter or not self.client_adapter.is_vertex_ai(): + return options + + headers = options.headers or {} + has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error + if has_auth: + return options + + adapted_headers = await self.client_adapter.async_get_auth_headers() + if adapted_headers: + options.headers = { + **adapted_headers, + **headers + } + return options def copy( self, @@ -373,6 +429,7 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, + client_adapter: AsyncGeminiNextGenAPIClientAdapter | None = None, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: """ @@ -406,6 +463,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + client_adapter=self.client_adapter or client_adapter, **_extra_kwargs, ) diff --git a/google/genai/_interactions/_client_adapter.py b/google/genai/_interactions/_client_adapter.py new file mode 100644 index 000000000..32c4750c2 --- /dev/null +++ b/google/genai/_interactions/_client_adapter.py @@ -0,0 +1,48 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from abc import ABC, abstractmethod + +__all__ = [ + "GeminiNextGenAPIClientAdapter", + "AsyncGeminiNextGenAPIClientAdapter" +] + +class BaseGeminiNextGenAPIClientAdapter(ABC): + @abstractmethod + def is_vertex_ai(self) -> bool: + ... + + @abstractmethod + def get_project(self) -> str | None: + ... + + @abstractmethod + def get_location(self) -> str | None: + ... + + +class AsyncGeminiNextGenAPIClientAdapter(BaseGeminiNextGenAPIClientAdapter): + @abstractmethod + async def async_get_auth_headers(self) -> dict[str, str] | None: + ... + + +class GeminiNextGenAPIClientAdapter(BaseGeminiNextGenAPIClientAdapter): + @abstractmethod + def get_auth_headers(self) -> dict[str, str] | None: + ... diff --git a/google/genai/client.py b/google/genai/client.py index 14373ac31..cec5cef4c 100644 --- a/google/genai/client.py +++ b/google/genai/client.py @@ -16,7 +16,7 @@ import asyncio import os from types import TracebackType -from typing import Optional, Union, cast +from typing import Optional, Union import google.auth import pydantic @@ -43,13 +43,65 @@ from . import _common -from ._interactions import AsyncGeminiNextGenAPIClient, DEFAULT_MAX_RETRIES, DefaultAioHttpClient, GeminiNextGenAPIClient -from ._interactions._models import FinalRequestOptions -from ._interactions._types import Headers -from ._interactions._utils import is_given +from ._interactions import AsyncGeminiNextGenAPIClient, DEFAULT_MAX_RETRIES, GeminiNextGenAPIClient +from . import _interactions + from ._interactions.resources import AsyncInteractionsResource as AsyncNextGenInteractionsResource, InteractionsResource as NextGenInteractionsResource _interactions_experimental_warned = False +class AsyncGeminiNextGenAPIClientAdapter(_interactions.AsyncGeminiNextGenAPIClientAdapter): + """Adapter for the Gemini NextGen API Client.""" + def __init__(self, api_client: BaseApiClient): + self._api_client = api_client + + def is_vertex_ai(self) -> bool: + return self._api_client.vertexai or False + + def get_project(self) -> str | None: + return self._api_client.project + + def get_location(self) -> str | None: + return self._api_client.location + + async def async_get_auth_headers(self) -> dict[str, str]: + if self._api_client.api_key: + return {"x-goog-api-key": self._api_client.api_key} + access_token = await self._api_client._async_access_token() + headers = { + "Authorization": f"Bearer {access_token}", + } + if creds := self._api_client._credentials: + if creds.quota_project_id: + headers["x-goog-user-project"] = creds.quota_project_id + return headers + + +class GeminiNextGenAPIClientAdapter(_interactions.GeminiNextGenAPIClientAdapter): + """Adapter for the Gemini NextGen API Client.""" + def __init__(self, api_client: BaseApiClient): + self._api_client = api_client + + def is_vertex_ai(self) -> bool: + return self._api_client.vertexai or False + + def get_project(self) -> str | None: + return self._api_client.project + + def get_location(self) -> str | None: + return self._api_client.location + + def get_auth_headers(self) -> dict[str, str]: + if self._api_client.api_key: + return {"x-goog-api-key": self._api_client.api_key} + access_token = self._api_client._access_token() + headers = { + "Authorization": f"Bearer {access_token}", + } + if creds := self._api_client._credentials: + if creds.quota_project_id: + headers["x-goog-user-project"] = creds.quota_project_id + return headers + class AsyncClient: """Client for making asynchronous (non-blocking) requests.""" @@ -122,6 +174,7 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient: # uSDk expects ms, nextgen uses a httpx Timeout -> expects seconds. timeout=http_opts.timeout / 1000 if http_opts.timeout else None, max_retries=max_retries, + client_adapter=AsyncGeminiNextGenAPIClientAdapter(self._api_client) ) client = self._nextgen_client_instance @@ -130,30 +183,6 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient: client._vertex_project = self._api_client.project client._vertex_location = self._api_client.location - async def prepare_options(options: FinalRequestOptions) -> FinalRequestOptions: - headers = {} - if is_given(options.headers): - headers = {**options.headers} - - headers['Authorization'] = f'Bearer {await self._api_client._async_access_token()}' - if ( - self._api_client._credentials - and self._api_client._credentials.quota_project_id - ): - headers['x-goog-user-project'] = ( - self._api_client._credentials.quota_project_id - ) - options.headers = headers - - return options - - if self._api_client.project or self._api_client.location: - client._prepare_options = prepare_options # type: ignore[method-assign] - - def validate_headers(headers: Headers, custom_headers: Headers) -> None: - return - - client._validate_headers = validate_headers # type: ignore[method-assign] return self._nextgen_client_instance @property @@ -279,6 +308,7 @@ class DebugConfig(pydantic.BaseModel): ) + class Client: """Client for making synchronous requests. @@ -492,39 +522,15 @@ def _nextgen_client(self) -> GeminiNextGenAPIClient: # uSDk expects ms, nextgen uses a httpx Timeout -> expects seconds. timeout=http_opts.timeout / 1000 if http_opts.timeout else None, max_retries=max_retries, + client_adapter=GeminiNextGenAPIClientAdapter(self._api_client), ) client = self._nextgen_client_instance - if self.vertexai: + if self._api_client.vertexai: client._is_vertex = True client._vertex_project = self._api_client.project client._vertex_location = self._api_client.location - def prepare_options(options: FinalRequestOptions) -> FinalRequestOptions: - headers = {} - if is_given(options.headers): - headers = {**options.headers} - options.headers = headers - - headers['Authorization'] = f'Bearer {self._api_client._access_token()}' - if ( - self._api_client._credentials - and self._api_client._credentials.quota_project_id - ): - headers['x-goog-user-project'] = ( - self._api_client._credentials.quota_project_id - ) - - return options - - if self._api_client.project or self._api_client.location: - client._prepare_options = prepare_options # type: ignore[method-assign] - - def validate_headers(headers: Headers, custom_headers: Headers) -> None: - return - - client._validate_headers = validate_headers # type: ignore[method-assign] - return self._nextgen_client_instance @property diff --git a/google/genai/tests/interactions/test_auth.py b/google/genai/tests/interactions/test_auth.py index 7740f9fce..5c92a225e 100644 --- a/google/genai/tests/interactions/test_auth.py +++ b/google/genai/tests/interactions/test_auth.py @@ -214,7 +214,7 @@ def get_token(): headers = mock_send.call_args_list[i][0][0].headers assert headers['authorization'] == f'Bearer {token_values[i]}' -@pytest.mark.xfail(reason="extra_headers don't override default auth") + def test_interactions_vertex_extra_headers_override(): from ..._api_client import BaseApiClient from httpx import Client as HTTPClient @@ -358,8 +358,6 @@ async def test_async_interactions_vertex_auth_header(): @pytest.mark.asyncio async def test_async_interactions_vertex_key_no_auth_header(): from ..._api_client import BaseApiClient - from ..._api_client import AsyncHttpxClient - creds = mock.Mock() client = Client(vertexai=True, api_key='test-api-key') with ( @@ -436,7 +434,6 @@ async def get_token(): headers = mock_send.call_args_list[i][0][0].headers assert headers['authorization'] == f'Bearer {token_values[i]}' -@pytest.mark.xfail(reason="extra_headers don't override default auth") @pytest.mark.asyncio async def test_async_interactions_vertex_extra_headers_override(): from ..._api_client import BaseApiClient diff --git a/google/genai/tests/interactions/test_integration.py b/google/genai/tests/interactions/test_integration.py index e2427776b..13b76f5c0 100644 --- a/google/genai/tests/interactions/test_integration.py +++ b/google/genai/tests/interactions/test_integration.py @@ -55,6 +55,7 @@ def test_client_timeout(): http_client=mock.ANY, timeout=5.0, max_retries=mock.ANY, + client_adapter=mock.ANY, ) @@ -79,4 +80,5 @@ async def test_async_client_timeout(): http_client=mock.ANY, timeout=5.0, max_retries=mock.ANY, + client_adapter=mock.ANY, )