1111import logging
1212from collections .abc import AsyncGenerator
1313from contextlib import asynccontextmanager
14+ from http import HTTPStatus
1415from typing import Any
1516
1617import anyio
3334# Maximum size for incoming messages
3435MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB
3536
37+ # Header names
38+ MCP_SESSION_ID_HEADER = "mcp-session-id"
39+ LAST_EVENT_ID_HEADER = "last-event-id"
40+
41+ # Content types
42+ CONTENT_TYPE_JSON = "application/json"
43+ CONTENT_TYPE_SSE = "text/event-stream"
44+
3645
3746class StreamableHTTPServerTransport :
3847 """
@@ -61,6 +70,34 @@ def __init__(
6170 self .mcp_session_id = mcp_session_id
6271 self ._request_streams = {}
6372
73+ def _create_error_response (
74+ self ,
75+ message : str ,
76+ status_code : HTTPStatus ,
77+ headers : dict [str , str ] | None = None ,
78+ ) -> Response :
79+ """
80+ Create a standardized error response.
81+ """
82+ response_headers = {"Content-Type" : CONTENT_TYPE_JSON }
83+ if headers :
84+ response_headers .update (headers )
85+
86+ if self .mcp_session_id :
87+ response_headers [MCP_SESSION_ID_HEADER ] = self .mcp_session_id
88+
89+ return Response (
90+ message ,
91+ status_code = status_code ,
92+ headers = response_headers ,
93+ )
94+
95+ def _get_session_id (self , request : Request ) -> str | None :
96+ """
97+ Extract the session ID from request headers.
98+ """
99+ return request .headers .get (MCP_SESSION_ID_HEADER )
100+
64101 async def handle_request (self , scope : Scope , receive : Receive , send : Send ) -> None :
65102 """
66103 ASGI application entry point that handles all HTTP requests
@@ -80,7 +117,46 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
80117 elif request .method == "DELETE" :
81118 await self ._handle_delete_request (request , send )
82119 else :
83- await self ._handle_unsupported_request (send )
120+ await self ._handle_unsupported_request (request , send )
121+
122+ def _check_accept_headers (self , request : Request ) -> tuple [bool , bool ]:
123+ """
124+ Check if the request accepts the required media types.
125+
126+ Args:
127+ request: The HTTP request
128+
129+ Returns:
130+ Tuple of (has_json, has_sse) indicating whether each media type is accepted
131+ """
132+ accept_header = request .headers .get ("accept" , "" )
133+ accept_types = [media_type .strip () for media_type in accept_header .split ("," )]
134+
135+ has_json = any (
136+ media_type .startswith (CONTENT_TYPE_JSON ) for media_type in accept_types
137+ )
138+ has_sse = any (
139+ media_type .startswith (CONTENT_TYPE_SSE ) for media_type in accept_types
140+ )
141+
142+ return has_json , has_sse
143+
144+ def _check_content_type (self , request : Request ) -> bool :
145+ """
146+ Check if the request has the correct Content-Type.
147+
148+ Args:
149+ request: The HTTP request
150+
151+ Returns:
152+ True if Content-Type is acceptable, False otherwise
153+ """
154+ content_type = request .headers .get ("content-type" , "" )
155+ content_type_parts = [
156+ part .strip () for part in content_type .split (";" )[0 ].split ("," )
157+ ]
158+
159+ return any (part == CONTENT_TYPE_JSON for part in content_type_parts )
84160
85161 async def _handle_post_request (
86162 self , scope : Scope , request : Request , receive : Receive , send : Send
@@ -89,85 +165,75 @@ async def _handle_post_request(
89165 Handles POST requests containing JSON-RPC messages
90166
91167 Args:
92- stream_id: Unique identifier for this stream
93168 scope: ASGI scope
94169 request: Starlette Request object
95170 receive: ASGI receive function
96171 send: ASGI send function
97172 """
98- body = await request .body ()
99173 writer = self ._read_stream_writer
100174 if writer is None :
101175 raise ValueError (
102176 "No read stream writer available. Ensure connect() is called first."
103177 )
104178 return
105179 try :
106- # Validate Accept header
107- accept_header = request .headers .get ("accept" , "" )
108- if (
109- "application/json" not in accept_header
110- or "text/event-stream" not in accept_header
111- ):
112- response = Response (
180+ # Check Accept headers
181+ has_json , has_sse = self ._check_accept_headers (request )
182+ if not (has_json and has_sse ):
183+ response = self ._create_error_response (
113184 (
114185 "Not Acceptable: Client must accept both application/json and "
115186 "text/event-stream"
116187 ),
117- status_code = 406 ,
118- headers = {"Content-Type" : "application/json" },
188+ HTTPStatus .NOT_ACCEPTABLE ,
119189 )
120190 await response (scope , receive , send )
121191 return
122192
123193 # Validate Content-Type
124- content_type = request .headers .get ("content-type" , "" )
125- if "application/json" not in content_type :
126- response = Response (
194+ if not self ._check_content_type (request ):
195+ response = self ._create_error_response (
127196 "Unsupported Media Type: Content-Type must be application/json" ,
128- status_code = 415 ,
129- headers = {"Content-Type" : "application/json" },
197+ HTTPStatus .UNSUPPORTED_MEDIA_TYPE ,
130198 )
131199 await response (scope , receive , send )
132200 return
133201
134- # Parse the body
202+ # Parse the body - only read it once
135203 body = await request .body ()
136204 if len (body ) > MAXIMUM_MESSAGE_SIZE :
137- response = Response (
205+ response = self . _create_error_response (
138206 "Payload Too Large: Message exceeds maximum size" ,
139- status_code = 413 ,
140- headers = {"Content-Type" : "application/json" },
207+ HTTPStatus .REQUEST_ENTITY_TOO_LARGE ,
141208 )
142209 await response (scope , receive , send )
143210 return
144211
145212 try :
146213 raw_message = json .loads (body )
147214 except json .JSONDecodeError as e :
148- response = Response (
215+ response = self . _create_error_response (
149216 f"Parse error: { str (e )} " ,
150- status_code = 400 ,
151- headers = {"Content-Type" : "application/json" },
217+ HTTPStatus .BAD_REQUEST ,
152218 )
153219 await response (scope , receive , send )
154220 return
221+
155222 message = None
156223 try :
157224 message = JSONRPCMessage .model_validate (raw_message )
158225 except ValidationError as e :
159- response = Response (
226+ response = self . _create_error_response (
160227 f"Validation error: { str (e )} " ,
161- status_code = 400 ,
162- headers = {"Content-Type" : "application/json" },
228+ HTTPStatus .BAD_REQUEST ,
163229 )
164230 await response (scope , receive , send )
165231 return
232+
166233 if not message :
167- response = Response (
234+ response = self . _create_error_response (
168235 "Invalid Request: Message is empty" ,
169- status_code = 400 ,
170- headers = {"Content-Type" : "application/json" },
236+ HTTPStatus .BAD_REQUEST ,
171237 )
172238 await response (scope , receive , send )
173239 return
@@ -179,8 +245,19 @@ async def _handle_post_request(
179245 )
180246
181247 if is_initialization_request :
182- # TODO validate
183- logger .info ("INITIALIZATION REQUEST" )
248+ # Check if the server already has an established session
249+ if self .mcp_session_id :
250+ # Check if request has a session ID
251+ request_session_id = self ._get_session_id (request )
252+
253+ # If request has a session ID but doesn't match, return 404
254+ if request_session_id and request_session_id != self .mcp_session_id :
255+ response = self ._create_error_response (
256+ "Not Found: Invalid or expired session ID" ,
257+ HTTPStatus .NOT_FOUND ,
258+ )
259+ await response (scope , receive , send )
260+ return
184261 # For non-initialization requests, validate the session
185262 elif not await self ._validate_session (request , send ):
186263 return
@@ -189,12 +266,11 @@ async def _handle_post_request(
189266
190267 # For notifications and responses only, return 202 Accepted
191268 if not is_request :
192- headers : dict [str , str ] = {}
193- if self .mcp_session_id :
194- headers ["mcp-session-id" ] = self .mcp_session_id
195-
196269 # Create response object and send it
197- response = Response ("Accepted" , status_code = 202 , headers = headers )
270+ response = self ._create_error_response (
271+ "Accepted" ,
272+ HTTPStatus .ACCEPTED ,
273+ )
198274 await response (scope , receive , send )
199275
200276 # Process the message after sending the response
@@ -208,13 +284,11 @@ async def _handle_post_request(
208284 headers = {
209285 "Cache-Control" : "no-cache, no-transform" ,
210286 "Connection" : "keep-alive" ,
287+ "Content-Type" : CONTENT_TYPE_SSE ,
211288 }
212289
213290 if self .mcp_session_id :
214- headers ["mcp-session-id" ] = self .mcp_session_id
215-
216- # For SSE responses, set up SSE stream
217- headers ["Content-Type" ] = "text/event-stream"
291+ headers [MCP_SESSION_ID_HEADER ] = self .mcp_session_id
218292 # Create SSE stream
219293 sse_stream_writer , sse_stream_reader = (
220294 anyio .create_memory_object_stream [dict [str , Any ]](0 )
@@ -306,23 +380,125 @@ async def sse_writer():
306380
307381 except Exception as err :
308382 logger .exception ("Error handling POST request" )
309- response = Response (f"Error handling POST request: { err } " , status_code = 500 )
383+ response = self ._create_error_response (
384+ f"Error handling POST request: { err } " ,
385+ HTTPStatus .INTERNAL_SERVER_ERROR ,
386+ )
310387 await response (scope , receive , send )
311388 if writer :
312389 await writer .send (err )
313390 return
314391
315392 async def _handle_get_request (self , request : Request , send : Send ) -> None :
316- pass
393+ """
394+ Handle GET requests for SSE stream establishment
395+
396+ Args:
397+ request: The HTTP request
398+ send: ASGI send function
399+ """
400+ # Validate session ID if server has one
401+ if not await self ._validate_session (request , send ):
402+ return
403+
404+ # Validate Accept header - must include text/event-stream
405+ _ , has_sse = self ._check_accept_headers (request )
406+
407+ if not has_sse :
408+ response = self ._create_error_response (
409+ "Not Acceptable: Client must accept text/event-stream" ,
410+ HTTPStatus .NOT_ACCEPTABLE ,
411+ )
412+ await response (request .scope , request .receive , send )
413+ return
414+
415+ # TODO: Implement SSE stream for GET requests
416+ # For now, return 501 Not Implemented
417+ response = self ._create_error_response (
418+ "SSE stream from GET request not implemented yet" ,
419+ HTTPStatus .NOT_IMPLEMENTED ,
420+ )
421+ await response (request .scope , request .receive , send )
317422
318423 async def _handle_delete_request (self , request : Request , send : Send ) -> None :
319- pass
424+ """
425+ Handle DELETE requests for explicit session termination
426+
427+ Args:
428+ request: The HTTP request
429+ send: ASGI send function
430+ """
431+ # Validate session ID
432+ if not self .mcp_session_id :
433+ # If no session ID set, return Method Not Allowed
434+ response = self ._create_error_response (
435+ "Method Not Allowed: Session termination not supported" ,
436+ HTTPStatus .METHOD_NOT_ALLOWED ,
437+ )
438+ await response (request .scope , request .receive , send )
439+ return
440+ if not await self ._validate_session (request , send ):
441+ return
442+ # TODO : Implement session termination logic
320443
321- async def _handle_unsupported_request (self , send : Send ) -> None :
322- pass
444+ async def _handle_unsupported_request (self , request : Request , send : Send ) -> None :
445+ """
446+ Handle unsupported HTTP methods
447+
448+ Args:
449+ request: The HTTP request
450+ send: ASGI send function
451+ """
452+ headers = {
453+ "Content-Type" : CONTENT_TYPE_JSON ,
454+ "Allow" : "GET, POST, DELETE" ,
455+ }
456+ if self .mcp_session_id :
457+ headers [MCP_SESSION_ID_HEADER ] = self .mcp_session_id
458+
459+ response = Response (
460+ "Method Not Allowed" ,
461+ status_code = HTTPStatus .METHOD_NOT_ALLOWED ,
462+ headers = headers ,
463+ )
464+ await response (request .scope , request .receive , send )
323465
324466 async def _validate_session (self , request : Request , send : Send ) -> bool :
325- # TODO
467+ """
468+ Validate the session ID in the request.
469+
470+ Args:
471+ request: The HTTP request
472+ send: ASGI send function
473+
474+ Returns:
475+ bool: True if session is valid, False otherwise
476+ """
477+ if not self .mcp_session_id :
478+ # If we're not using session IDs, return True
479+ return True
480+
481+ # Get the session ID from the request headers
482+ request_session_id = self ._get_session_id (request )
483+
484+ # If no session ID provided but required, return error
485+ if not request_session_id :
486+ response = self ._create_error_response (
487+ "Bad Request: Missing session ID" ,
488+ HTTPStatus .BAD_REQUEST ,
489+ )
490+ await response (request .scope , request .receive , send )
491+ return False
492+
493+ # If session ID doesn't match, return error
494+ if request_session_id != self .mcp_session_id :
495+ response = self ._create_error_response (
496+ "Not Found: Invalid or expired session ID" ,
497+ HTTPStatus .NOT_FOUND ,
498+ )
499+ await response (request .scope , request .receive , send )
500+ return False
501+
326502 return True
327503
328504 @asynccontextmanager
0 commit comments