|
1 | 1 | import contextlib |
2 | 2 | import logging |
| 3 | +from http import HTTPStatus |
3 | 4 | from uuid import uuid4 |
4 | 5 |
|
5 | 6 | import anyio |
6 | 7 | import click |
7 | 8 | import mcp.types as types |
8 | 9 | from mcp.server.lowlevel import Server |
9 | | -from mcp.server.streamableHttp import StreamableHTTPServerTransport |
| 10 | +from mcp.server.streamableHttp import ( |
| 11 | + MCP_SESSION_ID_HEADER, |
| 12 | + StreamableHTTPServerTransport, |
| 13 | +) |
10 | 14 | from starlette.applications import Starlette |
| 15 | +from starlette.requests import Request |
| 16 | +from starlette.responses import Response |
11 | 17 | from starlette.routing import Mount |
12 | 18 |
|
13 | 19 | # Configure logging |
@@ -116,40 +122,56 @@ async def list_tools() -> list[types.Tool]: |
116 | 122 | ) |
117 | 123 | ] |
118 | 124 |
|
119 | | - # Create a Streamable HTTP transport |
120 | | - http_transport = StreamableHTTPServerTransport( |
121 | | - mcp_session_id=uuid4().hex, |
122 | | - ) |
123 | | - |
124 | 125 | # We need to store the server instances between requests |
125 | 126 | server_instances = {} |
| 127 | + # Lock to prevent race conditions when creating new sessions |
| 128 | + session_creation_lock = anyio.Lock() |
126 | 129 |
|
127 | 130 | # ASGI handler for streamable HTTP connections |
128 | 131 | async def handle_streamable_http(scope, receive, send): |
129 | | - if http_transport.mcp_session_id in server_instances: |
| 132 | + request = Request(scope, receive) |
| 133 | + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) |
| 134 | + if ( |
| 135 | + request_mcp_session_id is not None |
| 136 | + and request_mcp_session_id in server_instances |
| 137 | + ): |
| 138 | + transport = server_instances[request_mcp_session_id] |
130 | 139 | logger.debug("Session already exists, handling request directly") |
131 | | - await http_transport.handle_request(scope, receive, send) |
| 140 | + await transport.handle_request(scope, receive, send) |
| 141 | + elif request_mcp_session_id is None: |
| 142 | + # try to establish new session |
| 143 | + logger.debug("Creating new transport") |
| 144 | + # Use lock to prevent race conditions when creating new sessions |
| 145 | + async with session_creation_lock: |
| 146 | + new_session_id = uuid4().hex |
| 147 | + http_transport = StreamableHTTPServerTransport( |
| 148 | + mcp_session_id=new_session_id, |
| 149 | + ) |
| 150 | + async with http_transport.connect() as streams: |
| 151 | + read_stream, write_stream = streams |
| 152 | + |
| 153 | + async def run_server(): |
| 154 | + await app.run( |
| 155 | + read_stream, |
| 156 | + write_stream, |
| 157 | + app.create_initialization_options(), |
| 158 | + ) |
| 159 | + |
| 160 | + if not task_group: |
| 161 | + raise RuntimeError("Task group is not initialized") |
| 162 | + |
| 163 | + # Store the instance before starting the task to prevent races |
| 164 | + server_instances[http_transport.mcp_session_id] = http_transport |
| 165 | + task_group.start_soon(run_server) |
| 166 | + |
| 167 | + # Handle the HTTP request and return the response |
| 168 | + await http_transport.handle_request(scope, receive, send) |
132 | 169 | else: |
133 | | - # Start new server instance for this session |
134 | | - async with http_transport.connect() as streams: |
135 | | - read_stream, write_stream = streams |
136 | | - |
137 | | - async def run_server(): |
138 | | - await app.run( |
139 | | - read_stream, write_stream, app.create_initialization_options() |
140 | | - ) |
141 | | - |
142 | | - if not task_group: |
143 | | - raise RuntimeError("Task group is not initialized") |
144 | | - |
145 | | - task_group.start_soon(run_server) |
146 | | - |
147 | | - # For initialization requests, store the server reference |
148 | | - if http_transport.mcp_session_id: |
149 | | - server_instances[http_transport.mcp_session_id] = True |
150 | | - |
151 | | - # Handle the HTTP request and return the response |
152 | | - await http_transport.handle_request(scope, receive, send) |
| 170 | + response = Response( |
| 171 | + "Bad Request: No valid session ID provided", |
| 172 | + status_code=HTTPStatus.BAD_REQUEST, |
| 173 | + ) |
| 174 | + await response(scope, receive, send) |
153 | 175 |
|
154 | 176 | # Create an ASGI application using the transport |
155 | 177 | starlette_app = Starlette( |
|
0 commit comments