Skip to content

Commit c17e2cd

Browse files
committed
refactor: extract helper functions to reduce stdio_client complexity
1 parent 850fba2 commit c17e2cd

File tree

1 file changed

+101
-81
lines changed

1 file changed

+101
-81
lines changed

src/mcp/client/stdio/__init__.py

Lines changed: 101 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,104 @@ class StdioServerParameters(BaseModel):
144144
"""
145145

146146

147+
async def _stdout_reader(
148+
process: Process | FallbackProcess,
149+
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception],
150+
encoding: str,
151+
encoding_error_handler: str,
152+
):
153+
"""Read stdout from the process and parse JSONRPC messages."""
154+
assert process.stdout, "Opened process is missing stdout"
155+
156+
try:
157+
async with read_stream_writer:
158+
buffer = ""
159+
async for chunk in TextReceiveStream(
160+
process.stdout,
161+
encoding=encoding,
162+
errors=encoding_error_handler,
163+
):
164+
lines = (buffer + chunk).split("\n")
165+
buffer = lines.pop()
166+
167+
for line in lines:
168+
try:
169+
message = types.JSONRPCMessage.model_validate_json(line)
170+
except Exception as exc: # pragma: no cover
171+
logger.exception("Failed to parse JSONRPC message from server")
172+
await read_stream_writer.send(exc)
173+
continue
174+
175+
session_message = SessionMessage(message)
176+
await read_stream_writer.send(session_message)
177+
except anyio.ClosedResourceError: # pragma: no cover
178+
await anyio.lowlevel.checkpoint()
179+
180+
181+
async def _stdin_writer(
182+
process: Process | FallbackProcess,
183+
write_stream_reader: MemoryObjectReceiveStream[SessionMessage],
184+
encoding: str,
185+
encoding_error_handler: str,
186+
):
187+
"""Write session messages to the process stdin."""
188+
assert process.stdin, "Opened process is missing stdin"
189+
190+
try:
191+
async with write_stream_reader:
192+
async for session_message in write_stream_reader:
193+
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
194+
await process.stdin.send(
195+
(json + "\n").encode(
196+
encoding=encoding,
197+
errors=encoding_error_handler,
198+
)
199+
)
200+
except anyio.ClosedResourceError: # pragma: no cover
201+
await anyio.lowlevel.checkpoint()
202+
203+
204+
async def _stderr_reader(
205+
process: Process | FallbackProcess,
206+
errlog: TextIO,
207+
encoding: str,
208+
encoding_error_handler: str,
209+
):
210+
"""Read stderr from the process and display it appropriately."""
211+
if not process.stderr:
212+
return
213+
214+
try:
215+
buffer = ""
216+
async for chunk in TextReceiveStream(
217+
process.stderr,
218+
encoding=encoding,
219+
errors=encoding_error_handler,
220+
):
221+
lines = (buffer + chunk).split("\n")
222+
buffer = lines.pop()
223+
224+
for line in lines:
225+
if line.strip(): # Only print non-empty lines
226+
try:
227+
_print_stderr(line, errlog)
228+
except Exception:
229+
# Log errors but continue (non-critical)
230+
logger.debug("Failed to print stderr line", exc_info=True)
231+
232+
# Print any remaining buffer content
233+
if buffer.strip():
234+
try:
235+
_print_stderr(buffer, errlog)
236+
except Exception:
237+
logger.debug("Failed to print final stderr buffer", exc_info=True)
238+
except anyio.ClosedResourceError: # pragma: no cover
239+
await anyio.lowlevel.checkpoint()
240+
except Exception:
241+
# Log errors but continue (non-critical)
242+
logger.debug("Error reading stderr", exc_info=True)
243+
244+
147245
@asynccontextmanager
148246
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
149247
"""
@@ -190,92 +288,14 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
190288
await write_stream_reader.aclose()
191289
raise
192290

193-
async def stdout_reader():
194-
assert process.stdout, "Opened process is missing stdout"
195-
196-
try:
197-
async with read_stream_writer:
198-
buffer = ""
199-
async for chunk in TextReceiveStream(
200-
process.stdout,
201-
encoding=server.encoding,
202-
errors=server.encoding_error_handler,
203-
):
204-
lines = (buffer + chunk).split("\n")
205-
buffer = lines.pop()
206-
207-
for line in lines:
208-
try:
209-
message = types.JSONRPCMessage.model_validate_json(line)
210-
except Exception as exc: # pragma: no cover
211-
logger.exception("Failed to parse JSONRPC message from server")
212-
await read_stream_writer.send(exc)
213-
continue
214-
215-
session_message = SessionMessage(message)
216-
await read_stream_writer.send(session_message)
217-
except anyio.ClosedResourceError: # pragma: no cover
218-
await anyio.lowlevel.checkpoint()
219-
220-
async def stdin_writer():
221-
assert process.stdin, "Opened process is missing stdin"
222-
223-
try:
224-
async with write_stream_reader:
225-
async for session_message in write_stream_reader:
226-
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
227-
await process.stdin.send(
228-
(json + "\n").encode(
229-
encoding=server.encoding,
230-
errors=server.encoding_error_handler,
231-
)
232-
)
233-
except anyio.ClosedResourceError: # pragma: no cover
234-
await anyio.lowlevel.checkpoint()
235-
236-
async def stderr_reader():
237-
"""Read stderr from the process and display it appropriately."""
238-
if not process.stderr:
239-
return
240-
241-
try:
242-
buffer = ""
243-
async for chunk in TextReceiveStream(
244-
process.stderr,
245-
encoding=server.encoding,
246-
errors=server.encoding_error_handler,
247-
):
248-
lines = (buffer + chunk).split("\n")
249-
buffer = lines.pop()
250-
251-
for line in lines:
252-
if line.strip(): # Only print non-empty lines
253-
try:
254-
_print_stderr(line, errlog)
255-
except Exception:
256-
# Log errors but continue (non-critical)
257-
logger.debug("Failed to print stderr line", exc_info=True)
258-
259-
# Print any remaining buffer content
260-
if buffer.strip():
261-
try:
262-
_print_stderr(buffer, errlog)
263-
except Exception:
264-
logger.debug("Failed to print final stderr buffer", exc_info=True)
265-
except anyio.ClosedResourceError: # pragma: no cover
266-
await anyio.lowlevel.checkpoint()
267-
except Exception:
268-
# Log errors but continue (non-critical)
269-
logger.debug("Error reading stderr", exc_info=True)
270-
271291
async with (
272292
anyio.create_task_group() as tg,
273293
process,
274294
):
275-
tg.start_soon(stdout_reader)
276-
tg.start_soon(stdin_writer)
295+
tg.start_soon(_stdout_reader, process, read_stream_writer, server.encoding, server.encoding_error_handler)
296+
tg.start_soon(_stdin_writer, process, write_stream_reader, server.encoding, server.encoding_error_handler)
277297
if process.stderr:
278-
tg.start_soon(stderr_reader)
298+
tg.start_soon(_stderr_reader, process, errlog, server.encoding, server.encoding_error_handler)
279299
try:
280300
yield read_stream, write_stream
281301
finally:

0 commit comments

Comments
 (0)