Skip to content

Commit 27bc01e

Browse files
committed
session management
1 parent 3d790f8 commit 27bc01e

File tree

3 files changed

+128
-31
lines changed

3 files changed

+128
-31
lines changed

examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
import contextlib
22
import logging
3+
from http import HTTPStatus
34
from uuid import uuid4
45

56
import anyio
67
import click
78
import mcp.types as types
89
from mcp.server.lowlevel import Server
9-
from mcp.server.streamableHttp import StreamableHTTPServerTransport
10+
from mcp.server.streamableHttp import (
11+
MCP_SESSION_ID_HEADER,
12+
StreamableHTTPServerTransport,
13+
)
1014
from starlette.applications import Starlette
15+
from starlette.requests import Request
16+
from starlette.responses import Response
1117
from starlette.routing import Mount
1218

1319
# Configure logging
@@ -116,40 +122,56 @@ async def list_tools() -> list[types.Tool]:
116122
)
117123
]
118124

119-
# Create a Streamable HTTP transport
120-
http_transport = StreamableHTTPServerTransport(
121-
mcp_session_id=uuid4().hex,
122-
)
123-
124125
# We need to store the server instances between requests
125126
server_instances = {}
127+
# Lock to prevent race conditions when creating new sessions
128+
session_creation_lock = anyio.Lock()
126129

127130
# ASGI handler for streamable HTTP connections
128131
async def handle_streamable_http(scope, receive, send):
129-
if http_transport.mcp_session_id in server_instances:
132+
request = Request(scope, receive)
133+
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
134+
if (
135+
request_mcp_session_id is not None
136+
and request_mcp_session_id in server_instances
137+
):
138+
transport = server_instances[request_mcp_session_id]
130139
logger.debug("Session already exists, handling request directly")
131-
await http_transport.handle_request(scope, receive, send)
140+
await transport.handle_request(scope, receive, send)
141+
elif request_mcp_session_id is None:
142+
# try to establish new session
143+
logger.debug("Creating new transport")
144+
# Use lock to prevent race conditions when creating new sessions
145+
async with session_creation_lock:
146+
new_session_id = uuid4().hex
147+
http_transport = StreamableHTTPServerTransport(
148+
mcp_session_id=new_session_id,
149+
)
150+
async with http_transport.connect() as streams:
151+
read_stream, write_stream = streams
152+
153+
async def run_server():
154+
await app.run(
155+
read_stream,
156+
write_stream,
157+
app.create_initialization_options(),
158+
)
159+
160+
if not task_group:
161+
raise RuntimeError("Task group is not initialized")
162+
163+
# Store the instance before starting the task to prevent races
164+
server_instances[http_transport.mcp_session_id] = http_transport
165+
task_group.start_soon(run_server)
166+
167+
# Handle the HTTP request and return the response
168+
await http_transport.handle_request(scope, receive, send)
132169
else:
133-
# Start new server instance for this session
134-
async with http_transport.connect() as streams:
135-
read_stream, write_stream = streams
136-
137-
async def run_server():
138-
await app.run(
139-
read_stream, write_stream, app.create_initialization_options()
140-
)
141-
142-
if not task_group:
143-
raise RuntimeError("Task group is not initialized")
144-
145-
task_group.start_soon(run_server)
146-
147-
# For initialization requests, store the server reference
148-
if http_transport.mcp_session_id:
149-
server_instances[http_transport.mcp_session_id] = True
150-
151-
# Handle the HTTP request and return the response
152-
await http_transport.handle_request(scope, receive, send)
170+
response = Response(
171+
"Bad Request: No valid session ID provided",
172+
status_code=HTTPStatus.BAD_REQUEST,
173+
)
174+
await response(scope, receive, send)
153175

154176
# Create an ASGI application using the transport
155177
starlette_app = Starlette(

src/mcp/server/streamableHttp.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import json
1111
import logging
12+
import re
1213
from collections.abc import AsyncGenerator
1314
from contextlib import asynccontextmanager
1415
from http import HTTPStatus
@@ -42,6 +43,10 @@
4243
CONTENT_TYPE_JSON = "application/json"
4344
CONTENT_TYPE_SSE = "text/event-stream"
4445

46+
# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
47+
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
48+
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")
49+
4550

4651
class StreamableHTTPServerTransport:
4752
"""
@@ -65,8 +70,20 @@ def __init__(
6570
Initialize a new StreamableHTTP server transport.
6671
6772
Args:
68-
mcp_session_id: Optional session identifier for this connection
73+
mcp_session_id: Optional session identifier for this connection.
74+
Must contain only visible ASCII characters (0x21-0x7E).
75+
76+
Raises:
77+
ValueError: If the session ID contains invalid characters.
6978
"""
79+
if mcp_session_id is not None and (
80+
not SESSION_ID_PATTERN.match(mcp_session_id) or
81+
SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None
82+
):
83+
raise ValueError(
84+
"Session ID must only contain visible ASCII characters (0x21-0x7E)"
85+
)
86+
7087
self.mcp_session_id = mcp_session_id
7188
self._request_streams = {}
7289

@@ -439,7 +456,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
439456
return
440457
if not await self._validate_session(request, send):
441458
return
442-
# TODO : Implement session termination logic
459+
# TODO : Implement session termination logic
443460

444461
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
445462
"""

tests/server/test_streamableHttp.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
import pytest
1414
import requests
1515
import uvicorn
16-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1716
from starlette.applications import Starlette
1817
from starlette.requests import Request
1918
from starlette.responses import Response
2019
from starlette.routing import Route
2120

2221
from mcp.server.streamableHttp import (
2322
MCP_SESSION_ID_HEADER,
23+
SESSION_ID_PATTERN,
2424
StreamableHTTPServerTransport,
2525
)
2626
from mcp.types import JSONRPCMessage
@@ -376,3 +376,61 @@ def test_delete_without_session_support(basic_server, server_url):
376376
response = requests.delete(f"{server_url}/mcp")
377377
assert response.status_code == 405
378378
assert "Method Not Allowed" in response.text
379+
380+
381+
def test_session_id_pattern():
382+
"""Test that SESSION_ID_PATTERN correctly validates session IDs."""
383+
# Valid session IDs (visible ASCII characters from 0x21 to 0x7E)
384+
valid_session_ids = [
385+
"test-session-id",
386+
"1234567890",
387+
"session!@#$%^&*()_+-=[]{}|;:,.<>?/",
388+
"~`",
389+
]
390+
391+
for session_id in valid_session_ids:
392+
assert SESSION_ID_PATTERN.match(session_id) is not None
393+
# Ensure fullmatch matches too (whole string)
394+
assert SESSION_ID_PATTERN.fullmatch(session_id) is not None
395+
396+
# Invalid session IDs
397+
invalid_session_ids = [
398+
"", # Empty string
399+
" test", # Space (0x20)
400+
"test\t", # Tab
401+
"test\n", # Newline
402+
"test\r", # Carriage return
403+
"test" + chr(0x7F), # DEL character
404+
"test" + chr(0x80), # Extended ASCII
405+
"test" + chr(0x00), # Null character
406+
"test" + chr(0x20), # Space (0x20)
407+
]
408+
409+
for session_id in invalid_session_ids:
410+
# For invalid IDs, either match will fail or fullmatch will fail
411+
if SESSION_ID_PATTERN.match(session_id) is not None:
412+
# If match succeeds, fullmatch should fail (partial match case)
413+
assert SESSION_ID_PATTERN.fullmatch(session_id) is None
414+
415+
416+
def test_streamable_http_transport_init_validation():
417+
"""Test that StreamableHTTPServerTransport validates session ID on initialization."""
418+
# Valid session ID should initialize without errors
419+
valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id")
420+
assert valid_transport.mcp_session_id == "valid-id"
421+
422+
# None should be accepted
423+
none_transport = StreamableHTTPServerTransport(mcp_session_id=None)
424+
assert none_transport.mcp_session_id is None
425+
426+
# Invalid session ID should raise ValueError
427+
with pytest.raises(ValueError) as excinfo:
428+
StreamableHTTPServerTransport(mcp_session_id="invalid id with space")
429+
assert "Session ID must only contain visible ASCII characters" in str(excinfo.value)
430+
431+
# Test with control characters
432+
with pytest.raises(ValueError):
433+
StreamableHTTPServerTransport(mcp_session_id="test\nid")
434+
435+
with pytest.raises(ValueError):
436+
StreamableHTTPServerTransport(mcp_session_id="test\n")

0 commit comments

Comments
 (0)