|
| 1 | +import logging |
| 2 | +from collections.abc import AsyncGenerator, Callable, Mapping |
| 3 | +from contextlib import asynccontextmanager |
| 4 | +from http import HTTPStatus |
| 5 | +from types import TracebackType |
| 6 | +from typing import Any, NamedTuple |
| 7 | + |
| 8 | +import anyio |
| 9 | +from anyio.abc import TaskGroup, TaskStatus |
| 10 | +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
| 11 | +from sse_starlette.sse import EventSourceResponse |
| 12 | +from starlette.applications import Starlette |
| 13 | +from starlette.requests import Request |
| 14 | +from starlette.responses import Response |
| 15 | +from starlette.routing import Route |
| 16 | +from typing_extensions import override |
| 17 | + |
| 18 | +from mcp.server.minimcp.exceptions import MCPRuntimeError, MiniMCPError |
| 19 | +from mcp.server.minimcp.managers.context_manager import ScopeT |
| 20 | +from mcp.server.minimcp.minimcp import MiniMCP |
| 21 | +from mcp.server.minimcp.transports.base_http import MEDIA_TYPE_JSON, BaseHTTPTransport, MCPHTTPResponse |
| 22 | +from mcp.server.minimcp.types import MESSAGE_ENCODING, Message |
| 23 | + |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +MEDIA_TYPE_SSE = "text/event-stream" |
| 28 | + |
| 29 | +SSE_HEADERS = { |
| 30 | + "Cache-Control": "no-cache, no-transform", |
| 31 | + "Connection": "keep-alive", |
| 32 | + "Content-Type": MEDIA_TYPE_SSE, |
| 33 | +} |
| 34 | + |
| 35 | + |
| 36 | +class MCPStreamingHTTPResponse(NamedTuple): |
| 37 | + """ |
| 38 | + Represents the response from a MiniMCP server to a client HTTP request. |
| 39 | +
|
| 40 | + Attributes: |
| 41 | + status_code: The HTTP status code to return to the client. |
| 42 | + content: The response content as a MemoryObjectReceiveStream. |
| 43 | + media_type: The MIME type of the response response stream ("text/event-stream"). |
| 44 | + headers: Additional HTTP headers to include in the response. |
| 45 | + """ |
| 46 | + |
| 47 | + status_code: HTTPStatus |
| 48 | + content: MemoryObjectReceiveStream[str] |
| 49 | + headers: Mapping[str, str] = SSE_HEADERS |
| 50 | + media_type: str = MEDIA_TYPE_SSE |
| 51 | + |
| 52 | + |
| 53 | +class StreamManager: |
| 54 | + """ |
| 55 | + Manages the lifecycle of memory object streams for the StreamableHTTPTransport. |
| 56 | +
|
| 57 | + Streams are created on demand - Once streaming is activated, the receive stream |
| 58 | + is handed off to the consumer via on_create callback, while the send stream |
| 59 | + remains owned by the StreamManager. |
| 60 | +
|
| 61 | + Once the handling completes, the close method needs to be called manually to close |
| 62 | + the streams gracefully - Not using a context manager to keep the approach explicit. |
| 63 | +
|
| 64 | + On close, the send stream is closed immediately and the receive stream is closed after |
| 65 | + a configurable delay to allow consumers to finish draining the stream. The cleanup is |
| 66 | + shielded from cancellation to prevent resource leaks when tasks are cancelled during |
| 67 | + transport shutdown. |
| 68 | + """ |
| 69 | + |
| 70 | + _lock: anyio.Lock |
| 71 | + _on_create: Callable[[MCPStreamingHTTPResponse], None] |
| 72 | + |
| 73 | + _send_stream: MemoryObjectSendStream[Message] | None |
| 74 | + _receive_stream: MemoryObjectReceiveStream[Message] | None |
| 75 | + |
| 76 | + def __init__(self, on_create: Callable[[MCPStreamingHTTPResponse], None]) -> None: |
| 77 | + """ |
| 78 | + Args: |
| 79 | + on_create: Callback to be called when the streams are created. |
| 80 | + """ |
| 81 | + self._on_create = on_create |
| 82 | + self._lock = anyio.Lock() |
| 83 | + |
| 84 | + self._send_stream = None |
| 85 | + self._receive_stream = None |
| 86 | + |
| 87 | + def is_streaming(self) -> bool: |
| 88 | + """ |
| 89 | + Returns: |
| 90 | + True if the streams are created and ready to be used, False otherwise. |
| 91 | + """ |
| 92 | + return self._send_stream is not None and self._receive_stream is not None |
| 93 | + |
| 94 | + async def create_and_send(self, message: Message, create_timeout: float = 0.1) -> None: |
| 95 | + """ |
| 96 | + Creates the streams and sends the message. If the streams are already available, |
| 97 | + it sends the message over the existing streams. |
| 98 | +
|
| 99 | + Args: |
| 100 | + message: Message to send. |
| 101 | + create_timeout: Timeout to create the streams. |
| 102 | + """ |
| 103 | + if self._send_stream is None: |
| 104 | + with anyio.fail_after(create_timeout): |
| 105 | + async with self._lock: |
| 106 | + if self._send_stream is None: |
| 107 | + send_stream, receive_stream = anyio.create_memory_object_stream[Message](0) |
| 108 | + self._on_create(MCPStreamingHTTPResponse(HTTPStatus.OK, receive_stream)) |
| 109 | + self._send_stream = send_stream |
| 110 | + self._receive_stream = receive_stream |
| 111 | + |
| 112 | + await self.send(message) |
| 113 | + |
| 114 | + async def send(self, message: Message) -> None: |
| 115 | + """ |
| 116 | + Sends the message to the send stream. |
| 117 | +
|
| 118 | + Args: |
| 119 | + message: Message to send. |
| 120 | +
|
| 121 | + Raises: |
| 122 | + MiniMCPError: If the send stream is unavailable. |
| 123 | + """ |
| 124 | + if self._send_stream is None: |
| 125 | + raise MiniMCPError("Send stream is unavailable") |
| 126 | + |
| 127 | + try: |
| 128 | + await self._send_stream.send(message) |
| 129 | + except (anyio.BrokenResourceError, anyio.ClosedResourceError) as e: |
| 130 | + # Consumer went away or stream closed or stream not created; ignore further sends. |
| 131 | + logger.debug("Failed to send message: consumer disconnected. Error: %s", e) |
| 132 | + pass |
| 133 | + |
| 134 | + async def close(self, receive_close_delay: float) -> None: |
| 135 | + """ |
| 136 | + Closes the send and receive streams gracefully if they were created by the StreamManager. |
| 137 | + After closing the send stream, it waits for the receive stream to be closed by the consumer. If the |
| 138 | + consumer does not close the receive stream, it will be closed after the delay. |
| 139 | +
|
| 140 | + Args: |
| 141 | + receive_close_delay: Delay to wait for the receive stream to be closed by the consumer. |
| 142 | + """ |
| 143 | + if self._send_stream is not None: |
| 144 | + try: |
| 145 | + await self._send_stream.aclose() |
| 146 | + except (anyio.BrokenResourceError, anyio.ClosedResourceError): |
| 147 | + pass |
| 148 | + |
| 149 | + if self._receive_stream is not None: |
| 150 | + try: |
| 151 | + with anyio.CancelScope(shield=True): |
| 152 | + await anyio.sleep(receive_close_delay) |
| 153 | + await self._receive_stream.aclose() |
| 154 | + except (anyio.BrokenResourceError, anyio.ClosedResourceError): |
| 155 | + pass |
| 156 | + |
| 157 | + |
| 158 | +# TODO: Add resumability based on Last-Event-ID header on GET method. |
| 159 | +class StreamableHTTPTransport(BaseHTTPTransport[ScopeT]): |
| 160 | + """ |
| 161 | + Adds support for MCP's streamable HTTP transport mechanism. |
| 162 | + More details @ https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http |
| 163 | +
|
| 164 | + With Streamable HTTP the MCP server can operates as an independent process that can handle multiple |
| 165 | + client connections using HTTP. |
| 166 | +
|
| 167 | + Security Warning: Security is not provided inbuilt. It is the responsibility of the web framework to |
| 168 | + provide security. |
| 169 | + """ |
| 170 | + |
| 171 | + _ping_interval: int |
| 172 | + _receive_close_delay: float |
| 173 | + |
| 174 | + _tg: TaskGroup | None |
| 175 | + |
| 176 | + RESPONSE_MEDIA_TYPES: frozenset[str] = frozenset[str]([MEDIA_TYPE_JSON, MEDIA_TYPE_SSE]) |
| 177 | + SUPPORTED_HTTP_METHODS: frozenset[str] = frozenset[str](["POST"]) |
| 178 | + |
| 179 | + def __init__( |
| 180 | + self, |
| 181 | + minimcp: MiniMCP[ScopeT], |
| 182 | + ping_interval: int = 15, |
| 183 | + receive_close_delay: float = 0.1, |
| 184 | + ) -> None: |
| 185 | + """ |
| 186 | + Args: |
| 187 | + minimcp: The MiniMCP instance to use. |
| 188 | + ping_interval: The ping interval in seconds to keep the streams alive. By default, it is set to |
| 189 | + 15 seconds based on a widely adopted convention. |
| 190 | + receive_close_delay: After request handling is complete, wait for these many seconds before |
| 191 | + automatically closing the receive stream. By default, it is set to 0.1 seconds to allow |
| 192 | + the consumer to finish draining the stream. |
| 193 | + """ |
| 194 | + super().__init__(minimcp) |
| 195 | + self._ping_interval = ping_interval |
| 196 | + self._receive_close_delay = receive_close_delay |
| 197 | + self._tg = None |
| 198 | + |
| 199 | + async def __aenter__(self) -> "StreamableHTTPTransport[ScopeT]": |
| 200 | + self._tg = await anyio.create_task_group().__aenter__() |
| 201 | + return self |
| 202 | + |
| 203 | + async def __aexit__( |
| 204 | + self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None |
| 205 | + ) -> bool | None: |
| 206 | + if self._tg is not None: |
| 207 | + logger.debug("Shutting down StreamableHTTPTransport") |
| 208 | + |
| 209 | + # Cancel all background tasks to prevent hanging on. |
| 210 | + self._tg.cancel_scope.cancel() |
| 211 | + |
| 212 | + # Exit the task group |
| 213 | + result = await self._tg.__aexit__(exc_type, exc, tb) |
| 214 | + |
| 215 | + if self._tg.cancel_scope.cancelled_caught: |
| 216 | + logger.warning("Background tasks were cancelled during StreamableHTTPTransport shutdown") |
| 217 | + |
| 218 | + self._tg = None |
| 219 | + return result |
| 220 | + return None |
| 221 | + |
| 222 | + @asynccontextmanager |
| 223 | + async def lifespan(self, _: Any) -> AsyncGenerator[None, None]: |
| 224 | + """ |
| 225 | + Provides an easy to use lifespan context manager for the StreamableHTTPTransport. |
| 226 | + """ |
| 227 | + async with self: |
| 228 | + yield |
| 229 | + |
| 230 | + @override |
| 231 | + async def dispatch( |
| 232 | + self, method: str, headers: Mapping[str, str], body: str, scope: ScopeT | None = None |
| 233 | + ) -> MCPHTTPResponse | MCPStreamingHTTPResponse: |
| 234 | + """ |
| 235 | + Dispatch an HTTP request to the MiniMCP server. |
| 236 | +
|
| 237 | + Args: |
| 238 | + method: The HTTP method of the request. |
| 239 | + headers: HTTP request headers. |
| 240 | + body: HTTP request body as a string. |
| 241 | + scope: Optional message scope passed to the MiniMCP server. |
| 242 | +
|
| 243 | + Returns: |
| 244 | + MCPHTTPResponse object with the response from the MiniMCP server. |
| 245 | + """ |
| 246 | + |
| 247 | + if self._tg is None: |
| 248 | + raise MCPRuntimeError( |
| 249 | + "dispatch can only be used inside an 'async with' block or a lifespan of StreamableHTTPTransport" |
| 250 | + ) |
| 251 | + |
| 252 | + logger.debug("Handling HTTP request. Method: %s, Headers: %s", method, headers) |
| 253 | + |
| 254 | + if method == "POST": |
| 255 | + # Start _handle_post_request in a separate task and await for readiness. |
| 256 | + # Once ready _handle_post_request_task returns a MCPHTTPResponse or MCPStreamingHTTPResponse. |
| 257 | + return await self._tg.start(self._handle_post_request_task, headers, body, scope) |
| 258 | + else: |
| 259 | + return self._handle_unsupported_request() |
| 260 | + |
| 261 | + @override |
| 262 | + async def starlette_dispatch(self, request: Request, scope: ScopeT | None = None) -> Response: |
| 263 | + """ |
| 264 | + Dispatch a Starlette request to the MiniMCP server and return the response as a Starlette response object. |
| 265 | +
|
| 266 | + Args: |
| 267 | + request: Starlette request object. |
| 268 | + scope: Optional message scope passed to the MiniMCP server. |
| 269 | +
|
| 270 | + Returns: |
| 271 | + MiniMCP server response formatted as a Starlette Response object. |
| 272 | + """ |
| 273 | + msg = await request.body() |
| 274 | + body_str = msg.decode(MESSAGE_ENCODING) |
| 275 | + result = await self.dispatch(request.method, request.headers, body_str, scope) |
| 276 | + |
| 277 | + if isinstance(result, MCPStreamingHTTPResponse): |
| 278 | + return EventSourceResponse(result.content, headers=result.headers, ping=self._ping_interval) |
| 279 | + |
| 280 | + return Response(result.content, result.status_code, result.headers, result.media_type) |
| 281 | + |
| 282 | + @override |
| 283 | + def as_starlette(self, path: str = "/", debug: bool = False) -> Starlette: |
| 284 | + """ |
| 285 | + Provide the HTTP transport as a Starlette application. |
| 286 | +
|
| 287 | + Args: |
| 288 | + path: The path to the MCP application endpoint. |
| 289 | + debug: Whether to enable debug mode. |
| 290 | +
|
| 291 | + Returns: |
| 292 | + Starlette application. |
| 293 | + """ |
| 294 | + |
| 295 | + route = Route(path, endpoint=self.starlette_dispatch, methods=self.SUPPORTED_HTTP_METHODS) |
| 296 | + |
| 297 | + logger.info("Creating MCP application at path: %s", path) |
| 298 | + return Starlette(routes=[route], debug=debug, lifespan=self.lifespan) |
| 299 | + |
| 300 | + async def _handle_post_request_task( |
| 301 | + self, |
| 302 | + headers: Mapping[str, str], |
| 303 | + body: str, |
| 304 | + scope: ScopeT | None, |
| 305 | + task_status: TaskStatus[MCPHTTPResponse | MCPStreamingHTTPResponse], |
| 306 | + ): |
| 307 | + """ |
| 308 | + This is the special sauce that makes the smart StreamableHTTPTransport possible. |
| 309 | + _handle_post_request_task runs as a separate task and manages the handler execution. Once ready, it sends a |
| 310 | + MCPHTTPResponse via the task_status. If the handler calls the send callback, streaming is activated, |
| 311 | + else it acts like a regular HTTP transport. For streaming, _handle_post_request_task sends a MCPHTTPResponse |
| 312 | + with a MemoryObjectReceiveStream as the content and continues running in the background until |
| 313 | + the handler finishes executing. |
| 314 | +
|
| 315 | + Args: |
| 316 | + headers: HTTP request headers. |
| 317 | + body: HTTP request body as a string. |
| 318 | + scope: Optional message scope passed to the MiniMCP server. |
| 319 | + task_status: Task status object to communicate task readiness and result. |
| 320 | + """ |
| 321 | + |
| 322 | + stream_manager = StreamManager(on_create=task_status.started) |
| 323 | + not_completed = True |
| 324 | + |
| 325 | + try: |
| 326 | + result = await self._handle_post_request(headers, body, scope, send_callback=stream_manager.create_and_send) |
| 327 | + |
| 328 | + if stream_manager.is_streaming(): |
| 329 | + if result.content: |
| 330 | + await stream_manager.send(result.content) |
| 331 | + else: |
| 332 | + task_status.started(result) |
| 333 | + |
| 334 | + not_completed = False |
| 335 | + finally: |
| 336 | + if stream_manager.is_streaming(): |
| 337 | + await stream_manager.close(self._receive_close_delay) |
| 338 | + elif not_completed: |
| 339 | + # This should never happen, _handle_post_request should handle all exceptions, |
| 340 | + # but adding this fallback to ensure the task is always started. |
| 341 | + try: |
| 342 | + error = MCPRuntimeError("Task was not completed by StreamableHTTPTransport") |
| 343 | + task_status.started(self._build_error_response(error, body)) |
| 344 | + except RuntimeError as e: |
| 345 | + logger.error("Task is not completed: %s", e) |
0 commit comments