diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index f13fe3ab..adb3c5ae 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -1,6 +1,7 @@ import json import logging +from collections.abc import Callable from typing import Any import httpx @@ -44,6 +45,7 @@ async def get_agent_card( self, relative_card_path: str | None = None, http_kwargs: dict[str, Any] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: """Fetches an agent card from a specified path relative to the base_url. @@ -56,6 +58,7 @@ async def get_agent_card( agent card path. Use `'/'` for an empty path. http_kwargs: Optional dictionary of keyword arguments to pass to the underlying httpx.get request. + signature_verifier: A callable used to verify the agent card's signatures. Returns: An `AgentCard` object representing the agent's capabilities. @@ -86,6 +89,8 @@ async def get_agent_card( agent_card_data, ) agent_card = AgentCard.model_validate(agent_card_data) + if signature_verifier: + signature_verifier(agent_card) except httpx.HTTPStatusError as e: raise A2AClientHTTPError( e.response.status_code, diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index e2eb066a..c3d5762e 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -116,6 +116,7 @@ async def connect( # noqa: PLR0913 resolver_http_kwargs: dict[str, Any] | None = None, extra_transports: dict[str, TransportProducer] | None = None, extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> Client: """Convenience method for constructing a client. @@ -146,6 +147,7 @@ async def connect( # noqa: PLR0913 extra_transports: Additional transport protocols to enable when constructing the client. extensions: List of extensions to be activated. + signature_verifier: A callable used to verify the agent card's signatures. Returns: A `Client` object. @@ -158,12 +160,14 @@ async def connect( # noqa: PLR0913 card = await resolver.get_agent_card( relative_card_path=relative_card_path, http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, ) else: resolver = A2ACardResolver(client_config.httpx_client, agent) card = await resolver.get_agent_card( relative_card_path=relative_card_path, http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, ) else: card = agent diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index c5edf7a1..6a8b16f9 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -237,7 +237,7 @@ async def get_card( metadata=self._get_grpc_metadata(extensions), ) card = proto_utils.FromProto.agent_card(card_pb) - if signature_verifier is not None: + if signature_verifier: signature_verifier(card) self.agent_card = card diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 54c758ff..a565e640 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -390,9 +390,10 @@ async def get_card( if not card: resolver = A2ACardResolver(self.httpx_client, self.url) - card = await resolver.get_agent_card(http_kwargs=modified_kwargs) - if signature_verifier is not None: - signature_verifier(card) + card = await resolver.get_agent_card( + http_kwargs=modified_kwargs, + signature_verifier=signature_verifier, + ) self._needs_extended_card = ( card.supports_authenticated_extended_card ) @@ -418,7 +419,7 @@ async def get_card( if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) card = response.root.result - if signature_verifier is not None: + if signature_verifier: signature_verifier(card) self.agent_card = card diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 1649be1c..afc9dd08 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -382,9 +382,10 @@ async def get_card( if not card: resolver = A2ACardResolver(self.httpx_client, self.url) - card = await resolver.get_agent_card(http_kwargs=modified_kwargs) - if signature_verifier is not None: - signature_verifier(card) + card = await resolver.get_agent_card( + http_kwargs=modified_kwargs, + signature_verifier=signature_verifier, + ) self._needs_extended_card = ( card.supports_authenticated_extended_card ) @@ -402,7 +403,7 @@ async def get_card( '/v1/card', {}, modified_kwargs ) card = AgentCard.model_validate(response_data) - if signature_verifier is not None: + if signature_verifier: signature_verifier(card) self.agent_card = card diff --git a/tests/client/test_card_resolver.py b/tests/client/test_card_resolver.py index f87d9450..26f3f106 100644 --- a/tests/client/test_card_resolver.py +++ b/tests/client/test_card_resolver.py @@ -1,7 +1,7 @@ import json import logging -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import httpx import pytest @@ -371,9 +371,30 @@ async def test_get_agent_card_returns_agent_card_instance( self, resolver, mock_httpx_client, mock_response, valid_agent_card_data ): """Test that get_agent_card returns an AgentCard instance.""" + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response mock_agent_card = Mock(spec=AgentCard) + with patch.object( AgentCard, 'model_validate', return_value=mock_agent_card ): result = await resolver.get_agent_card() assert result == mock_agent_card + mock_response.raise_for_status.assert_called_once() + + @pytest.mark.asyncio + async def test_get_agent_card_with_signature_verifier( + self, resolver, mock_httpx_client, valid_agent_card_data + ): + """Test that the signature verifier is called if provided.""" + mock_verifier = MagicMock() + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + + agent_card = await resolver.get_agent_card( + signature_verifier=mock_verifier + ) + + mock_verifier.assert_called_once_with(agent_card) diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index 16a1433f..c388974b 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -190,6 +190,7 @@ async def test_client_factory_connect_with_resolver_args( mock_resolver.return_value.get_agent_card.assert_awaited_once_with( relative_card_path=relative_path, http_kwargs=http_kwargs, + signature_verifier=None, ) @@ -216,6 +217,7 @@ async def test_client_factory_connect_resolver_args_without_client( mock_resolver.return_value.get_agent_card.assert_awaited_once_with( relative_card_path=relative_path, http_kwargs=http_kwargs, + signature_verifier=None, )