Skip to content

Commit c9de04b

Browse files
committed
feat(socket-transport): improve cancellation handling and cleanup
- Add graceful cancellation handling for socket reader/writer - Implement timeout-based cleanup mechanism (5s timeout) - Add force process termination for hanging cleanup - Improve server-side resource cleanup - Add comprehensive tests for cancellation and cleanup scenarios - Fix Python interpreter path resolution in tests
1 parent 98ede50 commit c9de04b

File tree

3 files changed

+257
-12
lines changed

3 files changed

+257
-12
lines changed

src/mcp/client/socket_transport.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ async def socket_reader():
163163
continue
164164
except anyio.ClosedResourceError:
165165
await anyio.lowlevel.checkpoint()
166+
except anyio.get_cancelled_exc_class():
167+
# Handle cancellation gracefully
168+
logger.info("Socket reader cancelled")
169+
return
166170
except Exception as e:
167171
logger.error(f"Error in socket reader: {e}")
168172
raise
@@ -181,6 +185,10 @@ async def socket_writer():
181185
await stream.send(data)
182186
except anyio.ClosedResourceError:
183187
await anyio.lowlevel.checkpoint()
188+
except anyio.get_cancelled_exc_class():
189+
# Handle cancellation gracefully
190+
logger.info("Socket writer cancelled")
191+
return
184192
except Exception as e:
185193
logger.error(f"Error in socket writer: {e}")
186194
raise
@@ -201,18 +209,29 @@ async def socket_writer():
201209
async with process, stream:
202210
yield read_stream, write_stream
203211
finally:
204-
# Cancel all tasks and clean up
212+
# Cancel all tasks and clean up with timeout
205213
tg.cancel_scope.cancel()
206-
# Clean up process to prevent any dangling orphaned processes
214+
215+
# Force cleanup with timeout to prevent hanging
207216
try:
208-
process.terminate()
209-
except ProcessLookupError:
210-
# Process already exited, which is fine
211-
pass
212-
await read_stream.aclose()
213-
await write_stream.aclose()
214-
await read_stream_writer.aclose()
215-
await write_stream_reader.aclose()
217+
with anyio.fail_after(5.0): # 5 second timeout for cleanup
218+
# Clean up process to prevent any dangling orphaned processes
219+
try:
220+
process.terminate()
221+
except ProcessLookupError:
222+
# Process already exited, which is fine
223+
pass
224+
await read_stream.aclose()
225+
await write_stream.aclose()
226+
await read_stream_writer.aclose()
227+
await write_stream_reader.aclose()
228+
except anyio.get_cancelled_exc_class():
229+
# If cleanup times out, force kill the process
230+
logger.warning("Cleanup timed out, force killing process")
231+
try:
232+
process.kill()
233+
except ProcessLookupError:
234+
pass
216235

217236
finally:
218237
# Clean up process

src/mcp/server/socket_transport.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ async def socket_reader():
103103
continue
104104
except anyio.ClosedResourceError:
105105
await anyio.lowlevel.checkpoint()
106+
except anyio.get_cancelled_exc_class():
107+
# Handle cancellation gracefully
108+
logger.info("Socket reader cancelled")
109+
return
110+
except Exception as e:
111+
logger.error(f"Error in socket reader: {e}")
112+
raise
106113
finally:
107114
await stream.aclose()
108115

@@ -118,6 +125,13 @@ async def socket_writer():
118125
await stream.send(data)
119126
except anyio.ClosedResourceError:
120127
await anyio.lowlevel.checkpoint()
128+
except anyio.get_cancelled_exc_class():
129+
# Handle cancellation gracefully
130+
logger.info("Socket writer cancelled")
131+
return
132+
except Exception as e:
133+
logger.error(f"Error in socket writer: {e}")
134+
raise
121135
finally:
122136
await stream.aclose()
123137

@@ -128,9 +142,21 @@ async def socket_writer():
128142
try:
129143
yield read_stream, write_stream
130144
finally:
131-
await stream.aclose()
145+
# Cancel all tasks and clean up with timeout
132146
tg.cancel_scope.cancel()
133147

148+
# Force cleanup with timeout to prevent hanging
149+
try:
150+
with anyio.fail_after(5.0): # 5 second timeout for cleanup
151+
await stream.aclose()
152+
await read_stream.aclose()
153+
await write_stream.aclose()
154+
await read_stream_writer.aclose()
155+
await write_stream_reader.aclose()
156+
except anyio.get_cancelled_exc_class():
157+
# If cleanup times out, log warning
158+
logger.warning("Server cleanup timed out")
159+
134160
except Exception:
135161
await read_stream.aclose()
136162
await write_stream.aclose()

tests/client/test_socket_transport.py

Lines changed: 201 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from mcp.shared.message import SessionMessage
1818
from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
1919

20-
python: str = shutil.which("python") # type: ignore
20+
python = shutil.which("python") or "python"
2121

2222

2323
@pytest.mark.anyio
@@ -267,3 +267,203 @@ async def test_socket_client_invalid_json():
267267
async for message in read_stream:
268268
assert isinstance(message, Exception)
269269
break
270+
271+
272+
@pytest.mark.anyio
273+
async def test_socket_client_cancellation_handling():
274+
"""Test that socket_client handles cancellation gracefully."""
275+
server_params = SocketServerParameters(
276+
command=python,
277+
args=[
278+
"-c",
279+
"""
280+
import socket, sys, time
281+
s = socket.socket()
282+
s.connect(('127.0.0.1', int(sys.argv[2])))
283+
# Keep connection alive for a bit
284+
time.sleep(2)
285+
s.close()
286+
""",
287+
],
288+
)
289+
290+
# Test that cancellation works properly
291+
with anyio.move_on_after(0.5) as cancel_scope:
292+
async with socket_client(server_params) as (read_stream, write_stream):
293+
# Wait a bit, then the move_on_after should cancel
294+
await anyio.sleep(1)
295+
296+
# The cancellation should have occurred
297+
assert cancel_scope.cancelled_caught
298+
299+
300+
@pytest.mark.anyio
301+
async def test_socket_client_cleanup_timeout():
302+
"""Test that socket_client cleanup has timeout protection."""
303+
server_params = SocketServerParameters(
304+
command=python,
305+
args=[
306+
"-c",
307+
"""
308+
import socket, sys, time
309+
s = socket.socket()
310+
s.connect(('127.0.0.1', int(sys.argv[2])))
311+
312+
# Send a message and keep connection alive briefly
313+
s.send(b'{"jsonrpc": "2.0", "id": 1, "method": "test"}\\n')
314+
time.sleep(1) # Brief delay, then exit normally
315+
s.close()
316+
""",
317+
],
318+
)
319+
320+
# Test that cleanup completes within reasonable time
321+
start_time = anyio.current_time()
322+
323+
async with socket_client(server_params) as (read_stream, write_stream):
324+
# Do some work
325+
await anyio.sleep(0.1)
326+
327+
end_time = anyio.current_time()
328+
329+
# Normal cleanup should complete quickly (within 3 seconds)
330+
# This tests that the cleanup mechanism works without hanging
331+
assert end_time - start_time < 3.0
332+
333+
334+
@pytest.mark.anyio
335+
async def test_socket_client_cleanup_mechanism():
336+
"""Test that socket_client cleanup mechanism is robust."""
337+
server_params = SocketServerParameters(
338+
command=python,
339+
args=[
340+
"-c",
341+
"""
342+
import socket, sys, time
343+
s = socket.socket()
344+
s.connect(('127.0.0.1', int(sys.argv[2])))
345+
346+
# Send a test message
347+
s.send(b'{"jsonrpc": "2.0", "id": 1, "method": "test"}\\n')
348+
349+
# Close after brief delay
350+
time.sleep(0.2)
351+
s.close()
352+
""",
353+
],
354+
)
355+
356+
# Test that cleanup works correctly
357+
async with socket_client(server_params) as (read_stream, write_stream):
358+
# Process at least one message
359+
async for message in read_stream:
360+
if isinstance(message, Exception):
361+
continue
362+
# Exit after first valid message
363+
break
364+
365+
# If we reach here, cleanup worked properly
366+
assert True
367+
368+
369+
@pytest.mark.anyio
370+
async def test_socket_client_reader_writer_exception_handling():
371+
"""Test that socket reader/writer handle exceptions properly."""
372+
server_params = SocketServerParameters(
373+
command=python,
374+
args=[
375+
"-c",
376+
"""
377+
import socket, sys, time
378+
s = socket.socket()
379+
s.connect(('127.0.0.1', int(sys.argv[2])))
380+
381+
# Send some data then close abruptly
382+
s.send(b'{"jsonrpc": "2.0", "id": 1, "method": "test"}\\n')
383+
time.sleep(0.1)
384+
s.close() # Close connection abruptly
385+
""",
386+
],
387+
)
388+
389+
async with socket_client(server_params) as (read_stream, write_stream):
390+
# Should handle the abrupt connection close gracefully
391+
messages_received = 0
392+
async for message in read_stream:
393+
if isinstance(message, Exception):
394+
# Exceptions in the stream are expected
395+
continue
396+
messages_received += 1
397+
if messages_received >= 1:
398+
break
399+
400+
assert messages_received >= 1
401+
402+
403+
@pytest.mark.anyio
404+
async def test_socket_client_process_cleanup():
405+
"""Test that socket_client cleans up processes properly."""
406+
server_params = SocketServerParameters(
407+
command=python,
408+
args=[
409+
"-c",
410+
"""
411+
import socket, sys, time, os
412+
pid = os.getpid()
413+
print(f"Process PID: {pid}", file=sys.stderr)
414+
415+
s = socket.socket()
416+
s.connect(('127.0.0.1', int(sys.argv[2])))
417+
time.sleep(0.5)
418+
s.close()
419+
""",
420+
],
421+
)
422+
423+
async with socket_client(server_params) as (read_stream, write_stream):
424+
# Brief interaction
425+
await anyio.sleep(0.1)
426+
427+
# Process should be cleaned up after context exit
428+
# This is mainly to ensure no zombie processes remain
429+
await anyio.sleep(0.1) # Give cleanup time to complete
430+
431+
432+
@pytest.mark.anyio
433+
async def test_socket_client_multiple_messages_with_cancellation():
434+
"""Test handling multiple messages with cancellation."""
435+
server_params = SocketServerParameters(
436+
command=python,
437+
args=[
438+
"-c",
439+
"""
440+
import socket, sys, time, json
441+
s = socket.socket()
442+
s.connect(('127.0.0.1', int(sys.argv[2])))
443+
444+
# Send multiple messages
445+
for i in range(10):
446+
msg = {"jsonrpc": "2.0", "id": i, "method": "test", "params": {"counter": i}}
447+
s.send((json.dumps(msg) + '\\n').encode())
448+
time.sleep(0.01) # Small delay between messages
449+
450+
s.close()
451+
""",
452+
],
453+
)
454+
455+
messages_received = 0
456+
457+
with anyio.move_on_after(1.0) as cancel_scope:
458+
async with socket_client(server_params) as (read_stream, write_stream):
459+
async for message in read_stream:
460+
if isinstance(message, Exception):
461+
continue
462+
messages_received += 1
463+
if messages_received >= 5:
464+
# Cancel after receiving some messages
465+
cancel_scope.cancel()
466+
467+
# Should have received some messages before cancellation
468+
assert messages_received >= 5
469+
assert cancel_scope.cancelled_caught

0 commit comments

Comments
 (0)