Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/a2a/server/request_handlers/jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
GetTaskRequest,
ListTaskPushNotificationConfigRequest,
ListTasksRequest,
ListTasksResponse,
Message,
SendMessageRequest,
SendMessageResponse,
Expand Down Expand Up @@ -388,25 +387,27 @@ async def list_tasks(
self,
request: ListTasksRequest,
context: ServerCallContext | None = None,
) -> ListTasksResponse:
) -> dict[str, Any]:
"""Handles the 'tasks/list' JSON-RPC method.

Args:
request: The incoming `ListTasksRequest` object.
context: Context provided by the server.

Returns:
A `ListTasksResponse` object containing the Task or a JSON-RPC error.
A dict representing the JSON-RPC response.
"""
request_id = self._get_request_id(context)
try:
result = await self.request_handler.on_list_tasks(request, context)
except ServerError:
return ListTasksResponse(
# This needs to be appropriately handled since error fields on proto messages
# might be different from the old pydantic models
# Ignoring proto error handling for now as it diverges from the current pattern
response = await self.request_handler.on_list_tasks(
request, context
)
result = MessageToDict(response, preserving_proto_field_name=False)
return _build_success_response(request_id, result)
except ServerError as e:
return _build_error_response(
request_id, e.error if e.error else InternalError()
)
return result

async def list_push_notification_config(
self,
Expand Down
59 changes: 59 additions & 0 deletions tests/integration/test_client_server_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
ListTasksRequest,
ListTasksResponse,
)
from cryptography.hazmat.primitives import asymmetric
from cryptography.hazmat.primitives.asymmetric import ec
Expand Down Expand Up @@ -91,6 +93,11 @@
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
)

LIST_TASKS_RESPONSE = ListTasksResponse(
tasks=[TASK_FROM_BLOCKING, GET_TASK_RESPONSE],
next_page_token='page-2',
)


def create_key_provider(verification_key: Any):
"""Creates a key provider function for testing."""
Expand Down Expand Up @@ -121,6 +128,7 @@ async def stream_side_effect(*args, **kwargs):
# Configure other methods
handler.on_get_task.return_value = GET_TASK_RESPONSE
handler.on_cancel_task.return_value = CANCEL_TASK_RESPONSE
handler.on_list_tasks.return_value = LIST_TASKS_RESPONSE
handler.on_create_task_push_notification_config.return_value = (
CALLBACK_CONFIG
)
Expand Down Expand Up @@ -450,6 +458,57 @@ def channel_factory(address: str) -> Channel:
await transport.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'transport_setup_fixture',
[
pytest.param('jsonrpc_setup', id='JSON-RPC'),
pytest.param('rest_setup', id='REST'),
],
)
async def test_http_transport_list_tasks(
transport_setup_fixture: str, request
) -> None:
transport_setup: TransportSetup = request.getfixturevalue(
transport_setup_fixture
)
transport = transport_setup.transport
handler = transport_setup.handler

params = ListTasksRequest(page_size=10, page_token='page-1')
result = await transport.list_tasks(request=params)

assert len(result.tasks) == 2
assert result.next_page_token == 'page-2'
handler.on_list_tasks.assert_awaited_once()

if hasattr(transport, 'close'):
await transport.close()


@pytest.mark.asyncio
async def test_grpc_transport_list_tasks(
grpc_server_and_handler: tuple[str, AsyncMock],
agent_card: AgentCard,
) -> None:
server_address, handler = grpc_server_and_handler

def channel_factory(address: str) -> Channel:
return grpc.aio.insecure_channel(address)

channel = channel_factory(server_address)
transport = GrpcTransport(channel=channel, agent_card=agent_card)

params = ListTasksRequest(page_size=10, page_token='page-1')
result = await transport.list_tasks(request=params)

assert len(result.tasks) == 2
assert result.next_page_token == 'page-2'
handler.on_list_tasks.assert_awaited_once()

await transport.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'transport_setup_fixture',
Expand Down
26 changes: 24 additions & 2 deletions tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,30 @@ async def test_on_list_tasks_success(self) -> None:
response = await handler.list_tasks(request, call_context)

request_handler.on_list_tasks.assert_awaited_once()
self.assertIsInstance(response, ListTasksResponse)
self.assertEqual(response, mock_result)
self.assertIsInstance(response, dict)
self.assertTrue(is_success_response(response))
self.assertIn('tasks', response['result'])
self.assertEqual(len(response['result']['tasks']), 2)
self.assertEqual(response['result']['nextPageToken'], '123')

async def test_on_list_tasks_error(self) -> None:
request_handler = AsyncMock(spec=DefaultRequestHandler)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)

request_handler.on_list_tasks.side_effect = ServerError(
InternalError(message='DB down')
)
from a2a.types.a2a_pb2 import ListTasksRequest

request = ListTasksRequest(page_size=10)
call_context = ServerCallContext(state={'request_id': '2'})

response = await handler.list_tasks(request, call_context)

request_handler.on_list_tasks.assert_awaited_once()
self.assertIsInstance(response, dict)
self.assertTrue(is_error_response(response))
self.assertEqual(response['error']['message'], 'DB down')

async def test_on_cancel_task_success(self) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
Expand Down
Loading