Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 267 additions & 3 deletions tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch, call
from unittest.mock import AsyncMock, MagicMock, call, patch

import pytest
import httpx
import pytest

from a2a.server.agent_execution import AgentExecutor
from a2a.server.agent_execution.request_context_builder import (
RequestContextBuilder,
)
from a2a.server.events import (
QueueManager,
)
Expand All @@ -30,14 +33,15 @@
GetTaskRequest,
GetTaskResponse,
GetTaskSuccessResponse,
InternalError,
JSONRPCErrorResponse,
Message,
MessageSendConfiguration,
MessageSendParams,
Part,
PushNotificationConfig,
SendMessageRequest,
SendMessageSuccessResponse,
PushNotificationConfig,
SendStreamingMessageRequest,
SendStreamingMessageSuccessResponse,
SetTaskPushNotificationConfigRequest,
Expand All @@ -58,6 +62,7 @@
)
from a2a.utils.errors import ServerError


MINIMAL_TASK: dict[str, Any] = {
'id': 'task_123',
'contextId': 'session-xyz',
Expand Down Expand Up @@ -642,3 +647,262 @@ async def test_on_resubscribe_no_existing_task_error(self) -> None:
assert len(collected_events) == 1
self.assertIsInstance(collected_events[0].root, JSONRPCErrorResponse)
assert collected_events[0].root.error == TaskNotFoundError()

async def test_streaming_not_supported_error(
self,
) -> None:
"""Test that on_message_send_stream raises an error when streaming not supported."""
# Arrange
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
# Create agent card with streaming capability disabled
self.mock_agent_card.capabilities = AgentCapabilities(streaming=False)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)

# Act & Assert
request = SendStreamingMessageRequest(
id='1',
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
)

# Should raise ServerError about streaming not supported
with self.assertRaises(ServerError) as context:
async for _ in handler.on_message_send_stream(request):
pass

aaa = context.exception
self.assertEqual(
str(context.exception.error.message),
'Streaming is not supported by the agent',
)

async def test_push_notifications_not_supported_error(self) -> None:
"""Test that set_push_notification raises an error when push notifications not supported."""
# Arrange
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
# Create agent card with push notifications capability disabled
self.mock_agent_card.capabilities = AgentCapabilities(
pushNotifications=False, streaming=True
)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)

# Act & Assert
task_push_config = TaskPushNotificationConfig(
taskId='task_123',
pushNotificationConfig=PushNotificationConfig(
url='http://example.com'
),
)
request = SetTaskPushNotificationConfigRequest(
id='1', params=task_push_config
)

# Should raise ServerError about push notifications not supported
with self.assertRaises(ServerError) as context:
await handler.set_push_notification(request)

self.assertEqual(
str(context.exception.error.message),
'Push notifications are not supported by the agent',
)

async def test_on_get_push_notification_no_push_notifier(self) -> None:
"""Test get_push_notification with no push notifier configured."""
# Arrange
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
# Create request handler without a push notifier
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
self.mock_agent_card.capabilities = AgentCapabilities(
pushNotifications=True
)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)

mock_task = Task(**MINIMAL_TASK)
mock_task_store.get.return_value = mock_task

# Act
get_request = GetTaskPushNotificationConfigRequest(
id='1', params=TaskIdParams(id=mock_task.id)
)
response = await handler.get_push_notification(get_request)

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertEqual(response.root.error, UnsupportedOperationError())

async def test_on_set_push_notification_no_push_notifier(self) -> None:
"""Test set_push_notification with no push notifier configured."""
# Arrange
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
# Create request handler without a push notifier
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
self.mock_agent_card.capabilities = AgentCapabilities(
pushNotifications=True
)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)

mock_task = Task(**MINIMAL_TASK)
mock_task_store.get.return_value = mock_task

# Act
task_push_config = TaskPushNotificationConfig(
taskId=mock_task.id,
pushNotificationConfig=PushNotificationConfig(
url='http://example.com'
),
)
request = SetTaskPushNotificationConfigRequest(
id='1', params=task_push_config
)
response = await handler.set_push_notification(request)

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertEqual(response.root.error, UnsupportedOperationError())

async def test_on_message_send_internal_error(self) -> None:
"""Test on_message_send with an internal error."""
# Arrange
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)

# Make the request handler raise an Internal error without specifying an error type
async def raise_server_error(*args, **kwargs):
raise ServerError(InternalError(message='Internal Error'))

# Patch the method to raise an error
with patch.object(
request_handler, 'on_message_send', side_effect=raise_server_error
):
# Act
request = SendMessageRequest(
id='1',
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
)
response = await handler.on_message_send(request)

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertIsInstance(response.root.error, InternalError)

async def test_on_message_stream_internal_error(self) -> None:
"""Test on_message_send_stream with an internal error."""
# Arrange
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)

# Make the request handler raise an Internal error without specifying an error type
async def raise_server_error(*args, **kwargs):
raise ServerError(InternalError(message='Internal Error'))
yield # Need this to make it an async generator

# Patch the method to raise an error
with patch.object(
request_handler,
'on_message_send_stream',
return_value=raise_server_error(),
):
# Act
request = SendStreamingMessageRequest(
id='1',
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
)

# Get the single error response
responses = []
async for response in handler.on_message_send_stream(request):
responses.append(response)

# Assert
self.assertEqual(len(responses), 1)
self.assertIsInstance(responses[0].root, JSONRPCErrorResponse)
self.assertIsInstance(responses[0].root.error, InternalError)

async def test_default_request_handler_with_custom_components(self) -> None:
"""Test DefaultRequestHandler initialization with custom components."""
# Arrange
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
mock_queue_manager = AsyncMock(spec=QueueManager)
mock_push_notifier = AsyncMock(spec=PushNotifier)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)

# Act
handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
queue_manager=mock_queue_manager,
push_notifier=mock_push_notifier,
request_context_builder=mock_request_context_builder,
)

# Assert
self.assertEqual(handler.agent_executor, mock_agent_executor)
self.assertEqual(handler.task_store, mock_task_store)
self.assertEqual(handler._queue_manager, mock_queue_manager)
self.assertEqual(handler._push_notifier, mock_push_notifier)
self.assertEqual(
handler._request_context_builder, mock_request_context_builder
)

async def test_on_message_send_error_handling(self) -> None:
"""Test error handling in on_message_send when consuming raises ServerError."""
# Arrange
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)

# Let task exist
mock_task = Task(**MINIMAL_TASK)
mock_task_store.get.return_value = mock_task

# Set up consume_and_break_on_interrupt to raise ServerError
async def consume_raises_error(*args, **kwargs):
raise ServerError(error=UnsupportedOperationError())

with patch(
'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt',
side_effect=consume_raises_error,
):
# Act
request = SendMessageRequest(
id='1',
params=MessageSendParams(
message=Message(
**MESSAGE_PAYLOAD,
taskId=mock_task.id,
contextId=mock_task.contextId,
)
),
)

response = await handler.on_message_send(request)

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertEqual(response.root.error, UnsupportedOperationError())
4 changes: 2 additions & 2 deletions tests/server/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_cancel_task(client: TestClient, handler: mock.AsyncMock):
"""Test cancelling a task."""
# Setup mock response
task_status = TaskStatus(**MINIMAL_TASK_STATUS)
task_status.state = TaskState.canceled # 'cancelled' #
task_status.state = TaskState.canceled # 'cancelled' #
task = Task(
id='task1', contextId='ctx1', state='cancelled', status=task_status
)
Expand Down Expand Up @@ -543,7 +543,7 @@ async def stream_generator():

def test_invalid_json(client: TestClient):
"""Test handling invalid JSON."""
response = client.post('/', data='This is not JSON')
response = client.post('/', content=b'This is not JSON') # Use bytes
assert response.status_code == 200 # JSON-RPC errors still return 200
data = response.json()
assert 'error' in data
Expand Down