Skip to content

Commit 1688c6a

Browse files
committed
Simplify catch-all case in client session match
Replace NotImplementedError with pass since task requests are handled earlier by _task_handlers. The catch-all satisfies pyright's exhaustiveness check while making it clear these cases are intentionally handled elsewhere.
1 parent 757df38 commit 1688c6a

File tree

8 files changed

+18
-126
lines changed

8 files changed

+18
-126
lines changed

src/mcp/client/experimental/task_handlers.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from dataclasses import dataclass, field
1616
from typing import TYPE_CHECKING, Any, Protocol
1717

18+
from pydantic import TypeAdapter
19+
1820
import mcp.types as types
1921
from mcp.shared.context import RequestContext
2022
from mcp.shared.session import RequestResponder
@@ -111,11 +113,6 @@ async def __call__(
111113
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
112114

113115

114-
# =============================================================================
115-
# Default Handlers (return "not supported" errors)
116-
# =============================================================================
117-
118-
119116
async def default_get_task_handler(
120117
context: RequestContext["ClientSession", Any],
121118
params: types.GetTaskRequestParams,
@@ -259,8 +256,6 @@ async def handle_request(
259256
260257
Call handles_request() first to check if this handler can handle the request.
261258
"""
262-
from pydantic import TypeAdapter
263-
264259
client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
265260
types.ClientResult | types.ErrorData
266261
)

src/mcp/client/session.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,9 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
582582
with responder:
583583
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
584584

585-
case _: # pragma: no cover
586-
raise NotImplementedError()
585+
case _:
586+
pass # Task requests handled above by _task_handlers
587+
587588
return None
588589

589590
async def _handle_incoming(

src/mcp/server/experimental/request_context.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,8 @@ async def work(task: ServerTaskContext) -> CallToolResult:
210210
# Access task_group via TaskSupport - raises if not in run() context
211211
task_group = support.task_group
212212

213-
# Create the task
214213
task = await support.store.create_task(self.task_metadata, task_id)
215214

216-
# Build ServerTaskContext with full capabilities
217215
task_ctx = ServerTaskContext(
218216
task=task,
219217
store=support.store,
@@ -222,21 +220,17 @@ async def work(task: ServerTaskContext) -> CallToolResult:
222220
handler=support.handler,
223221
)
224222

225-
# Spawn the work
226223
async def execute() -> None:
227224
try:
228225
result = await work(task_ctx)
229-
# Auto-complete if work returns successfully and not already terminal
230226
if not is_terminal(task_ctx.task.status):
231227
await task_ctx.complete(result)
232228
except Exception as e:
233-
# Auto-fail if not already terminal
234229
if not is_terminal(task_ctx.task.status):
235230
await task_ctx.fail(str(e))
236231

237232
task_group.start_soon(execute)
238233

239-
# Build _meta if model_immediate_response is provided
240234
meta: dict[str, Any] | None = None
241235
if model_immediate_response is not None:
242236
meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response}

src/mcp/server/experimental/task_context.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,9 @@ async def elicit(
241241
)
242242
request_id: RequestId = request.id
243243

244-
# Create resolver and register with handler for response routing
245244
resolver: Resolver[dict[str, Any]] = Resolver()
246245
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
247246

248-
# Queue the request
249247
queued = QueuedMessage(
250248
type="request",
251249
message=request,
@@ -315,11 +313,9 @@ async def elicit_url(
315313
)
316314
request_id: RequestId = request.id
317315

318-
# Create resolver and register with handler for response routing
319316
resolver: Resolver[dict[str, Any]] = Resolver()
320317
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
321318

322-
# Queue the request
323319
queued = QueuedMessage(
324320
type="request",
325321
message=request,
@@ -408,11 +404,9 @@ async def create_message(
408404
)
409405
request_id: RequestId = request.id
410406

411-
# Create resolver and register with handler for response routing
412407
resolver: Resolver[dict[str, Any]] = Resolver()
413408
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
414409

415-
# Queue the request
416410
queued = QueuedMessage(
417411
type="request",
418412
message=request,
@@ -469,7 +463,6 @@ async def elicit_as_task(
469463
# Update status to input_required
470464
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
471465

472-
# Build request WITH task field for task-augmented elicitation
473466
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
474467
message=message,
475468
requestedSchema=requestedSchema,
@@ -478,11 +471,9 @@ async def elicit_as_task(
478471
)
479472
request_id: RequestId = request.id
480473

481-
# Create resolver and register with handler for response routing
482474
resolver: Resolver[dict[str, Any]] = Resolver()
483475
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
484476

485-
# Queue the request
486477
queued = QueuedMessage(
487478
type="request",
488479
message=request,
@@ -586,11 +577,9 @@ async def create_message_as_task(
586577
)
587578
request_id: RequestId = request.id
588579

589-
# Create resolver and register with handler for response routing
590580
resolver: Resolver[dict[str, Any]] = Resolver()
591581
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
592582

593-
# Queue the request
594583
queued = QueuedMessage(
595584
type="request",
596585
message=request,

src/mcp/server/experimental/task_result_handler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ async def handle(
109109
task_id = request.params.taskId
110110

111111
while True:
112-
# Get fresh task state each iteration
113112
task = await self._store.get_task(task_id)
114113
if task is None:
115114
raise McpError(
@@ -119,7 +118,6 @@ async def handle(
119118
)
120119
)
121120

122-
# Dequeue and send all pending messages
123121
await self._deliver_queued_messages(task_id, session, request_id)
124122

125123
# If task is terminal, return result
@@ -131,9 +129,7 @@ async def handle(
131129
related_task = RelatedTaskMetadata(taskId=task_id)
132130
related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)}
133131
if result is not None:
134-
# Copy result fields and add required metadata
135132
result_data = result.model_dump(by_alias=True)
136-
# Merge with existing _meta if present
137133
existing_meta: dict[str, Any] = result_data.get("_meta") or {}
138134
result_data["_meta"] = {**existing_meta, **related_task_meta}
139135
return GetTaskPayloadResult.model_validate(result_data)

src/mcp/server/session.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
5050
from mcp.server.models import InitializationOptions
5151
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
5252
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
53+
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY
5354
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5455
from mcp.shared.response_router import ResponseRouter
5556
from mcp.shared.session import (
@@ -520,7 +521,9 @@ def _build_elicit_form_request(
520521
# Defensive: model_dump() never includes _meta, but guard against future changes
521522
if "_meta" not in params_data: # pragma: no cover
522523
params_data["_meta"] = {}
523-
params_data["_meta"]["io.modelcontextprotocol/related-task"] = {"taskId": related_task_id}
524+
params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata(
525+
taskId=related_task_id
526+
).model_dump(by_alias=True)
524527

525528
request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id
526529
if related_task_id is None:
@@ -563,7 +566,9 @@ def _build_elicit_url_request(
563566
# Defensive: model_dump() never includes _meta, but guard against future changes
564567
if "_meta" not in params_data: # pragma: no cover
565568
params_data["_meta"] = {}
566-
params_data["_meta"]["io.modelcontextprotocol/related-task"] = {"taskId": related_task_id}
569+
params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata(
570+
taskId=related_task_id
571+
).model_dump(by_alias=True)
567572

568573
request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id
569574
if related_task_id is None:
@@ -631,7 +636,9 @@ def _build_create_message_request(
631636
# Defensive: model_dump() never includes _meta, but guard against future changes
632637
if "_meta" not in params_data: # pragma: no cover
633638
params_data["_meta"] = {}
634-
params_data["_meta"]["io.modelcontextprotocol/related-task"] = {"taskId": related_task_id}
639+
params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata(
640+
taskId=related_task_id
641+
).model_dump(by_alias=True)
635642

636643
request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id
637644
if related_task_id is None:

tests/experimental/tasks/test_request_context.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
ToolExecution,
1717
)
1818

19-
# --- Experimental.is_task ---
20-
2119

2220
def test_is_task_true_when_metadata_present() -> None:
2321
exp = Experimental(task_metadata=TaskMetadata(ttl=60000))
@@ -29,9 +27,6 @@ def test_is_task_false_when_no_metadata() -> None:
2927
assert exp.is_task is False
3028

3129

32-
# --- Experimental.client_supports_tasks ---
33-
34-
3530
def test_client_supports_tasks_true() -> None:
3631
exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability()))
3732
assert exp.client_supports_tasks is True
@@ -47,9 +42,6 @@ def test_client_supports_tasks_false_no_capabilities() -> None:
4742
assert exp.client_supports_tasks is False
4843

4944

50-
# --- Experimental.validate_task_mode ---
51-
52-
5345
def test_validate_task_mode_required_with_task_is_valid() -> None:
5446
exp = Experimental(task_metadata=TaskMetadata(ttl=60000))
5547
error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False)
@@ -111,9 +103,6 @@ def test_validate_task_mode_optional_without_task_is_valid() -> None:
111103
assert error is None
112104

113105

114-
# --- Experimental.validate_for_tool ---
115-
116-
117106
def test_validate_for_tool_with_execution_required() -> None:
118107
exp = Experimental(task_metadata=None)
119108
tool = Tool(
@@ -152,9 +141,6 @@ def test_validate_for_tool_optional_with_task() -> None:
152141
assert error is None
153142

154143

155-
# --- Experimental.can_use_tool ---
156-
157-
158144
def test_can_use_tool_required_with_task_support() -> None:
159145
exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability()))
160146
assert exp.can_use_tool(TASK_REQUIRED) is True

0 commit comments

Comments
 (0)