Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 74 additions & 24 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import AsyncIterator
from typing import cast

from a2a.client.client import (
Client,
Expand All @@ -11,6 +12,13 @@
from a2a.client.errors import A2AClientInvalidStateError
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.extensions.base import Extension
from a2a.extensions.trace import (
AgentInvocation,
CallTypeEnum,
StepAction,
TraceExtension,
)
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
Expand Down Expand Up @@ -41,6 +49,12 @@
self._card = card
self._config = config
self._transport = transport
self._extensions: list[Extension] = []

def install_extension(self, extension: Extension) -> None:
"""Installs an extension on the client."""
extension.install(self)
self._extensions.append(extension)

async def send_message(
self,
Expand All @@ -61,6 +75,31 @@
Yields:
An async iterator of `ClientEvent` or a final `Message` response.
"""
trace_extension: TraceExtension | None = None
for extension in self._extensions:
if isinstance(extension, TraceExtension):
trace_extension = cast(TraceExtension, extension)

Check failure on line 81 in src/a2a/client/base_client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (TC006)

src/a2a/client/base_client.py:81:40: TC006 Add quotes to type expression in `typing.cast()`
extension.on_client_message(request)

step = None
if trace_extension:
trace_id = request.metadata.get('trace', {}).get('trace_id')

Check failure on line 86 in src/a2a/client/base_client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

"get" is not a known attribute of "None" (reportOptionalMemberAccess)
parent_step_id = request.metadata.get('trace', {}).get(

Check failure on line 87 in src/a2a/client/base_client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

"get" is not a known attribute of "None" (reportOptionalMemberAccess)
'parent_step_id'
)
step = trace_extension.start_step(
trace_id=trace_id,
parent_step_id=parent_step_id,
call_type=CallTypeEnum.AGENT,
step_action=StepAction(
agent_invocation=AgentInvocation(
agent_url=self._card.url,
agent_name=self._card.name,
requests=request.model_dump(mode='json'),
)
),
)

config = MessageSendConfiguration(
accepted_output_modes=self._config.accepted_output_modes,
blocking=not self._config.polling,
Expand All @@ -72,33 +111,44 @@
)
params = MessageSendParams(message=request, configuration=config)

if not self._config.streaming or not self._card.capabilities.streaming:
response = await self._transport.send_message(
try:
if (
not self._config.streaming
or not self._card.capabilities.streaming
):
response = await self._transport.send_message(
params, context=context
)
result = (
(response, None)
if isinstance(response, Task)
else response
)
await self.consume(result, self._card)
yield result
return

tracker = ClientTaskManager()
stream = self._transport.send_message_streaming(
params, context=context
)
result = (
(response, None) if isinstance(response, Task) else response
)
await self.consume(result, self._card)
yield result
return

tracker = ClientTaskManager()
stream = self._transport.send_message_streaming(params, context=context)

first_event = await anext(stream)
# The response from a server may be either exactly one Message or a
# series of Task updates. Separate out the first message for special
# case handling, which allows us to simplify further stream processing.
if isinstance(first_event, Message):
await self.consume(first_event, self._card)
yield first_event
return

yield await self._process_response(tracker, first_event)

async for event in stream:
yield await self._process_response(tracker, event)
first_event = await anext(stream)
# The response from a server may be either exactly one Message or a
# series of Task updates. Separate out the first message for special
# case handling, which allows us to simplify further stream processing.
if isinstance(first_event, Message):
await self.consume(first_event, self._card)
yield first_event
return

yield await self._process_response(tracker, first_event)

async for event in stream:
yield await self._process_response(tracker, event)
finally:
if trace_extension and step:
trace_extension.end_step(step.step_id)

async def _process_response(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/a2a/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,0 +1,6 @@
"""A2A extensions."""

from a2a.extensions.base import Extension
from a2a.extensions import common, trace

Check failure on line 4 in src/a2a/extensions/__init__.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (I001)

src/a2a/extensions/__init__.py:3:1: I001 Import block is un-sorted or un-formatted

__all__ = ['Extension', 'common', 'trace']
26 changes: 26 additions & 0 deletions src/a2a/extensions/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

Check failure on line 3 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (I001)

src/a2a/extensions/base.py:1:1: I001 Import block is un-sorted or un-formatted

if TYPE_CHECKING:
from a2a.client.client import A2AClient

Check failure on line 6 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

"A2AClient" is unknown import symbol (reportAttributeAccessIssue)
from a2a.server.server import A2AServer


class Extension:
"""Base class for all extensions."""

def __init__(self, **kwargs: Any) -> None:
...

def on_client_message(self, message: Any) -> None:
"""Called when a message is sent from the client."""
Comment on lines +16 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety, the message parameter in on_client_message should be typed as Message instead of Any. Based on the usage in base_client.py and default_request_handler.py, this parameter is always an instance of a2a.types.Message.

To make this work, you'll also need to add from a2a.types import Message inside the if TYPE_CHECKING: block at the top of the file to avoid circular imports.

Suggested change
def on_client_message(self, message: Any) -> None:
"""Called when a message is sent from the client."""
def on_client_message(self, message: Message) -> None:

...

Check failure on line 18 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (PIE790)

src/a2a/extensions/base.py:18:9: PIE790 Unnecessary `...` literal

def on_server_message(self, message: Any) -> None:
"""Called when a message is received by the server."""
Comment on lines +20 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety, the message parameter in on_server_message should be typed as Message instead of Any. Based on the usage in base_client.py and default_request_handler.py, this parameter is always an instance of a2a.types.Message.

To make this work, you'll also need to add from a2a.types import Message inside the if TYPE_CHECKING: block at the top of the file to avoid circular imports.

Suggested change
def on_server_message(self, message: Any) -> None:
"""Called when a message is received by the server."""
def on_server_message(self, message: Message) -> None:

...

Check failure on line 22 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (PIE790)

src/a2a/extensions/base.py:22:9: PIE790 Unnecessary `...` literal

def install(self, client_or_server: A2AClient | A2AServer) -> None:
"""Called when the extension is installed on a client or server."""
...

Check failure on line 26 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (PIE790)

src/a2a/extensions/base.py:26:9: PIE790 Unnecessary `...` literal
146 changes: 146 additions & 0 deletions src/a2a/extensions/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

import time

Check failure on line 3 in src/a2a/extensions/trace.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (F401)

src/a2a/extensions/trace.py:3:8: F401 `time` imported but unused
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Any

from a2a._base import A2ABaseModel
from a2a.extensions.base import Extension

Check failure on line 10 in src/a2a/extensions/trace.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (I001)

src/a2a/extensions/trace.py:1:1: I001 Import block is un-sorted or un-formatted


class CallTypeEnum(str, Enum):
"""The type of the operation a step represents."""

AGENT = 'AGENT'
TOOL = 'TOOL'


class ToolInvocation(A2ABaseModel):
"""A tool invocation."""

tool_name: str
parameters: dict[str, Any]


class AgentInvocation(A2ABaseModel):
"""An agent invocation."""

agent_url: str
agent_name: str
requests: dict[str, Any]
response_trace: ResponseTrace | None = None


class StepAction(A2ABaseModel):
"""The action of a step."""

tool_invocation: ToolInvocation | None = None
agent_invocation: AgentInvocation | None = None


class Step(A2ABaseModel):
"""A single operation within a trace."""

step_id: str
trace_id: str
parent_step_id: str | None = None
call_type: CallTypeEnum
step_action: StepAction
cost: int | None = None
total_tokens: int | None = None
additional_attributes: dict[str, str] | None = None
latency: int | None = None
start_time: datetime
end_time: datetime | None = None


class ResponseTrace(A2ABaseModel):
"""A trace message that contains a collection of spans."""

trace_id: str
steps: list[Step]


class TraceExtension(Extension):
"""An extension for traceability."""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.traces: dict[str, ResponseTrace] = {}
self._current_steps: dict[str, Step] = {}

def _generate_id(self, prefix: str) -> str:
return f'{prefix}-{uuid.uuid4()}'

def start_trace(self) -> ResponseTrace:
"""Starts a new trace."""
trace_id = self._generate_id('trace')
trace = ResponseTrace(trace_id=trace_id, steps=[])
self.traces[trace_id] = trace
return trace

def start_step(
self,
trace_id: str,
parent_step_id: str | None,
call_type: CallTypeEnum,
step_action: StepAction,
) -> Step:
"""Starts a new step."""
step_id = self._generate_id('step')
step = Step(
step_id=step_id,
trace_id=trace_id,
parent_step_id=parent_step_id,
call_type=call_type,
step_action=step_action,
start_time=datetime.now(timezone.utc),
)
self._current_steps[step_id] = step
return step

def end_step(
self,
step_id: str,
cost: int | None = None,
total_tokens: int | None = None,
additional_attributes: dict[str, str] | None = None,
) -> None:
"""Ends a step."""
if step_id not in self._current_steps:
return

step = self._current_steps.pop(step_id)
step.end_time = datetime.now(timezone.utc)
step.latency = int(
(step.end_time - step.start_time).total_seconds() * 1000
)
step.cost = cost
step.total_tokens = total_tokens
step.additional_attributes = additional_attributes

if step.trace_id in self.traces:
self.traces[step.trace_id].steps.append(step)

def on_client_message(self, message: Any) -> None:
"""Appends trace information to the message."""
trace = self.start_trace()
if message.metadata is None:
message.metadata = {}
message.metadata['trace'] = trace.model_dump(mode='json')

def on_server_message(self, message: Any) -> None:
"""Processes trace information from the message."""
if (
hasattr(message, 'metadata')
and message.metadata is not None
and 'trace' in message.metadata
):
trace_data = message.metadata['trace']
trace = ResponseTrace.model_validate(trace_data)
self.traces[trace.trace_id] = trace


AgentInvocation.model_rebuild()
7 changes: 6 additions & 1 deletion src/a2a/server/agent_execution/agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any

from a2a.server.agent_execution.context import RequestContext
from a2a.server.events.event_queue import EventQueue
Expand All @@ -13,7 +14,10 @@ class AgentExecutor(ABC):

@abstractmethod
async def execute(
self, context: RequestContext, event_queue: EventQueue
self,
context: RequestContext,
event_queue: EventQueue,
request_handler: Any,
) -> None:
"""Execute the agent's logic for a given request context.

Expand All @@ -26,6 +30,7 @@ async def execute(
Args:
context: The request context containing the message, task ID, etc.
event_queue: The queue to publish events to.
request_handler: The request handler that is executing the agent.
"""

@abstractmethod
Expand Down
Loading
Loading