Skip to content

Commit 7105d9d

Browse files
committed
add mvp for server side tasks
1 parent 2f23ceb commit 7105d9d

File tree

5 files changed

+266
-92
lines changed

5 files changed

+266
-92
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Experimental handlers for the low-level MCP server.
2+
3+
WARNING: These APIs are experimental and may change without notice.
4+
"""
5+
6+
import logging
7+
from collections.abc import Awaitable, Callable
8+
9+
from mcp.server.lowlevel.func_inspection import create_call_wrapper
10+
from mcp.types import (
11+
CancelTaskRequest,
12+
CancelTaskResult,
13+
GetTaskPayloadRequest,
14+
GetTaskPayloadResult,
15+
GetTaskRequest,
16+
GetTaskResult,
17+
ListTasksRequest,
18+
ListTasksResult,
19+
ServerCapabilities,
20+
ServerResult,
21+
ServerTasksCapability,
22+
ServerTasksRequestsCapability,
23+
TasksCancelCapability,
24+
TasksListCapability,
25+
TasksToolsCapability,
26+
)
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class ExperimentalHandlers:
32+
"""Experimental request/notification handlers.
33+
34+
WARNING: These APIs are experimental and may change without notice.
35+
"""
36+
37+
def __init__(
38+
self,
39+
request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]],
40+
notification_handlers: dict[type, Callable[..., Awaitable[None]]],
41+
):
42+
self._request_handlers = request_handlers
43+
self._notification_handlers = notification_handlers
44+
45+
def update_capabilities(self, capabilities: ServerCapabilities) -> None:
46+
capabilities.tasks = ServerTasksCapability()
47+
if ListTasksRequest in self._request_handlers:
48+
capabilities.tasks.list = TasksListCapability()
49+
if CancelTaskRequest in self._request_handlers:
50+
capabilities.tasks.cancel = TasksCancelCapability()
51+
52+
capabilities.tasks.requests = ServerTasksRequestsCapability(
53+
tools=TasksToolsCapability()
54+
) # assuming always supported for now
55+
56+
def list_tasks(
57+
self,
58+
) -> Callable[
59+
[Callable[[ListTasksRequest], Awaitable[ListTasksResult]]],
60+
Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
61+
]:
62+
"""Register a handler for listing tasks.
63+
64+
WARNING: This API is experimental and may change without notice.
65+
"""
66+
67+
def decorator(
68+
func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
69+
) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]:
70+
logger.debug("Registering handler for ListTasksRequest")
71+
wrapper = create_call_wrapper(func, ListTasksRequest)
72+
73+
async def handler(req: ListTasksRequest):
74+
result = await wrapper(req)
75+
return ServerResult(result)
76+
77+
self._request_handlers[ListTasksRequest] = handler
78+
return func
79+
80+
return decorator
81+
82+
def get_task(self):
83+
"""Register a handler for getting task status.
84+
85+
WARNING: This API is experimental and may change without notice.
86+
"""
87+
88+
def decorator(func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]]):
89+
logger.debug("Registering handler for GetTaskRequest")
90+
wrapper = create_call_wrapper(func, GetTaskRequest)
91+
92+
async def handler(req: GetTaskRequest):
93+
result = await wrapper(req)
94+
return ServerResult(result)
95+
96+
self._request_handlers[GetTaskRequest] = handler
97+
return func
98+
99+
return decorator
100+
101+
def get_task_result(self):
102+
"""Register a handler for getting task results/payload.
103+
104+
WARNING: This API is experimental and may change without notice.
105+
"""
106+
107+
def decorator(func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]):
108+
logger.debug("Registering handler for GetTaskPayloadRequest")
109+
wrapper = create_call_wrapper(func, GetTaskPayloadRequest)
110+
111+
async def handler(req: GetTaskPayloadRequest):
112+
result = await wrapper(req)
113+
return ServerResult(result)
114+
115+
self._request_handlers[GetTaskPayloadRequest] = handler
116+
return func
117+
118+
return decorator
119+
120+
def cancel_task(self):
121+
"""Register a handler for cancelling tasks.
122+
123+
WARNING: This API is experimental and may change without notice.
124+
"""
125+
126+
def decorator(func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]):
127+
logger.debug("Registering handler for CancelTaskRequest")
128+
wrapper = create_call_wrapper(func, CancelTaskRequest)
129+
130+
async def handler(req: CancelTaskRequest):
131+
result = await wrapper(req)
132+
return ServerResult(result)
133+
134+
self._request_handlers[CancelTaskRequest] = handler
135+
return func
136+
137+
return decorator

src/mcp/server/lowlevel/server.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ async def main():
8282
from typing_extensions import TypeVar
8383

8484
import mcp.types as types
85+
from mcp.server.lowlevel.experimental import ExperimentalHandlers
8586
from mcp.server.lowlevel.func_inspection import create_call_wrapper
8687
from mcp.server.lowlevel.helper_types import ReadResourceContents
8788
from mcp.server.models import InitializationOptions
8889
from mcp.server.session import ServerSession
89-
from mcp.shared.context import RequestContext
90+
from mcp.shared.context import Experimental, RequestContext
9091
from mcp.shared.exceptions import McpError
9192
from mcp.shared.message import ServerMessageMetadata, SessionMessage
9293
from mcp.shared.session import RequestResponder
@@ -155,6 +156,7 @@ def __init__(
155156
}
156157
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
157158
self._tool_cache: dict[str, types.Tool] = {}
159+
self._experimental_handlers: ExperimentalHandlers | None = None
158160
logger.debug("Initializing server %r", name)
159161

160162
def create_initialization_options(
@@ -220,14 +222,17 @@ def get_capabilities(
220222
if types.CompleteRequest in self.request_handlers:
221223
completions_capability = types.CompletionsCapability()
222224

223-
return types.ServerCapabilities(
225+
capabilities = types.ServerCapabilities(
224226
prompts=prompts_capability,
225227
resources=resources_capability,
226228
tools=tools_capability,
227229
logging=logging_capability,
228230
experimental=experimental_capabilities,
229231
completions=completions_capability,
230232
)
233+
if self._experimental_handlers:
234+
self._experimental_handlers.update_capabilities(capabilities)
235+
return capabilities
231236

232237
@property
233238
def request_context(
@@ -236,6 +241,18 @@ def request_context(
236241
"""If called outside of a request context, this will raise a LookupError."""
237242
return request_ctx.get()
238243

244+
@property
245+
def experimental(self) -> ExperimentalHandlers:
246+
"""Experimental APIs for tasks and other features.
247+
248+
WARNING: These APIs are experimental and may change without notice.
249+
"""
250+
251+
# We create this inline so we only add these capabilities _if_ they're actually used
252+
if self._experimental_handlers is None:
253+
self._experimental_handlers = ExperimentalHandlers(self.request_handlers, self.notification_handlers)
254+
return self._experimental_handlers
255+
239256
def list_prompts(self):
240257
def decorator(
241258
func: Callable[[], Awaitable[list[types.Prompt]]]
@@ -669,13 +686,14 @@ async def _handle_message(
669686
async def _handle_request(
670687
self,
671688
message: RequestResponder[types.ClientRequest, types.ServerResult],
672-
req: Any,
689+
req: types.ClientRequestType,
673690
session: ServerSession,
674691
lifespan_context: LifespanResultT,
675692
raise_exceptions: bool,
676693
):
677694
logger.info("Processing request of type %s", type(req).__name__)
678-
if handler := self.request_handlers.get(type(req)): # type: ignore
695+
696+
if handler := self.request_handlers.get(type(req)):
679697
logger.debug("Dispatching request of type %s", type(req).__name__)
680698

681699
token = None
@@ -695,6 +713,7 @@ async def _handle_request(
695713
message.request_meta,
696714
session,
697715
lifespan_context,
716+
Experimental(task_metadata=message.request_params.task if message.request_params else None),
698717
request=request_data,
699718
)
700719
)

src/mcp/shared/context.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,27 @@
44
from typing_extensions import TypeVar
55

66
from mcp.shared.session import BaseSession
7-
from mcp.types import RequestId, RequestParams
7+
from mcp.types import RequestId, RequestParams, TaskMetadata
88

99
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
1010
LifespanContextT = TypeVar("LifespanContextT")
1111
RequestT = TypeVar("RequestT", default=Any)
1212

1313

14+
@dataclass
15+
class Experimental:
16+
task_metadata: TaskMetadata | None = None
17+
18+
@property
19+
def is_task(self) -> bool:
20+
return self.task_metadata is not None
21+
22+
1423
@dataclass
1524
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
1625
request_id: RequestId
1726
meta: RequestParams.Meta | None
1827
session: SessionT
1928
lifespan_context: LifespanContextT
29+
experimental: Experimental = Experimental()
2030
request: RequestT | None = None

src/mcp/shared/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ def __init__(
8181
]""",
8282
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
8383
message_metadata: MessageMetadata = None,
84+
request_params: RequestParams | None = None,
8485
) -> None:
8586
self.request_id = request_id
8687
self.request_meta = request_meta
88+
self.request_params = request_params
8789
self.request = request
8890
self.message_metadata = message_metadata
8991
self._session = session
@@ -353,6 +355,7 @@ async def _receive_loop(self) -> None:
353355
session=self,
354356
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
355357
message_metadata=message.metadata,
358+
request_params=validated_request.root.params,
356359
)
357360
self._in_flight[responder.request_id] = responder
358361
await self._received_request(responder)

0 commit comments

Comments
 (0)