diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index d2c502023..de168719f 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -26,7 +26,6 @@ GetTaskRequest, ListTaskPushNotificationConfigRequest, ListTasksRequest, - ListTasksResponse, Message, SendMessageRequest, SendMessageResponse, @@ -388,7 +387,7 @@ async def list_tasks( self, request: ListTasksRequest, context: ServerCallContext | None = None, - ) -> ListTasksResponse: + ) -> dict[str, Any]: """Handles the 'tasks/list' JSON-RPC method. Args: @@ -396,17 +395,19 @@ async def list_tasks( context: Context provided by the server. Returns: - A `ListTasksResponse` object containing the Task or a JSON-RPC error. + A dict representing the JSON-RPC response. """ + request_id = self._get_request_id(context) try: - result = await self.request_handler.on_list_tasks(request, context) - except ServerError: - return ListTasksResponse( - # This needs to be appropriately handled since error fields on proto messages - # might be different from the old pydantic models - # Ignoring proto error handling for now as it diverges from the current pattern + response = await self.request_handler.on_list_tasks( + request, context + ) + result = MessageToDict(response, preserving_proto_field_name=False) + return _build_success_response(request_id, result) + except ServerError as e: + return _build_error_response( + request_id, e.error if e.error else InternalError() ) - return result async def list_push_notification_config( self, diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 6acb9b685..011359fc3 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -47,6 +47,8 @@ TaskState, TaskStatus, TaskStatusUpdateEvent, + ListTasksRequest, + ListTasksResponse, ) from cryptography.hazmat.primitives import asymmetric from cryptography.hazmat.primitives.asymmetric import ec @@ -91,6 +93,11 @@ status=TaskStatus(state=TaskState.TASK_STATE_WORKING), ) +LIST_TASKS_RESPONSE = ListTasksResponse( + tasks=[TASK_FROM_BLOCKING, GET_TASK_RESPONSE], + next_page_token='page-2', +) + def create_key_provider(verification_key: Any): """Creates a key provider function for testing.""" @@ -121,6 +128,7 @@ async def stream_side_effect(*args, **kwargs): # Configure other methods handler.on_get_task.return_value = GET_TASK_RESPONSE handler.on_cancel_task.return_value = CANCEL_TASK_RESPONSE + handler.on_list_tasks.return_value = LIST_TASKS_RESPONSE handler.on_create_task_push_notification_config.return_value = ( CALLBACK_CONFIG ) @@ -450,6 +458,57 @@ def channel_factory(address: str) -> Channel: await transport.close() +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_setup_fixture', + [ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ], +) +async def test_http_transport_list_tasks( + transport_setup_fixture: str, request +) -> None: + transport_setup: TransportSetup = request.getfixturevalue( + transport_setup_fixture + ) + transport = transport_setup.transport + handler = transport_setup.handler + + params = ListTasksRequest(page_size=10, page_token='page-1') + result = await transport.list_tasks(request=params) + + assert len(result.tasks) == 2 + assert result.next_page_token == 'page-2' + handler.on_list_tasks.assert_awaited_once() + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_list_tasks( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> None: + server_address, handler = grpc_server_and_handler + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + + params = ListTasksRequest(page_size=10, page_token='page-1') + result = await transport.list_tasks(request=params) + + assert len(result.tasks) == 2 + assert result.next_page_token == 'page-2' + handler.on_list_tasks.assert_awaited_once() + + await transport.close() + + @pytest.mark.asyncio @pytest.mark.parametrize( 'transport_setup_fixture', diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 71890e8be..b5a5a07ad 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -190,8 +190,30 @@ async def test_on_list_tasks_success(self) -> None: response = await handler.list_tasks(request, call_context) request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, ListTasksResponse) - self.assertEqual(response, mock_result) + self.assertIsInstance(response, dict) + self.assertTrue(is_success_response(response)) + self.assertIn('tasks', response['result']) + self.assertEqual(len(response['result']['tasks']), 2) + self.assertEqual(response['result']['nextPageToken'], '123') + + async def test_on_list_tasks_error(self) -> None: + request_handler = AsyncMock(spec=DefaultRequestHandler) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + request_handler.on_list_tasks.side_effect = ServerError( + InternalError(message='DB down') + ) + from a2a.types.a2a_pb2 import ListTasksRequest + + request = ListTasksRequest(page_size=10) + call_context = ServerCallContext(state={'request_id': '2'}) + + response = await handler.list_tasks(request, call_context) + + request_handler.on_list_tasks.assert_awaited_once() + self.assertIsInstance(response, dict) + self.assertTrue(is_error_response(response)) + self.assertEqual(response['error']['message'], 'DB down') async def test_on_cancel_task_success(self) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor)