11from __future__ import annotations
22
3+ from collections .abc import Awaitable , Callable
4+ from unittest .mock import AsyncMock
35from uuid import uuid4
46
57import anyio
68import pytest
9+ from pydantic import ValidationError
10+ from starlette .responses import Response
711from starlette .types import Message
812
913from mcp .server .sse import SseServerTransport
1014from mcp .shared .message import SessionMessage
1115
1216
17+ def make_receive (body : bytes ) -> Callable [[], Awaitable [Message ]]:
18+ async def receive () -> Message :
19+ return {"type" : "http.request" , "body" : body , "more_body" : False }
20+
21+ return receive
22+
23+
1324@pytest .mark .anyio
1425async def test_sse_max_body_bytes_rejects_large_request ():
1526 sse_transport = SseServerTransport ("/messages/" , max_body_bytes = 10 )
@@ -38,8 +49,7 @@ async def send(message: Message):
3849
3950 body = b'{"a":"' + (b"x" * 20 ) + b'"}'
4051
41- async def receive ():
42- return {"type" : "http.request" , "body" : body , "more_body" : False } # pragma: no cover
52+ receive = make_receive (body )
4353
4454 await sse_transport .handle_post_message (scope , receive , send )
4555
@@ -53,3 +63,209 @@ async def receive():
5363 finally :
5464 await writer .aclose ()
5565 await reader .aclose ()
66+
67+
68+ @pytest .mark .anyio
69+ async def test_sse_handle_post_message_short_circuits_on_security_error ():
70+ sse_transport = SseServerTransport ("/messages/" )
71+ sse_transport ._security .validate_request = AsyncMock (return_value = Response ("blocked" , status_code = 403 )) # type: ignore[method-assign]
72+
73+ sent_messages : list [Message ] = []
74+ response_body = b""
75+
76+ async def send (message : Message ):
77+ nonlocal response_body
78+ sent_messages .append (message )
79+ if message ["type" ] == "http.response.body" :
80+ response_body += message .get ("body" , b"" )
81+
82+ scope = {
83+ "type" : "http" ,
84+ "method" : "POST" ,
85+ "path" : "/messages/" ,
86+ "query_string" : b"" ,
87+ "headers" : [(b"content-type" , b"application/json" )],
88+ }
89+ receive = make_receive (b"{}" )
90+
91+ await sse_transport .handle_post_message (scope , receive , send )
92+
93+ response_start = next ((msg for msg in sent_messages if msg ["type" ] == "http.response.start" ), None )
94+ assert response_start is not None , "Should have sent a response"
95+ assert response_start ["status" ] == 403
96+ assert response_body == b"blocked"
97+
98+
99+ @pytest .mark .anyio
100+ async def test_sse_handle_post_message_returns_400_when_session_id_missing ():
101+ sse_transport = SseServerTransport ("/messages/" )
102+
103+ sent_messages : list [Message ] = []
104+ response_body = b""
105+
106+ async def send (message : Message ):
107+ nonlocal response_body
108+ sent_messages .append (message )
109+ if message ["type" ] == "http.response.body" :
110+ response_body += message .get ("body" , b"" )
111+
112+ scope = {
113+ "type" : "http" ,
114+ "method" : "POST" ,
115+ "path" : "/messages/" ,
116+ "query_string" : b"" ,
117+ "headers" : [(b"content-type" , b"application/json" )],
118+ }
119+ receive = make_receive (b"{}" )
120+
121+ await sse_transport .handle_post_message (scope , receive , send )
122+
123+ response_start = next ((msg for msg in sent_messages if msg ["type" ] == "http.response.start" ), None )
124+ assert response_start is not None , "Should have sent a response"
125+ assert response_start ["status" ] == 400
126+ assert response_body == b"session_id is required"
127+
128+
129+ @pytest .mark .anyio
130+ async def test_sse_handle_post_message_returns_400_when_session_id_invalid ():
131+ sse_transport = SseServerTransport ("/messages/" )
132+
133+ sent_messages : list [Message ] = []
134+ response_body = b""
135+
136+ async def send (message : Message ):
137+ nonlocal response_body
138+ sent_messages .append (message )
139+ if message ["type" ] == "http.response.body" :
140+ response_body += message .get ("body" , b"" )
141+
142+ scope = {
143+ "type" : "http" ,
144+ "method" : "POST" ,
145+ "path" : "/messages/" ,
146+ "query_string" : b"session_id=not-a-uuid" ,
147+ "headers" : [(b"content-type" , b"application/json" )],
148+ }
149+ receive = make_receive (b"{}" )
150+
151+ await sse_transport .handle_post_message (scope , receive , send )
152+
153+ response_start = next ((msg for msg in sent_messages if msg ["type" ] == "http.response.start" ), None )
154+ assert response_start is not None , "Should have sent a response"
155+ assert response_start ["status" ] == 400
156+ assert response_body == b"Invalid session ID"
157+
158+
159+ @pytest .mark .anyio
160+ async def test_sse_handle_post_message_returns_404_when_session_not_found ():
161+ sse_transport = SseServerTransport ("/messages/" )
162+
163+ sent_messages : list [Message ] = []
164+ response_body = b""
165+
166+ async def send (message : Message ):
167+ nonlocal response_body
168+ sent_messages .append (message )
169+ if message ["type" ] == "http.response.body" :
170+ response_body += message .get ("body" , b"" )
171+
172+ scope = {
173+ "type" : "http" ,
174+ "method" : "POST" ,
175+ "path" : "/messages/" ,
176+ "query_string" : f"session_id={ uuid4 ().hex } " .encode (),
177+ "headers" : [(b"content-type" , b"application/json" )],
178+ }
179+ receive = make_receive (b"{}" )
180+
181+ await sse_transport .handle_post_message (scope , receive , send )
182+
183+ response_start = next ((msg for msg in sent_messages if msg ["type" ] == "http.response.start" ), None )
184+ assert response_start is not None , "Should have sent a response"
185+ assert response_start ["status" ] == 404
186+ assert response_body == b"Could not find session"
187+
188+
189+ @pytest .mark .anyio
190+ async def test_sse_handle_post_message_returns_400_and_sends_error_on_invalid_jsonrpc ():
191+ sse_transport = SseServerTransport ("/messages/" , max_body_bytes = 1_000 )
192+
193+ session_id = uuid4 ()
194+ writer , reader = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
195+ try :
196+ sse_transport ._read_stream_writers [session_id ] = writer
197+
198+ sent_messages : list [Message ] = []
199+ response_body = b""
200+
201+ async def send (message : Message ):
202+ nonlocal response_body
203+ sent_messages .append (message )
204+ if message ["type" ] == "http.response.body" :
205+ response_body += message .get ("body" , b"" )
206+
207+ scope = {
208+ "type" : "http" ,
209+ "method" : "POST" ,
210+ "path" : "/messages/" ,
211+ "query_string" : f"session_id={ session_id .hex } " .encode (),
212+ "headers" : [(b"content-type" , b"application/json" )],
213+ }
214+ receive = make_receive (b"{}" )
215+
216+ await sse_transport .handle_post_message (scope , receive , send )
217+
218+ response_start = next ((msg for msg in sent_messages if msg ["type" ] == "http.response.start" ), None )
219+ assert response_start is not None , "Should have sent a response"
220+ assert response_start ["status" ] == 400
221+ assert response_body == b"Could not parse message"
222+
223+ err = await reader .receive ()
224+ assert isinstance (err , ValidationError )
225+ finally :
226+ await writer .aclose ()
227+ await reader .aclose ()
228+
229+
230+ @pytest .mark .anyio
231+ async def test_sse_handle_post_message_accepts_valid_jsonrpc_and_sends_session_message ():
232+ sse_transport = SseServerTransport ("/messages/" , max_body_bytes = 1_000 )
233+
234+ session_id = uuid4 ()
235+ writer , reader = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
236+ try :
237+ sse_transport ._read_stream_writers [session_id ] = writer
238+
239+ sent_messages : list [Message ] = []
240+ response_body = b""
241+
242+ async def send (message : Message ):
243+ nonlocal response_body
244+ sent_messages .append (message )
245+ if message ["type" ] == "http.response.body" :
246+ response_body += message .get ("body" , b"" )
247+
248+ scope = {
249+ "type" : "http" ,
250+ "method" : "POST" ,
251+ "path" : "/messages/" ,
252+ "query_string" : f"session_id={ session_id .hex } " .encode (),
253+ "headers" : [(b"content-type" , b"application/json" )],
254+ }
255+
256+ body = b'{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}'
257+ receive = make_receive (body )
258+
259+ await sse_transport .handle_post_message (scope , receive , send )
260+
261+ response_start = next ((msg for msg in sent_messages if msg ["type" ] == "http.response.start" ), None )
262+ assert response_start is not None , "Should have sent a response"
263+ assert response_start ["status" ] == 202
264+ assert response_body == b"Accepted"
265+
266+ session_message = await reader .receive ()
267+ assert isinstance (session_message , SessionMessage )
268+ assert getattr (session_message .message , "method" , None ) == "initialize"
269+ finally :
270+ await writer .aclose ()
271+ await reader .aclose ()
0 commit comments