Skip to content

Commit b184785

Browse files
committed
Add coverage tests and fix gaps
Coverage improvements: - Add test_server_task_context.py with tests for ServerTaskContext - Add tests for default task handlers (list, cancel, result, augmented sampling/elicitation) - Add test for task-augmented handler capability building - Add tests for elicit/create_message flows including cancellation - Add test for meta parameter in call_tool_as_task - Add test for model_immediate_response parameter - Add tests for run_task error cases Code changes: - Remove backwards compat task_id parameter from ServerTaskContext (require task directly) - Remove type: ignore from request_context.py - Add pragma: no cover for async exception handlers (coverage limitation with async)
1 parent 2f3b792 commit b184785

File tree

6 files changed

+1126
-21
lines changed

6 files changed

+1126
-21
lines changed

src/mcp/server/experimental/request_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ async def work(task: ServerTaskContext) -> CallToolResult:
220220
session=self._session,
221221
queue=support.queue,
222222
handler=support.handler,
223-
) # type: ignore[call-arg]
223+
)
224224

225225
# Spawn the work
226226
async def execute() -> None:

src/mcp/server/experimental/task_context.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from mcp.server.session import ServerSession
1616
from mcp.shared.exceptions import McpError
1717
from mcp.shared.experimental.tasks.context import TaskContext
18-
from mcp.shared.experimental.tasks.helpers import create_task_state
1918
from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue
2019
from mcp.shared.experimental.tasks.resolver import Resolver
2120
from mcp.shared.experimental.tasks.store import TaskStore
@@ -37,7 +36,6 @@
3736
SamplingMessage,
3837
ServerNotification,
3938
Task,
40-
TaskMetadata,
4139
TaskStatusNotification,
4240
TaskStatusNotificationParams,
4341
)
@@ -70,8 +68,7 @@ async def my_task_work(task: ServerTaskContext) -> CallToolResult:
7068
def __init__(
7169
self,
7270
*,
73-
task: Task | None = None,
74-
task_id: str | None = None,
71+
task: Task,
7572
store: TaskStore,
7673
session: ServerSession,
7774
queue: TaskMessageQueue,
@@ -81,23 +78,12 @@ def __init__(
8178
Create a ServerTaskContext.
8279
8380
Args:
84-
task: The Task object (provide either task or task_id)
85-
task_id: The task ID to look up (provide either task or task_id)
81+
task: The Task object
8682
store: The task store
8783
session: The server session
8884
queue: The message queue for elicitation/sampling
8985
handler: The result handler for response routing (required for elicit/create_message)
9086
"""
91-
if task is None and task_id is None:
92-
raise ValueError("Must provide either task or task_id")
93-
if task is not None and task_id is not None:
94-
raise ValueError("Provide either task or task_id, not both")
95-
96-
# If task_id provided, create a minimal task object
97-
# This is for backwards compatibility with tests that pass task_id
98-
if task is None:
99-
task = create_task_state(TaskMetadata(ttl=None), task_id=task_id)
100-
10187
self._ctx = TaskContext(task=task, store=store)
10288
self._session = session
10389
self._queue = queue
@@ -264,7 +250,10 @@ async def elicit(
264250
response_data = await resolver.wait()
265251
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
266252
return ElicitResult.model_validate(response_data)
267-
except anyio.get_cancelled_exc_class():
253+
except anyio.get_cancelled_exc_class(): # pragma: no cover
254+
# Coverage can't track async exception handlers reliably.
255+
# This path is tested in test_elicit_restores_status_on_cancellation
256+
# which verifies status is restored to "working" after cancellation.
268257
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
269258
raise
270259

@@ -347,6 +336,9 @@ async def create_message(
347336
response_data = await resolver.wait()
348337
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
349338
return CreateMessageResult.model_validate(response_data)
350-
except anyio.get_cancelled_exc_class():
339+
except anyio.get_cancelled_exc_class(): # pragma: no cover
340+
# Coverage can't track async exception handlers reliably.
341+
# This path is tested in test_create_message_restores_status_on_cancellation
342+
# which verifies status is restored to "working" after cancellation.
351343
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
352344
raise

tests/experimental/tasks/client/test_capabilities.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,81 @@ async def mock_server():
251251
assert received_capabilities.tasks.cancel is not None
252252
# requests should be None since we didn't provide task-augmented handlers
253253
assert received_capabilities.tasks.requests is None
254+
255+
256+
@pytest.mark.anyio
257+
async def test_client_capabilities_with_task_augmented_handlers():
258+
"""Test that requests capability is built when augmented handlers are provided."""
259+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
260+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
261+
262+
received_capabilities: ClientCapabilities | None = None
263+
264+
# Define task-augmented handler
265+
async def my_augmented_sampling_handler(
266+
context: RequestContext[ClientSession, None],
267+
params: types.CreateMessageRequestParams,
268+
task_metadata: types.TaskMetadata,
269+
) -> types.CreateTaskResult | types.ErrorData:
270+
return types.ErrorData(code=types.INVALID_REQUEST, message="Not implemented")
271+
272+
async def mock_server():
273+
nonlocal received_capabilities
274+
275+
session_message = await client_to_server_receive.receive()
276+
jsonrpc_request = session_message.message
277+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
278+
request = ClientRequest.model_validate(
279+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
280+
)
281+
assert isinstance(request.root, InitializeRequest)
282+
received_capabilities = request.root.params.capabilities
283+
284+
result = ServerResult(
285+
InitializeResult(
286+
protocolVersion=LATEST_PROTOCOL_VERSION,
287+
capabilities=ServerCapabilities(),
288+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
289+
)
290+
)
291+
292+
async with server_to_client_send:
293+
await server_to_client_send.send(
294+
SessionMessage(
295+
JSONRPCMessage(
296+
JSONRPCResponse(
297+
jsonrpc="2.0",
298+
id=jsonrpc_request.root.id,
299+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
300+
)
301+
)
302+
)
303+
)
304+
await client_to_server_receive.receive()
305+
306+
# Provide task-augmented sampling handler
307+
task_handlers = ExperimentalTaskHandlers(
308+
augmented_sampling=my_augmented_sampling_handler,
309+
)
310+
311+
async with (
312+
ClientSession(
313+
server_to_client_receive,
314+
client_to_server_send,
315+
experimental_task_handlers=task_handlers,
316+
) as session,
317+
anyio.create_task_group() as tg,
318+
client_to_server_send,
319+
client_to_server_receive,
320+
server_to_client_send,
321+
server_to_client_receive,
322+
):
323+
tg.start_soon(mock_server)
324+
await session.initialize()
325+
326+
# Assert that tasks capability includes requests.sampling
327+
assert received_capabilities is not None
328+
assert received_capabilities.tasks is not None
329+
assert received_capabilities.tasks.requests is not None
330+
assert received_capabilities.tasks.requests.sampling is not None
331+
assert received_capabilities.tasks.requests.elicitation is None # Not provided

tests/experimental/tasks/client/test_handlers.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,3 +686,201 @@ async def run_client() -> None:
686686
)
687687

688688
tg.cancel_scope.cancel()
689+
690+
691+
@pytest.mark.anyio
692+
async def test_client_returns_error_for_unhandled_task_result_request(client_streams: ClientTestStreams) -> None:
693+
"""Test that client returns error for unhandled tasks/result request."""
694+
with anyio.fail_after(10):
695+
client_ready = anyio.Event()
696+
697+
async with anyio.create_task_group() as tg:
698+
699+
async def run_client() -> None:
700+
async with ClientSession(
701+
client_streams.client_receive,
702+
client_streams.client_send,
703+
message_handler=_default_message_handler,
704+
):
705+
client_ready.set()
706+
await anyio.sleep_forever()
707+
708+
tg.start_soon(run_client)
709+
await client_ready.wait()
710+
711+
typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId="nonexistent"))
712+
request = types.JSONRPCRequest(
713+
jsonrpc="2.0",
714+
id="req-result",
715+
**typed_request.model_dump(by_alias=True),
716+
)
717+
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
718+
719+
response_msg = await client_streams.server_receive.receive()
720+
response = response_msg.message.root
721+
assert isinstance(response, types.JSONRPCError)
722+
assert "not supported" in response.error.message.lower()
723+
724+
tg.cancel_scope.cancel()
725+
726+
727+
@pytest.mark.anyio
728+
async def test_client_returns_error_for_unhandled_list_tasks_request(client_streams: ClientTestStreams) -> None:
729+
"""Test that client returns error for unhandled tasks/list request."""
730+
with anyio.fail_after(10):
731+
client_ready = anyio.Event()
732+
733+
async with anyio.create_task_group() as tg:
734+
735+
async def run_client() -> None:
736+
async with ClientSession(
737+
client_streams.client_receive,
738+
client_streams.client_send,
739+
message_handler=_default_message_handler,
740+
):
741+
client_ready.set()
742+
await anyio.sleep_forever()
743+
744+
tg.start_soon(run_client)
745+
await client_ready.wait()
746+
747+
typed_request = ListTasksRequest()
748+
request = types.JSONRPCRequest(
749+
jsonrpc="2.0",
750+
id="req-list",
751+
**typed_request.model_dump(by_alias=True),
752+
)
753+
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
754+
755+
response_msg = await client_streams.server_receive.receive()
756+
response = response_msg.message.root
757+
assert isinstance(response, types.JSONRPCError)
758+
assert "not supported" in response.error.message.lower()
759+
760+
tg.cancel_scope.cancel()
761+
762+
763+
@pytest.mark.anyio
764+
async def test_client_returns_error_for_unhandled_cancel_task_request(client_streams: ClientTestStreams) -> None:
765+
"""Test that client returns error for unhandled tasks/cancel request."""
766+
with anyio.fail_after(10):
767+
client_ready = anyio.Event()
768+
769+
async with anyio.create_task_group() as tg:
770+
771+
async def run_client() -> None:
772+
async with ClientSession(
773+
client_streams.client_receive,
774+
client_streams.client_send,
775+
message_handler=_default_message_handler,
776+
):
777+
client_ready.set()
778+
await anyio.sleep_forever()
779+
780+
tg.start_soon(run_client)
781+
await client_ready.wait()
782+
783+
typed_request = CancelTaskRequest(params=CancelTaskRequestParams(taskId="nonexistent"))
784+
request = types.JSONRPCRequest(
785+
jsonrpc="2.0",
786+
id="req-cancel",
787+
**typed_request.model_dump(by_alias=True),
788+
)
789+
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
790+
791+
response_msg = await client_streams.server_receive.receive()
792+
response = response_msg.message.root
793+
assert isinstance(response, types.JSONRPCError)
794+
assert "not supported" in response.error.message.lower()
795+
796+
tg.cancel_scope.cancel()
797+
798+
799+
@pytest.mark.anyio
800+
async def test_client_returns_error_for_unhandled_task_augmented_sampling(client_streams: ClientTestStreams) -> None:
801+
"""Test that client returns error for task-augmented sampling without handler."""
802+
with anyio.fail_after(10):
803+
client_ready = anyio.Event()
804+
805+
async with anyio.create_task_group() as tg:
806+
807+
async def run_client() -> None:
808+
# No task handlers provided - uses defaults
809+
async with ClientSession(
810+
client_streams.client_receive,
811+
client_streams.client_send,
812+
message_handler=_default_message_handler,
813+
):
814+
client_ready.set()
815+
await anyio.sleep_forever()
816+
817+
tg.start_soon(run_client)
818+
await client_ready.wait()
819+
820+
# Send task-augmented sampling request
821+
typed_request = CreateMessageRequest(
822+
params=CreateMessageRequestParams(
823+
messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))],
824+
maxTokens=100,
825+
task=TaskMetadata(ttl=60000),
826+
)
827+
)
828+
request = types.JSONRPCRequest(
829+
jsonrpc="2.0",
830+
id="req-sampling",
831+
**typed_request.model_dump(by_alias=True),
832+
)
833+
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
834+
835+
response_msg = await client_streams.server_receive.receive()
836+
response = response_msg.message.root
837+
assert isinstance(response, types.JSONRPCError)
838+
assert "not supported" in response.error.message.lower()
839+
840+
tg.cancel_scope.cancel()
841+
842+
843+
@pytest.mark.anyio
844+
async def test_client_returns_error_for_unhandled_task_augmented_elicitation(
845+
client_streams: ClientTestStreams,
846+
) -> None:
847+
"""Test that client returns error for task-augmented elicitation without handler."""
848+
with anyio.fail_after(10):
849+
client_ready = anyio.Event()
850+
851+
async with anyio.create_task_group() as tg:
852+
853+
async def run_client() -> None:
854+
# No task handlers provided - uses defaults
855+
async with ClientSession(
856+
client_streams.client_receive,
857+
client_streams.client_send,
858+
message_handler=_default_message_handler,
859+
):
860+
client_ready.set()
861+
await anyio.sleep_forever()
862+
863+
tg.start_soon(run_client)
864+
await client_ready.wait()
865+
866+
# Send task-augmented elicitation request
867+
typed_request = ElicitRequest(
868+
params=ElicitRequestFormParams(
869+
message="What is your name?",
870+
requestedSchema={"type": "object", "properties": {"name": {"type": "string"}}},
871+
task=TaskMetadata(ttl=60000),
872+
)
873+
)
874+
request = types.JSONRPCRequest(
875+
jsonrpc="2.0",
876+
id="req-elicit",
877+
**typed_request.model_dump(by_alias=True),
878+
)
879+
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
880+
881+
response_msg = await client_streams.server_receive.receive()
882+
response = response_msg.message.root
883+
assert isinstance(response, types.JSONRPCError)
884+
assert "not supported" in response.error.message.lower()
885+
886+
tg.cancel_scope.cancel()

0 commit comments

Comments
 (0)