@@ -46,22 +46,16 @@ async def on_list_tools(self) -> list[Tool]:
4646 return []
4747
4848
49- def run_server_with_settings (
50- port : int , security_settings : TransportSecuritySettings | None = None
51- ):
49+ def run_server_with_settings (port : int , security_settings : TransportSecuritySettings | None = None ):
5250 """Run the SSE server with specified security settings."""
5351 app = SecurityTestServer ()
5452 sse_transport = SseServerTransport ("/messages/" , security_settings )
5553
5654 async def handle_sse (request : Request ):
5755 try :
58- async with sse_transport .connect_sse (
59- request .scope , request .receive , request ._send
60- ) as streams :
56+ async with sse_transport .connect_sse (request .scope , request .receive , request ._send ) as streams :
6157 if streams :
62- await app .run (
63- streams [0 ], streams [1 ], app .create_initialization_options ()
64- )
58+ await app .run (streams [0 ], streams [1 ], app .create_initialization_options ())
6559 except ValueError as e :
6660 # Validation error was already handled inside connect_sse
6761 logger .debug (f"SSE connection failed validation: { e } " )
@@ -76,13 +70,9 @@ async def handle_sse(request: Request):
7670 uvicorn .run (starlette_app , host = "127.0.0.1" , port = port , log_level = "error" )
7771
7872
79- def start_server_process (
80- port : int , security_settings : TransportSecuritySettings | None = None
81- ):
73+ def start_server_process (port : int , security_settings : TransportSecuritySettings | None = None ):
8274 """Start server in a separate process."""
83- process = multiprocessing .Process (
84- target = run_server_with_settings , args = (port , security_settings )
85- )
75+ process = multiprocessing .Process (target = run_server_with_settings , args = (port , security_settings ))
8676 process .start ()
8777 # Give server time to start
8878 time .sleep (1 )
@@ -98,9 +88,7 @@ async def test_sse_security_default_settings(server_port: int):
9888 headers = {"Host" : "evil.com" , "Origin" : "http://evil.com" }
9989
10090 async with httpx .AsyncClient (timeout = 5.0 ) as client :
101- async with client .stream (
102- "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
103- ) as response :
91+ async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
10492 assert response .status_code == 200
10593 finally :
10694 process .terminate ()
@@ -111,19 +99,15 @@ async def test_sse_security_default_settings(server_port: int):
11199async def test_sse_security_invalid_host_header (server_port : int ):
112100 """Test SSE with invalid Host header."""
113101 # Enable security by providing settings with an empty allowed_hosts list
114- security_settings = TransportSecuritySettings (
115- enable_dns_rebinding_protection = True , allowed_hosts = ["example.com" ]
116- )
102+ security_settings = TransportSecuritySettings (enable_dns_rebinding_protection = True , allowed_hosts = ["example.com" ])
117103 process = start_server_process (server_port , security_settings )
118104
119105 try :
120106 # Test with invalid host header
121107 headers = {"Host" : "evil.com" }
122108
123109 async with httpx .AsyncClient () as client :
124- response = await client .get (
125- f"http://127.0.0.1:{ server_port } /sse" , headers = headers
126- )
110+ response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
127111 assert response .status_code == 421
128112 assert response .text == "Invalid Host header"
129113
@@ -148,9 +132,7 @@ async def test_sse_security_invalid_origin_header(server_port: int):
148132 headers = {"Origin" : "http://evil.com" }
149133
150134 async with httpx .AsyncClient () as client :
151- response = await client .get (
152- f"http://127.0.0.1:{ server_port } /sse" , headers = headers
153- )
135+ response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
154136 assert response .status_code == 400
155137 assert response .text == "Invalid Origin header"
156138
@@ -207,9 +189,7 @@ async def test_sse_security_disabled(server_port: int):
207189
208190 async with httpx .AsyncClient (timeout = 5.0 ) as client :
209191 # For SSE endpoints, we need to use stream to avoid timeout
210- async with client .stream (
211- "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
212- ) as response :
192+ async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
213193 # Should connect successfully even with invalid host
214194 assert response .status_code == 200
215195
@@ -234,19 +214,15 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
234214
235215 async with httpx .AsyncClient (timeout = 5.0 ) as client :
236216 # For SSE endpoints, we need to use stream to avoid timeout
237- async with client .stream (
238- "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
239- ) as response :
217+ async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
240218 # Should connect successfully with custom host
241219 assert response .status_code == 200
242220
243221 # Test with non-allowed host
244222 headers = {"Host" : "evil.com" }
245223
246224 async with httpx .AsyncClient () as client :
247- response = await client .get (
248- f"http://127.0.0.1:{ server_port } /sse" , headers = headers
249- )
225+ response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
250226 assert response .status_code == 421
251227 assert response .text == "Invalid Host header"
252228
@@ -272,19 +248,15 @@ async def test_sse_security_wildcard_ports(server_port: int):
272248
273249 async with httpx .AsyncClient (timeout = 5.0 ) as client :
274250 # For SSE endpoints, we need to use stream to avoid timeout
275- async with client .stream (
276- "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
277- ) as response :
251+ async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
278252 # Should connect successfully with any port
279253 assert response .status_code == 200
280254
281255 headers = {"Origin" : f"http://localhost:{ test_port } " }
282256
283257 async with httpx .AsyncClient (timeout = 5.0 ) as client :
284258 # For SSE endpoints, we need to use stream to avoid timeout
285- async with client .stream (
286- "GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers
287- ) as response :
259+ async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
288260 # Should connect successfully with any port
289261 assert response .status_code == 200
290262
0 commit comments