Skip to content

Commit a0c1553

Browse files
committed
fix: proxy resource management and test cleanup
- Remove stream closing from proxy (streams are owned by transport context managers) - Add proper test stream cleanup in all test cases - Fix test_proxy_cleans_up_streams to match actual behavior - Fix formatting (add missing newlines)
1 parent 24e9aae commit a0c1553

File tree

5 files changed

+58
-37
lines changed

5 files changed

+58
-37
lines changed

examples/servers/simple-proxy/mcp_simple_proxy/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,3 @@
33
from .server import main
44

55
__all__ = ["main"]
6-
7-

examples/servers/simple-proxy/mcp_simple_proxy/__main__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,3 @@
66

77
if __name__ == "__main__":
88
sys.exit(main())
9-
10-

examples/servers/simple-proxy/mcp_simple_proxy/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,3 @@ async def run_proxy():
109109

110110
if __name__ == "__main__":
111111
sys.exit(main())
112-
113-

src/mcp/shared/proxy.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -135,33 +135,26 @@ async def forward_messages(
135135
onerror(exc)
136136

137137
async with anyio.create_task_group() as tg:
138-
try:
139-
# Start forwarding tasks in both directions
140-
tg.start_soon(
141-
forward_messages,
142-
client_read,
143-
server_write,
144-
on_client_message,
145-
"client→server",
146-
)
147-
tg.start_soon(
148-
forward_messages,
149-
server_read,
150-
client_write,
151-
on_server_message,
152-
"server→client",
153-
)
138+
# Start forwarding tasks in both directions
139+
tg.start_soon(
140+
forward_messages,
141+
client_read,
142+
server_write,
143+
on_client_message,
144+
"client→server",
145+
)
146+
tg.start_soon(
147+
forward_messages,
148+
server_read,
149+
client_write,
150+
on_server_message,
151+
"server→client",
152+
)
154153

154+
try:
155155
yield
156-
157156
finally:
158-
# Cancel all forwarding tasks
157+
# Cancel all forwarding tasks when the context exits
158+
# Note: We don't close the streams here because they're owned by the caller
159+
# (the transport context managers that created them)
159160
tg.cancel_scope.cancel()
160-
161-
# Close all streams to ensure proper cleanup
162-
# Note: We close streams even if they might already be closed
163-
for stream in [client_read, client_write, server_read, server_write]:
164-
try:
165-
await stream.aclose()
166-
except Exception as exc:
167-
logger.debug(f"Error closing stream during cleanup: {exc}")

tests/shared/test_proxy.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ def _create():
2222
server_write, server_write_reader = anyio.create_memory_object_stream[SessionMessage](10)
2323

2424
# Track all streams for cleanup
25-
streams_to_cleanup.extend(
26-
[client_read_writer, client_write_reader, server_read_writer, server_write_reader]
27-
)
25+
streams_to_cleanup.extend([client_read_writer, client_write_reader, server_read_writer, server_write_reader])
2826

2927
return (
3028
(client_read, client_write),
@@ -62,6 +60,10 @@ async def test_proxy_forwards_client_to_server(create_streams):
6260
assert received.message.root.id == "1"
6361
assert received.message.root.method == "test_method"
6462

63+
# Clean up test streams
64+
await client_read_writer.aclose()
65+
await server_write_reader.aclose()
66+
6567

6668
@pytest.mark.anyio
6769
async def test_proxy_forwards_server_to_client(create_streams):
@@ -82,6 +84,10 @@ async def test_proxy_forwards_server_to_client(create_streams):
8284
assert received.message.root.id == "2"
8385
assert received.message.root.method == "server_method"
8486

87+
# Clean up test streams
88+
await server_read_writer.aclose()
89+
await client_write_reader.aclose()
90+
8591

8692
@pytest.mark.anyio
8793
async def test_proxy_bidirectional_forwarding(create_streams):
@@ -118,6 +124,12 @@ async def test_proxy_bidirectional_forwarding(create_streams):
118124
received_at_client = await client_write_reader.receive()
119125
assert received_at_client.message.root.id == "server_1"
120126

127+
# Clean up test streams
128+
await client_read_writer.aclose()
129+
await client_write_reader.aclose()
130+
await server_read_writer.aclose()
131+
await server_write_reader.aclose()
132+
121133

122134
@pytest.mark.anyio
123135
async def test_proxy_message_transformation(create_streams):
@@ -146,6 +158,10 @@ async def transform_client_message(msg: SessionMessage) -> SessionMessage | None
146158
assert received.message.root.id == "transformed_original"
147159
assert "client" in transformed
148160

161+
# Clean up test streams
162+
await client_read_writer.aclose()
163+
await server_write_reader.aclose()
164+
149165

150166
@pytest.mark.anyio
151167
async def test_proxy_message_dropping(create_streams):
@@ -168,11 +184,15 @@ async def drop_message(msg: SessionMessage) -> SessionMessage | None:
168184
# If we get here, the message was not dropped
169185
pytest.fail("Message should have been dropped")
170186

187+
# Clean up test streams
188+
await client_read_writer.aclose()
189+
await server_write_reader.aclose()
190+
171191

172192
@pytest.mark.anyio
173193
async def test_proxy_error_handling(create_streams):
174194
"""Test that errors are caught and onerror callback is invoked."""
175-
client_streams, server_streams, (client_read_writer, _), (_, _) = create_streams()
195+
client_streams, server_streams, (client_read_writer, _), (_, server_write_reader) = create_streams()
176196

177197
errors = []
178198

@@ -194,6 +214,10 @@ def error_handler(error: Exception) -> None:
194214
assert isinstance(errors[0], ValueError)
195215
assert str(errors[0]) == "Test error"
196216

217+
# Clean up test streams
218+
await client_read_writer.aclose()
219+
await server_write_reader.aclose()
220+
197221

198222
@pytest.mark.anyio
199223
async def test_proxy_continues_after_error(create_streams):
@@ -222,6 +246,10 @@ def error_handler(error: Exception) -> None:
222246
# Error should have been captured
223247
assert len(errors) == 1
224248

249+
# Clean up test streams
250+
await client_read_writer.aclose()
251+
await server_write_reader.aclose()
252+
225253

226254
@pytest.mark.anyio
227255
async def test_proxy_transform_error_handling(create_streams):
@@ -255,6 +283,10 @@ async def failing_transform(msg: SessionMessage) -> SessionMessage | None:
255283
received = await server_write_reader.receive()
256284
pytest.fail("Message should not have been forwarded after transform error")
257285

286+
# Clean up test streams
287+
await client_read_writer.aclose()
288+
await server_write_reader.aclose()
289+
258290

259291
@pytest.mark.anyio
260292
async def test_proxy_cleans_up_streams(create_streams):
@@ -306,4 +338,6 @@ async def test_proxy_multiple_messages(create_streams):
306338
assert received.message.root.id == str(i)
307339
assert received.message.root.method == f"method_{i}"
308340

309-
341+
# Clean up test streams
342+
await client_read_writer.aclose()
343+
await server_write_reader.aclose()

0 commit comments

Comments
 (0)