Skip to content

Commit b7d44fa

Browse files
committed
Add comprehensive capability tests and improve test coverage
- Add test_capabilities.py with unit tests for all capability checking functions - Add tests for elicit_as_task and create_message_as_task without handler - Add scenario 4 sampling test (task-augmented tool call + task-augmented sampling) - Replace sleep-based polling with event-based synchronization for faster, deterministic tests - Simplify for/else patterns in test code - Add additional check_tasks_capability edge case tests Test coverage improved to 99.94% with 0 missing statements.
1 parent 8cd2765 commit b7d44fa

File tree

3 files changed

+519
-60
lines changed

3 files changed

+519
-60
lines changed

tests/experimental/tasks/server/test_server_task_context.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ async def test_create_message_restores_status_on_cancellation() -> None:
502502
"""Test that create_message() restores task status to working when cancelled."""
503503
import anyio
504504

505-
from mcp.types import JSONRPCRequest, SamplingMessage, TextContent
505+
from mcp.types import JSONRPCRequest, SamplingMessage
506506

507507
store = InMemoryTaskStore()
508508
queue = InMemoryTaskMessageQueue()
@@ -570,3 +570,98 @@ async def do_sampling() -> None:
570570
assert cancelled_error_raised
571571

572572
store.cleanup()
573+
574+
575+
@pytest.mark.anyio
576+
async def test_elicit_as_task_raises_without_handler() -> None:
577+
"""Test that elicit_as_task() raises when handler is not provided."""
578+
from mcp.types import (
579+
ClientCapabilities,
580+
ClientTasksCapability,
581+
ClientTasksRequestsCapability,
582+
Implementation,
583+
InitializeRequestParams,
584+
TasksCreateElicitationCapability,
585+
TasksElicitationCapability,
586+
)
587+
588+
store = InMemoryTaskStore()
589+
queue = InMemoryTaskMessageQueue()
590+
task = await store.create_task(TaskMetadata(ttl=60000))
591+
592+
# Create mock session with proper client capabilities
593+
mock_session = Mock()
594+
mock_session.client_params = InitializeRequestParams(
595+
protocolVersion="2025-01-01",
596+
capabilities=ClientCapabilities(
597+
tasks=ClientTasksCapability(
598+
requests=ClientTasksRequestsCapability(
599+
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability())
600+
)
601+
)
602+
),
603+
clientInfo=Implementation(name="test", version="1.0"),
604+
)
605+
606+
ctx = ServerTaskContext(
607+
task=task,
608+
store=store,
609+
session=mock_session,
610+
queue=queue,
611+
handler=None, # No handler
612+
)
613+
614+
with pytest.raises(RuntimeError, match="handler is required for elicit_as_task"):
615+
await ctx.elicit_as_task(message="Test?", requestedSchema={"type": "object"})
616+
617+
store.cleanup()
618+
619+
620+
@pytest.mark.anyio
621+
async def test_create_message_as_task_raises_without_handler() -> None:
622+
"""Test that create_message_as_task() raises when handler is not provided."""
623+
from mcp.types import (
624+
ClientCapabilities,
625+
ClientTasksCapability,
626+
ClientTasksRequestsCapability,
627+
Implementation,
628+
InitializeRequestParams,
629+
SamplingMessage,
630+
TasksCreateMessageCapability,
631+
TasksSamplingCapability,
632+
TextContent,
633+
)
634+
635+
store = InMemoryTaskStore()
636+
queue = InMemoryTaskMessageQueue()
637+
task = await store.create_task(TaskMetadata(ttl=60000))
638+
639+
# Create mock session with proper client capabilities
640+
mock_session = Mock()
641+
mock_session.client_params = InitializeRequestParams(
642+
protocolVersion="2025-01-01",
643+
capabilities=ClientCapabilities(
644+
tasks=ClientTasksCapability(
645+
requests=ClientTasksRequestsCapability(
646+
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability())
647+
)
648+
)
649+
),
650+
clientInfo=Implementation(name="test", version="1.0"),
651+
)
652+
653+
ctx = ServerTaskContext(
654+
task=task,
655+
store=store,
656+
session=mock_session,
657+
queue=queue,
658+
handler=None, # No handler
659+
)
660+
661+
with pytest.raises(RuntimeError, match="handler is required for create_message_as_task"):
662+
await ctx.create_message_as_task(
663+
messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))],
664+
max_tokens=100,
665+
)
666+
667+
store.cleanup()
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
"""Tests for tasks capability checking utilities."""
2+
3+
import pytest
4+
5+
from mcp.shared.exceptions import McpError
6+
from mcp.shared.experimental.tasks.capabilities import (
7+
check_tasks_capability,
8+
has_task_augmented_elicitation,
9+
has_task_augmented_sampling,
10+
require_task_augmented_elicitation,
11+
require_task_augmented_sampling,
12+
)
13+
from mcp.types import (
14+
ClientCapabilities,
15+
ClientTasksCapability,
16+
ClientTasksRequestsCapability,
17+
TasksCreateElicitationCapability,
18+
TasksCreateMessageCapability,
19+
TasksElicitationCapability,
20+
TasksSamplingCapability,
21+
)
22+
23+
24+
class TestCheckTasksCapability:
25+
"""Tests for check_tasks_capability function."""
26+
27+
def test_required_requests_none_returns_true(self) -> None:
28+
"""When required.requests is None, should return True."""
29+
required = ClientTasksCapability()
30+
client = ClientTasksCapability()
31+
assert check_tasks_capability(required, client) is True
32+
33+
def test_client_requests_none_returns_false(self) -> None:
34+
"""When client.requests is None but required.requests is set, should return False."""
35+
required = ClientTasksCapability(requests=ClientTasksRequestsCapability())
36+
client = ClientTasksCapability()
37+
assert check_tasks_capability(required, client) is False
38+
39+
def test_elicitation_required_but_client_missing(self) -> None:
40+
"""When elicitation is required but client doesn't have it."""
41+
required = ClientTasksCapability(
42+
requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability())
43+
)
44+
client = ClientTasksCapability(requests=ClientTasksRequestsCapability())
45+
assert check_tasks_capability(required, client) is False
46+
47+
def test_elicitation_create_required_but_client_missing(self) -> None:
48+
"""When elicitation.create is required but client doesn't have it."""
49+
required = ClientTasksCapability(
50+
requests=ClientTasksRequestsCapability(
51+
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability())
52+
)
53+
)
54+
client = ClientTasksCapability(
55+
requests=ClientTasksRequestsCapability(
56+
elicitation=TasksElicitationCapability() # No create
57+
)
58+
)
59+
assert check_tasks_capability(required, client) is False
60+
61+
def test_elicitation_create_present(self) -> None:
62+
"""When elicitation.create is required and client has it."""
63+
required = ClientTasksCapability(
64+
requests=ClientTasksRequestsCapability(
65+
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability())
66+
)
67+
)
68+
client = ClientTasksCapability(
69+
requests=ClientTasksRequestsCapability(
70+
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability())
71+
)
72+
)
73+
assert check_tasks_capability(required, client) is True
74+
75+
def test_sampling_required_but_client_missing(self) -> None:
76+
"""When sampling is required but client doesn't have it."""
77+
required = ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability()))
78+
client = ClientTasksCapability(requests=ClientTasksRequestsCapability())
79+
assert check_tasks_capability(required, client) is False
80+
81+
def test_sampling_create_message_required_but_client_missing(self) -> None:
82+
"""When sampling.createMessage is required but client doesn't have it."""
83+
required = ClientTasksCapability(
84+
requests=ClientTasksRequestsCapability(
85+
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability())
86+
)
87+
)
88+
client = ClientTasksCapability(
89+
requests=ClientTasksRequestsCapability(
90+
sampling=TasksSamplingCapability() # No createMessage
91+
)
92+
)
93+
assert check_tasks_capability(required, client) is False
94+
95+
def test_sampling_create_message_present(self) -> None:
96+
"""When sampling.createMessage is required and client has it."""
97+
required = ClientTasksCapability(
98+
requests=ClientTasksRequestsCapability(
99+
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability())
100+
)
101+
)
102+
client = ClientTasksCapability(
103+
requests=ClientTasksRequestsCapability(
104+
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability())
105+
)
106+
)
107+
assert check_tasks_capability(required, client) is True
108+
109+
def test_both_elicitation_and_sampling_present(self) -> None:
110+
"""When both elicitation.create and sampling.createMessage are required and client has both."""
111+
required = ClientTasksCapability(
112+
requests=ClientTasksRequestsCapability(
113+
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()),
114+
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()),
115+
)
116+
)
117+
client = ClientTasksCapability(
118+
requests=ClientTasksRequestsCapability(
119+
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()),
120+
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()),
121+
)
122+
)
123+
assert check_tasks_capability(required, client) is True
124+
125+
def test_elicitation_without_create_required(self) -> None:
126+
"""When elicitation is required but not create specifically."""
127+
required = ClientTasksCapability(
128+
requests=ClientTasksRequestsCapability(
129+
elicitation=TasksElicitationCapability() # No create
130+
)
131+
)
132+
client = ClientTasksCapability(
133+
requests=ClientTasksRequestsCapability(
134+
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability())
135+
)
136+
)
137+
assert check_tasks_capability(required, client) is True
138+
139+
def test_sampling_without_create_message_required(self) -> None:
140+
"""When sampling is required but not createMessage specifically."""
141+
required = ClientTasksCapability(
142+
requests=ClientTasksRequestsCapability(
143+
sampling=TasksSamplingCapability() # No createMessage
144+
)
145+
)
146+
client = ClientTasksCapability(
147+
requests=ClientTasksRequestsCapability(
148+
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability())
149+
)
150+
)
151+
assert check_tasks_capability(required, client) is True
152+
153+
154+
class TestHasTaskAugmentedElicitation:
155+
"""Tests for has_task_augmented_elicitation function."""
156+
157+
def test_tasks_none(self) -> None:
158+
"""Returns False when caps.tasks is None."""
159+
caps = ClientCapabilities()
160+
assert has_task_augmented_elicitation(caps) is False
161+
162+
def test_requests_none(self) -> None:
163+
"""Returns False when caps.tasks.requests is None."""
164+
caps = ClientCapabilities(tasks=ClientTasksCapability())
165+
assert has_task_augmented_elicitation(caps) is False
166+
167+
def test_elicitation_none(self) -> None:
168+
"""Returns False when caps.tasks.requests.elicitation is None."""
169+
caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability()))
170+
assert has_task_augmented_elicitation(caps) is False
171+
172+
def test_create_none(self) -> None:
173+
"""Returns False when caps.tasks.requests.elicitation.create is None."""
174+
caps = ClientCapabilities(
175+
tasks=ClientTasksCapability(
176+
requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability())
177+
)
178+
)
179+
assert has_task_augmented_elicitation(caps) is False
180+
181+
def test_create_present(self) -> None:
182+
"""Returns True when full capability path is present."""
183+
caps = ClientCapabilities(
184+
tasks=ClientTasksCapability(
185+
requests=ClientTasksRequestsCapability(
186+
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability())
187+
)
188+
)
189+
)
190+
assert has_task_augmented_elicitation(caps) is True
191+
192+
193+
class TestHasTaskAugmentedSampling:
194+
"""Tests for has_task_augmented_sampling function."""
195+
196+
def test_tasks_none(self) -> None:
197+
"""Returns False when caps.tasks is None."""
198+
caps = ClientCapabilities()
199+
assert has_task_augmented_sampling(caps) is False
200+
201+
def test_requests_none(self) -> None:
202+
"""Returns False when caps.tasks.requests is None."""
203+
caps = ClientCapabilities(tasks=ClientTasksCapability())
204+
assert has_task_augmented_sampling(caps) is False
205+
206+
def test_sampling_none(self) -> None:
207+
"""Returns False when caps.tasks.requests.sampling is None."""
208+
caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability()))
209+
assert has_task_augmented_sampling(caps) is False
210+
211+
def test_create_message_none(self) -> None:
212+
"""Returns False when caps.tasks.requests.sampling.createMessage is None."""
213+
caps = ClientCapabilities(
214+
tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability()))
215+
)
216+
assert has_task_augmented_sampling(caps) is False
217+
218+
def test_create_message_present(self) -> None:
219+
"""Returns True when full capability path is present."""
220+
caps = ClientCapabilities(
221+
tasks=ClientTasksCapability(
222+
requests=ClientTasksRequestsCapability(
223+
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability())
224+
)
225+
)
226+
)
227+
assert has_task_augmented_sampling(caps) is True
228+
229+
230+
class TestRequireTaskAugmentedElicitation:
231+
"""Tests for require_task_augmented_elicitation function."""
232+
233+
def test_raises_when_none(self) -> None:
234+
"""Raises McpError when client_caps is None."""
235+
with pytest.raises(McpError) as exc_info:
236+
require_task_augmented_elicitation(None)
237+
assert "task-augmented elicitation" in str(exc_info.value)
238+
239+
def test_raises_when_missing(self) -> None:
240+
"""Raises McpError when capability is missing."""
241+
caps = ClientCapabilities()
242+
with pytest.raises(McpError) as exc_info:
243+
require_task_augmented_elicitation(caps)
244+
assert "task-augmented elicitation" in str(exc_info.value)
245+
246+
def test_passes_when_present(self) -> None:
247+
"""Does not raise when capability is present."""
248+
caps = ClientCapabilities(
249+
tasks=ClientTasksCapability(
250+
requests=ClientTasksRequestsCapability(
251+
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability())
252+
)
253+
)
254+
)
255+
require_task_augmented_elicitation(caps) # Should not raise
256+
257+
258+
class TestRequireTaskAugmentedSampling:
259+
"""Tests for require_task_augmented_sampling function."""
260+
261+
def test_raises_when_none(self) -> None:
262+
"""Raises McpError when client_caps is None."""
263+
with pytest.raises(McpError) as exc_info:
264+
require_task_augmented_sampling(None)
265+
assert "task-augmented sampling" in str(exc_info.value)
266+
267+
def test_raises_when_missing(self) -> None:
268+
"""Raises McpError when capability is missing."""
269+
caps = ClientCapabilities()
270+
with pytest.raises(McpError) as exc_info:
271+
require_task_augmented_sampling(caps)
272+
assert "task-augmented sampling" in str(exc_info.value)
273+
274+
def test_passes_when_present(self) -> None:
275+
"""Does not raise when capability is present."""
276+
caps = ClientCapabilities(
277+
tasks=ClientTasksCapability(
278+
requests=ClientTasksRequestsCapability(
279+
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability())
280+
)
281+
)
282+
)
283+
require_task_augmented_sampling(caps) # Should not raise

0 commit comments

Comments
 (0)