Skip to content

Commit 27303bc

Browse files
committed
Add TaskResultHandler unit tests
1 parent b184785 commit 27303bc

File tree

1 file changed

+255
-0
lines changed

1 file changed

+255
-0
lines changed
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
"""Tests for TaskResultHandler."""
2+
3+
from collections.abc import AsyncIterator
4+
from typing import Any
5+
from unittest.mock import AsyncMock, Mock
6+
7+
import anyio
8+
import pytest
9+
10+
from mcp.server.experimental.task_result_handler import TaskResultHandler
11+
from mcp.shared.exceptions import McpError
12+
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
13+
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, QueuedMessage
14+
from mcp.shared.experimental.tasks.resolver import Resolver
15+
from mcp.shared.message import SessionMessage
16+
from mcp.types import (
17+
CallToolResult,
18+
ErrorData,
19+
GetTaskPayloadRequest,
20+
GetTaskPayloadRequestParams,
21+
GetTaskPayloadResult,
22+
JSONRPCRequest,
23+
TaskMetadata,
24+
TextContent,
25+
)
26+
27+
28+
@pytest.fixture
29+
async def store() -> AsyncIterator[InMemoryTaskStore]:
30+
"""Provide a clean store for each test."""
31+
s = InMemoryTaskStore()
32+
yield s
33+
s.cleanup()
34+
35+
36+
@pytest.fixture
37+
def queue() -> InMemoryTaskMessageQueue:
38+
"""Provide a clean queue for each test."""
39+
return InMemoryTaskMessageQueue()
40+
41+
42+
@pytest.fixture
43+
def handler(store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue) -> TaskResultHandler:
44+
"""Provide a handler for each test."""
45+
return TaskResultHandler(store, queue)
46+
47+
48+
@pytest.mark.anyio
49+
async def test_handle_returns_result_for_completed_task(
50+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
51+
) -> None:
52+
"""Test that handle() returns the stored result for a completed task."""
53+
task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task")
54+
result = CallToolResult(content=[TextContent(type="text", text="Done!")])
55+
await store.store_result(task.taskId, result)
56+
await store.update_task(task.taskId, status="completed")
57+
58+
mock_session = Mock()
59+
mock_session.send_message = AsyncMock()
60+
61+
request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task.taskId))
62+
response = await handler.handle(request, mock_session, "req-1")
63+
64+
assert response is not None
65+
assert response.meta is not None
66+
assert "io.modelcontextprotocol/related-task" in response.meta
67+
68+
69+
@pytest.mark.anyio
70+
async def test_handle_raises_for_nonexistent_task(
71+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
72+
) -> None:
73+
"""Test that handle() raises McpError for nonexistent task."""
74+
mock_session = Mock()
75+
request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId="nonexistent"))
76+
77+
with pytest.raises(McpError) as exc_info:
78+
await handler.handle(request, mock_session, "req-1")
79+
80+
assert "not found" in exc_info.value.error.message
81+
82+
83+
@pytest.mark.anyio
84+
async def test_handle_returns_empty_result_when_no_result_stored(
85+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
86+
) -> None:
87+
"""Test that handle() returns minimal result when task completed without stored result."""
88+
task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task")
89+
await store.update_task(task.taskId, status="completed")
90+
91+
mock_session = Mock()
92+
mock_session.send_message = AsyncMock()
93+
94+
request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task.taskId))
95+
response = await handler.handle(request, mock_session, "req-1")
96+
97+
assert response is not None
98+
assert response.meta is not None
99+
assert "io.modelcontextprotocol/related-task" in response.meta
100+
101+
102+
@pytest.mark.anyio
103+
async def test_handle_delivers_queued_messages(
104+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
105+
) -> None:
106+
"""Test that handle() delivers queued messages before returning."""
107+
task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task")
108+
109+
queued_msg = QueuedMessage(
110+
type="notification",
111+
message=JSONRPCRequest(
112+
jsonrpc="2.0",
113+
id="notif-1",
114+
method="test/notification",
115+
params={},
116+
),
117+
)
118+
await queue.enqueue(task.taskId, queued_msg)
119+
await store.update_task(task.taskId, status="completed")
120+
121+
sent_messages: list[SessionMessage] = []
122+
123+
async def track_send(msg: SessionMessage) -> None:
124+
sent_messages.append(msg)
125+
126+
mock_session = Mock()
127+
mock_session.send_message = track_send
128+
129+
request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task.taskId))
130+
await handler.handle(request, mock_session, "req-1")
131+
132+
assert len(sent_messages) == 1
133+
134+
135+
@pytest.mark.anyio
136+
async def test_handle_waits_for_task_completion(
137+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
138+
) -> None:
139+
"""Test that handle() waits for task to complete before returning."""
140+
task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task")
141+
142+
mock_session = Mock()
143+
mock_session.send_message = AsyncMock()
144+
145+
request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task.taskId))
146+
result_holder: list[GetTaskPayloadResult | None] = [None]
147+
148+
async def run_handle() -> None:
149+
result_holder[0] = await handler.handle(request, mock_session, "req-1")
150+
151+
async with anyio.create_task_group() as tg:
152+
tg.start_soon(run_handle)
153+
await anyio.sleep(0.05)
154+
155+
await store.store_result(task.taskId, CallToolResult(content=[TextContent(type="text", text="Done")]))
156+
await store.update_task(task.taskId, status="completed")
157+
158+
assert result_holder[0] is not None
159+
160+
161+
@pytest.mark.anyio
162+
async def test_route_response_resolves_pending_request(
163+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
164+
) -> None:
165+
"""Test that route_response() resolves a pending request."""
166+
resolver: Resolver[dict[str, Any]] = Resolver()
167+
handler._pending_requests["req-123"] = resolver
168+
169+
result = handler.route_response("req-123", {"status": "ok"})
170+
171+
assert result is True
172+
assert resolver.done()
173+
assert await resolver.wait() == {"status": "ok"}
174+
175+
176+
@pytest.mark.anyio
177+
async def test_route_response_returns_false_for_unknown_request(
178+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
179+
) -> None:
180+
"""Test that route_response() returns False for unknown request ID."""
181+
result = handler.route_response("unknown-req", {"status": "ok"})
182+
assert result is False
183+
184+
185+
@pytest.mark.anyio
186+
async def test_route_response_returns_false_for_already_done_resolver(
187+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
188+
) -> None:
189+
"""Test that route_response() returns False if resolver already completed."""
190+
resolver: Resolver[dict[str, Any]] = Resolver()
191+
resolver.set_result({"already": "done"})
192+
handler._pending_requests["req-123"] = resolver
193+
194+
result = handler.route_response("req-123", {"new": "data"})
195+
196+
assert result is False
197+
198+
199+
@pytest.mark.anyio
200+
async def test_route_error_resolves_pending_request_with_exception(
201+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
202+
) -> None:
203+
"""Test that route_error() sets exception on pending request."""
204+
resolver: Resolver[dict[str, Any]] = Resolver()
205+
handler._pending_requests["req-123"] = resolver
206+
207+
error = ErrorData(code=-32600, message="Something went wrong")
208+
result = handler.route_error("req-123", error)
209+
210+
assert result is True
211+
assert resolver.done()
212+
213+
with pytest.raises(McpError) as exc_info:
214+
await resolver.wait()
215+
assert exc_info.value.error.message == "Something went wrong"
216+
217+
218+
@pytest.mark.anyio
219+
async def test_route_error_returns_false_for_unknown_request(
220+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
221+
) -> None:
222+
"""Test that route_error() returns False for unknown request ID."""
223+
error = ErrorData(code=-32600, message="Error")
224+
result = handler.route_error("unknown-req", error)
225+
assert result is False
226+
227+
228+
@pytest.mark.anyio
229+
async def test_deliver_registers_resolver_for_request_messages(
230+
store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler
231+
) -> None:
232+
"""Test that _deliver_queued_messages registers resolvers for request messages."""
233+
task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task")
234+
235+
resolver: Resolver[dict[str, Any]] = Resolver()
236+
queued_msg = QueuedMessage(
237+
type="request",
238+
message=JSONRPCRequest(
239+
jsonrpc="2.0",
240+
id="inner-req-1",
241+
method="elicitation/create",
242+
params={},
243+
),
244+
resolver=resolver,
245+
original_request_id="inner-req-1",
246+
)
247+
await queue.enqueue(task.taskId, queued_msg)
248+
249+
mock_session = Mock()
250+
mock_session.send_message = AsyncMock()
251+
252+
await handler._deliver_queued_messages(task.taskId, mock_session, "outer-req-1")
253+
254+
assert "inner-req-1" in handler._pending_requests
255+
assert handler._pending_requests["inner-req-1"] is resolver

0 commit comments

Comments
 (0)