Skip to content

Commit f505572

Browse files
committed
style: Apply ruff formatting to integration test changes
1 parent d0ec057 commit f505572

File tree

4 files changed

+74
-233
lines changed

4 files changed

+74
-233
lines changed

tests/client/test_stdio.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,14 @@ async def test_stdio_client():
6666
break
6767

6868
assert len(read_messages) == 2
69-
assert read_messages[0] == JSONRPCMessage(
70-
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
71-
)
72-
assert read_messages[1] == JSONRPCMessage(
73-
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
74-
)
69+
assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
70+
assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
7571

7672

7773
@pytest.mark.anyio
7874
async def test_stdio_client_bad_path():
7975
"""Check that the connection doesn't hang if process errors."""
80-
server_params = StdioServerParameters(
81-
command="python", args=["-c", "non-existent-file.py"]
82-
)
76+
server_params = StdioServerParameters(command="python", args=["-c", "non-existent-file.py"])
8377
async with stdio_client(server_params) as (read_stream, write_stream):
8478
async with ClientSession(read_stream, write_stream) as session:
8579
# The session should raise an error when the connection closes
@@ -167,9 +161,7 @@ async def test_stdio_client_universal_cleanup():
167161

168162

169163
@pytest.mark.anyio
170-
@pytest.mark.skipif(
171-
sys.platform == "win32", reason="Windows signal handling is different"
172-
)
164+
@pytest.mark.skipif(sys.platform == "win32", reason="Windows signal handling is different")
173165
async def test_stdio_client_sigint_only_process():
174166
"""
175167
Test cleanup with a process that ignores SIGTERM but responds to SIGINT.
@@ -262,9 +254,7 @@ class TestChildProcessCleanup:
262254
"""
263255

264256
@pytest.mark.anyio
265-
@pytest.mark.filterwarnings(
266-
"ignore::ResourceWarning" if sys.platform == "win32" else "default"
267-
)
257+
@pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default")
268258
async def test_basic_child_process_cleanup(self):
269259
"""
270260
Test basic parent-child process cleanup.
@@ -313,9 +303,7 @@ async def test_basic_child_process_cleanup(self):
313303
print("\nStarting child process termination test...")
314304

315305
# Start the parent process
316-
proc = await _create_platform_compatible_process(
317-
sys.executable, ["-c", parent_script]
318-
)
306+
proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script])
319307

320308
# Wait for processes to start
321309
await anyio.sleep(0.5)
@@ -329,9 +317,7 @@ async def test_basic_child_process_cleanup(self):
329317
await anyio.sleep(0.3)
330318
size_after_wait = os.path.getsize(marker_file)
331319
assert size_after_wait > initial_size, "Child process should be writing"
332-
print(
333-
f"Child is writing (file grew from {initial_size} to {size_after_wait} bytes)"
334-
)
320+
print(f"Child is writing (file grew from {initial_size} to {size_after_wait} bytes)")
335321

336322
# Terminate using our function
337323
print("Terminating process and children...")
@@ -347,9 +333,9 @@ async def test_basic_child_process_cleanup(self):
347333
final_size = os.path.getsize(marker_file)
348334

349335
print(f"After cleanup: file size {size_after_cleanup} -> {final_size}")
350-
assert (
351-
final_size == size_after_cleanup
352-
), f"Child process still running! File grew by {final_size - size_after_cleanup} bytes"
336+
assert final_size == size_after_cleanup, (
337+
f"Child process still running! File grew by {final_size - size_after_cleanup} bytes"
338+
)
353339

354340
print("SUCCESS: Child process was properly terminated")
355341

@@ -362,9 +348,7 @@ async def test_basic_child_process_cleanup(self):
362348
pass
363349

364350
@pytest.mark.anyio
365-
@pytest.mark.filterwarnings(
366-
"ignore::ResourceWarning" if sys.platform == "win32" else "default"
367-
)
351+
@pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default")
368352
async def test_nested_process_tree(self):
369353
"""
370354
Test nested process tree cleanup (parent → child → grandchild).
@@ -424,9 +408,7 @@ async def test_nested_process_tree(self):
424408
)
425409

426410
# Start the parent process
427-
proc = await _create_platform_compatible_process(
428-
sys.executable, ["-c", parent_script]
429-
)
411+
proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script])
430412

431413
# Let all processes start
432414
await anyio.sleep(1.0)
@@ -472,9 +454,7 @@ async def test_nested_process_tree(self):
472454
pass
473455

474456
@pytest.mark.anyio
475-
@pytest.mark.filterwarnings(
476-
"ignore::ResourceWarning" if sys.platform == "win32" else "default"
477-
)
457+
@pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default")
478458
async def test_early_parent_exit(self):
479459
"""
480460
Test cleanup when parent exits during termination sequence.
@@ -518,9 +498,7 @@ def handle_term(sig, frame):
518498
)
519499

520500
# Start the parent process
521-
proc = await _create_platform_compatible_process(
522-
sys.executable, ["-c", parent_script]
523-
)
501+
proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script])
524502

525503
# Let child start writing
526504
await anyio.sleep(0.5)

tests/server/test_sse_security.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,16 @@ async def on_list_tools(self) -> list[Tool]:
4646
return []
4747

4848

49-
def run_server_with_settings(
50-
port: int, security_settings: TransportSecuritySettings | None = None
51-
):
49+
def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None):
5250
"""Run the SSE server with specified security settings."""
5351
app = SecurityTestServer()
5452
sse_transport = SseServerTransport("/messages/", security_settings)
5553

5654
async def handle_sse(request: Request):
5755
try:
58-
async with sse_transport.connect_sse(
59-
request.scope, request.receive, request._send
60-
) as streams:
56+
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
6157
if streams:
62-
await app.run(
63-
streams[0], streams[1], app.create_initialization_options()
64-
)
58+
await app.run(streams[0], streams[1], app.create_initialization_options())
6559
except ValueError as e:
6660
# Validation error was already handled inside connect_sse
6761
logger.debug(f"SSE connection failed validation: {e}")
@@ -76,13 +70,9 @@ async def handle_sse(request: Request):
7670
uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error")
7771

7872

79-
def start_server_process(
80-
port: int, security_settings: TransportSecuritySettings | None = None
81-
):
73+
def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None):
8274
"""Start server in a separate process."""
83-
process = multiprocessing.Process(
84-
target=run_server_with_settings, args=(port, security_settings)
85-
)
75+
process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings))
8676
process.start()
8777
# Give server time to start
8878
time.sleep(1)
@@ -98,9 +88,7 @@ async def test_sse_security_default_settings(server_port: int):
9888
headers = {"Host": "evil.com", "Origin": "http://evil.com"}
9989

10090
async with httpx.AsyncClient(timeout=5.0) as client:
101-
async with client.stream(
102-
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
103-
) as response:
91+
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
10492
assert response.status_code == 200
10593
finally:
10694
process.terminate()
@@ -111,19 +99,15 @@ async def test_sse_security_default_settings(server_port: int):
11199
async def test_sse_security_invalid_host_header(server_port: int):
112100
"""Test SSE with invalid Host header."""
113101
# Enable security by providing settings with an empty allowed_hosts list
114-
security_settings = TransportSecuritySettings(
115-
enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]
116-
)
102+
security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"])
117103
process = start_server_process(server_port, security_settings)
118104

119105
try:
120106
# Test with invalid host header
121107
headers = {"Host": "evil.com"}
122108

123109
async with httpx.AsyncClient() as client:
124-
response = await client.get(
125-
f"http://127.0.0.1:{server_port}/sse", headers=headers
126-
)
110+
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
127111
assert response.status_code == 421
128112
assert response.text == "Invalid Host header"
129113

@@ -148,9 +132,7 @@ async def test_sse_security_invalid_origin_header(server_port: int):
148132
headers = {"Origin": "http://evil.com"}
149133

150134
async with httpx.AsyncClient() as client:
151-
response = await client.get(
152-
f"http://127.0.0.1:{server_port}/sse", headers=headers
153-
)
135+
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
154136
assert response.status_code == 400
155137
assert response.text == "Invalid Origin header"
156138

@@ -207,9 +189,7 @@ async def test_sse_security_disabled(server_port: int):
207189

208190
async with httpx.AsyncClient(timeout=5.0) as client:
209191
# For SSE endpoints, we need to use stream to avoid timeout
210-
async with client.stream(
211-
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
212-
) as response:
192+
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
213193
# Should connect successfully even with invalid host
214194
assert response.status_code == 200
215195

@@ -234,19 +214,15 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
234214

235215
async with httpx.AsyncClient(timeout=5.0) as client:
236216
# For SSE endpoints, we need to use stream to avoid timeout
237-
async with client.stream(
238-
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
239-
) as response:
217+
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
240218
# Should connect successfully with custom host
241219
assert response.status_code == 200
242220

243221
# Test with non-allowed host
244222
headers = {"Host": "evil.com"}
245223

246224
async with httpx.AsyncClient() as client:
247-
response = await client.get(
248-
f"http://127.0.0.1:{server_port}/sse", headers=headers
249-
)
225+
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
250226
assert response.status_code == 421
251227
assert response.text == "Invalid Host header"
252228

@@ -272,19 +248,15 @@ async def test_sse_security_wildcard_ports(server_port: int):
272248

273249
async with httpx.AsyncClient(timeout=5.0) as client:
274250
# For SSE endpoints, we need to use stream to avoid timeout
275-
async with client.stream(
276-
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
277-
) as response:
251+
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
278252
# Should connect successfully with any port
279253
assert response.status_code == 200
280254

281255
headers = {"Origin": f"http://localhost:{test_port}"}
282256

283257
async with httpx.AsyncClient(timeout=5.0) as client:
284258
# For SSE endpoints, we need to use stream to avoid timeout
285-
async with client.stream(
286-
"GET", f"http://127.0.0.1:{server_port}/sse", headers=headers
287-
) as response:
259+
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
288260
# Should connect successfully with any port
289261
assert response.status_code == 200
290262

0 commit comments

Comments
 (0)