Skip to content

Commit 49543bb

Browse files
committed
Add coverage for remaining branch gaps
- Add pragma for defensive _meta checks in session.py (unreachable code) - Add tests for router loop continuation (non-matching routers) - Add tests for enable_tasks with custom store/queue - Add tests for skipping default handlers when custom registered
1 parent a118f98 commit 49543bb

File tree

3 files changed

+219
-2
lines changed

3 files changed

+219
-2
lines changed

src/mcp/server/session.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,9 @@ def _build_elicit_request(
536536

537537
# Add related-task metadata if in task mode
538538
if task_id is not None:
539-
if "_meta" not in params_data:
539+
# Defensive check: _meta can't exist currently since ElicitRequestFormParams
540+
# doesn't pass meta to model_dump, but guard against future changes.
541+
if "_meta" not in params_data: # pragma: no cover
540542
params_data["_meta"] = {}
541543
params_data["_meta"]["io.modelcontextprotocol/related-task"] = {"taskId": task_id}
542544

@@ -594,7 +596,9 @@ def _build_create_message_request(
594596

595597
# Add related-task metadata if in task mode
596598
if task_id is not None:
597-
if "_meta" not in params_data:
599+
# Defensive check: _meta can't exist currently since CreateMessageRequestParams
600+
# doesn't pass meta to model_dump, but guard against future changes.
601+
if "_meta" not in params_data: # pragma: no cover
598602
params_data["_meta"] = {}
599603
params_data["_meta"]["io.modelcontextprotocol/related-task"] = {"taskId": task_id}
600604

tests/experimental/tasks/server/test_run_task_flow.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
These are integration tests that verify the complete flow works end-to-end.
1010
"""
1111

12+
from datetime import datetime, timezone
1213
from typing import Any
1314

1415
import anyio
@@ -214,6 +215,92 @@ async def test_enable_tasks_auto_registers_handlers() -> None:
214215
assert caps_after.tasks.cancel is not None
215216

216217

218+
@pytest.mark.anyio
219+
async def test_enable_tasks_with_custom_store_and_queue() -> None:
220+
"""Test that enable_tasks() uses provided store and queue instead of defaults."""
221+
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
222+
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue
223+
224+
server = Server("test-custom-store-queue")
225+
226+
# Create custom store and queue
227+
custom_store = InMemoryTaskStore()
228+
custom_queue = InMemoryTaskMessageQueue()
229+
230+
# Enable tasks with custom implementations
231+
task_support = server.experimental.enable_tasks(store=custom_store, queue=custom_queue)
232+
233+
# Verify our custom implementations are used
234+
assert task_support.store is custom_store
235+
assert task_support.queue is custom_queue
236+
237+
238+
@pytest.mark.anyio
239+
async def test_enable_tasks_skips_default_handlers_when_custom_registered() -> None:
240+
"""Test that enable_tasks() doesn't override already-registered handlers."""
241+
from mcp.types import (
242+
CancelTaskRequest,
243+
CancelTaskResult,
244+
GetTaskPayloadRequest,
245+
GetTaskPayloadResult,
246+
GetTaskRequest,
247+
GetTaskResult,
248+
ListTasksRequest,
249+
ListTasksResult,
250+
)
251+
252+
server = Server("test-custom-handlers")
253+
254+
# Track which custom handlers were called
255+
custom_handlers_called: list[str] = []
256+
257+
# Use a fixed timestamp for deterministic tests
258+
fixed_time = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
259+
260+
# Register custom handlers BEFORE enable_tasks
261+
@server.experimental.get_task()
262+
async def custom_get_task(req: GetTaskRequest) -> GetTaskResult:
263+
custom_handlers_called.append("get_task")
264+
return GetTaskResult(
265+
taskId="custom",
266+
status="working",
267+
createdAt=fixed_time,
268+
lastUpdatedAt=fixed_time,
269+
ttl=60000,
270+
)
271+
272+
@server.experimental.get_task_result()
273+
async def custom_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult:
274+
custom_handlers_called.append("get_task_result")
275+
return GetTaskPayloadResult()
276+
277+
@server.experimental.list_tasks()
278+
async def custom_list_tasks(req: ListTasksRequest) -> ListTasksResult:
279+
custom_handlers_called.append("list_tasks")
280+
return ListTasksResult(tasks=[])
281+
282+
@server.experimental.cancel_task()
283+
async def custom_cancel_task(req: CancelTaskRequest) -> CancelTaskResult:
284+
custom_handlers_called.append("cancel_task")
285+
return CancelTaskResult(
286+
taskId="custom",
287+
status="cancelled",
288+
createdAt=fixed_time,
289+
lastUpdatedAt=fixed_time,
290+
ttl=60000,
291+
)
292+
293+
# Now enable tasks - should NOT override our custom handlers
294+
server.experimental.enable_tasks()
295+
296+
# Verify our custom handlers are still registered (not replaced by defaults)
297+
# The handlers dict should contain our custom handlers
298+
assert GetTaskRequest in server.request_handlers
299+
assert GetTaskPayloadRequest in server.request_handlers
300+
assert ListTasksRequest in server.request_handlers
301+
assert CancelTaskRequest in server.request_handlers
302+
303+
217304
@pytest.mark.anyio
218305
async def test_run_task_without_enable_tasks_raises() -> None:
219306
"""Test that run_task raises when enable_tasks() wasn't called."""

tests/experimental/tasks/server/test_server.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,129 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool:
864864
await server_to_client_receive.aclose()
865865
await client_to_server_send.aclose()
866866
await client_to_server_receive.aclose()
867+
868+
869+
@pytest.mark.anyio
870+
async def test_response_routing_skips_non_matching_routers() -> None:
871+
"""Test that routing continues to next router when first doesn't match."""
872+
from mcp.shared.session import ResponseRouter
873+
874+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
875+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
876+
877+
# Track which routers were called
878+
router_calls: list[str] = []
879+
response_received = anyio.Event()
880+
881+
class NonMatchingRouter(ResponseRouter):
882+
def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool:
883+
router_calls.append("non_matching_response")
884+
return False # Doesn't handle it
885+
886+
def route_error(self, request_id: str | int, error: ErrorData) -> bool:
887+
router_calls.append("non_matching_error")
888+
return False # Doesn't handle it
889+
890+
class MatchingRouter(ResponseRouter):
891+
def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool:
892+
router_calls.append("matching_response")
893+
response_received.set()
894+
return True # Handles it
895+
896+
def route_error(self, request_id: str | int, error: ErrorData) -> bool:
897+
router_calls.append("matching_error")
898+
response_received.set()
899+
return True # Handles it
900+
901+
try:
902+
async with ServerSession(
903+
client_to_server_receive,
904+
server_to_client_send,
905+
InitializationOptions(
906+
server_name="test-server",
907+
server_version="1.0.0",
908+
capabilities=ServerCapabilities(),
909+
),
910+
) as server_session:
911+
# Add non-matching router first, then matching router
912+
server_session.add_response_router(NonMatchingRouter())
913+
server_session.add_response_router(MatchingRouter())
914+
915+
# Send a response - should skip first router and be handled by second
916+
response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"})
917+
message = SessionMessage(message=JSONRPCMessage(response))
918+
await client_to_server_send.send(message)
919+
920+
with anyio.fail_after(5):
921+
await response_received.wait()
922+
923+
# Verify both routers were called (first returned False, second returned True)
924+
assert router_calls == ["non_matching_response", "matching_response"]
925+
finally:
926+
await server_to_client_send.aclose()
927+
await server_to_client_receive.aclose()
928+
await client_to_server_send.aclose()
929+
await client_to_server_receive.aclose()
930+
931+
932+
@pytest.mark.anyio
933+
async def test_error_routing_skips_non_matching_routers() -> None:
934+
"""Test that error routing continues to next router when first doesn't match."""
935+
from mcp.shared.session import ResponseRouter
936+
937+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
938+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
939+
940+
# Track which routers were called
941+
router_calls: list[str] = []
942+
error_received = anyio.Event()
943+
944+
class NonMatchingRouter(ResponseRouter):
945+
def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool:
946+
router_calls.append("non_matching_response")
947+
return False
948+
949+
def route_error(self, request_id: str | int, error: ErrorData) -> bool:
950+
router_calls.append("non_matching_error")
951+
return False # Doesn't handle it
952+
953+
class MatchingRouter(ResponseRouter):
954+
def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool:
955+
router_calls.append("matching_response")
956+
return True
957+
958+
def route_error(self, request_id: str | int, error: ErrorData) -> bool:
959+
router_calls.append("matching_error")
960+
error_received.set()
961+
return True # Handles it
962+
963+
try:
964+
async with ServerSession(
965+
client_to_server_receive,
966+
server_to_client_send,
967+
InitializationOptions(
968+
server_name="test-server",
969+
server_version="1.0.0",
970+
capabilities=ServerCapabilities(),
971+
),
972+
) as server_session:
973+
# Add non-matching router first, then matching router
974+
server_session.add_response_router(NonMatchingRouter())
975+
server_session.add_response_router(MatchingRouter())
976+
977+
# Send an error - should skip first router and be handled by second
978+
error_data = ErrorData(code=-32600, message="Test error")
979+
error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data)
980+
message = SessionMessage(message=JSONRPCMessage(error_response))
981+
await client_to_server_send.send(message)
982+
983+
with anyio.fail_after(5):
984+
await error_received.wait()
985+
986+
# Verify both routers were called (first returned False, second returned True)
987+
assert router_calls == ["non_matching_error", "matching_error"]
988+
finally:
989+
await server_to_client_send.aclose()
990+
await server_to_client_receive.aclose()
991+
await client_to_server_send.aclose()
992+
await client_to_server_receive.aclose()

0 commit comments

Comments
 (0)