Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,14 @@ initdb
inmemory
INR
isready
jku
JPY
JSONRPCt
jwk
jwks
JWS
jws
kid
kwarg
langgraph
lifecycles
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"]
telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"]
postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
signing = ["PyJWT>=2.0.0"]
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]

sql = ["a2a-sdk[postgresql,mysql,sqlite]"]
Expand All @@ -45,6 +46,7 @@ all = [
"a2a-sdk[encryption]",
"a2a-sdk[grpc]",
"a2a-sdk[telemetry]",
"a2a-sdk[signing]",
]

[project.urls]
Expand Down Expand Up @@ -86,6 +88,7 @@ style = "pep440"
dev = [
"datamodel-code-generator>=0.30.0",
"mypy>=1.15.0",
"PyJWT>=2.0.0",
"pytest>=8.3.5",
"pytest-asyncio>=0.26.0",
"pytest-cov>=6.1.1",
Expand Down
8 changes: 6 additions & 2 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Callable
from typing import Any

from a2a.client.client import (
Expand Down Expand Up @@ -261,6 +261,7 @@ async def get_card(
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card.

Expand All @@ -270,12 +271,15 @@ async def get_card(
Args:
context: The client call context.
extensions: List of extensions to be activated.
signature_verifier: A callable used to verify the agent card's signatures.

Returns:
The `AgentCard` for the agent.
"""
card = await self._transport.get_card(
context=context, extensions=extensions
context=context,
extensions=extensions,
signature_verifier=signature_verifier,
)
self._card = card
return card
Expand Down
1 change: 1 addition & 0 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,18 @@
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream."""
return
yield

@abstractmethod
async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""

Check notice on line 190 in src/a2a/client/client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/base.py (97-108)

async def add_event_consumer(self, consumer: Consumer) -> None:
"""Attaches additional consumers to the `Client`."""
Expand Down
3 changes: 2 additions & 1 deletion src/a2a/client/transports/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable

from a2a.client.middleware import ClientCallContext
from a2a.types import (
Expand Down Expand Up @@ -94,17 +94,18 @@
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
]:
"""Reconnects to get task updates."""
return
yield

@abstractmethod
async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the AgentCard."""

Check notice on line 108 in src/a2a/client/transports/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/client.py (179-190)

@abstractmethod
async def close(self) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable


try:
Expand Down Expand Up @@ -223,6 +223,7 @@ async def get_card(
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
card = self.agent_card
Expand All @@ -236,6 +237,9 @@ async def get_card(
metadata=self._get_grpc_metadata(extensions),
)
card = proto_utils.FromProto.agent_card(card_pb)
if signature_verifier is not None:
signature_verifier(card)

self.agent_card = card
self._needs_extended_card = False
return card
Expand Down
14 changes: 11 additions & 3 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging

from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from typing import Any
from uuid import uuid4

Expand Down Expand Up @@ -363,41 +363,45 @@
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card

if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
if signature_verifier is not None:
signature_verifier(card)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)
self.agent_card = card

if not self._needs_extended_card:
return card

request = GetAuthenticatedExtendedCardRequest(id=str(uuid4()))

Check notice on line 404 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (162-396)
payload, modified_kwargs = await self._apply_interceptors(
request.method,
request.model_dump(mode='json', exclude_none=True),
Expand All @@ -413,9 +417,13 @@
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
self.agent_card = response.root.result
card = response.root.result
if signature_verifier is not None:
signature_verifier(card)

self.agent_card = card
self._needs_extended_card = False
return self.agent_card
return card

async def close(self) -> None:
"""Closes the httpx client."""
Expand Down
9 changes: 8 additions & 1 deletion src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging

from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from typing import Any

import httpx
Expand Down Expand Up @@ -159,237 +159,241 @@
yield proto_utils.FromProto.stream_response(event)
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
try:
response = await self.httpx_client.send(request)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_post_request(
self,
target: str,
rpc_request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
return await self._send_request(
self.httpx_client.build_request(
'POST',
f'{self.url}{target}',
json=rpc_request_payload,
**(http_kwargs or {}),
)
)

async def _send_get_request(
self,
target: str,
query_params: dict[str, str],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
return await self._send_request(
self.httpx_client.build_request(
'GET',
f'{self.url}{target}',
params=query_params,
**(http_kwargs or {}),
)
)

async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
_payload, modified_kwargs = await self._apply_interceptors(
request.model_dump(mode='json', exclude_none=True),
modified_kwargs,
context,
)
response_data = await self._send_get_request(
f'/v1/tasks/{request.id}',
{'historyLength': str(request.history_length)}
if request.history_length is not None
else {},
modified_kwargs,
)
task = a2a_pb2.Task()
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""
pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}')
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload,
modified_kwargs,
context,
)
response_data = await self._send_post_request(
f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs
)
task = a2a_pb2.Task()
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""
pb = a2a_pb2.CreateTaskPushNotificationConfigRequest(
parent=f'tasks/{request.task_id}',
config_id=request.push_notification_config.id,
config=proto_utils.ToProto.task_push_notification_config(request),
)
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload, modified_kwargs, context
)
response_data = await self._send_post_request(
f'/v1/tasks/{request.task_id}/pushNotificationConfigs',
payload,
modified_kwargs,
)
config = a2a_pb2.TaskPushNotificationConfig()
ParseDict(response_data, config)
return proto_utils.FromProto.task_push_notification_config(config)

async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""
pb = a2a_pb2.GetTaskPushNotificationConfigRequest(
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
)
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload,
modified_kwargs,
context,
)
response_data = await self._send_get_request(
f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
{},
modified_kwargs,
)
config = a2a_pb2.TaskPushNotificationConfig()
ParseDict(response_data, config)
return proto_utils.FromProto.task_push_notification_config(config)

async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncGenerator[
Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message
]:
"""Reconnects to get task updates."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
modified_kwargs.setdefault('timeout', None)

async with aconnect_sse(
self.httpx_client,
'GET',
f'{self.url}/v1/tasks/{request.id}:subscribe',
**modified_kwargs,
) as event_source:
try:
async for sse in event_source.aiter_sse():
event = a2a_pb2.StreamResponse()
Parse(sse.data, event)
yield proto_utils.FromProto.stream_response(event)
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card

if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
if signature_verifier is not None:
signature_verifier(card)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)
self.agent_card = card

if not self._needs_extended_card:
return card

_, modified_kwargs = await self._apply_interceptors(

Check notice on line 396 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (366-404)
{},
modified_kwargs,
context,
Expand All @@ -398,6 +402,9 @@
'/v1/card', {}, modified_kwargs
)
card = AgentCard.model_validate(response_data)
if signature_verifier is not None:
signature_verifier(card)

self.agent_card = card
self._needs_extended_card = False
return card
Expand Down
28 changes: 28 additions & 0 deletions src/a2a/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import functools
import inspect
import json
import logging

from collections.abc import Callable
from typing import Any
from uuid import uuid4

from a2a.types import (
AgentCard,
Artifact,
MessageSendParams,
Part,
Expand Down Expand Up @@ -340,3 +342,29 @@ def are_modalities_compatible(
return True

return any(x in server_output_modes for x in client_output_modes)


def _clean_empty(d: Any) -> Any:
"""Recursively remove empty strings, lists and dicts from a dictionary."""
if isinstance(d, dict):
cleaned_dict: dict[Any, Any] = {
k: _clean_empty(v) for k, v in d.items()
}
return {k: v for k, v in cleaned_dict.items() if v}
if isinstance(d, list):
cleaned_list: list[Any] = [_clean_empty(v) for v in d]
return [v for v in cleaned_list if v]
return d if d not in ['', [], {}] else None


def canonicalize_agent_card(agent_card: AgentCard) -> str:
"""Canonicalizes the Agent Card JSON according to RFC 8785 (JCS)."""
card_dict = agent_card.model_dump(
exclude={'signatures'},
exclude_defaults=True,
exclude_none=True,
by_alias=True,
)
# Recursively remove empty values
cleaned_dict = _clean_empty(card_dict)
return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True)
28 changes: 28 additions & 0 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,21 @@ def agent_card(
]
if card.additional_interfaces
else None,
signatures=[cls.agent_card_signature(x) for x in card.signatures]
if card.signatures
else None,
)

@classmethod
def agent_card_signature(
cls, signature: types.AgentCardSignature
) -> a2a_pb2.AgentCardSignature:
return a2a_pb2.AgentCardSignature(
protected=signature.protected,
signature=signature.signature,
header=dict_to_struct(signature.header)
if signature.header is not None
else None,
)

@classmethod
Expand Down Expand Up @@ -865,6 +880,19 @@ def agent_card(
]
if card.additional_interfaces
else None,
signatures=[cls.agent_card_signature(x) for x in card.signatures]
if card.signatures
else None,
)

@classmethod
def agent_card_signature(
cls, signature: a2a_pb2.AgentCardSignature
) -> types.AgentCardSignature:
return types.AgentCardSignature(
protected=signature.protected,
signature=signature.signature,
header=json_format.MessageToDict(signature.header),
)

@classmethod
Expand Down
Loading
Loading