Skip to content

Commit 8cd2765

Browse files
committed
Refactor tasks capability checking into isolated module
Move all task-related capability checking logic into mcp/shared/experimental/tasks/capabilities.py to keep tasks code isolated from core session code. Changes: - Create capabilities.py with check_tasks_capability() and require_* helpers - Update ServerSession to import and use the shared function - Update ServerTaskContext to use require_* helpers instead of inline checks - Add missing capability checks to ExperimentalServerSessionFeatures This improves code organization and fixes a bug where session.experimental.elicit_as_task() wasn't checking capabilities.
1 parent 1efe8b0 commit 8cd2765

File tree

4 files changed

+141
-69
lines changed

4 files changed

+141
-69
lines changed

src/mcp/server/experimental/session_features.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from typing import TYPE_CHECKING, Any, TypeVar
1212

1313
import mcp.types as types
14+
from mcp.shared.experimental.tasks.capabilities import (
15+
require_task_augmented_elicitation,
16+
require_task_augmented_sampling,
17+
)
1418
from mcp.shared.experimental.tasks.polling import poll_until_terminal
1519

1620
if TYPE_CHECKING:
@@ -113,7 +117,13 @@ async def elicit_as_task(
113117
114118
Returns:
115119
The client's elicitation response
120+
121+
Raises:
122+
McpError: If client doesn't support task-augmented elicitation
116123
"""
124+
client_caps = self._session.client_params.capabilities if self._session.client_params else None
125+
require_task_augmented_elicitation(client_caps)
126+
117127
create_result = await self._session.send_request(
118128
types.ServerRequest(
119129
types.ElicitRequest(
@@ -166,7 +176,13 @@ async def create_message_as_task(
166176
167177
Returns:
168178
The sampling result from the client
179+
180+
Raises:
181+
McpError: If client doesn't support task-augmented sampling
169182
"""
183+
client_caps = self._session.client_params.capabilities if self._session.client_params else None
184+
require_task_augmented_sampling(client_caps)
185+
170186
create_result = await self._session.send_request(
171187
types.ServerRequest(
172188
types.CreateMessageRequest(

src/mcp/server/experimental/task_context.py

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
from mcp.server.experimental.task_result_handler import TaskResultHandler
1515
from mcp.server.session import ServerSession
1616
from mcp.shared.exceptions import McpError
17+
from mcp.shared.experimental.tasks.capabilities import (
18+
require_task_augmented_elicitation,
19+
require_task_augmented_sampling,
20+
)
1721
from mcp.shared.experimental.tasks.context import TaskContext
1822
from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue
1923
from mcp.shared.experimental.tasks.resolver import Resolver
@@ -23,8 +27,6 @@
2327
TASK_STATUS_INPUT_REQUIRED,
2428
TASK_STATUS_WORKING,
2529
ClientCapabilities,
26-
ClientTasksCapability,
27-
ClientTasksRequestsCapability,
2830
CreateMessageResult,
2931
CreateTaskResult,
3032
ElicitationCapability,
@@ -40,10 +42,6 @@
4042
ServerNotification,
4143
Task,
4244
TaskMetadata,
43-
TasksCreateElicitationCapability,
44-
TasksCreateMessageCapability,
45-
TasksElicitationCapability,
46-
TasksSamplingCapability,
4745
TaskStatusNotification,
4846
TaskStatusNotificationParams,
4947
)
@@ -198,40 +196,6 @@ def _check_sampling_capability(self) -> None:
198196
)
199197
)
200198

201-
def _check_task_augmented_elicitation_capability(self) -> None:
202-
"""Check if the client supports task-augmented elicitation."""
203-
capability = ClientCapabilities(
204-
tasks=ClientTasksCapability(
205-
requests=ClientTasksRequestsCapability(
206-
elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability())
207-
)
208-
)
209-
)
210-
if not self._session.check_client_capability(capability):
211-
raise McpError(
212-
ErrorData(
213-
code=INVALID_REQUEST,
214-
message="Client does not support task-augmented elicitation capability",
215-
)
216-
)
217-
218-
def _check_task_augmented_sampling_capability(self) -> None:
219-
"""Check if the client supports task-augmented sampling."""
220-
capability = ClientCapabilities(
221-
tasks=ClientTasksCapability(
222-
requests=ClientTasksRequestsCapability(
223-
sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability())
224-
)
225-
)
226-
)
227-
if not self._session.check_client_capability(capability):
228-
raise McpError(
229-
ErrorData(
230-
code=INVALID_REQUEST,
231-
message="Client does not support task-augmented sampling capability",
232-
)
233-
)
234-
235199
async def elicit(
236200
self,
237201
message: str,
@@ -412,7 +376,8 @@ async def elicit_as_task(
412376
McpError: If client doesn't support task-augmented elicitation
413377
RuntimeError: If handler is not configured
414378
"""
415-
self._check_task_augmented_elicitation_capability()
379+
client_caps = self._session.client_params.capabilities if self._session.client_params else None
380+
require_task_augmented_elicitation(client_caps)
416381

417382
if self._handler is None:
418383
raise RuntimeError("handler is required for elicit_as_task()")
@@ -504,7 +469,8 @@ async def create_message_as_task(
504469
McpError: If client doesn't support task-augmented sampling
505470
RuntimeError: If handler is not configured
506471
"""
507-
self._check_task_augmented_sampling_capability()
472+
client_caps = self._session.client_params.capabilities if self._session.client_params else None
473+
require_task_augmented_sampling(client_caps)
508474

509475
if self._handler is None:
510476
raise RuntimeError("handler is required for create_message_as_task()")

src/mcp/server/session.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4949
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
5050
from mcp.server.models import InitializationOptions
5151
from mcp.shared.exceptions import McpError
52+
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
5253
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5354
from mcp.shared.response_router import ResponseRouter
5455
from mcp.shared.session import (
@@ -116,32 +117,6 @@ def experimental(self) -> ExperimentalServerSessionFeatures:
116117
self._experimental_features = ExperimentalServerSessionFeatures(self)
117118
return self._experimental_features
118119

119-
def _check_tasks_capability(
120-
self,
121-
required: types.ClientTasksCapability,
122-
client: types.ClientTasksCapability,
123-
) -> bool: # pragma: no cover
124-
"""Check if client's tasks capability matches the required capability."""
125-
if required.requests is None:
126-
return True
127-
if client.requests is None:
128-
return False
129-
# Check elicitation.create
130-
if required.requests.elicitation is not None:
131-
if client.requests.elicitation is None:
132-
return False
133-
if required.requests.elicitation.create is not None:
134-
if client.requests.elicitation.create is None:
135-
return False
136-
# Check sampling.createMessage
137-
if required.requests.sampling is not None:
138-
if client.requests.sampling is None:
139-
return False
140-
if required.requests.sampling.createMessage is not None:
141-
if client.requests.sampling.createMessage is None:
142-
return False
143-
return True
144-
145120
def check_client_capability(self, capability: types.ClientCapabilities) -> bool: # pragma: no cover
146121
"""Check if the client supports a specific capability."""
147122
if self._client_params is None:
@@ -176,7 +151,7 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
176151
if capability.tasks is not None:
177152
if client_caps.tasks is None:
178153
return False
179-
if not self._check_tasks_capability(capability.tasks, client_caps.tasks):
154+
if not check_tasks_capability(capability.tasks, client_caps.tasks):
180155
return False
181156

182157
return True
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
Tasks capability checking utilities.
3+
4+
This module provides functions for checking and requiring task-related
5+
capabilities. All tasks capability logic is centralized here to keep
6+
the main session code clean.
7+
8+
WARNING: These APIs are experimental and may change without notice.
9+
"""
10+
11+
from mcp.shared.exceptions import McpError
12+
from mcp.types import (
13+
INVALID_REQUEST,
14+
ClientCapabilities,
15+
ClientTasksCapability,
16+
ErrorData,
17+
)
18+
19+
20+
def check_tasks_capability(
21+
required: ClientTasksCapability,
22+
client: ClientTasksCapability,
23+
) -> bool:
24+
"""
25+
Check if client's tasks capability matches the required capability.
26+
27+
Args:
28+
required: The capability being checked for
29+
client: The client's declared capabilities
30+
31+
Returns:
32+
True if client has the required capability, False otherwise
33+
"""
34+
if required.requests is None:
35+
return True
36+
if client.requests is None:
37+
return False
38+
39+
# Check elicitation.create
40+
if required.requests.elicitation is not None:
41+
if client.requests.elicitation is None:
42+
return False
43+
if required.requests.elicitation.create is not None:
44+
if client.requests.elicitation.create is None:
45+
return False
46+
47+
# Check sampling.createMessage
48+
if required.requests.sampling is not None:
49+
if client.requests.sampling is None:
50+
return False
51+
if required.requests.sampling.createMessage is not None:
52+
if client.requests.sampling.createMessage is None:
53+
return False
54+
55+
return True
56+
57+
58+
def has_task_augmented_elicitation(caps: ClientCapabilities) -> bool:
59+
"""Check if capabilities include task-augmented elicitation support."""
60+
if caps.tasks is None:
61+
return False
62+
if caps.tasks.requests is None:
63+
return False
64+
if caps.tasks.requests.elicitation is None:
65+
return False
66+
return caps.tasks.requests.elicitation.create is not None
67+
68+
69+
def has_task_augmented_sampling(caps: ClientCapabilities) -> bool:
70+
"""Check if capabilities include task-augmented sampling support."""
71+
if caps.tasks is None:
72+
return False
73+
if caps.tasks.requests is None:
74+
return False
75+
if caps.tasks.requests.sampling is None:
76+
return False
77+
return caps.tasks.requests.sampling.createMessage is not None
78+
79+
80+
def require_task_augmented_elicitation(client_caps: ClientCapabilities | None) -> None:
81+
"""
82+
Raise McpError if client doesn't support task-augmented elicitation.
83+
84+
Args:
85+
client_caps: The client's declared capabilities, or None if not initialized
86+
87+
Raises:
88+
McpError: If client doesn't support task-augmented elicitation
89+
"""
90+
if client_caps is None or not has_task_augmented_elicitation(client_caps):
91+
raise McpError(
92+
ErrorData(
93+
code=INVALID_REQUEST,
94+
message="Client does not support task-augmented elicitation",
95+
)
96+
)
97+
98+
99+
def require_task_augmented_sampling(client_caps: ClientCapabilities | None) -> None:
100+
"""
101+
Raise McpError if client doesn't support task-augmented sampling.
102+
103+
Args:
104+
client_caps: The client's declared capabilities, or None if not initialized
105+
106+
Raises:
107+
McpError: If client doesn't support task-augmented sampling
108+
"""
109+
if client_caps is None or not has_task_augmented_sampling(client_caps):
110+
raise McpError(
111+
ErrorData(
112+
code=INVALID_REQUEST,
113+
message="Client does not support task-augmented sampling",
114+
)
115+
)

0 commit comments

Comments
 (0)