From bdc81a165b1baffd842e0d8920598a516edc46ca Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 24 Nov 2025 10:04:27 +0000 Subject: [PATCH 1/2] fix(jsonrpc, rest): fix get_card methods in json-rpc and rest transports. Headers are now updated with extensions before the get_agent_card call. --- src/a2a/client/transports/jsonrpc.py | 12 +- src/a2a/client/transports/rest.py | 12 +- .../client/transports/test_jsonrpc_client.py | 84 +++++++++++++ tests/client/transports/test_rest_client.py | 111 +++++++++++++++++- 4 files changed, 204 insertions(+), 15 deletions(-) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index d8011cf4..0006b5fb 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -378,12 +378,14 @@ async def get_card( extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) card = self.agent_card if not card: resolver = A2ACardResolver(self.httpx_client, self.url) - card = await resolver.get_agent_card( - http_kwargs=self._get_http_args(context) - ) + card = await resolver.get_agent_card(http_kwargs=modified_kwargs) self._needs_extended_card = ( card.supports_authenticated_extended_card ) @@ -393,10 +395,6 @@ async def get_card( return card request = GetAuthenticatedExtendedCardRequest(id=str(uuid4())) - modified_kwargs = update_extension_header( - self._get_http_args(context), - extensions if extensions is not None else self.extensions, - ) payload, modified_kwargs = await self._apply_interceptors( request.method, request.model_dump(mode='json', exclude_none=True), diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 83c26787..948f3f35 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -370,12 +370,14 @@ async def get_card( extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) card = self.agent_card if not card: resolver = A2ACardResolver(self.httpx_client, self.url) - card = await resolver.get_agent_card( - http_kwargs=self._get_http_args(context) - ) + card = await resolver.get_agent_card(http_kwargs=modified_kwargs) self._needs_extended_card = ( card.supports_authenticated_extended_card ) @@ -384,10 +386,6 @@ async def get_card( if not self._needs_extended_card: return card - modified_kwargs = update_extension_header( - self._get_http_args(context), - extensions if extensions is not None else self.extensions, - ) _, modified_kwargs = await self._apply_interceptors( {}, modified_kwargs, diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index bd705d93..30665987 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -875,3 +875,87 @@ async def test_send_message_streaming_with_new_extensions( assert ( headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' ) + + @pytest.mark.asyncio + async def test_get_card_no_card_provided_with_extensions( + self, mock_httpx_client: AsyncMock + ): + """Test get_card with extensions set in Client when no card is initially provided. + Tests that the extensions are added to the HTTP GET request.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + url=TestJsonRpcTransport.AGENT_URL, + extensions=extensions, + ) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = AGENT_CARD.model_dump(mode='json') + mock_httpx_client.get.return_value = mock_response + + await client.get_card() + + mock_httpx_client.get.assert_called_once() + _, mock_kwargs = mock_httpx_client.get.call_args + + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + } + assert len(actual_extensions_list) == 2 + assert actual_extensions == expected_extensions + + @pytest.mark.asyncio + async def test_get_card_with_extended_card_support_with_extensions( + self, mock_httpx_client: AsyncMock + ): + """Test get_card with extensions passed to get_card call when extended card support is enabled. + Tests that the extensions are added to the RPC request.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + agent_card = AGENT_CARD.model_copy( + update={'supports_authenticated_extended_card': True} + ) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=agent_card, + extensions=extensions, + ) + + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': AGENT_CARD_EXTENDED.model_dump(mode='json'), + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + await client.get_card(extensions=extensions) + + mock_send_request.assert_called_once() + _, mock_kwargs = mock_send_request.call_args[0] + + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + } + assert len(actual_extensions_list) == 2 + assert actual_extensions == expected_extensions diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 04bd1036..df7f5c85 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -9,7 +9,13 @@ from a2a.client import create_text_message_object from a2a.client.transports.rest import RestTransport from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.types import AgentCard, MessageSendParams, Role +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + MessageSendParams, + Role, +) @pytest.fixture @@ -119,3 +125,106 @@ async def test_send_message_streaming_with_new_extensions( assert ( headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' ) + + @pytest.mark.asyncio + async def test_get_card_no_card_provided_with_extensions( + self, mock_httpx_client: AsyncMock + ): + """Test get_card with extensions set in Client when no card is initially provided. + Tests that the extensions are added to the HTTP GET request.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + client = RestTransport( + httpx_client=mock_httpx_client, + url='http://agent.example.com/api', + extensions=extensions, + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = { + 'name': 'Test Agent', + 'description': 'Test Agent Description', + 'url': 'http://agent.example.com/api', + 'version': '1.0.0', + 'default_input_modes': ['text'], + 'default_output_modes': ['text'], + 'capabilities': AgentCapabilities().model_dump(), + 'skills': [], + } + mock_httpx_client.get.return_value = mock_response + + await client.get_card() + + mock_httpx_client.get.assert_called_once() + _, mock_kwargs = mock_httpx_client.get.call_args + + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + } + assert len(actual_extensions_list) == 2 + assert actual_extensions == expected_extensions + + @pytest.mark.asyncio + async def test_get_card_with_extended_card_support_with_extensions( + self, mock_httpx_client: AsyncMock + ): + """Test get_card with extensions passed to get_card call when extended card support is enabled. + Tests that the extensions are added to the GET request.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + agent_card = AgentCard( + name='Test Agent', + description='Test Agent Description', + url='http://agent.example.com/api', + version='1.0.0', + default_input_modes=['text'], + default_output_modes=['text'], + capabilities=AgentCapabilities(), + skills=[], + supports_authenticated_extended_card=True, + ) + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=agent_card, + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = agent_card.model_dump(mode='json') + mock_httpx_client.send.return_value = mock_response + + with patch.object( + client, '_send_get_request', new_callable=AsyncMock + ) as mock_send_get_request: + mock_send_get_request.return_value = agent_card.model_dump( + mode='json' + ) + await client.get_card(extensions=extensions) + + mock_send_get_request.assert_called_once() + _, _, mock_kwargs = mock_send_get_request.call_args[0] + + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + } + assert len(actual_extensions_list) == 2 + assert actual_extensions == expected_extensions From 900133d1e4d739e09c4dcbb03e92cda5f71534ab Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 24 Nov 2025 10:45:10 +0000 Subject: [PATCH 2/2] refractor: reduce redundancy by extracting duplicated code into a shared helper function. --- .../client/transports/test_jsonrpc_client.py | 74 +++++++++---------- tests/client/transports/test_rest_client.py | 74 +++++++++---------- 2 files changed, 68 insertions(+), 80 deletions(-) diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index 30665987..df04d71c 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -114,6 +114,14 @@ async def async_iterable_from_list( yield item +def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]): + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions = {e.strip() for e in header_value.split(',')} + assert actual_extensions == expected_extensions + + class TestA2ACardResolver: BASE_URL = 'http://example.com' AGENT_CARD_PATH = AGENT_CARD_WELL_KNOWN_PATH @@ -823,18 +831,13 @@ async def test_send_message_with_default_extensions( mock_httpx_client.post.assert_called_once() _, mock_kwargs = mock_httpx_client.post.call_args - headers = mock_kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - header_value = headers[HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'https://example.com/test-ext/v1', - 'https://example.com/test-ext/v2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + ) @pytest.mark.asyncio @patch('a2a.client.transports.jsonrpc.aconnect_sse') @@ -870,10 +873,11 @@ async def test_send_message_streaming_with_new_extensions( mock_aconnect_sse.assert_called_once() _, kwargs = mock_aconnect_sse.call_args - headers = kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - assert ( - headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' + _assert_extensions_header( + kwargs, + { + 'https://example.com/test-ext/v2', + }, ) @pytest.mark.asyncio @@ -901,18 +905,13 @@ async def test_get_card_no_card_provided_with_extensions( mock_httpx_client.get.assert_called_once() _, mock_kwargs = mock_httpx_client.get.call_args - headers = mock_kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - header_value = headers[HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'https://example.com/test-ext/v1', - 'https://example.com/test-ext/v2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + ) @pytest.mark.asyncio async def test_get_card_with_extended_card_support_with_extensions( @@ -947,15 +946,10 @@ async def test_get_card_with_extended_card_support_with_extensions( mock_send_request.assert_called_once() _, mock_kwargs = mock_send_request.call_args[0] - headers = mock_kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - header_value = headers[HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'https://example.com/test-ext/v1', - 'https://example.com/test-ext/v2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index df7f5c85..49d20d9d 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -38,6 +38,14 @@ async def async_iterable_from_list( yield item +def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]): + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions = {e.strip() for e in header_value.split(',')} + assert actual_extensions == expected_extensions + + class TestRestTransportExtensions: @pytest.mark.asyncio async def test_send_message_with_default_extensions( @@ -73,18 +81,13 @@ async def test_send_message_with_default_extensions( mock_build_request.assert_called_once() _, kwargs = mock_build_request.call_args - headers = kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - header_value = kwargs['headers'][HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'https://example.com/test-ext/v1', - 'https://example.com/test-ext/v2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions + _assert_extensions_header( + kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + ) @pytest.mark.asyncio @patch('a2a.client.transports.rest.aconnect_sse') @@ -120,10 +123,11 @@ async def test_send_message_streaming_with_new_extensions( mock_aconnect_sse.assert_called_once() _, kwargs = mock_aconnect_sse.call_args - headers = kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - assert ( - headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' + _assert_extensions_header( + kwargs, + { + 'https://example.com/test-ext/v2', + }, ) @pytest.mark.asyncio @@ -161,18 +165,13 @@ async def test_get_card_no_card_provided_with_extensions( mock_httpx_client.get.assert_called_once() _, mock_kwargs = mock_httpx_client.get.call_args - headers = mock_kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - header_value = headers[HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'https://example.com/test-ext/v1', - 'https://example.com/test-ext/v2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + ) @pytest.mark.asyncio async def test_get_card_with_extended_card_support_with_extensions( @@ -216,15 +215,10 @@ async def test_get_card_with_extended_card_support_with_extensions( mock_send_get_request.assert_called_once() _, _, mock_kwargs = mock_send_get_request.call_args[0] - headers = mock_kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - header_value = headers[HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'https://example.com/test-ext/v1', - 'https://example.com/test-ext/v2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + )