Skip to content

Commit dbe7abf

Browse files
committed
Simplify code on stdio_client
1 parent c260e29 commit dbe7abf

File tree

2 files changed

+20
-48
lines changed

2 files changed

+20
-48
lines changed

src/mcp/client/stdio/__init__.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import logging
22
import os
33
import sys
4-
from contextlib import asynccontextmanager
4+
from contextlib import AsyncExitStack, asynccontextmanager
55
from pathlib import Path
66
from typing import Literal, TextIO
77

88
import anyio
99
import anyio.lowlevel
1010
from anyio.abc import Process
11-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1211
from anyio.streams.text import TextReceiveStream
1312
from pydantic import BaseModel, Field
1413

@@ -107,33 +106,19 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
107106
Client transport for stdio: this will connect to a server by spawning a
108107
process and communicating with it over stdin/stdout.
109108
"""
110-
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
111-
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
109+
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
110+
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
112111

113-
write_stream: MemoryObjectSendStream[SessionMessage]
114-
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
112+
command = _get_executable_command(server.command)
115113

116-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
117-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
118-
119-
try:
120-
command = _get_executable_command(server.command)
121-
122-
# Open process with stderr piped for capture
123-
process = await _create_platform_compatible_process(
124-
command=command,
125-
args=server.args,
126-
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
127-
errlog=errlog,
128-
cwd=server.cwd,
129-
)
130-
except OSError:
131-
# Clean up streams if process creation fails
132-
await read_stream.aclose()
133-
await write_stream.aclose()
134-
await read_stream_writer.aclose()
135-
await write_stream_reader.aclose()
136-
raise
114+
# Open process with stderr piped for capture
115+
process = await _create_platform_compatible_process(
116+
command=command,
117+
args=server.args,
118+
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
119+
errlog=errlog,
120+
cwd=server.cwd,
121+
)
137122

138123
async def stdout_reader():
139124
assert process.stdout, "Opened process is missing stdout"
@@ -177,14 +162,13 @@ async def stdin_writer():
177162
except anyio.ClosedResourceError:
178163
await anyio.lowlevel.checkpoint()
179164

180-
async with (
181-
anyio.create_task_group() as tg,
182-
process,
183-
):
165+
async with anyio.create_task_group() as tg, process:
184166
tg.start_soon(stdout_reader)
185167
tg.start_soon(stdin_writer)
168+
186169
try:
187-
yield read_stream, write_stream
170+
async with read_stream, write_stream:
171+
yield read_stream, write_stream
188172
finally:
189173
# MCP spec: stdio shutdown sequence
190174
# 1. Close input stream to server
@@ -208,10 +192,6 @@ async def stdin_writer():
208192
except ProcessLookupError:
209193
# Process already exited, which is fine
210194
pass
211-
await read_stream.aclose()
212-
await write_stream.aclose()
213-
await read_stream_writer.aclose()
214-
await write_stream_reader.aclose()
215195

216196

217197
def _get_executable_command(command: str) -> str:

src/mcp/shared/session.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,7 @@ async def send_request(
252252
self._progress_callbacks[request_id] = progress_callback
253253

254254
try:
255-
jsonrpc_request = JSONRPCRequest(
256-
jsonrpc="2.0",
257-
id=request_id,
258-
**request_data,
259-
)
260-
255+
jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)
261256
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
262257

263258
# request read timeout takes precedence over session read timeout
@@ -329,10 +324,7 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er
329324
await self._write_stream.send(session_message)
330325

331326
async def _receive_loop(self) -> None:
332-
async with (
333-
self._read_stream,
334-
self._write_stream,
335-
):
327+
async with self._read_stream, self._write_stream:
336328
try:
337329
async for message in self._read_stream:
338330
if isinstance(message, Exception):
@@ -418,10 +410,10 @@ async def _receive_loop(self) -> None:
418410
# Without this handler, the exception would propagate up and
419411
# crash the server's task group.
420412
logging.debug("Read stream closed by client")
421-
except Exception as e:
413+
except Exception:
422414
# Other exceptions are not expected and should be logged. We purposefully
423415
# catch all exceptions here to avoid crashing the server.
424-
logging.exception(f"Unhandled exception in receive loop: {e}")
416+
logging.exception("Unhandled exception in receive loop")
425417
finally:
426418
# after the read stream is closed, we need to send errors
427419
# to any pending requests

0 commit comments

Comments
 (0)