Skip to content
5 changes: 5 additions & 0 deletions src/a2a/client/card_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging

from collections.abc import Callable
from typing import Any

import httpx
Expand Down Expand Up @@ -44,6 +45,7 @@ async def get_agent_card(
self,
relative_card_path: str | None = None,
http_kwargs: dict[str, Any] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Fetches an agent card from a specified path relative to the base_url.

Expand All @@ -56,6 +58,7 @@ async def get_agent_card(
agent card path. Use `'/'` for an empty path.
http_kwargs: Optional dictionary of keyword arguments to pass to the
underlying httpx.get request.
signature_verifier: A callable used to verify the agent card's signatures.

Returns:
An `AgentCard` object representing the agent's capabilities.
Expand Down Expand Up @@ -86,6 +89,8 @@ async def get_agent_card(
agent_card_data,
)
agent_card = AgentCard.model_validate(agent_card_data)
if signature_verifier:
signature_verifier(agent_card)
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(
e.response.status_code,
Expand Down
4 changes: 4 additions & 0 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ async def connect( # noqa: PLR0913
resolver_http_kwargs: dict[str, Any] | None = None,
extra_transports: dict[str, TransportProducer] | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> Client:
"""Convenience method for constructing a client.

Expand Down Expand Up @@ -146,6 +147,7 @@ async def connect( # noqa: PLR0913
extra_transports: Additional transport protocols to enable when
constructing the client.
extensions: List of extensions to be activated.
signature_verifier: A callable used to verify the agent card's signatures.

Returns:
A `Client` object.
Expand All @@ -158,12 +160,14 @@ async def connect( # noqa: PLR0913
card = await resolver.get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=resolver_http_kwargs,
signature_verifier=signature_verifier,
)
else:
resolver = A2ACardResolver(client_config.httpx_client, agent)
card = await resolver.get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=resolver_http_kwargs,
signature_verifier=signature_verifier,
)
else:
card = agent
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def get_card(
metadata=self._get_grpc_metadata(extensions),
)
card = proto_utils.FromProto.agent_card(card_pb)
if signature_verifier is not None:
if signature_verifier:
signature_verifier(card)

self.agent_card = card
Expand Down
9 changes: 5 additions & 4 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,45 +363,46 @@
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card

if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
if signature_verifier is not None:
signature_verifier(card)
card = await resolver.get_agent_card(
http_kwargs=modified_kwargs,
signature_verifier=signature_verifier,
)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)
self.agent_card = card

if not self._needs_extended_card:
return card

request = GetAuthenticatedExtendedCardRequest(id=str(uuid4()))

Check notice on line 405 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (162-397)
payload, modified_kwargs = await self._apply_interceptors(
request.method,
request.model_dump(mode='json', exclude_none=True),
Expand All @@ -418,7 +419,7 @@
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
card = response.root.result
if signature_verifier is not None:
if signature_verifier:
signature_verifier(card)

self.agent_card = card
Expand Down
9 changes: 5 additions & 4 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,241 +159,242 @@
yield proto_utils.FromProto.stream_response(event)
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
try:
response = await self.httpx_client.send(request)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_post_request(
self,
target: str,
rpc_request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
return await self._send_request(
self.httpx_client.build_request(
'POST',
f'{self.url}{target}',
json=rpc_request_payload,
**(http_kwargs or {}),
)
)

async def _send_get_request(
self,
target: str,
query_params: dict[str, str],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
return await self._send_request(
self.httpx_client.build_request(
'GET',
f'{self.url}{target}',
params=query_params,
**(http_kwargs or {}),
)
)

async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
_payload, modified_kwargs = await self._apply_interceptors(
request.model_dump(mode='json', exclude_none=True),
modified_kwargs,
context,
)
response_data = await self._send_get_request(
f'/v1/tasks/{request.id}',
{'historyLength': str(request.history_length)}
if request.history_length is not None
else {},
modified_kwargs,
)
task = a2a_pb2.Task()
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""
pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}')
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload,
modified_kwargs,
context,
)
response_data = await self._send_post_request(
f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs
)
task = a2a_pb2.Task()
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""
pb = a2a_pb2.CreateTaskPushNotificationConfigRequest(
parent=f'tasks/{request.task_id}',
config_id=request.push_notification_config.id,
config=proto_utils.ToProto.task_push_notification_config(request),
)
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload, modified_kwargs, context
)
response_data = await self._send_post_request(
f'/v1/tasks/{request.task_id}/pushNotificationConfigs',
payload,
modified_kwargs,
)
config = a2a_pb2.TaskPushNotificationConfig()
ParseDict(response_data, config)
return proto_utils.FromProto.task_push_notification_config(config)

async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""
pb = a2a_pb2.GetTaskPushNotificationConfigRequest(
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
)
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload,
modified_kwargs,
context,
)
response_data = await self._send_get_request(
f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
{},
modified_kwargs,
)
config = a2a_pb2.TaskPushNotificationConfig()
ParseDict(response_data, config)
return proto_utils.FromProto.task_push_notification_config(config)

async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncGenerator[
Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message
]:
"""Reconnects to get task updates."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
modified_kwargs.setdefault('timeout', None)

async with aconnect_sse(
self.httpx_client,
'GET',
f'{self.url}/v1/tasks/{request.id}:subscribe',
**modified_kwargs,
) as event_source:
try:
async for sse in event_source.aiter_sse():
event = a2a_pb2.StreamResponse()
Parse(sse.data, event)
yield proto_utils.FromProto.stream_response(event)
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card

if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
if signature_verifier is not None:
signature_verifier(card)
card = await resolver.get_agent_card(
http_kwargs=modified_kwargs,
signature_verifier=signature_verifier,
)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)
self.agent_card = card

if not self._needs_extended_card:
return card

_, modified_kwargs = await self._apply_interceptors(

Check notice on line 397 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (366-405)
{},
modified_kwargs,
context,
Expand All @@ -402,7 +403,7 @@
'/v1/card', {}, modified_kwargs
)
card = AgentCard.model_validate(response_data)
if signature_verifier is not None:
if signature_verifier:
signature_verifier(card)

self.agent_card = card
Expand Down
23 changes: 22 additions & 1 deletion tests/client/test_card_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging

from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import httpx
import pytest
Expand Down Expand Up @@ -371,9 +371,30 @@ async def test_get_agent_card_returns_agent_card_instance(
self, resolver, mock_httpx_client, mock_response, valid_agent_card_data
):
"""Test that get_agent_card returns an AgentCard instance."""
mock_response.json.return_value = valid_agent_card_data
mock_httpx_client.get.return_value = mock_response
mock_agent_card = Mock(spec=AgentCard)

with patch.object(
AgentCard, 'model_validate', return_value=mock_agent_card
):
result = await resolver.get_agent_card()
assert result == mock_agent_card
mock_response.raise_for_status.assert_called_once()

@pytest.mark.asyncio
async def test_get_agent_card_with_signature_verifier(
self, resolver, mock_httpx_client, valid_agent_card_data
):
"""Test that the signature verifier is called if provided."""
mock_verifier = MagicMock()

mock_response = MagicMock(spec=httpx.Response)
mock_response.json.return_value = valid_agent_card_data
mock_httpx_client.get.return_value = mock_response

agent_card = await resolver.get_agent_card(
signature_verifier=mock_verifier
)

mock_verifier.assert_called_once_with(agent_card)
2 changes: 2 additions & 0 deletions tests/client/test_client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ async def test_client_factory_connect_with_resolver_args(
mock_resolver.return_value.get_agent_card.assert_awaited_once_with(
relative_card_path=relative_path,
http_kwargs=http_kwargs,
signature_verifier=None,
)


Expand All @@ -216,6 +217,7 @@ async def test_client_factory_connect_resolver_args_without_client(
mock_resolver.return_value.get_agent_card.assert_awaited_once_with(
relative_card_path=relative_path,
http_kwargs=http_kwargs,
signature_verifier=None,
)


Expand Down
Loading