|
1 | 1 | import unittest |
2 | 2 | import unittest.async_case |
3 | | -from unittest.mock import AsyncMock, patch, MagicMock |
| 3 | + |
| 4 | +from collections.abc import AsyncGenerator |
| 5 | +from typing import Any |
| 6 | +from unittest.mock import AsyncMock, MagicMock, patch |
| 7 | + |
4 | 8 | import pytest |
5 | | -from a2a.server.events.event_queue import EventQueue |
| 9 | + |
6 | 10 | from a2a.server.agent_execution import AgentExecutor |
7 | | -from a2a.utils.errors import ServerError |
| 11 | +from a2a.server.events import ( |
| 12 | + QueueManager, |
| 13 | +) |
| 14 | +from a2a.server.events.event_queue import EventQueue |
8 | 15 | from a2a.server.request_handlers import ( |
9 | 16 | DefaultRequestHandler, |
10 | 17 | JSONRPCHandler, |
11 | 18 | ) |
12 | | -from a2a.server.events import ( |
13 | | - QueueManager, |
14 | | -) |
15 | 19 | from a2a.server.tasks import TaskStore |
16 | 20 | from a2a.types import ( |
17 | | - AgentCard, |
18 | 21 | AgentCapabilities, |
| 22 | + AgentCard, |
| 23 | + Artifact, |
| 24 | + CancelTaskRequest, |
| 25 | + CancelTaskSuccessResponse, |
19 | 26 | GetTaskRequest, |
20 | 27 | GetTaskResponse, |
21 | 28 | GetTaskSuccessResponse, |
22 | | - Task, |
23 | | - TaskQueryParams, |
24 | 29 | JSONRPCErrorResponse, |
25 | | - TaskNotFoundError, |
26 | | - TaskIdParams, |
27 | | - CancelTaskRequest, |
28 | | - CancelTaskSuccessResponse, |
29 | | - UnsupportedOperationError, |
30 | | - SendMessageRequest, |
31 | 30 | Message, |
32 | 31 | MessageSendParams, |
| 32 | + Part, |
| 33 | + SendMessageRequest, |
33 | 34 | SendMessageSuccessResponse, |
34 | 35 | SendStreamingMessageRequest, |
35 | 36 | SendStreamingMessageSuccessResponse, |
| 37 | + Task, |
36 | 38 | TaskArtifactUpdateEvent, |
| 39 | + TaskIdParams, |
| 40 | + TaskNotFoundError, |
| 41 | + TaskQueryParams, |
| 42 | + TaskResubscriptionRequest, |
| 43 | + TaskState, |
| 44 | + TaskStatus, |
37 | 45 | TaskStatusUpdateEvent, |
38 | | - Artifact, |
39 | | - Part, |
40 | 46 | TextPart, |
41 | | - TaskStatus, |
42 | | - TaskState, |
43 | | - TaskResubscriptionRequest, |
| 47 | + UnsupportedOperationError, |
44 | 48 | ) |
45 | | -from collections.abc import AsyncGenerator |
46 | | -from typing import Any |
| 49 | +from a2a.utils.errors import ServerError |
| 50 | + |
47 | 51 |
|
48 | 52 | MINIMAL_TASK: dict[str, Any] = { |
49 | 53 | 'id': 'task_123', |
@@ -316,7 +320,7 @@ async def streaming_coro(): |
316 | 320 | assert isinstance( |
317 | 321 | event.root, SendStreamingMessageSuccessResponse |
318 | 322 | ) |
319 | | - assert collected_events[i].root.result == events[i] |
| 323 | + assert event.root.result == events[i] |
320 | 324 | mock_agent_executor.execute.assert_called_once() |
321 | 325 |
|
322 | 326 | async def test_on_message_stream_new_message_existing_task_success( |
@@ -387,7 +391,7 @@ async def test_on_resubscribe_existing_task_success( |
387 | 391 | request_handler = DefaultRequestHandler( |
388 | 392 | mock_agent_executor, mock_task_store, mock_queue_manager |
389 | 393 | ) |
390 | | - mock_agent_card = MagicMock(spec=AgentCard) |
| 394 | + self.mock_agent_card = MagicMock(spec=AgentCard) |
391 | 395 | handler = JSONRPCHandler(self.mock_agent_card, request_handler) |
392 | 396 | mock_task = Task(**MINIMAL_TASK, history=[]) |
393 | 397 | events: list[Any] = [ |
|
0 commit comments