Skip to content

Commit 0357258

Browse files
committed
feat: add MCP proxy pattern convenience function
Implements mcp_proxy() function in mcp.shared.proxy module that enables bidirectional message forwarding between two MCP transports. Features: - Bidirectional message forwarding using anyio task groups - Error handling with optional sync/async callback support - Automatic cleanup when one transport closes - Proper handling of SessionMessage and Exception objects - Comprehensive test coverage Closes #12
1 parent c92bb2f commit 0357258

File tree

2 files changed

+549
-0
lines changed

2 files changed

+549
-0
lines changed

src/mcp/shared/proxy.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""
2+
MCP Proxy Module
3+
4+
This module provides utilities for proxying messages between two MCP transports,
5+
enabling bidirectional message forwarding with proper error handling and cleanup.
6+
"""
7+
8+
import logging
9+
from collections.abc import Awaitable, Callable
10+
from contextlib import asynccontextmanager
11+
from typing import AsyncGenerator
12+
13+
import anyio
14+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
15+
16+
from mcp.shared.message import SessionMessage
17+
18+
logger = logging.getLogger(__name__)
19+
20+
MessageStream = tuple[
21+
MemoryObjectReceiveStream[SessionMessage | Exception],
22+
MemoryObjectSendStream[SessionMessage],
23+
]
24+
25+
26+
@asynccontextmanager
27+
async def mcp_proxy(
28+
transport_to_client: MessageStream,
29+
transport_to_server: MessageStream,
30+
onerror: Callable[[Exception], None | Awaitable[None]] | None = None,
31+
) -> AsyncGenerator[None, None]:
32+
"""
33+
Proxy messages bidirectionally between two MCP transports.
34+
35+
This function sets up bidirectional message forwarding between two transport pairs.
36+
When one transport closes, the other is also closed. Errors are forwarded to the
37+
error callback if provided.
38+
39+
Args:
40+
transport_to_client: A tuple of (read_stream, write_stream) for the client-facing transport.
41+
transport_to_server: A tuple of (read_stream, write_stream) for the server-facing transport.
42+
onerror: Optional callback function for handling errors. Can be sync or async.
43+
Called with the Exception object when an error occurs.
44+
45+
Example:
46+
```python
47+
async with mcp_proxy(
48+
transport_to_client=(client_read, client_write),
49+
transport_to_server=(server_read, server_write),
50+
onerror=lambda e: logger.error(f"Proxy error: {e}"),
51+
):
52+
# Proxy is active, forwarding messages bidirectionally
53+
await some_operation()
54+
# Both transports are closed when exiting the context
55+
```
56+
57+
Yields:
58+
None: The context manager yields control while the proxy is active.
59+
"""
60+
client_read, client_write = transport_to_client
61+
server_read, server_write = transport_to_server
62+
63+
async def forward_to_server():
64+
"""Forward messages from client to server."""
65+
try:
66+
async with client_read:
67+
async for message in client_read:
68+
try:
69+
# Forward SessionMessage objects directly
70+
if isinstance(message, SessionMessage):
71+
await server_write.send(message)
72+
# Handle Exception objects via error callback
73+
elif isinstance(message, Exception):
74+
logger.debug(f"Exception received from client: {message}")
75+
if onerror:
76+
try:
77+
result = onerror(message)
78+
if isinstance(result, Awaitable):
79+
await result
80+
except Exception as callback_error: # pragma: no cover
81+
logger.exception("Error in onerror callback", exc_info=callback_error)
82+
# Exceptions are not forwarded as messages (write streams only accept SessionMessage)
83+
except anyio.ClosedResourceError:
84+
logger.debug("Server write stream closed while forwarding from client")
85+
break
86+
except Exception as exc: # pragma: no cover
87+
logger.exception("Error forwarding message from client to server", exc_info=exc)
88+
if onerror:
89+
try:
90+
result = onerror(exc)
91+
if isinstance(result, Awaitable):
92+
await result
93+
except Exception as callback_error: # pragma: no cover
94+
logger.exception("Error in onerror callback", exc_info=callback_error)
95+
except anyio.ClosedResourceError:
96+
logger.debug("Client read stream closed")
97+
except Exception as exc: # pragma: no cover
98+
logger.exception("Error in forward_to_server task", exc_info=exc)
99+
if onerror:
100+
try:
101+
result = onerror(exc)
102+
if isinstance(result, Awaitable):
103+
await result
104+
except Exception as callback_error: # pragma: no cover
105+
logger.exception("Error in onerror callback", exc_info=callback_error)
106+
finally:
107+
# Close server write stream when client read closes
108+
try:
109+
await server_write.aclose()
110+
except Exception: # pragma: no cover
111+
# Stream might already be closed
112+
pass
113+
114+
async def forward_to_client():
115+
"""Forward messages from server to client."""
116+
try:
117+
async with server_read:
118+
async for message in server_read:
119+
try:
120+
# Forward SessionMessage objects directly
121+
if isinstance(message, SessionMessage):
122+
await client_write.send(message)
123+
# Handle Exception objects via error callback
124+
elif isinstance(message, Exception):
125+
logger.debug(f"Exception received from server: {message}")
126+
if onerror:
127+
try:
128+
result = onerror(message)
129+
if isinstance(result, Awaitable):
130+
await result
131+
except Exception as callback_error: # pragma: no cover
132+
logger.exception("Error in onerror callback", exc_info=callback_error)
133+
# Exceptions are not forwarded as messages (write streams only accept SessionMessage)
134+
except anyio.ClosedResourceError:
135+
logger.debug("Client write stream closed while forwarding from server")
136+
break
137+
except Exception as exc: # pragma: no cover
138+
logger.exception("Error forwarding message from server to client", exc_info=exc)
139+
if onerror:
140+
try:
141+
result = onerror(exc)
142+
if isinstance(result, Awaitable):
143+
await result
144+
except Exception as callback_error: # pragma: no cover
145+
logger.exception("Error in onerror callback", exc_info=callback_error)
146+
except anyio.ClosedResourceError:
147+
logger.debug("Server read stream closed")
148+
except Exception as exc: # pragma: no cover
149+
logger.exception("Error in forward_to_client task", exc_info=exc)
150+
if onerror:
151+
try:
152+
result = onerror(exc)
153+
if isinstance(result, Awaitable):
154+
await result
155+
except Exception as callback_error: # pragma: no cover
156+
logger.exception("Error in onerror callback", exc_info=callback_error)
157+
finally:
158+
# Close client write stream when server read closes
159+
try:
160+
await client_write.aclose()
161+
except Exception: # pragma: no cover
162+
# Stream might already be closed
163+
pass
164+
165+
async with anyio.create_task_group() as tg:
166+
tg.start_soon(forward_to_server)
167+
tg.start_soon(forward_to_client)
168+
try:
169+
yield
170+
finally:
171+
# Cancel the task group to stop forwarding
172+
tg.cancel_scope.cancel()
173+
# Close both write streams
174+
try:
175+
await client_write.aclose()
176+
except Exception: # pragma: no cover
177+
pass
178+
try:
179+
await server_write.aclose()
180+
except Exception: # pragma: no cover
181+
pass

0 commit comments

Comments
 (0)