diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 7e36aeb9..6eb680a3 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -3,12 +3,15 @@ from collections.abc import AsyncGenerator from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch, call +from unittest.mock import AsyncMock, MagicMock, call, patch -import pytest import httpx +import pytest from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution.request_context_builder import ( + RequestContextBuilder, +) from a2a.server.events import ( QueueManager, ) @@ -30,14 +33,15 @@ GetTaskRequest, GetTaskResponse, GetTaskSuccessResponse, + InternalError, JSONRPCErrorResponse, Message, MessageSendConfiguration, MessageSendParams, Part, + PushNotificationConfig, SendMessageRequest, SendMessageSuccessResponse, - PushNotificationConfig, SendStreamingMessageRequest, SendStreamingMessageSuccessResponse, SetTaskPushNotificationConfigRequest, @@ -58,6 +62,7 @@ ) from a2a.utils.errors import ServerError + MINIMAL_TASK: dict[str, Any] = { 'id': 'task_123', 'contextId': 'session-xyz', @@ -642,3 +647,262 @@ async def test_on_resubscribe_no_existing_task_error(self) -> None: assert len(collected_events) == 1 self.assertIsInstance(collected_events[0].root, JSONRPCErrorResponse) assert collected_events[0].root.error == TaskNotFoundError() + + async def test_streaming_not_supported_error( + self, + ) -> None: + """Test that on_message_send_stream raises an error when streaming not supported.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + # Create agent card with streaming capability disabled + self.mock_agent_card.capabilities = AgentCapabilities(streaming=False) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Act & Assert + request = SendStreamingMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + + # Should raise ServerError about streaming not supported + with self.assertRaises(ServerError) as context: + async for _ in handler.on_message_send_stream(request): + pass + + aaa = context.exception + self.assertEqual( + str(context.exception.error.message), + 'Streaming is not supported by the agent', + ) + + async def test_push_notifications_not_supported_error(self) -> None: + """Test that set_push_notification raises an error when push notifications not supported.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + # Create agent card with push notifications capability disabled + self.mock_agent_card.capabilities = AgentCapabilities( + pushNotifications=False, streaming=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Act & Assert + task_push_config = TaskPushNotificationConfig( + taskId='task_123', + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + request = SetTaskPushNotificationConfigRequest( + id='1', params=task_push_config + ) + + # Should raise ServerError about push notifications not supported + with self.assertRaises(ServerError) as context: + await handler.set_push_notification(request) + + self.assertEqual( + str(context.exception.error.message), + 'Push notifications are not supported by the agent', + ) + + async def test_on_get_push_notification_no_push_notifier(self) -> None: + """Test get_push_notification with no push notifier configured.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + # Create request handler without a push notifier + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + self.mock_agent_card.capabilities = AgentCapabilities( + pushNotifications=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + + # Act + get_request = GetTaskPushNotificationConfigRequest( + id='1', params=TaskIdParams(id=mock_task.id) + ) + response = await handler.get_push_notification(get_request) + + # Assert + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertEqual(response.root.error, UnsupportedOperationError()) + + async def test_on_set_push_notification_no_push_notifier(self) -> None: + """Test set_push_notification with no push notifier configured.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + # Create request handler without a push notifier + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + self.mock_agent_card.capabilities = AgentCapabilities( + pushNotifications=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + + # Act + task_push_config = TaskPushNotificationConfig( + taskId=mock_task.id, + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + request = SetTaskPushNotificationConfigRequest( + id='1', params=task_push_config + ) + response = await handler.set_push_notification(request) + + # Assert + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertEqual(response.root.error, UnsupportedOperationError()) + + async def test_on_message_send_internal_error(self) -> None: + """Test on_message_send with an internal error.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Make the request handler raise an Internal error without specifying an error type + async def raise_server_error(*args, **kwargs): + raise ServerError(InternalError(message='Internal Error')) + + # Patch the method to raise an error + with patch.object( + request_handler, 'on_message_send', side_effect=raise_server_error + ): + # Act + request = SendMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + response = await handler.on_message_send(request) + + # Assert + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertIsInstance(response.root.error, InternalError) + + async def test_on_message_stream_internal_error(self) -> None: + """Test on_message_send_stream with an internal error.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Make the request handler raise an Internal error without specifying an error type + async def raise_server_error(*args, **kwargs): + raise ServerError(InternalError(message='Internal Error')) + yield # Need this to make it an async generator + + # Patch the method to raise an error + with patch.object( + request_handler, + 'on_message_send_stream', + return_value=raise_server_error(), + ): + # Act + request = SendStreamingMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + + # Get the single error response + responses = [] + async for response in handler.on_message_send_stream(request): + responses.append(response) + + # Assert + self.assertEqual(len(responses), 1) + self.assertIsInstance(responses[0].root, JSONRPCErrorResponse) + self.assertIsInstance(responses[0].root.error, InternalError) + + async def test_default_request_handler_with_custom_components(self) -> None: + """Test DefaultRequestHandler initialization with custom components.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + mock_queue_manager = AsyncMock(spec=QueueManager) + mock_push_notifier = AsyncMock(spec=PushNotifier) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + + # Act + handler = DefaultRequestHandler( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + queue_manager=mock_queue_manager, + push_notifier=mock_push_notifier, + request_context_builder=mock_request_context_builder, + ) + + # Assert + self.assertEqual(handler.agent_executor, mock_agent_executor) + self.assertEqual(handler.task_store, mock_task_store) + self.assertEqual(handler._queue_manager, mock_queue_manager) + self.assertEqual(handler._push_notifier, mock_push_notifier) + self.assertEqual( + handler._request_context_builder, mock_request_context_builder + ) + + async def test_on_message_send_error_handling(self) -> None: + """Test error handling in on_message_send when consuming raises ServerError.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Let task exist + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + + # Set up consume_and_break_on_interrupt to raise ServerError + async def consume_raises_error(*args, **kwargs): + raise ServerError(error=UnsupportedOperationError()) + + with patch( + 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', + side_effect=consume_raises_error, + ): + # Act + request = SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + **MESSAGE_PAYLOAD, + taskId=mock_task.id, + contextId=mock_task.contextId, + ) + ), + ) + + response = await handler.on_message_send(request) + + # Assert + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertEqual(response.root.error, UnsupportedOperationError()) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 79814577..b116b2cc 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -230,7 +230,7 @@ def test_cancel_task(client: TestClient, handler: mock.AsyncMock): """Test cancelling a task.""" # Setup mock response task_status = TaskStatus(**MINIMAL_TASK_STATUS) - task_status.state = TaskState.canceled # 'cancelled' # + task_status.state = TaskState.canceled # 'cancelled' # task = Task( id='task1', contextId='ctx1', state='cancelled', status=task_status ) @@ -543,7 +543,7 @@ async def stream_generator(): def test_invalid_json(client: TestClient): """Test handling invalid JSON.""" - response = client.post('/', data='This is not JSON') + response = client.post('/', content=b'This is not JSON') # Use bytes assert response.status_code == 200 # JSON-RPC errors still return 200 data = response.json() assert 'error' in data