Skip to content

Commit 3576adc

Browse files
committed
Implement capability gating for tasks
1 parent 73b1799 commit 3576adc

File tree

7 files changed

+595
-5
lines changed

7 files changed

+595
-5
lines changed

src/mcp/client/session.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from mcp.shared.context import RequestContext
1616
from mcp.shared.message import SessionMessage
1717
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
18+
from mcp.shared.task import TaskStore
1819
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1920

2021
if TYPE_CHECKING:
@@ -126,13 +127,15 @@ def __init__(
126127
logging_callback: LoggingFnT | None = None,
127128
message_handler: MessageHandlerFnT | None = None,
128129
client_info: types.Implementation | None = None,
130+
task_store: TaskStore | None = None,
129131
) -> None:
130132
super().__init__(
131133
read_stream,
132134
write_stream,
133135
types.ServerRequest,
134136
types.ServerNotification,
135137
read_timeout_seconds=read_timeout_seconds,
138+
task_store=task_store,
136139
)
137140
self._client_info = client_info or DEFAULT_CLIENT_INFO
138141
self._sampling_callback = sampling_callback or _default_sampling_callback
@@ -156,6 +159,18 @@ async def initialize(self) -> types.InitializeResult:
156159
else None
157160
)
158161

162+
# Build tasks capability - only if task store is configured
163+
tasks = None
164+
if self._task_store is not None:
165+
tasks = types.ClientTasksCapability(
166+
requests=types.ClientTasksRequestsCapability(
167+
sampling=types.TaskSamplingCapability(createMessage=True),
168+
elicitation=types.TaskElicitationCapability(create=True),
169+
roots=types.TaskRootsCapability(list=True),
170+
tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True),
171+
)
172+
)
173+
159174
result = await self.send_request(
160175
types.ClientRequest(
161176
types.InitializeRequest(
@@ -166,6 +181,7 @@ async def initialize(self) -> types.InitializeResult:
166181
elicitation=elicitation,
167182
experimental=None,
168183
roots=roots,
184+
tasks=tasks,
169185
),
170186
clientInfo=self._client_info,
171187
),

src/mcp/server/lowlevel/server.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,12 @@ def __init__(
112112
prompts_changed: bool = False,
113113
resources_changed: bool = False,
114114
tools_changed: bool = False,
115+
tasks_changed: bool = False,
115116
):
116117
self.prompts_changed = prompts_changed
117118
self.resources_changed = resources_changed
118119
self.tools_changed = tools_changed
120+
self.tasks_changed = tasks_changed
119121

120122

121123
@asynccontextmanager
@@ -199,6 +201,7 @@ def get_capabilities(
199201
tools_capability = None
200202
logging_capability = None
201203
completions_capability = None
204+
tasks_capability = None
202205

203206
# Set prompt capabilities if handler exists
204207
if types.ListPromptsRequest in self.request_handlers:
@@ -222,13 +225,59 @@ def get_capabilities(
222225
if types.CompleteRequest in self.request_handlers:
223226
completions_capability = types.CompletionsCapability()
224227

228+
# Set tasks capabilities if task store is configured
229+
if self.task_store is not None:
230+
# Build nested request capabilities based on available handlers
231+
tools_req_cap = None
232+
resources_req_cap = None
233+
prompts_req_cap = None
234+
tasks_ops_cap = None
235+
236+
# Check for tool capabilities
237+
has_call_tool = types.CallToolRequest in self.request_handlers
238+
has_list_tools = types.ListToolsRequest in self.request_handlers
239+
if has_call_tool or has_list_tools:
240+
tools_req_cap = types.TaskToolsCapability(
241+
call=True if has_call_tool else None, list=True if has_list_tools else None
242+
)
243+
244+
# Check for resource capabilities
245+
has_read_resource = types.ReadResourceRequest in self.request_handlers
246+
has_list_resources = types.ListResourcesRequest in self.request_handlers
247+
if has_read_resource or has_list_resources:
248+
resources_req_cap = types.TaskResourcesCapability(
249+
read=True if has_read_resource else None, list=True if has_list_resources else None
250+
)
251+
252+
# Check for prompt capabilities
253+
has_get_prompt = types.GetPromptRequest in self.request_handlers
254+
has_list_prompts = types.ListPromptsRequest in self.request_handlers
255+
if has_get_prompt or has_list_prompts:
256+
prompts_req_cap = types.TaskPromptsCapability(
257+
get=True if has_get_prompt else None, list=True if has_list_prompts else None
258+
)
259+
260+
# Task operations are always available if task_store is configured
261+
tasks_ops_cap = types.TasksOperationsCapability(get=True, list=True, result=True, delete=True)
262+
263+
# Build the nested tasks capability
264+
tasks_capability = types.ServerTasksCapability(
265+
requests=types.ServerTasksRequestsCapability(
266+
tools=tools_req_cap,
267+
resources=resources_req_cap,
268+
prompts=prompts_req_cap,
269+
tasks=tasks_ops_cap,
270+
)
271+
)
272+
225273
return types.ServerCapabilities(
226274
prompts=prompts_capability,
227275
resources=resources_capability,
228276
tools=tools_capability,
229277
logging=logging_capability,
230278
experimental=experimental_capabilities,
231279
completions=completions_capability,
280+
tasks=tasks_capability,
232281
)
233282

234283
@property

src/mcp/server/session.py

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,52 @@ def __init__(
107107
def client_params(self) -> types.InitializeRequestParams | None:
108108
return self._client_params
109109

110+
def _check_tasks_capability(
111+
self, required: types.ClientTasksCapability, client: types.ClientTasksCapability
112+
) -> bool:
113+
"""Check if client supports required tasks capabilities."""
114+
if required.requests is None:
115+
return True
116+
if client.requests is None:
117+
return False
118+
119+
req_cap = required.requests
120+
client_req_cap = client.requests
121+
122+
# Check sampling requests
123+
if req_cap.sampling is not None and (
124+
client_req_cap.sampling is None
125+
or (req_cap.sampling.createMessage and not client_req_cap.sampling.createMessage)
126+
):
127+
return False
128+
129+
# Check elicitation requests
130+
if req_cap.elicitation is not None and (
131+
client_req_cap.elicitation is None or (req_cap.elicitation.create and not client_req_cap.elicitation.create)
132+
):
133+
return False
134+
135+
# Check roots requests
136+
if req_cap.roots is not None and (
137+
client_req_cap.roots is None or (req_cap.roots.list and not client_req_cap.roots.list)
138+
):
139+
return False
140+
141+
# Check tasks operations
142+
if req_cap.tasks is not None:
143+
if client_req_cap.tasks is None:
144+
return False
145+
tasks_checks = [
146+
not req_cap.tasks.get or client_req_cap.tasks.get,
147+
not req_cap.tasks.list or client_req_cap.tasks.list,
148+
not req_cap.tasks.result or client_req_cap.tasks.result,
149+
not req_cap.tasks.delete or client_req_cap.tasks.delete,
150+
]
151+
if not all(tasks_checks):
152+
return False
153+
154+
return True
155+
110156
def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
111157
"""Check if the client supports a specific capability."""
112158
if self._client_params is None:
@@ -138,6 +184,12 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
138184
if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value:
139185
return False
140186

187+
if capability.tasks is not None:
188+
if client_caps.tasks is None:
189+
return False
190+
if not self._check_tasks_capability(capability.tasks, client_caps.tasks):
191+
return False
192+
141193
return True
142194

143195
async def _receive_loop(self) -> None:
@@ -193,8 +245,17 @@ async def _received_request( # noqa: PLR0912
193245
# Ping requests are allowed at any time
194246
pass
195247
case types.GetTaskRequest(params=params):
248+
# Check if client has announced tasks capability
249+
if self._client_params is None or self._client_params.capabilities.tasks is None:
250+
with responder:
251+
await responder.respond(
252+
types.ErrorData(
253+
code=types.INVALID_REQUEST,
254+
message="Client has not announced tasks capability",
255+
)
256+
)
196257
# Handle get task requests if task store is available
197-
if self._task_store:
258+
elif self._task_store:
198259
task = await self._task_store.get_task(params.taskId)
199260
if task is None:
200261
with responder:
@@ -220,8 +281,17 @@ async def _received_request( # noqa: PLR0912
220281
types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured")
221282
)
222283
case types.GetTaskPayloadRequest(params=params):
284+
# Check if client has announced tasks capability
285+
if self._client_params is None or self._client_params.capabilities.tasks is None:
286+
with responder:
287+
await responder.respond(
288+
types.ErrorData(
289+
code=types.INVALID_REQUEST,
290+
message="Client has not announced tasks capability",
291+
)
292+
)
223293
# Handle get task result requests if task store is available
224-
if self._task_store:
294+
elif self._task_store:
225295
task = await self._task_store.get_task(params.taskId)
226296
if task is None:
227297
with responder:
@@ -253,8 +323,17 @@ async def _received_request( # noqa: PLR0912
253323
types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured")
254324
)
255325
case types.ListTasksRequest(params=params):
326+
# Check if client has announced tasks capability
327+
if self._client_params is None or self._client_params.capabilities.tasks is None:
328+
with responder:
329+
await responder.respond(
330+
types.ErrorData(
331+
code=types.INVALID_REQUEST,
332+
message="Client has not announced tasks capability",
333+
)
334+
)
256335
# Handle list tasks requests if task store is available
257-
if self._task_store:
336+
elif self._task_store:
258337
try:
259338
result = await self._task_store.list_tasks(params.cursor if params else None)
260339
with responder:
@@ -278,8 +357,17 @@ async def _received_request( # noqa: PLR0912
278357
types.ErrorData(code=types.INVALID_REQUEST, message="Task store not configured")
279358
)
280359
case types.DeleteTaskRequest(params=params):
360+
# Check if client has announced tasks capability
361+
if self._client_params is None or self._client_params.capabilities.tasks is None:
362+
with responder:
363+
await responder.respond(
364+
types.ErrorData(
365+
code=types.INVALID_REQUEST,
366+
message="Client has not announced tasks capability",
367+
)
368+
)
281369
# Handle delete task requests if task store is available
282-
if self._task_store:
370+
elif self._task_store:
283371
try:
284372
await self._task_store.delete_task(params.taskId)
285373
with responder:

0 commit comments

Comments
 (0)