Skip to content

Commit bca59f2

Browse files
committed
feat: add client connection manager to manage multiple sessions without requiring async with
1 parent 7bc190b commit bca59f2

File tree

5 files changed

+571
-0
lines changed

5 files changed

+571
-0
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import asyncio
2+
3+
from mcp.client.client_connection_manager import ClientConnectionManager, StreamalbeHttpClientParams
4+
5+
6+
async def main():
7+
s1_name = "s1_name"
8+
s2_name = "s2_name"
9+
s1 = StreamalbeHttpClientParams(name=s1_name, url="http://localhost:8910/mcp/")
10+
s2 = StreamalbeHttpClientParams(name=s2_name, url="http://localhost:8910/mcp/")
11+
12+
m = ClientConnectionManager()
13+
14+
await m.connect(s1)
15+
await m.connect(s2)
16+
17+
print("---session initialize---")
18+
19+
await m.session_initialize(s1_name)
20+
await m.session_initialize(s2_name)
21+
await asyncio.sleep(1)
22+
23+
print("---session list tools---")
24+
res = await m.session_list_tools(s1_name)
25+
26+
await asyncio.sleep(1)
27+
print("---session call tool---")
28+
res = await m.session_call_tool(s1_name, "create_user")
29+
print(res)
30+
await asyncio.sleep(3)
31+
print("---session disconnect---")
32+
await m.disconnect(s1_name)
33+
# await m.cleanup(s2_name)
34+
35+
36+
if __name__ == "__main__":
37+
asyncio.run(main())
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
import asyncio
2+
import logging
3+
from collections.abc import Coroutine
4+
from contextlib import asynccontextmanager
5+
from datetime import timedelta
6+
from typing import Any, TypeVar
7+
8+
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
9+
10+
import mcp
11+
from mcp import types
12+
from mcp.client.exceptions import ConnectTimeOut
13+
from mcp.client.streamable_http import streamablehttp_client
14+
from mcp.shared.exceptions import McpError
15+
from mcp.shared.session import ProgressFnT
16+
from mcp.types import StreamalbeHttpClientParams
17+
18+
logger = logging.getLogger(__name__)
19+
20+
R = TypeVar("R")
21+
22+
23+
class ClientSessionState(BaseModel):
24+
session: mcp.ClientSession | None = None
25+
lifespan_task: asyncio.Task[Any] | None = None
26+
running_event: asyncio.Event = Field(default_factory=asyncio.Event)
27+
error: Exception | None = None
28+
request_task: dict[str, asyncio.Task[Any]] = Field(default_factory=dict)
29+
model_config = ConfigDict(arbitrary_types_allowed=True)
30+
31+
@property
32+
def lifespan(self) -> asyncio.Task[Any]:
33+
if self.lifespan_task is None:
34+
raise RuntimeError("lifespan_task is not set")
35+
return self.lifespan_task
36+
37+
@property
38+
def active_session(self) -> mcp.ClientSession:
39+
if self.session is None:
40+
raise RuntimeError("session is not set")
41+
return self.session
42+
43+
44+
class ClientConnectionManager:
45+
def __init__(
46+
self,
47+
):
48+
self._session: dict[str, ClientSessionState] = {}
49+
50+
async def connect(self, parameter: StreamalbeHttpClientParams):
51+
logger.info(f"Attempting to connect to MCP server: {parameter.name} ({parameter.url})")
52+
state = ClientSessionState()
53+
if not self._is_session_exists(parameter.name):
54+
self._session[parameter.name] = state
55+
logger.debug(f"Session state created for: {parameter.name}")
56+
else:
57+
raise McpError(
58+
types.ErrorData(
59+
code=types.CONNECTION_CLOSED,
60+
message=f"Session with name '{parameter.name}' already exists. \
61+
Duplicate connections are not allowed.",
62+
)
63+
)
64+
ready_future = asyncio.get_running_loop().create_future()
65+
66+
task = asyncio.create_task(self._maintain_session(parameter, ready_future))
67+
state.lifespan_task = task
68+
69+
try:
70+
await asyncio.wait_for(ready_future, timeout=5)
71+
except asyncio.TimeoutError:
72+
task.cancel()
73+
try:
74+
await task # 等待 task 真正結束或取消
75+
except asyncio.CancelledError:
76+
pass
77+
state.error = ConnectTimeOut(f"Connection to {parameter.name} timed out")
78+
raise state.error
79+
except Exception as e:
80+
task.cancel()
81+
state.error = e
82+
raise e
83+
84+
async def _maintain_session(self, parameter: StreamalbeHttpClientParams, connect_res: asyncio.Future[Any]):
85+
try:
86+
async with self._session_context(parameter):
87+
if not connect_res.done():
88+
connect_res.set_result(True)
89+
90+
logger.debug(f"Session maintenance started for: {parameter.name}. Waiting for shutdown event...")
91+
await self._session[parameter.name].running_event.wait()
92+
logger.info(f"Graceful shutdown initiated for session: {parameter.name}")
93+
94+
except Exception as e:
95+
if not connect_res.done():
96+
connect_res.set_exception(e)
97+
self._session[parameter.name].running_event.set()
98+
self._session[parameter.name].error = e
99+
raise e
100+
101+
@asynccontextmanager
102+
async def _session_context(self, parameter: StreamalbeHttpClientParams):
103+
try:
104+
async with streamablehttp_client(parameter.url) as streams:
105+
read_stream, write_stream, _ = streams
106+
async with mcp.ClientSession(read_stream, write_stream) as session:
107+
state = self._session[parameter.name]
108+
state.session = session
109+
110+
logger.info(f"Connected to MCP server: {parameter.name} ({parameter.url})")
111+
yield
112+
logger.info(f"MCP server {parameter.name} ({parameter.url}): disconnected")
113+
114+
except Exception as e:
115+
raise e
116+
117+
def _is_session_exists(self, session_name: str) -> bool:
118+
if session_name in self._session:
119+
return True
120+
return False
121+
122+
def _validate_session(self, session_name: str) -> ClientSessionState:
123+
if self._is_session_exists(session_name):
124+
state = self._session[session_name]
125+
if state.error:
126+
raise McpError(
127+
types.ErrorData(
128+
code=types.CONNECTION_CLOSED,
129+
message=f"Session with name '{session_name}' has error. {state.error}",
130+
)
131+
)
132+
return state
133+
else:
134+
raise McpError(
135+
types.ErrorData(
136+
code=types.CONNECTION_CLOSED,
137+
message=f"Session with name '{session_name}' does not exist. Please establish a connection first.",
138+
)
139+
)
140+
141+
async def _safe_run_task(self, session_name: str, task_cor: Coroutine[Any, Any, R]) -> R:
142+
actived_task = asyncio.create_task(task_cor)
143+
144+
async def monitor():
145+
await asyncio.sleep(0.1)
146+
while not actived_task.done():
147+
if self._session[session_name].error is not None:
148+
actived_task.cancel()
149+
break
150+
151+
await asyncio.sleep(2)
152+
153+
asyncio.create_task(monitor())
154+
try:
155+
res = await actived_task
156+
except asyncio.exceptions.CancelledError as err:
157+
session_err = self._session[session_name].error
158+
if session_err is not None:
159+
raise session_err
160+
raise err
161+
# except Exception as err:
162+
# raise err
163+
return res
164+
165+
async def session_initialize(self, session_name: str) -> types.InitializeResult:
166+
session_state = self._validate_session(session_name)
167+
168+
try:
169+
res = await self._safe_run_task(session_name, session_state.active_session.initialize())
170+
171+
except Exception as e:
172+
raise e
173+
174+
return res
175+
176+
async def session_send_pings(self, session_name: str) -> types.EmptyResult:
177+
session_state = self._validate_session(session_name)
178+
return await self._safe_run_task(session_name, session_state.active_session.send_ping())
179+
180+
async def session_send_progress_notification(
181+
self,
182+
session_name: str,
183+
progress_token: str | int,
184+
progress: float,
185+
total: float | None = None,
186+
message: str | None = None,
187+
) -> None:
188+
session_state = self._validate_session(session_name)
189+
return await self._safe_run_task(
190+
session_name,
191+
session_state.active_session.send_progress_notification(progress_token, progress, total, message),
192+
)
193+
194+
async def session_set_logging_level(self, session_name: str, level: types.LoggingLevel) -> types.EmptyResult:
195+
session_state = self._validate_session(session_name)
196+
return await self._safe_run_task(session_name, session_state.active_session.set_logging_level(level))
197+
198+
async def session_list_resources(self, session_name: str, cursor: str | None = None) -> types.ListResourcesResult:
199+
session_state = self._validate_session(session_name)
200+
return await self._safe_run_task(
201+
session_name,
202+
session_state.active_session.list_resources(cursor),
203+
)
204+
205+
async def session_list_resource_templates(
206+
self, session_name: str, cursor: str | None = None
207+
) -> types.ListResourceTemplatesResult:
208+
session_state = self._validate_session(session_name)
209+
return await self._safe_run_task(
210+
session_name,
211+
session_state.active_session.list_resource_templates(cursor),
212+
)
213+
214+
async def session_read_resource(self, session_name: str, uri: AnyUrl) -> types.ReadResourceResult:
215+
session_state = self._validate_session(session_name)
216+
return await self._safe_run_task(
217+
session_name,
218+
session_state.active_session.read_resource(uri),
219+
)
220+
221+
async def session_subscribe_resource(self, session_name: str, uri: AnyUrl) -> types.EmptyResult:
222+
session_state = self._validate_session(session_name)
223+
return await self._safe_run_task(
224+
session_name,
225+
session_state.active_session.subscribe_resource(uri),
226+
)
227+
228+
async def session_unsubscribe_resource(self, session_name: str, uri: AnyUrl) -> types.EmptyResult:
229+
session_state = self._validate_session(session_name)
230+
return await self._safe_run_task(
231+
session_name,
232+
session_state.active_session.unsubscribe_resource(uri),
233+
)
234+
235+
async def session_call_tool(
236+
self,
237+
session_name: str,
238+
name: str,
239+
arguments: dict[str, Any] | None = None,
240+
read_timeout_seconds: timedelta | None = None,
241+
progress_callback: ProgressFnT | None = None,
242+
) -> types.CallToolResult:
243+
session_state = self._validate_session(session_name)
244+
return await self._safe_run_task(
245+
session_name,
246+
session_state.active_session.call_tool(name, arguments, read_timeout_seconds, progress_callback),
247+
)
248+
249+
async def session_list_prompts(self, session_name: str, cursor: str | None = None) -> types.ListPromptsResult:
250+
session_state = self._validate_session(session_name)
251+
return await self._safe_run_task(
252+
session_name,
253+
session_state.active_session.list_prompts(cursor),
254+
)
255+
256+
async def session_get_prompt(
257+
self, session_name: str, name: str, arguments: dict[str, str] | None = None
258+
) -> types.GetPromptResult:
259+
session_state = self._validate_session(session_name)
260+
return await self._safe_run_task(
261+
session_name,
262+
session_state.active_session.get_prompt(name, arguments),
263+
)
264+
265+
async def session_list_tools(self, session_name: str, cursor: str | None = None) -> types.ListToolsResult:
266+
session_state = self._validate_session(session_name)
267+
268+
return await self._safe_run_task(session_name, session_state.active_session.list_tools(cursor))
269+
270+
async def session_send_roots_list_changed(self, session_name: str) -> None:
271+
session_state = self._validate_session(session_name)
272+
273+
return await self._safe_run_task(session_name, session_state.active_session.send_roots_list_changed())
274+
275+
async def disconnect(self, name: str) -> None:
276+
session = self._session[name]
277+
if not session.session:
278+
return
279+
280+
if session.lifespan_task and not session.lifespan_task.done():
281+
session.running_event.set()
282+
283+
try:
284+
await session.lifespan
285+
except Exception as e:
286+
raise McpError(
287+
types.ErrorData(
288+
code=types.CONNECTION_CLOSED,
289+
message=f"MCP server {name} disconnect failed {e}",
290+
)
291+
)
292+
finally:
293+
session.session = None
294+
session.lifespan_task = None

src/mcp/client/exceptions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class ConnectTimeOut(Exception):
2+
"""Failed to connect: timeout"""

src/mcp/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable
2+
from datetime import timedelta
23
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
34

45
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
@@ -1310,3 +1311,11 @@ class ServerResult(
13101311
]
13111312
):
13121313
pass
1314+
1315+
1316+
class StreamalbeHttpClientParams(BaseModel):
1317+
name: str
1318+
url: str
1319+
headers: dict[str, Any] | None = None
1320+
timeout: timedelta = timedelta(seconds=30)
1321+
terminate_on_close: bool = True

0 commit comments

Comments
 (0)