Skip to content

Commit 61354eb

Browse files
committed
notifications and client side
1 parent d2968a9 commit 61354eb

File tree

17 files changed

+3323
-15
lines changed

17 files changed

+3323
-15
lines changed

examples/servers/simple-task/mcp_simple_task/server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import anyio
99
import click
1010
import mcp.types as types
11+
import uvicorn
1112
from anyio.abc import TaskGroup
1213
from mcp.server.lowlevel import Server
1314
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
@@ -107,8 +108,6 @@ async def handle_get_task_result(request: types.GetTaskPayloadRequest) -> types.
107108
@click.command()
108109
@click.option("--port", default=8000, help="Port to listen on")
109110
def main(port: int) -> int:
110-
import uvicorn
111-
112111
session_manager = StreamableHTTPSessionManager(app=server)
113112

114113
@asynccontextmanager

src/mcp/client/session.py

Lines changed: 211 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,95 @@ async def __call__(
4848
) -> None: ... # pragma: no branch
4949

5050

51+
# Experimental: Task handler protocols for server -> client requests
52+
class GetTaskHandlerFnT(Protocol):
53+
"""Handler for tasks/get requests from server.
54+
55+
WARNING: This is experimental and may change without notice.
56+
"""
57+
58+
async def __call__(
59+
self,
60+
context: RequestContext["ClientSession", Any],
61+
params: types.GetTaskRequestParams,
62+
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
63+
64+
65+
class GetTaskResultHandlerFnT(Protocol):
66+
"""Handler for tasks/result requests from server.
67+
68+
WARNING: This is experimental and may change without notice.
69+
"""
70+
71+
async def __call__(
72+
self,
73+
context: RequestContext["ClientSession", Any],
74+
params: types.GetTaskPayloadRequestParams,
75+
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
76+
77+
78+
class ListTasksHandlerFnT(Protocol):
79+
"""Handler for tasks/list requests from server.
80+
81+
WARNING: This is experimental and may change without notice.
82+
"""
83+
84+
async def __call__(
85+
self,
86+
context: RequestContext["ClientSession", Any],
87+
params: types.PaginatedRequestParams | None,
88+
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
89+
90+
91+
class CancelTaskHandlerFnT(Protocol):
92+
"""Handler for tasks/cancel requests from server.
93+
94+
WARNING: This is experimental and may change without notice.
95+
"""
96+
97+
async def __call__(
98+
self,
99+
context: RequestContext["ClientSession", Any],
100+
params: types.CancelTaskRequestParams,
101+
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
102+
103+
104+
class TaskAugmentedSamplingFnT(Protocol):
105+
"""Handler for task-augmented sampling/createMessage requests from server.
106+
107+
When server sends a CreateMessageRequest with task field, this callback
108+
is invoked. The callback should create a task, spawn background work,
109+
and return CreateTaskResult immediately.
110+
111+
WARNING: This is experimental and may change without notice.
112+
"""
113+
114+
async def __call__(
115+
self,
116+
context: RequestContext["ClientSession", Any],
117+
params: types.CreateMessageRequestParams,
118+
task_metadata: types.TaskMetadata,
119+
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
120+
121+
122+
class TaskAugmentedElicitationFnT(Protocol):
123+
"""Handler for task-augmented elicitation/create requests from server.
124+
125+
When server sends an ElicitRequest with task field, this callback
126+
is invoked. The callback should create a task, spawn background work,
127+
and return CreateTaskResult immediately.
128+
129+
WARNING: This is experimental and may change without notice.
130+
"""
131+
132+
async def __call__(
133+
self,
134+
context: RequestContext["ClientSession", Any],
135+
params: types.ElicitRequestParams,
136+
task_metadata: types.TaskMetadata,
137+
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
138+
139+
51140
class MessageHandlerFnT(Protocol):
52141
async def __call__(
53142
self,
@@ -96,6 +185,69 @@ async def _default_logging_callback(
96185
pass
97186

98187

188+
# Default handlers for experimental task requests (return "not supported" errors)
189+
async def _default_get_task_handler(
190+
context: RequestContext["ClientSession", Any],
191+
params: types.GetTaskRequestParams,
192+
) -> types.GetTaskResult | types.ErrorData:
193+
return types.ErrorData(
194+
code=types.METHOD_NOT_FOUND,
195+
message="tasks/get not supported",
196+
)
197+
198+
199+
async def _default_get_task_result_handler(
200+
context: RequestContext["ClientSession", Any],
201+
params: types.GetTaskPayloadRequestParams,
202+
) -> types.GetTaskPayloadResult | types.ErrorData:
203+
return types.ErrorData(
204+
code=types.METHOD_NOT_FOUND,
205+
message="tasks/result not supported",
206+
)
207+
208+
209+
async def _default_list_tasks_handler(
210+
context: RequestContext["ClientSession", Any],
211+
params: types.PaginatedRequestParams | None,
212+
) -> types.ListTasksResult | types.ErrorData:
213+
return types.ErrorData(
214+
code=types.METHOD_NOT_FOUND,
215+
message="tasks/list not supported",
216+
)
217+
218+
219+
async def _default_cancel_task_handler(
220+
context: RequestContext["ClientSession", Any],
221+
params: types.CancelTaskRequestParams,
222+
) -> types.CancelTaskResult | types.ErrorData:
223+
return types.ErrorData(
224+
code=types.METHOD_NOT_FOUND,
225+
message="tasks/cancel not supported",
226+
)
227+
228+
229+
async def _default_task_augmented_sampling_callback(
230+
context: RequestContext["ClientSession", Any],
231+
params: types.CreateMessageRequestParams,
232+
task_metadata: types.TaskMetadata,
233+
) -> types.CreateTaskResult | types.ErrorData:
234+
return types.ErrorData(
235+
code=types.INVALID_REQUEST,
236+
message="Task-augmented sampling not supported",
237+
)
238+
239+
240+
async def _default_task_augmented_elicitation_callback(
241+
context: RequestContext["ClientSession", Any],
242+
params: types.ElicitRequestParams,
243+
task_metadata: types.TaskMetadata,
244+
) -> types.CreateTaskResult | types.ErrorData:
245+
return types.ErrorData(
246+
code=types.INVALID_REQUEST,
247+
message="Task-augmented elicitation not supported",
248+
)
249+
250+
99251
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
100252

101253

@@ -119,6 +271,14 @@ def __init__(
119271
logging_callback: LoggingFnT | None = None,
120272
message_handler: MessageHandlerFnT | None = None,
121273
client_info: types.Implementation | None = None,
274+
tasks_capability: types.ClientTasksCapability | None = None,
275+
# Experimental: Task handlers for server -> client requests
276+
get_task_handler: GetTaskHandlerFnT | None = None,
277+
get_task_result_handler: GetTaskResultHandlerFnT | None = None,
278+
list_tasks_handler: ListTasksHandlerFnT | None = None,
279+
cancel_task_handler: CancelTaskHandlerFnT | None = None,
280+
task_augmented_sampling_callback: TaskAugmentedSamplingFnT | None = None,
281+
task_augmented_elicitation_callback: TaskAugmentedElicitationFnT | None = None,
122282
) -> None:
123283
super().__init__(
124284
read_stream,
@@ -133,9 +293,21 @@ def __init__(
133293
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
134294
self._logging_callback = logging_callback or _default_logging_callback
135295
self._message_handler = message_handler or _default_message_handler
296+
self._tasks_capability = tasks_capability
136297
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
137298
self._server_capabilities: types.ServerCapabilities | None = None
138299
self._experimental: ExperimentalClientFeatures | None = None
300+
# Experimental: Task handlers
301+
self._get_task_handler = get_task_handler or _default_get_task_handler
302+
self._get_task_result_handler = get_task_result_handler or _default_get_task_result_handler
303+
self._list_tasks_handler = list_tasks_handler or _default_list_tasks_handler
304+
self._cancel_task_handler = cancel_task_handler or _default_cancel_task_handler
305+
self._task_augmented_sampling_callback = (
306+
task_augmented_sampling_callback or _default_task_augmented_sampling_callback
307+
)
308+
self._task_augmented_elicitation_callback = (
309+
task_augmented_elicitation_callback or _default_task_augmented_elicitation_callback
310+
)
139311

140312
async def initialize(self) -> types.InitializeResult:
141313
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -166,6 +338,7 @@ async def initialize(self) -> types.InitializeResult:
166338
elicitation=elicitation,
167339
experimental=None,
168340
roots=roots,
341+
tasks=self._tasks_capability,
169342
),
170343
clientInfo=self._client_info,
171344
),
@@ -191,7 +364,7 @@ def get_server_capabilities(self) -> types.ServerCapabilities | None:
191364
return self._server_capabilities
192365

193366
@property
194-
def experimental(self) -> "ExperimentalClientFeatures":
367+
def experimental(self) -> ExperimentalClientFeatures:
195368
"""Experimental APIs for tasks and other features.
196369
197370
WARNING: These APIs are experimental and may change without notice.
@@ -540,13 +713,21 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
540713
match responder.request.root:
541714
case types.CreateMessageRequest(params=params):
542715
with responder:
543-
response = await self._sampling_callback(ctx, params)
716+
# Check if this is a task-augmented request
717+
if params.task is not None:
718+
response = await self._task_augmented_sampling_callback(ctx, params, params.task)
719+
else:
720+
response = await self._sampling_callback(ctx, params)
544721
client_response = ClientResponse.validate_python(response)
545722
await responder.respond(client_response)
546723

547724
case types.ElicitRequest(params=params):
548725
with responder:
549-
response = await self._elicitation_callback(ctx, params)
726+
# Check if this is a task-augmented request
727+
if params.task is not None:
728+
response = await self._task_augmented_elicitation_callback(ctx, params, params.task)
729+
else:
730+
response = await self._elicitation_callback(ctx, params)
550731
client_response = ClientResponse.validate_python(response)
551732
await responder.respond(client_response)
552733

@@ -559,7 +740,33 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
559740
case types.PingRequest(): # pragma: no cover
560741
with responder:
561742
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
562-
case _:
743+
744+
# Experimental: Task management requests from server
745+
case types.GetTaskRequest(params=params):
746+
with responder:
747+
response = await self._get_task_handler(ctx, params)
748+
client_response = ClientResponse.validate_python(response)
749+
await responder.respond(client_response)
750+
751+
case types.GetTaskPayloadRequest(params=params):
752+
with responder:
753+
response = await self._get_task_result_handler(ctx, params)
754+
client_response = ClientResponse.validate_python(response)
755+
await responder.respond(client_response)
756+
757+
case types.ListTasksRequest(params=params):
758+
with responder:
759+
response = await self._list_tasks_handler(ctx, params)
760+
client_response = ClientResponse.validate_python(response)
761+
await responder.respond(client_response)
762+
763+
case types.CancelTaskRequest(params=params):
764+
with responder:
765+
response = await self._cancel_task_handler(ctx, params)
766+
client_response = ClientResponse.validate_python(response)
767+
await responder.respond(client_response)
768+
769+
case _: # pragma: no cover
563770
raise NotImplementedError()
564771

565772
async def _handle_incoming(

src/mcp/server/lowlevel/server.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,14 @@ async def main():
6767

6868
from __future__ import annotations as _annotations
6969

70+
import base64
7071
import contextvars
7172
import json
7273
import logging
7374
import warnings
7475
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
7576
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
77+
from importlib.metadata import version as pkg_version
7678
from typing import Any, Generic, TypeAlias, cast
7779

7880
import anyio
@@ -166,19 +168,17 @@ def create_initialization_options(
166168
) -> InitializationOptions:
167169
"""Create initialization options from this server instance."""
168170

169-
def pkg_version(package: str) -> str:
171+
def get_package_version(package: str) -> str:
170172
try:
171-
from importlib.metadata import version
172-
173-
return version(package)
173+
return pkg_version(package)
174174
except Exception: # pragma: no cover
175175
pass
176176

177177
return "unknown" # pragma: no cover
178178

179179
return InitializationOptions(
180180
server_name=self.name,
181-
server_version=self.version if self.version else pkg_version("mcp"),
181+
server_version=self.version if self.version else get_package_version("mcp"),
182182
capabilities=self.get_capabilities(
183183
notification_options or NotificationOptions(),
184184
experimental_capabilities or {},
@@ -345,8 +345,6 @@ def create_content(data: str | bytes, mime_type: str | None):
345345
mimeType=mime_type or "text/plain",
346346
)
347347
case bytes() as data: # pragma: no cover
348-
import base64
349-
350348
return types.BlobResourceContents(
351349
uri=req.params.uri,
352350
blob=base64.b64encode(data).decode(),

src/mcp/shared/experimental/tasks/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
- TaskStore: Abstract interface for task state storage
66
- TaskContext: Context object for task work to interact with state/notifications
77
- InMemoryTaskStore: Reference implementation for testing/development
8+
- TaskMessageQueue: FIFO queue for task messages delivered via tasks/result
9+
- InMemoryTaskMessageQueue: Reference implementation for message queue
810
- Helper functions: run_task, is_terminal, create_task_state, generate_task_id
911
1012
Architecture:
1113
- TaskStore is pure storage - it doesn't know about execution
14+
- TaskMessageQueue stores messages to be delivered via tasks/result
1215
- TaskContext wraps store + session, providing a clean API for task work
1316
- run_task is optional convenience for spawning in-process tasks
1417
@@ -24,15 +27,31 @@
2427
task_execution,
2528
)
2629
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
30+
from mcp.shared.experimental.tasks.message_queue import (
31+
InMemoryTaskMessageQueue,
32+
QueuedMessage,
33+
TaskMessageQueue,
34+
)
35+
from mcp.shared.experimental.tasks.result_handler import (
36+
TaskResultHandler,
37+
create_task_result_handler,
38+
)
2739
from mcp.shared.experimental.tasks.store import TaskStore
40+
from mcp.shared.experimental.tasks.task_session import TaskSession
2841

2942
__all__ = [
3043
"TaskStore",
3144
"TaskContext",
45+
"TaskSession",
46+
"TaskResultHandler",
3247
"InMemoryTaskStore",
48+
"TaskMessageQueue",
49+
"InMemoryTaskMessageQueue",
50+
"QueuedMessage",
3351
"run_task",
3452
"task_execution",
3553
"is_terminal",
3654
"create_task_state",
3755
"generate_task_id",
56+
"create_task_result_handler",
3857
]

0 commit comments

Comments
 (0)