diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/pyproject.toml b/pyproject.toml index 51712fa12d..f57e1140e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -211,6 +211,7 @@ known_third_party = ["google.adk"] [tool.pytest.ini_options] testpaths = ["tests"] +pythonpath = "src" asyncio_default_fixture_loop_scope = "function" asyncio_mode = "auto" diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index ec7c75716c..e18026d508 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -7,11 +7,6 @@ # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations from typing import TYPE_CHECKING @@ -22,6 +17,7 @@ from .auth_schemes import AuthSchemeType from .auth_schemes import OpenIdConnectWithConfig from .auth_tool import AuthConfig +from .credential_manager import CredentialManager from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger if TYPE_CHECKING: @@ -48,10 +44,14 @@ async def exchange_auth_token( self, ) -> AuthCredential: exchanger = OAuth2CredentialExchanger() - exchange_result = await exchanger.exchange( - self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme - ) - return exchange_result.credential + + credential = self.auth_config.exchanged_auth_credential + + with CredentialManager.restore_client_secret(credential): + res = await exchanger.exchange( + credential, self.auth_config.auth_scheme + ) + return res.credential async def parse_and_store_auth_response(self, state: State) -> None: @@ -185,9 +185,15 @@ def generate_auth_uri( ) scopes = list(scopes.keys()) + client_id = auth_credential.oauth2.client_id + + client_secret = CredentialManager.get_client_secret(client_id) + if not client_secret: + client_secret = auth_credential.oauth2.client_secret + client = OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, + client_id, + client_secret, scope=" ".join(scopes), redirect_uri=auth_credential.oauth2.redirect_uri, ) diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py index cc8a244e71..932b9bf78f 100644 --- a/src/google/adk/auth/auth_tool.py +++ b/src/google/adk/auth/auth_tool.py @@ -113,7 +113,9 @@ def get_credential_key(self): ) auth_credential = self.raw_auth_credential - if auth_credential and auth_credential.model_extra: + if auth_credential and ( + auth_credential.model_extra or auth_credential.oauth2 + ): auth_credential = auth_credential.model_copy(deep=True) auth_credential.model_extra.clear() if auth_credential and auth_credential.oauth2: diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index d037a43e6b..8bdefb6f8a 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -14,13 +14,13 @@ from __future__ import annotations +import contextlib import logging from typing import Optional from fastapi.openapi.models import OAuth2 from ..agents.callback_context import CallbackContext -from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger from ..utils.feature_decorator import experimental from .auth_credential import AuthCredential from .auth_credential import AuthCredentialTypes @@ -76,11 +76,23 @@ class CredentialManager: ``` """ + # A map to store client secrets in memory. Key is client_id, value is client_secret + _CLIENT_SECRETS: dict[str, str] = {} + def __init__( self, auth_config: AuthConfig, ): - self._auth_config = auth_config + # We deep copy the auth_config to avoid modifying the original object passed + # by the user. This allows for safe redaction of sensitive information without + # causing side effects. + + self._auth_config = auth_config.model_copy(deep=True) + + # Secure the client secret + self._secure_client_secret(self._auth_config.raw_auth_credential) + self._secure_client_secret(self._auth_config.exchanged_auth_credential) + self._exchanger_registry = CredentialExchangerRegistry() self._refresher_registry = CredentialRefresherRegistry() self._discovery_manager = OAuth2DiscoveryManager() @@ -98,6 +110,8 @@ def __init__( ) # TODO: Move ServiceAccountCredentialExchanger to the auth module + from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger + self._exchanger_registry.register( AuthCredentialTypes.SERVICE_ACCOUNT, ServiceAccountCredentialExchanger(), @@ -111,6 +125,36 @@ def __init__( AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher ) + def _secure_client_secret(self, credential: Optional[AuthCredential]): + """Extracts client secret to memory and redacts it from the credential.""" + if ( + credential + and credential.oauth2 + and credential.oauth2.client_id + and credential.oauth2.client_secret + and credential.oauth2.client_secret != "" + ): + logger.info( + f"Securing client secret for client_id: {credential.oauth2.client_id}" + ) + # Store in memory map + CredentialManager._CLIENT_SECRETS[credential.oauth2.client_id] = ( + credential.oauth2.client_secret + ) + # Redact from config + credential.oauth2.client_secret = "" + else: + if credential and credential.oauth2: + logger.debug( + f"Not securing secret for client_id {credential.oauth2.client_id}:" + f" secret is {credential.oauth2.client_secret}" + ) + + @staticmethod + def get_client_secret(client_id: str) -> Optional[str]: + """Retrieves the client secret for a given client_id.""" + return CredentialManager._CLIENT_SECRETS.get(client_id) + def register_credential_exchanger( self, credential_type: AuthCredentialTypes, @@ -211,6 +255,40 @@ async def _load_from_auth_response( """Load credential from auth response in context.""" return context.get_auth_response(self._auth_config) + @staticmethod + @contextlib.contextmanager + def restore_client_secret(credential: AuthCredential, secret: str = None): + """Context manager to temporarily restore client secret in a credential. + + Args: + credential: The credential to restore secret for. + secret: Optional secret to use. If not provided, looks up by client_id. + """ + if not credential or not credential.oauth2: + yield + return + + restored = False + if secret: + credential.oauth2.client_secret = secret + restored = True + elif ( + credential.oauth2.client_id + and credential.oauth2.client_secret == "" + ): + stored_secret = CredentialManager.get_client_secret( + credential.oauth2.client_id + ) + if stored_secret: + credential.oauth2.client_secret = stored_secret + restored = True + + try: + yield + finally: + if restored: + credential.oauth2.client_secret = "" + async def _exchange_credential( self, credential: AuthCredential ) -> tuple[AuthCredential, bool]: @@ -219,6 +297,8 @@ async def _exchange_credential( if not exchanger: return credential, False + from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger + if isinstance(exchanger, ServiceAccountCredentialExchanger): return ( exchanger.exchange_credential( @@ -226,11 +306,12 @@ async def _exchange_credential( ), True, ) - - exchange_result = await exchanger.exchange( - credential, self._auth_config.auth_scheme - ) - return exchange_result.credential, exchange_result.was_exchanged + else: + with self.restore_client_secret(credential): + exchanged_credential = await exchanger.exchange( + credential, self._auth_config.auth_scheme + ) + return exchanged_credential.credential, True async def _refresh_credential( self, credential: AuthCredential diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 7643125d81..b884bc4048 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -17,6 +17,7 @@ import tempfile from unittest.mock import AsyncMock from unittest.mock import create_autospec +from unittest.mock import MagicMock from unittest.mock import Mock from unittest.mock import patch @@ -1002,7 +1003,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, - self.mock_a2a_part_converter, + self.agent._a2a_part_converter, ) # Check the parts are updated as Thought assert result.content.parts[0].thought is True @@ -1770,7 +1771,7 @@ async def test_run_async_impl_successful_request(self): ) # Tuple with parts and context_id # Mock A2A client - mock_a2a_client = create_autospec(spec=A2AClient, instance=True) + mock_a2a_client = MagicMock(spec=A2AClient) mock_response = Mock() mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] @@ -1909,7 +1910,7 @@ async def test_run_async_impl_with_meta_provider(self): ) # Tuple with parts and context_id # Mock A2A client - mock_a2a_client = create_autospec(spec=A2AClient, instance=True) + mock_a2a_client = MagicMock(spec=A2AClient) mock_response = Mock() mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] @@ -2046,7 +2047,7 @@ async def test_run_async_impl_successful_request(self): ) # Tuple with parts and context_id # Mock A2A client - mock_a2a_client = create_autospec(spec=A2AClient, instance=True) + mock_a2a_client = MagicMock(spec=A2AClient) mock_response = Mock() mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] diff --git a/tests/unittests/auth/test_auth_handler_secrets.py b/tests/unittests/auth/test_auth_handler_secrets.py new file mode 100644 index 0000000000..11d8eb9d39 --- /dev/null +++ b/tests/unittests/auth/test_auth_handler_secrets.py @@ -0,0 +1,131 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_handler import AuthHandler +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +from google.adk.auth.exchanger.base_credential_exchanger import ExchangeResult +import pytest + + +class TestAuthHandlerSecrets: + + @pytest.fixture(autouse=True) + def clear_credential_manager_secrets(self): + """Clear CredentialManager secrets buffer before/after each test.""" + CredentialManager._CLIENT_SECRETS = {} + yield + CredentialManager._CLIENT_SECRETS = {} + + @pytest.mark.asyncio + async def test_exchange_auth_token_restores_and_reredacts_secret(self): + client_id = "test_client_id" + secret = "super_secret_value" + + # Setup secure storage + CredentialManager._CLIENT_SECRETS[client_id] = secret + + # Create credential with redacted secret + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth(client_id=client_id, client_secret=""), + ) + + auth_config = Mock(spec=AuthConfig) + auth_config.exchanged_auth_credential = credential + auth_config.auth_scheme = Mock() + + handler = AuthHandler(auth_config) + + # Mock exchanger + mock_exchanger = AsyncMock() + + # Check secret inside exchange + def check_secret(cred, scheme): + assert cred.oauth2.client_secret == secret + return ExchangeResult(cred, True) + + mock_exchanger.exchange.side_effect = check_secret + + with patch( + "google.adk.auth.auth_handler.OAuth2CredentialExchanger", + return_value=mock_exchanger, + ): + await handler.exchange_auth_token() + + # Verify secret is re-redacted + assert credential.oauth2.client_secret == "" + + def test_generate_auth_uri_uses_restored_secret(self): + client_id = "test_client_id" + secret = "super_secret_value" + + # Setup secure storage + CredentialManager._CLIENT_SECRETS[client_id] = secret + + # Create credential with redacted secret + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=client_id, + client_secret="", + redirect_uri="http://localhost/callback", + ), + ) + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = credential + auth_config.auth_scheme = Mock() + # Mock flows for scopes + auth_config.auth_scheme.flows.implicit = None + auth_config.auth_scheme.flows.clientCredentials = None + auth_config.auth_scheme.flows.password = None + auth_config.auth_scheme.flows.authorizationCode.scopes = {"scope": "desc"} + auth_config.auth_scheme.flows.authorizationCode.authorizationUrl = ( + "http://auth" + ) + + handler = AuthHandler(auth_config) + + # Mock OAuth2Session + with ( + patch("google.adk.auth.auth_handler.OAuth2Session") as mock_session_cls, + patch("google.adk.auth.auth_handler.AUTHLIB_AVAILABLE", True), + ): + + mock_session = Mock() + mock_session.create_authorization_url.return_value = ( + "http://auth?param=1", + "state", + ) + mock_session_cls.return_value = mock_session + + handler.generate_auth_uri() + + # Verify session was created with the REAL secret, not redacted one + mock_session_cls.assert_called_with( + client_id, + secret, + scope="scope", + redirect_uri="http://localhost/callback", + ) diff --git a/tests/unittests/auth/test_auth_tool_key_stability.py b/tests/unittests/auth/test_auth_tool_key_stability.py new file mode 100644 index 0000000000..9f44299c2c --- /dev/null +++ b/tests/unittests/auth/test_auth_tool_key_stability.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig + + +class TestAuthToolKeyStability(unittest.TestCase): + + def test_key_stability_with_different_secrets(self): + from google.adk.auth.auth_schemes import AuthSchemeType + from google.adk.auth.auth_schemes import OAuth2 + + # Consistent scheme for both + auth_scheme = OAuth2(type=AuthSchemeType.oauth2, flows={}) + + # Config 1: Real secret + auth_credential_1 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client_id", client_secret="real_secret", auth_uri="uri" + ), + ) + config1 = AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential_1 + ) + + # Config 2: Redacted secret + auth_credential_2 = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="client_id", client_secret="", auth_uri="uri" + ), + ) + config2 = AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential_2 + ) + + # Keys should be identical + key1 = config1.credential_key + key2 = config2.credential_key + + self.assertEqual(key1, key2, f"Keys should match! {key1} vs {key2}") diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py index 7000c9b8f8..6d2454ac98 100644 --- a/tests/unittests/auth/test_credential_manager.py +++ b/tests/unittests/auth/test_credential_manager.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from unittest.mock import ANY from unittest.mock import AsyncMock from unittest.mock import Mock @@ -33,8 +34,8 @@ from google.adk.auth.auth_schemes import ExtendedOAuth2 from google.adk.auth.auth_tool import AuthConfig from google.adk.auth.credential_manager import CredentialManager -from google.adk.auth.credential_manager import ServiceAccountCredentialExchanger from google.adk.auth.oauth2_discovery import AuthorizationServerMetadata +from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.tools.tool_context import ToolContext import pytest @@ -42,14 +43,59 @@ from .. import testing_utils +def deepcopy_auth_config_mock(m): + """Deep copy helper for AuthConfig mock.""" + new_m = Mock(spec=AuthConfig) + + # Copy attributes if they have been accessed/set on the original mock + # We focus on the ones used in tests + # Use shallow copy for Mocks to preserve identity for assertions + # Use deep copy for real objects to simulate correct behavior + + if hasattr(m, "auth_scheme"): + val = m.auth_scheme + if isinstance(val, Mock): + new_m.auth_scheme = val + else: + new_m.auth_scheme = copy.deepcopy(val) + + if hasattr(m, "raw_auth_credential"): + val = m.raw_auth_credential + if isinstance(val, Mock): + new_m.raw_auth_credential = val + else: + new_m.raw_auth_credential = copy.deepcopy(val) + else: + # Set default to None to avoid AttributeError if spec doesn't expose it + new_m.raw_auth_credential = None + + if hasattr(m, "exchanged_auth_credential"): + val = m.exchanged_auth_credential + if isinstance(val, Mock): + new_m.exchanged_auth_credential = val + else: + new_m.exchanged_auth_credential = copy.deepcopy(val) + else: + new_m.exchanged_auth_credential = None + + new_m.model_copy.side_effect = lambda **kwargs: deepcopy_auth_config_mock(new_m) + return new_m + +def create_auth_config_mock(): + """Creates a mock AuthConfig that returns a deep copy on model_copy.""" + m = Mock(spec=AuthConfig) + m.model_copy.side_effect = lambda **kwargs: deepcopy_auth_config_mock(m) + return m + + class TestCredentialManager: """Test suite for CredentialManager.""" def test_init(self): """Test CredentialManager initialization.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() manager = CredentialManager(auth_config) - assert manager._auth_config == auth_config + assert manager._auth_config is not auth_config @pytest.mark.asyncio async def test_request_credential(self): @@ -61,13 +107,13 @@ async def test_request_credential(self): manager = CredentialManager(auth_config) await manager.request_credential(tool_context) - tool_context.request_credential.assert_called_once_with(auth_config) + tool_context.request_credential.assert_called_once_with(manager._auth_config) @pytest.mark.asyncio async def test_load_auth_credentials_success(self): """Test load_auth_credential with successful flow.""" # Create mocks - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None @@ -110,7 +156,7 @@ async def test_load_auth_credentials_success(self): @pytest.mark.asyncio async def test_load_auth_credentials_no_credential(self): """Test load_auth_credential when no credential is available.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None # Add auth_scheme for the _is_client_credentials_flow method @@ -139,9 +185,10 @@ async def test_load_auth_credentials_no_credential(self): @pytest.mark.asyncio async def test_load_existing_credential_already_exchanged(self): - """Test _load_existing_credential ignores shared config cache.""" - auth_config = Mock(spec=AuthConfig) + """Test _load_existing_credential when credential is already exchanged.""" + auth_config = create_auth_config_mock() mock_credential = Mock(spec=AuthCredential) + mock_credential.oauth2 = Mock() auth_config.exchanged_auth_credential = mock_credential tool_context = Mock() @@ -156,7 +203,7 @@ async def test_load_existing_credential_already_exchanged(self): @pytest.mark.asyncio async def test_load_existing_credential_with_credential_service(self): """Test _load_existing_credential with credential service.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.exchanged_auth_credential = None mock_credential = Mock(spec=AuthCredential) @@ -194,13 +241,13 @@ async def test_load_from_credential_service_with_service(self): manager = CredentialManager(auth_config) result = await manager._load_from_credential_service(tool_context) - tool_context.load_credential.assert_called_once_with(auth_config) + tool_context.load_credential.assert_called_once_with(manager._auth_config) assert result == mock_credential @pytest.mark.asyncio async def test_load_from_credential_service_no_service(self): """Test _load_from_credential_service when no credential service is available.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() # Mock invocation context with no credential service invocation_context = Mock() @@ -358,7 +405,7 @@ async def test_request_credential_does_not_leak_across_users(self): session=session_b, credential_service=None, ) - + tool_context_a = ToolContext( invocation_context_a, function_call_id="call_a" ) @@ -387,9 +434,10 @@ async def test_refresh_credential_oauth2(self): mock_oauth2_auth = Mock(spec=OAuth2Auth) mock_credential = Mock(spec=AuthCredential) + mock_credential.oauth2 = Mock() mock_credential.auth_type = AuthCredentialTypes.OAUTH2 - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = Mock() # Mock refresher @@ -424,7 +472,7 @@ async def test_refresh_credential_no_refresher(self): mock_credential = Mock(spec=AuthCredential) mock_credential.auth_type = AuthCredentialTypes.API_KEY - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() manager = CredentialManager(auth_config) @@ -443,9 +491,10 @@ async def test_refresh_credential_no_refresher(self): async def test_is_credential_ready_api_key(self): """Test _is_credential_ready with API key credential.""" mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.oauth2 = Mock() mock_raw_credential.auth_type = AuthCredentialTypes.API_KEY - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = mock_raw_credential manager = CredentialManager(auth_config) @@ -457,9 +506,10 @@ async def test_is_credential_ready_api_key(self): async def test_is_credential_ready_oauth2(self): """Test _is_credential_ready with OAuth2 credential (needs processing).""" mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.oauth2 = Mock() mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = mock_raw_credential manager = CredentialManager(auth_config) @@ -473,7 +523,7 @@ async def test_validate_credential_no_raw_credential_oauth2(self): auth_scheme = Mock() auth_scheme.type_ = AuthSchemeType.oauth2 - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.auth_scheme = auth_scheme @@ -488,7 +538,7 @@ async def test_validate_credential_no_raw_credential_openid(self): auth_scheme = Mock() auth_scheme.type_ = AuthSchemeType.openIdConnect - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.auth_scheme = auth_scheme @@ -503,7 +553,7 @@ async def test_validate_credential_no_raw_credential_other_scheme(self): auth_scheme = Mock() auth_scheme.type_ = AuthSchemeType.apiKey - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = None auth_config.auth_scheme = auth_scheme @@ -519,7 +569,7 @@ async def test_validate_credential_oauth2_missing_oauth2_field(self): mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 mock_raw_credential.oauth2 = None - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = mock_raw_credential auth_config.auth_scheme = Mock() @@ -534,10 +584,10 @@ async def test_validate_credential_oauth2_missing_scheme_info( ): """Test _validate_credential with OAuth2 missing scheme info.""" mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.oauth2 = Mock() mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 - mock_raw_credential.oauth2 = Mock(spec=OAuth2Auth) - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.raw_auth_credential = mock_raw_credential auth_config.auth_scheme = extended_oauth2_scheme @@ -555,7 +605,7 @@ async def test_exchange_credentials_service_account( self, service_account_credential, oauth2_auth_scheme ): """Test _exchange_credential with service account credential.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = oauth2_auth_scheme exchanged_credential = Mock(spec=AuthCredential) @@ -584,7 +634,7 @@ async def test_exchange_credential_no_exchanger(self): mock_credential = Mock(spec=AuthCredential) mock_credential.auth_type = AuthCredentialTypes.API_KEY - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() manager = CredentialManager(auth_config) @@ -640,7 +690,7 @@ async def test_populate_auth_scheme_success( self, auth_server_metadata, extended_oauth2_scheme ): """Test _populate_auth_scheme successfully populates missing info.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = extended_oauth2_scheme manager = CredentialManager(auth_config) @@ -663,7 +713,7 @@ async def test_populate_auth_scheme_success( @pytest.mark.asyncio async def test_populate_auth_scheme_fail(self, extended_oauth2_scheme): """Test _populate_auth_scheme when auto-discovery fails.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = extended_oauth2_scheme manager = CredentialManager(auth_config) @@ -682,7 +732,7 @@ async def test_populate_auth_scheme_fail(self, extended_oauth2_scheme): @pytest.mark.asyncio async def test_populate_auth_scheme_noop(self, implicit_oauth2_scheme): """Test _populate_auth_scheme when auth scheme info not missing.""" - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = implicit_oauth2_scheme manager = CredentialManager(auth_config) @@ -705,11 +755,7 @@ def test_is_client_credentials_flow_oauth2_with_client_credentials(self): ) ) - auth_config = Mock(spec=AuthConfig) - auth_config.auth_scheme = auth_scheme - auth_config.raw_auth_credential = None - auth_config.exchanged_auth_credential = None - + auth_config = AuthConfig(auth_scheme=auth_scheme) manager = CredentialManager(auth_config) assert manager._is_client_credentials_flow() is True @@ -730,11 +776,7 @@ def test_is_client_credentials_flow_oauth2_without_client_credentials(self): ) ) - auth_config = Mock(spec=AuthConfig) - auth_config.auth_scheme = auth_scheme - auth_config.raw_auth_credential = None - auth_config.exchanged_auth_credential = None - + auth_config = AuthConfig(auth_scheme=auth_scheme) manager = CredentialManager(auth_config) assert manager._is_client_credentials_flow() is False @@ -750,11 +792,7 @@ def test_is_client_credentials_flow_oidc_with_client_credentials(self): grant_types_supported=["authorization_code", "client_credentials"], ) - auth_config = Mock(spec=AuthConfig) - auth_config.auth_scheme = auth_scheme - auth_config.raw_auth_credential = None - auth_config.exchanged_auth_credential = None - + auth_config = AuthConfig(auth_scheme=auth_scheme) manager = CredentialManager(auth_config) assert manager._is_client_credentials_flow() is True @@ -770,11 +808,7 @@ def test_is_client_credentials_flow_oidc_without_client_credentials(self): grant_types_supported=["authorization_code"], ) - auth_config = Mock(spec=AuthConfig) - auth_config.auth_scheme = auth_scheme - auth_config.raw_auth_credential = None - auth_config.exchanged_auth_credential = None - + auth_config = AuthConfig(auth_scheme=auth_scheme) manager = CredentialManager(auth_config) assert manager._is_client_credentials_flow() is False @@ -784,7 +818,7 @@ def test_is_client_credentials_flow_other_scheme(self): # Create a non-OAuth2/OIDC scheme auth_scheme = Mock() - auth_config = Mock(spec=AuthConfig) + auth_config = create_auth_config_mock() auth_config.auth_scheme = auth_scheme auth_config.raw_auth_credential = None auth_config.exchanged_auth_credential = None diff --git a/tests/unittests/auth/test_credential_manager_secrets.py b/tests/unittests/auth/test_credential_manager_secrets.py new file mode 100644 index 0000000000..f0394b9d61 --- /dev/null +++ b/tests/unittests/auth/test_credential_manager_secrets.py @@ -0,0 +1,185 @@ +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.auth.credential_manager import CredentialManager +import pytest + + +@pytest.fixture(autouse=True) +def clear_credential_manager_secrets(): + """Clear CredentialManager secrets buffer before/after each test.""" + CredentialManager._CLIENT_SECRETS = {} + yield + CredentialManager._CLIENT_SECRETS = {} + + +@pytest.mark.asyncio +async def test_credential_manager_redacts_secrets_in_raw_credential(): + """Test that CredentialManager redacts client_secret from raw_auth_credential upon initialization.""" + + # Setup + client_id = "test_client_id" + client_secret = "test_client_secret" + + oauth_auth = OAuth2Auth(client_id=client_id, client_secret=client_secret) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth_auth + ) + + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + ) + ) + ) + + auth_config = AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=auth_credential + ) + + # Act + manager = CredentialManager(auth_config) + + # Assert + # 1. Check if secret is in memory map + assert client_id in manager._CLIENT_SECRETS + assert manager._CLIENT_SECRETS[client_id] == client_secret + + # 2. Check if secret is redacted in the manager's config + assert ( + manager._auth_config.raw_auth_credential.oauth2.client_secret + == "" + ) + + # 3. Check original config is NOT modified (AuthConfig copy behavior) + # Since we used model_copy(deep=True), calling on Pydantic model copies it. + assert auth_config.raw_auth_credential.oauth2.client_secret == client_secret + + +@pytest.mark.asyncio +async def test_credential_manager_redacts_secrets_in_exchanged_credential(): + """Test that CredentialManager redacts client_secret from exchanged_auth_credential if present.""" + + # Setup + client_id = "test_client_id_exchanged" + client_secret = "test_client_secret_exchanged" + + oauth_auth = OAuth2Auth( + client_id=client_id, + client_secret=client_secret, + access_token="some_token", + ) + + exchanged_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth_auth + ) + + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + ) + ) + ) + + auth_config = AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=None, + exchanged_auth_credential=exchanged_credential, + ) + + # Act + manager = CredentialManager(auth_config) + + # Assert + assert client_id in manager._CLIENT_SECRETS + assert manager._CLIENT_SECRETS[client_id] == client_secret + + assert ( + manager._auth_config.exchanged_auth_credential.oauth2.client_secret + == "" + ) + + +@pytest.mark.asyncio +async def test_exchange_credential_restores_secret(): + """Test that _exchange_credential restores the secret before calling exchanger.""" + + # Setup + client_id = "test_client_id_exchange" + client_secret = "test_client_secret_exchange" + + oauth_auth = OAuth2Auth(client_id=client_id, client_secret=client_secret) + + raw_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth_auth + ) + + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + ) + ) + ) + + auth_config = AuthConfig( + auth_scheme=auth_scheme, raw_auth_credential=raw_credential + ) + + manager = CredentialManager(auth_config) + + # Secret should be redacted now + assert ( + manager._auth_config.raw_auth_credential.oauth2.client_secret + == "" + ) + + # Prepare a credential to be exchanged (e.g. from client response, has no secret or redacted) + credential_to_exchange = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=client_id, + client_secret="", # or None + auth_code="some_code", + ), + ) + + # Mock exchanger + mock_exchanger = AsyncMock() + + # We use side_effect to verify the secret at the moment of call, because the object is mutated later + def check_secret(cred, scheme): + assert cred.oauth2.client_secret == client_secret + from google.adk.auth.exchanger.base_credential_exchanger import ExchangeResult + return ExchangeResult(credential=credential_to_exchange, was_exchanged=True) + + mock_exchanger.exchange.side_effect = check_secret + + with patch.object( + manager._exchanger_registry, "get_exchanger", return_value=mock_exchanger + ): + # Act + result_credential, exchanged = await manager._exchange_credential( + credential_to_exchange + ) + + # Assert + # Verification happened in side_effect + assert mock_exchanger.exchange.called + + # Check that the result credential (modified in place or returned) has secret REDACTED again + assert result_credential.oauth2.client_secret == "" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 0bf28cb3c2..d2b5ddca1b 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -86,10 +86,13 @@ def test_init_with_auth(self): # Create real auth scheme instances instead of mocks from fastapi.openapi.models import OAuth2 + test_client_secret = "test_secret" auth_scheme = OAuth2(flows={}) auth_credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"), + oauth2=OAuth2Auth( + client_id="test_id", client_secret=test_client_secret + ), ) tool = MCPTool( @@ -102,6 +105,15 @@ def test_init_with_auth(self): # The auth config is stored in the parent class _credentials_manager assert tool._credentials_manager is not None assert tool._credentials_manager._auth_config.auth_scheme == auth_scheme + assert ( + tool._credentials_manager._auth_config.raw_auth_credential.oauth2.client_secret + == "" + ) + + # Restore the client secret and validate it's the same credential in the end. + tool._credentials_manager._auth_config.raw_auth_credential.oauth2.client_secret = ( + test_client_secret + ) assert ( tool._credentials_manager._auth_config.raw_auth_credential == auth_credential