Skip to content

Commit 3d790f8

Browse files
committed
add request validation and tests
1 parent 2b95598 commit 3d790f8

File tree

2 files changed

+601
-47
lines changed

2 files changed

+601
-47
lines changed

src/mcp/server/streamableHttp.py

Lines changed: 223 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212
from collections.abc import AsyncGenerator
1313
from contextlib import asynccontextmanager
14+
from http import HTTPStatus
1415
from typing import Any
1516

1617
import anyio
@@ -33,6 +34,14 @@
3334
# Maximum size for incoming messages
3435
MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB
3536

37+
# Header names
38+
MCP_SESSION_ID_HEADER = "mcp-session-id"
39+
LAST_EVENT_ID_HEADER = "last-event-id"
40+
41+
# Content types
42+
CONTENT_TYPE_JSON = "application/json"
43+
CONTENT_TYPE_SSE = "text/event-stream"
44+
3645

3746
class StreamableHTTPServerTransport:
3847
"""
@@ -61,6 +70,34 @@ def __init__(
6170
self.mcp_session_id = mcp_session_id
6271
self._request_streams = {}
6372

73+
def _create_error_response(
74+
self,
75+
message: str,
76+
status_code: HTTPStatus,
77+
headers: dict[str, str] | None = None,
78+
) -> Response:
79+
"""
80+
Create a standardized error response.
81+
"""
82+
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
83+
if headers:
84+
response_headers.update(headers)
85+
86+
if self.mcp_session_id:
87+
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
88+
89+
return Response(
90+
message,
91+
status_code=status_code,
92+
headers=response_headers,
93+
)
94+
95+
def _get_session_id(self, request: Request) -> str | None:
96+
"""
97+
Extract the session ID from request headers.
98+
"""
99+
return request.headers.get(MCP_SESSION_ID_HEADER)
100+
64101
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
65102
"""
66103
ASGI application entry point that handles all HTTP requests
@@ -80,7 +117,46 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
80117
elif request.method == "DELETE":
81118
await self._handle_delete_request(request, send)
82119
else:
83-
await self._handle_unsupported_request(send)
120+
await self._handle_unsupported_request(request, send)
121+
122+
def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
123+
"""
124+
Check if the request accepts the required media types.
125+
126+
Args:
127+
request: The HTTP request
128+
129+
Returns:
130+
Tuple of (has_json, has_sse) indicating whether each media type is accepted
131+
"""
132+
accept_header = request.headers.get("accept", "")
133+
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
134+
135+
has_json = any(
136+
media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types
137+
)
138+
has_sse = any(
139+
media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types
140+
)
141+
142+
return has_json, has_sse
143+
144+
def _check_content_type(self, request: Request) -> bool:
145+
"""
146+
Check if the request has the correct Content-Type.
147+
148+
Args:
149+
request: The HTTP request
150+
151+
Returns:
152+
True if Content-Type is acceptable, False otherwise
153+
"""
154+
content_type = request.headers.get("content-type", "")
155+
content_type_parts = [
156+
part.strip() for part in content_type.split(";")[0].split(",")
157+
]
158+
159+
return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
84160

85161
async def _handle_post_request(
86162
self, scope: Scope, request: Request, receive: Receive, send: Send
@@ -89,85 +165,75 @@ async def _handle_post_request(
89165
Handles POST requests containing JSON-RPC messages
90166
91167
Args:
92-
stream_id: Unique identifier for this stream
93168
scope: ASGI scope
94169
request: Starlette Request object
95170
receive: ASGI receive function
96171
send: ASGI send function
97172
"""
98-
body = await request.body()
99173
writer = self._read_stream_writer
100174
if writer is None:
101175
raise ValueError(
102176
"No read stream writer available. Ensure connect() is called first."
103177
)
104178
return
105179
try:
106-
# Validate Accept header
107-
accept_header = request.headers.get("accept", "")
108-
if (
109-
"application/json" not in accept_header
110-
or "text/event-stream" not in accept_header
111-
):
112-
response = Response(
180+
# Check Accept headers
181+
has_json, has_sse = self._check_accept_headers(request)
182+
if not (has_json and has_sse):
183+
response = self._create_error_response(
113184
(
114185
"Not Acceptable: Client must accept both application/json and "
115186
"text/event-stream"
116187
),
117-
status_code=406,
118-
headers={"Content-Type": "application/json"},
188+
HTTPStatus.NOT_ACCEPTABLE,
119189
)
120190
await response(scope, receive, send)
121191
return
122192

123193
# Validate Content-Type
124-
content_type = request.headers.get("content-type", "")
125-
if "application/json" not in content_type:
126-
response = Response(
194+
if not self._check_content_type(request):
195+
response = self._create_error_response(
127196
"Unsupported Media Type: Content-Type must be application/json",
128-
status_code=415,
129-
headers={"Content-Type": "application/json"},
197+
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
130198
)
131199
await response(scope, receive, send)
132200
return
133201

134-
# Parse the body
202+
# Parse the body - only read it once
135203
body = await request.body()
136204
if len(body) > MAXIMUM_MESSAGE_SIZE:
137-
response = Response(
205+
response = self._create_error_response(
138206
"Payload Too Large: Message exceeds maximum size",
139-
status_code=413,
140-
headers={"Content-Type": "application/json"},
207+
HTTPStatus.REQUEST_ENTITY_TOO_LARGE,
141208
)
142209
await response(scope, receive, send)
143210
return
144211

145212
try:
146213
raw_message = json.loads(body)
147214
except json.JSONDecodeError as e:
148-
response = Response(
215+
response = self._create_error_response(
149216
f"Parse error: {str(e)}",
150-
status_code=400,
151-
headers={"Content-Type": "application/json"},
217+
HTTPStatus.BAD_REQUEST,
152218
)
153219
await response(scope, receive, send)
154220
return
221+
155222
message = None
156223
try:
157224
message = JSONRPCMessage.model_validate(raw_message)
158225
except ValidationError as e:
159-
response = Response(
226+
response = self._create_error_response(
160227
f"Validation error: {str(e)}",
161-
status_code=400,
162-
headers={"Content-Type": "application/json"},
228+
HTTPStatus.BAD_REQUEST,
163229
)
164230
await response(scope, receive, send)
165231
return
232+
166233
if not message:
167-
response = Response(
234+
response = self._create_error_response(
168235
"Invalid Request: Message is empty",
169-
status_code=400,
170-
headers={"Content-Type": "application/json"},
236+
HTTPStatus.BAD_REQUEST,
171237
)
172238
await response(scope, receive, send)
173239
return
@@ -179,8 +245,19 @@ async def _handle_post_request(
179245
)
180246

181247
if is_initialization_request:
182-
# TODO validate
183-
logger.info("INITIALIZATION REQUEST")
248+
# Check if the server already has an established session
249+
if self.mcp_session_id:
250+
# Check if request has a session ID
251+
request_session_id = self._get_session_id(request)
252+
253+
# If request has a session ID but doesn't match, return 404
254+
if request_session_id and request_session_id != self.mcp_session_id:
255+
response = self._create_error_response(
256+
"Not Found: Invalid or expired session ID",
257+
HTTPStatus.NOT_FOUND,
258+
)
259+
await response(scope, receive, send)
260+
return
184261
# For non-initialization requests, validate the session
185262
elif not await self._validate_session(request, send):
186263
return
@@ -189,12 +266,11 @@ async def _handle_post_request(
189266

190267
# For notifications and responses only, return 202 Accepted
191268
if not is_request:
192-
headers: dict[str, str] = {}
193-
if self.mcp_session_id:
194-
headers["mcp-session-id"] = self.mcp_session_id
195-
196269
# Create response object and send it
197-
response = Response("Accepted", status_code=202, headers=headers)
270+
response = self._create_error_response(
271+
"Accepted",
272+
HTTPStatus.ACCEPTED,
273+
)
198274
await response(scope, receive, send)
199275

200276
# Process the message after sending the response
@@ -208,13 +284,11 @@ async def _handle_post_request(
208284
headers = {
209285
"Cache-Control": "no-cache, no-transform",
210286
"Connection": "keep-alive",
287+
"Content-Type": CONTENT_TYPE_SSE,
211288
}
212289

213290
if self.mcp_session_id:
214-
headers["mcp-session-id"] = self.mcp_session_id
215-
216-
# For SSE responses, set up SSE stream
217-
headers["Content-Type"] = "text/event-stream"
291+
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
218292
# Create SSE stream
219293
sse_stream_writer, sse_stream_reader = (
220294
anyio.create_memory_object_stream[dict[str, Any]](0)
@@ -306,23 +380,125 @@ async def sse_writer():
306380

307381
except Exception as err:
308382
logger.exception("Error handling POST request")
309-
response = Response(f"Error handling POST request: {err}", status_code=500)
383+
response = self._create_error_response(
384+
f"Error handling POST request: {err}",
385+
HTTPStatus.INTERNAL_SERVER_ERROR,
386+
)
310387
await response(scope, receive, send)
311388
if writer:
312389
await writer.send(err)
313390
return
314391

315392
async def _handle_get_request(self, request: Request, send: Send) -> None:
316-
pass
393+
"""
394+
Handle GET requests for SSE stream establishment
395+
396+
Args:
397+
request: The HTTP request
398+
send: ASGI send function
399+
"""
400+
# Validate session ID if server has one
401+
if not await self._validate_session(request, send):
402+
return
403+
404+
# Validate Accept header - must include text/event-stream
405+
_, has_sse = self._check_accept_headers(request)
406+
407+
if not has_sse:
408+
response = self._create_error_response(
409+
"Not Acceptable: Client must accept text/event-stream",
410+
HTTPStatus.NOT_ACCEPTABLE,
411+
)
412+
await response(request.scope, request.receive, send)
413+
return
414+
415+
# TODO: Implement SSE stream for GET requests
416+
# For now, return 501 Not Implemented
417+
response = self._create_error_response(
418+
"SSE stream from GET request not implemented yet",
419+
HTTPStatus.NOT_IMPLEMENTED,
420+
)
421+
await response(request.scope, request.receive, send)
317422

318423
async def _handle_delete_request(self, request: Request, send: Send) -> None:
319-
pass
424+
"""
425+
Handle DELETE requests for explicit session termination
426+
427+
Args:
428+
request: The HTTP request
429+
send: ASGI send function
430+
"""
431+
# Validate session ID
432+
if not self.mcp_session_id:
433+
# If no session ID set, return Method Not Allowed
434+
response = self._create_error_response(
435+
"Method Not Allowed: Session termination not supported",
436+
HTTPStatus.METHOD_NOT_ALLOWED,
437+
)
438+
await response(request.scope, request.receive, send)
439+
return
440+
if not await self._validate_session(request, send):
441+
return
442+
# TODO : Implement session termination logic
320443

321-
async def _handle_unsupported_request(self, send: Send) -> None:
322-
pass
444+
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
445+
"""
446+
Handle unsupported HTTP methods
447+
448+
Args:
449+
request: The HTTP request
450+
send: ASGI send function
451+
"""
452+
headers = {
453+
"Content-Type": CONTENT_TYPE_JSON,
454+
"Allow": "GET, POST, DELETE",
455+
}
456+
if self.mcp_session_id:
457+
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
458+
459+
response = Response(
460+
"Method Not Allowed",
461+
status_code=HTTPStatus.METHOD_NOT_ALLOWED,
462+
headers=headers,
463+
)
464+
await response(request.scope, request.receive, send)
323465

324466
async def _validate_session(self, request: Request, send: Send) -> bool:
325-
# TODO
467+
"""
468+
Validate the session ID in the request.
469+
470+
Args:
471+
request: The HTTP request
472+
send: ASGI send function
473+
474+
Returns:
475+
bool: True if session is valid, False otherwise
476+
"""
477+
if not self.mcp_session_id:
478+
# If we're not using session IDs, return True
479+
return True
480+
481+
# Get the session ID from the request headers
482+
request_session_id = self._get_session_id(request)
483+
484+
# If no session ID provided but required, return error
485+
if not request_session_id:
486+
response = self._create_error_response(
487+
"Bad Request: Missing session ID",
488+
HTTPStatus.BAD_REQUEST,
489+
)
490+
await response(request.scope, request.receive, send)
491+
return False
492+
493+
# If session ID doesn't match, return error
494+
if request_session_id != self.mcp_session_id:
495+
response = self._create_error_response(
496+
"Not Found: Invalid or expired session ID",
497+
HTTPStatus.NOT_FOUND,
498+
)
499+
await response(request.scope, request.receive, send)
500+
return False
501+
326502
return True
327503

328504
@asynccontextmanager

0 commit comments

Comments
 (0)