Skip to content
Merged
6 changes: 2 additions & 4 deletions examples/langgraph/agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from agent import CurrencyAgent # type: ignore[import-untyped]
from typing_extensions import override
from agent import CurrencyAgent # type: ignore[import-untyped]

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events.event_queue import EventQueue
from a2a.types import (
Expand All @@ -17,7 +17,6 @@ class CurrencyAgentExecutor(AgentExecutor):
def __init__(self):
self.agent = CurrencyAgent()

@override
async def execute(
self,
context: RequestContext,
Expand Down Expand Up @@ -89,7 +88,6 @@ async def execute(
)
)

@override
async def cancel(
self, context: RequestContext, event_queue: EventQueue
) -> None:
Expand Down
34 changes: 21 additions & 13 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
EventQueue,
InMemoryQueueManager,
QueueManager,
TaskQueueExists,
)
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.server.tasks import (
Expand Down Expand Up @@ -212,6 +211,15 @@ async def on_message_send(
) = await result_aggregator.consume_and_break_on_interrupt(consumer)
if not result:
raise ServerError(error=InternalError())

if isinstance(result, Task) and task_id != result.id:
logger.error(
f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.'
)
raise ServerError(
InternalError(message='Task ID mismatch in agent response')
)

finally:
if interrupted:
# TODO: Track this disconnected cleanup task.
Expand Down Expand Up @@ -278,27 +286,27 @@ async def on_message_send_stream(
consumer = EventConsumer(queue)
producer_task.add_done_callback(consumer.agent_task_callback)
async for event in result_aggregator.consume_and_emit(consumer):
if isinstance(event, Task) and task_id != event.id:
logger.warning(
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
)
try:
created_task: Task = event
await self._queue_manager.add(created_task.id, queue)
task_id = created_task.id
except TaskQueueExists:
logging.info(
'Multiple Task objects created in event stream.'
if isinstance(event, Task):
if task_id != event.id:
logger.error(
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
)
raise ServerError(
InternalError(
message='Task ID mismatch in agent response'
)
)

if (
self._push_notifier
and params.configuration
and params.configuration.pushNotificationConfig
):
await self._push_notifier.set_info(
created_task.id,
task_id,
params.configuration.pushNotificationConfig,
)

if self._push_notifier and task_id:
latest_task = await result_aggregator.current_result
if isinstance(latest_task, Task):
Expand Down
120 changes: 112 additions & 8 deletions tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import httpx
import pytest

from a2a.server.agent_execution import AgentExecutor

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.agent_execution.request_context_builder import (
RequestContextBuilder,
)
Expand Down Expand Up @@ -59,6 +60,7 @@
TaskStatusUpdateEvent,
TextPart,
UnsupportedOperationError,
InternalError,
)
from a2a.utils.errors import ServerError

Expand Down Expand Up @@ -188,7 +190,12 @@ async def test_on_cancel_task_not_found(self) -> None:
mock_task_store.get.assert_called_once_with('nonexistent_id')
mock_agent_executor.cancel.assert_not_called()

async def test_on_message_new_message_success(self) -> None:
@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
)
async def test_on_message_new_message_success(
self, _mock_builder_build: AsyncMock
) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
Expand All @@ -199,6 +206,14 @@ async def test_on_message_new_message_success(self) -> None:
mock_task_store.get.return_value = mock_task
mock_agent_executor.execute.return_value = None

_mock_builder_build.return_value = RequestContext(
request=MagicMock(),
task_id='task_123',
context_id='session-xyz',
task=None,
related_tasks=None,
)

async def streaming_coro():
yield mock_task

Expand Down Expand Up @@ -284,15 +299,28 @@ async def streaming_coro():
assert response.root.error == UnsupportedOperationError() # type: ignore
mock_agent_executor.execute.assert_called_once()

async def test_on_message_stream_new_message_success(self) -> None:
@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
)
async def test_on_message_stream_new_message_success(
self, _mock_builder_build: AsyncMock
) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)

self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
_mock_builder_build.return_value = RequestContext(
request=MagicMock(),
task_id='task_123',
context_id='session-xyz',
task=None,
related_tasks=None,
)

events: list[Any] = [
Task(**MINIMAL_TASK),
TaskArtifactUpdateEvent(
Expand Down Expand Up @@ -467,8 +495,11 @@ async def test_get_push_notification_success(self) -> None:
)
assert get_response.root.result == task_push_config # type: ignore

@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
)
async def test_on_message_stream_new_message_send_push_notification_success(
self,
self, _mock_builder_build: AsyncMock
) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
Expand All @@ -480,6 +511,13 @@ async def test_on_message_stream_new_message_send_push_notification_success(
self.mock_agent_card.capabilities = AgentCapabilities(
streaming=True, pushNotifications=True
)
_mock_builder_build.return_value = RequestContext(
request=MagicMock(),
task_id='task_123',
context_id='session-xyz',
task=None,
related_tasks=None,
)

handler = JSONRPCHandler(self.mock_agent_card, request_handler)
events: list[Any] = [
Expand Down Expand Up @@ -738,7 +776,8 @@ async def test_on_get_push_notification_no_push_notifier(self) -> None:

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertEqual(response.root.error, UnsupportedOperationError())
self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore


async def test_on_set_push_notification_no_push_notifier(self) -> None:
"""Test set_push_notification with no push notifier configured."""
Expand Down Expand Up @@ -771,7 +810,8 @@ async def test_on_set_push_notification_no_push_notifier(self) -> None:

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertEqual(response.root.error, UnsupportedOperationError())
self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore


async def test_on_message_send_internal_error(self) -> None:
"""Test on_message_send with an internal error."""
Expand Down Expand Up @@ -800,7 +840,8 @@ async def raise_server_error(*args, **kwargs):

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertIsInstance(response.root.error, InternalError)
self.assertIsInstance(response.root.error, InternalError) # type: ignore


async def test_on_message_stream_internal_error(self) -> None:
"""Test on_message_send_stream with an internal error."""
Expand Down Expand Up @@ -906,3 +947,66 @@ async def consume_raises_error(*args, **kwargs):
# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertEqual(response.root.error, UnsupportedOperationError())

async def test_on_message_send_task_id_mismatch(self) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
mock_task = Task(**MINIMAL_TASK)
mock_task_store.get.return_value = mock_task
mock_agent_executor.execute.return_value = None

async def streaming_coro():
yield mock_task

with patch(
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
return_value=streaming_coro(),
):
request = SendMessageRequest(
id='1',
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
)
response = await handler.on_message_send(request)
assert mock_agent_executor.execute.call_count == 1
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertIsInstance(response.root.error, InternalError) # type: ignore

async def test_on_message_stream_task_id_mismatch(self) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)

self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
events: list[Any] = [Task(**MINIMAL_TASK)]

async def streaming_coro():
for event in events:
yield event

with patch(
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
return_value=streaming_coro(),
):
mock_task_store.get.return_value = None
mock_agent_executor.execute.return_value = None
request = SendStreamingMessageRequest(
id='1',
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
)
response = handler.on_message_send_stream(request)
assert isinstance(response, AsyncGenerator)
collected_events: list[Any] = []
async for event in response:
collected_events.append(event)
assert len(collected_events) == 1
self.assertIsInstance(
collected_events[0].root, JSONRPCErrorResponse
)
self.assertIsInstance(collected_events[0].root.error, InternalError)