diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 50bceddec8..030bc889a0 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -8,7 +8,6 @@ import anyio import anyio.lowlevel from anyio.abc import Process -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.text import TextReceiveStream from pydantic import BaseModel, Field @@ -107,33 +106,19 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. """ - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + command = _get_executable_command(server.command) - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - try: - command = _get_executable_command(server.command) - - # Open process with stderr piped for capture - process = await _create_platform_compatible_process( - command=command, - args=server.args, - env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()), - errlog=errlog, - cwd=server.cwd, - ) - except OSError: - # Clean up streams if process creation fails - await read_stream.aclose() - await write_stream.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() - raise + # Open process with stderr piped for capture + process = await _create_platform_compatible_process( + command=command, + args=server.args, + env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()), + errlog=errlog, + cwd=server.cwd, + ) async def stdout_reader(): assert process.stdout, "Opened process is missing stdout" @@ -177,12 +162,10 @@ async def stdin_writer(): except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() - async with ( - anyio.create_task_group() as tg, - process, - ): + async with anyio.create_task_group() as tg, process: tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) + try: yield read_stream, write_stream finally: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2f49fc8bc..330f8cdd0d 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -252,12 +252,7 @@ async def send_request( self._progress_callbacks[request_id] = progress_callback try: - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request_data, - ) - + jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) # request read timeout takes precedence over session read timeout @@ -329,10 +324,7 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er await self._write_stream.send(session_message) async def _receive_loop(self) -> None: - async with ( - self._read_stream, - self._write_stream, - ): + async with self._read_stream, self._write_stream: try: async for message in self._read_stream: if isinstance(message, Exception): @@ -418,10 +410,10 @@ async def _receive_loop(self) -> None: # Without this handler, the exception would propagate up and # crash the server's task group. logging.debug("Read stream closed by client") - except Exception as e: + except Exception: # Other exceptions are not expected and should be logged. We purposefully # catch all exceptions here to avoid crashing the server. - logging.exception(f"Unhandled exception in receive loop: {e}") + logging.exception("Unhandled exception in receive loop") finally: # after the read stream is closed, we need to send errors # to any pending requests diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 2abb42e5cd..a424cbc51e 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -118,7 +118,7 @@ async def test_stdio_client_universal_cleanup(): """ import time import sys - + # Simulate a long-running process for i in range(100): time.sleep(0.1) @@ -532,7 +532,7 @@ async def test_stdio_client_graceful_stdin_exit(): script_content = textwrap.dedent( """ import sys - + # Read from stdin until it's closed try: while True: @@ -541,7 +541,7 @@ async def test_stdio_client_graceful_stdin_exit(): break except: pass - + # Exit gracefully sys.exit(0) """ @@ -590,16 +590,16 @@ async def test_stdio_client_stdin_close_ignored(): import signal import sys import time - + # Set up SIGTERM handler to exit cleanly def sigterm_handler(signum, frame): sys.exit(0) - + signal.signal(signal.SIGTERM, sigterm_handler) - + # Close stdin immediately to simulate ignoring it sys.stdin.close() - + # Keep running until SIGTERM while True: time.sleep(0.1)