Skip to content

Commit e8c7c8a

Browse files
committed
Achieve 100% test coverage for tasks code
- Add tests for validation.py (check_sampling_tools_capability, validate_sampling_tools, validate_tool_use_result_messages) - Add flow test for elicit_url() in ServerTaskContext - Add pragma no cover comments to defensive _meta checks in builder methods (model_dump never includes _meta with current types) - Fix test code to use assertions instead of conditional branches - Add pragma no branch to polling loops in test scenarios
1 parent 6eb1b3f commit e8c7c8a

File tree

4 files changed

+242
-19
lines changed

4 files changed

+242
-19
lines changed

src/mcp/server/session.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,8 @@ def _build_elicit_form_request(
517517

518518
# Add related-task metadata if associated with a parent task
519519
if related_task_id is not None:
520-
if "_meta" not in params_data:
520+
# Defensive: model_dump() never includes _meta, but guard against future changes
521+
if "_meta" not in params_data: # pragma: no cover
521522
params_data["_meta"] = {}
522523
params_data["_meta"]["io.modelcontextprotocol/related-task"] = {"taskId": related_task_id}
523524

@@ -559,7 +560,8 @@ def _build_elicit_url_request(
559560

560561
# Add related-task metadata if associated with a parent task
561562
if related_task_id is not None:
562-
if "_meta" not in params_data:
563+
# Defensive: model_dump() never includes _meta, but guard against future changes
564+
if "_meta" not in params_data: # pragma: no cover
563565
params_data["_meta"] = {}
564566
params_data["_meta"]["io.modelcontextprotocol/related-task"] = {"taskId": related_task_id}
565567

@@ -626,7 +628,8 @@ def _build_create_message_request(
626628

627629
# Add related-task metadata if associated with a parent task
628630
if related_task_id is not None:
629-
if "_meta" not in params_data:
631+
# Defensive: model_dump() never includes _meta, but guard against future changes
632+
if "_meta" not in params_data: # pragma: no cover
630633
params_data["_meta"] = {}
631634
params_data["_meta"]["io.modelcontextprotocol/related-task"] = {"taskId": related_task_id}
632635

tests/experimental/tasks/server/test_server_task_context.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,78 @@ async def run_elicit() -> None:
370370
store.cleanup()
371371

372372

373+
@pytest.mark.anyio
374+
async def test_elicit_url_queues_request_and_waits_for_response() -> None:
375+
"""Test that elicit_url() queues request and waits for response."""
376+
import anyio
377+
378+
from mcp.types import JSONRPCRequest
379+
380+
store = InMemoryTaskStore()
381+
queue = InMemoryTaskMessageQueue()
382+
handler = TaskResultHandler(store, queue)
383+
task = await store.create_task(TaskMetadata(ttl=60000))
384+
385+
mock_session = Mock()
386+
mock_session.check_client_capability = Mock(return_value=True)
387+
mock_session._build_elicit_url_request = Mock(
388+
return_value=JSONRPCRequest(
389+
jsonrpc="2.0",
390+
id="test-url-req-1",
391+
method="elicitation/create",
392+
params={"message": "Authorize", "url": "https://example.com", "elicitationId": "123", "mode": "url"},
393+
)
394+
)
395+
396+
ctx = ServerTaskContext(
397+
task=task,
398+
store=store,
399+
session=mock_session,
400+
queue=queue,
401+
handler=handler,
402+
)
403+
404+
elicit_result = None
405+
406+
async def run_elicit_url() -> None:
407+
nonlocal elicit_result
408+
elicit_result = await ctx.elicit_url(
409+
message="Authorize",
410+
url="https://example.com/oauth",
411+
elicitation_id="oauth-123",
412+
)
413+
414+
async with anyio.create_task_group() as tg:
415+
tg.start_soon(run_elicit_url)
416+
417+
# Wait for request to be queued
418+
await queue.wait_for_message(task.taskId)
419+
420+
# Verify task is in input_required status
421+
updated_task = await store.get_task(task.taskId)
422+
assert updated_task is not None
423+
assert updated_task.status == "input_required"
424+
425+
# Dequeue and simulate response
426+
msg = await queue.dequeue(task.taskId)
427+
assert msg is not None
428+
assert msg.resolver is not None
429+
430+
# Resolve with mock elicitation response (URL mode just returns action)
431+
msg.resolver.set_result({"action": "accept"})
432+
433+
# Verify result
434+
assert elicit_result is not None
435+
assert elicit_result.action == "accept"
436+
437+
# Verify task is back to working
438+
final_task = await store.get_task(task.taskId)
439+
assert final_task is not None
440+
assert final_task.status == "working"
441+
442+
store.cleanup()
443+
444+
373445
@pytest.mark.anyio
374446
async def test_create_message_queues_request_and_waits_for_response() -> None:
375447
"""Test that create_message() queues request and waits for response."""

tests/experimental/tasks/test_elicitation_scenarios.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ async def handle_get_task_result(
9595
) -> GetTaskPayloadResult | ErrorData:
9696
"""Handle tasks/result from server."""
9797
event = task_complete_events.get(params.taskId)
98-
if event:
99-
await event.wait()
98+
assert event is not None, f"No completion event for task: {params.taskId}"
99+
await event.wait()
100100
result = await client_task_store.get_result(params.taskId)
101101
assert result is not None, f"Result not found for task: {params.taskId}"
102102
return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True))
@@ -163,8 +163,8 @@ async def handle_get_task_result(
163163
) -> GetTaskPayloadResult | ErrorData:
164164
"""Handle tasks/result from server."""
165165
event = task_complete_events.get(params.taskId)
166-
if event:
167-
await event.wait()
166+
assert event is not None, f"No completion event for task: {params.taskId}"
167+
await event.wait()
168168
result = await client_task_store.get_result(params.taskId)
169169
assert result is not None, f"Result not found for task: {params.taskId}"
170170
return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True))
@@ -417,9 +417,12 @@ async def run_client() -> None:
417417
assert create_result.task.status == "working"
418418

419419
# Poll until input_required, then call tasks/result
420-
async for status in client_session.experimental.poll_task(task_id):
421-
if status.status == "input_required":
420+
found_input_required = False
421+
async for status in client_session.experimental.poll_task(task_id): # pragma: no branch
422+
if status.status == "input_required": # pragma: no branch
423+
found_input_required = True
422424
break
425+
assert found_input_required, "Expected to see input_required status"
423426

424427
# This will deliver the elicitation and get the response
425428
final_result = await client_session.experimental.get_task_result(task_id, CallToolResult)
@@ -523,9 +526,12 @@ async def run_client() -> None:
523526
assert create_result.task.status == "working"
524527

525528
# Poll until input_required or terminal, then call tasks/result
526-
async for status in client_session.experimental.poll_task(task_id):
527-
if status.status == "input_required" or is_terminal(status.status):
529+
found_expected_status = False
530+
async for status in client_session.experimental.poll_task(task_id): # pragma: no branch
531+
if status.status == "input_required" or is_terminal(status.status): # pragma: no branch
532+
found_expected_status = True
528533
break
534+
assert found_expected_status, "Expected to see input_required or terminal status"
529535

530536
# This will deliver the task-augmented elicitation,
531537
# server will poll client, and eventually return the tool result
@@ -581,9 +587,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu
581587
ttl=60000,
582588
)
583589

584-
response_text = ""
585-
if isinstance(result.content, TextContent):
586-
response_text = result.content.text
590+
assert isinstance(result.content, TextContent), "Expected TextContent response"
591+
response_text = result.content.text
587592

588593
tool_result.append(response_text)
589594
return CallToolResult(content=[TextContent(type="text", text=response_text)])
@@ -671,9 +676,8 @@ async def work(task: ServerTaskContext) -> CallToolResult:
671676
ttl=60000,
672677
)
673678

674-
response_text = ""
675-
if isinstance(result.content, TextContent):
676-
response_text = result.content.text
679+
assert isinstance(result.content, TextContent), "Expected TextContent response"
680+
response_text = result.content.text
677681

678682
work_completed.set()
679683
return CallToolResult(content=[TextContent(type="text", text=response_text)])
@@ -710,9 +714,12 @@ async def run_client() -> None:
710714
assert create_result.task.status == "working"
711715

712716
# Poll until input_required or terminal
713-
async for status in client_session.experimental.poll_task(task_id):
714-
if status.status == "input_required" or is_terminal(status.status):
717+
found_expected_status = False
718+
async for status in client_session.experimental.poll_task(task_id): # pragma: no branch
719+
if status.status == "input_required" or is_terminal(status.status): # pragma: no branch
720+
found_expected_status = True
715721
break
722+
assert found_expected_status, "Expected to see input_required or terminal status"
716723

717724
final_result = await client_session.experimental.get_task_result(task_id, CallToolResult)
718725

tests/server/test_validation.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Tests for server validation functions."""
2+
3+
import pytest
4+
5+
from mcp.server.validation import (
6+
check_sampling_tools_capability,
7+
validate_sampling_tools,
8+
validate_tool_use_result_messages,
9+
)
10+
from mcp.shared.exceptions import McpError
11+
from mcp.types import (
12+
ClientCapabilities,
13+
SamplingCapability,
14+
SamplingMessage,
15+
SamplingToolsCapability,
16+
TextContent,
17+
Tool,
18+
ToolChoice,
19+
ToolResultContent,
20+
ToolUseContent,
21+
)
22+
23+
24+
class TestCheckSamplingToolsCapability:
25+
"""Tests for check_sampling_tools_capability function."""
26+
27+
def test_returns_false_when_caps_none(self) -> None:
28+
"""Returns False when client_caps is None."""
29+
assert check_sampling_tools_capability(None) is False
30+
31+
def test_returns_false_when_sampling_none(self) -> None:
32+
"""Returns False when client_caps.sampling is None."""
33+
caps = ClientCapabilities()
34+
assert check_sampling_tools_capability(caps) is False
35+
36+
def test_returns_false_when_tools_none(self) -> None:
37+
"""Returns False when client_caps.sampling.tools is None."""
38+
caps = ClientCapabilities(sampling=SamplingCapability())
39+
assert check_sampling_tools_capability(caps) is False
40+
41+
def test_returns_true_when_tools_present(self) -> None:
42+
"""Returns True when sampling.tools is present."""
43+
caps = ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability()))
44+
assert check_sampling_tools_capability(caps) is True
45+
46+
47+
class TestValidateSamplingTools:
48+
"""Tests for validate_sampling_tools function."""
49+
50+
def test_no_error_when_tools_none(self) -> None:
51+
"""No error when tools and tool_choice are None."""
52+
validate_sampling_tools(None, None, None) # Should not raise
53+
54+
def test_raises_when_tools_provided_but_no_capability(self) -> None:
55+
"""Raises McpError when tools provided but client doesn't support."""
56+
tool = Tool(name="test", inputSchema={"type": "object"})
57+
with pytest.raises(McpError) as exc_info:
58+
validate_sampling_tools(None, [tool], None)
59+
assert "sampling tools capability" in str(exc_info.value)
60+
61+
def test_raises_when_tool_choice_provided_but_no_capability(self) -> None:
62+
"""Raises McpError when tool_choice provided but client doesn't support."""
63+
with pytest.raises(McpError) as exc_info:
64+
validate_sampling_tools(None, None, ToolChoice(mode="auto"))
65+
assert "sampling tools capability" in str(exc_info.value)
66+
67+
def test_no_error_when_capability_present(self) -> None:
68+
"""No error when client has sampling.tools capability."""
69+
caps = ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability()))
70+
tool = Tool(name="test", inputSchema={"type": "object"})
71+
validate_sampling_tools(caps, [tool], ToolChoice(mode="auto")) # Should not raise
72+
73+
74+
class TestValidateToolUseResultMessages:
75+
"""Tests for validate_tool_use_result_messages function."""
76+
77+
def test_no_error_for_empty_messages(self) -> None:
78+
"""No error when messages list is empty."""
79+
validate_tool_use_result_messages([]) # Should not raise
80+
81+
def test_no_error_for_simple_text_messages(self) -> None:
82+
"""No error for simple text messages."""
83+
messages = [
84+
SamplingMessage(role="user", content=TextContent(type="text", text="Hello")),
85+
SamplingMessage(role="assistant", content=TextContent(type="text", text="Hi")),
86+
]
87+
validate_tool_use_result_messages(messages) # Should not raise
88+
89+
def test_raises_when_tool_result_mixed_with_other_content(self) -> None:
90+
"""Raises when tool_result is mixed with other content types."""
91+
messages = [
92+
SamplingMessage(
93+
role="user",
94+
content=[
95+
ToolResultContent(type="tool_result", toolUseId="123"),
96+
TextContent(type="text", text="also this"),
97+
],
98+
),
99+
]
100+
with pytest.raises(ValueError, match="only tool_result content"):
101+
validate_tool_use_result_messages(messages)
102+
103+
def test_raises_when_tool_result_without_previous_tool_use(self) -> None:
104+
"""Raises when tool_result appears without preceding tool_use."""
105+
messages = [
106+
SamplingMessage(
107+
role="user",
108+
content=ToolResultContent(type="tool_result", toolUseId="123"),
109+
),
110+
]
111+
with pytest.raises(ValueError, match="previous message containing tool_use"):
112+
validate_tool_use_result_messages(messages)
113+
114+
def test_raises_when_tool_result_ids_dont_match_tool_use(self) -> None:
115+
"""Raises when tool_result IDs don't match tool_use IDs."""
116+
messages = [
117+
SamplingMessage(
118+
role="assistant",
119+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
120+
),
121+
SamplingMessage(
122+
role="user",
123+
content=ToolResultContent(type="tool_result", toolUseId="tool-2"),
124+
),
125+
]
126+
with pytest.raises(ValueError, match="do not match"):
127+
validate_tool_use_result_messages(messages)
128+
129+
def test_no_error_when_tool_result_matches_tool_use(self) -> None:
130+
"""No error when tool_result IDs match tool_use IDs."""
131+
messages = [
132+
SamplingMessage(
133+
role="assistant",
134+
content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}),
135+
),
136+
SamplingMessage(
137+
role="user",
138+
content=ToolResultContent(type="tool_result", toolUseId="tool-1"),
139+
),
140+
]
141+
validate_tool_use_result_messages(messages) # Should not raise

0 commit comments

Comments
 (0)