Skip to content

Commit 9232cca

Browse files
feat: add experimental task-based session manager
Create a wrapper which wraps client creation and termination with a single task. This prevents AnyIO's CancelScope error ('Attempted to exit cancel scope in a different task than it was entered in') since creation and deletion happens in a single task.
1 parent 71b3289 commit 9232cca

File tree

3 files changed

+485
-0
lines changed

3 files changed

+485
-0
lines changed

src/google/adk/tools/mcp_tool/experimental/__init__.py

Whitespace-only changes.
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Lifecycle-based MCP session manager for AnyIO CancelScope compatibility.
16+
17+
This module provides a MCPSessionManager that ensures MCP client
18+
lifecycle operations (creation and cleanup) occur within the same asyncio task.
19+
This is required because MCP clients use AnyIO internally, and AnyIO's
20+
TaskGroup/CancelScope requires that the start and end of a scope occur within
21+
the same task.
22+
23+
The manager uses SessionLifecycle to manage each session's lifecycle within
24+
a dedicated task. SessionLifecycle spawns a background task that handles the
25+
entire lifecycle (creation, usage, and cleanup) within a single task, ensuring
26+
CancelScope constraints are satisfied regardless of which task calls
27+
create_session() or close().
28+
29+
Use this manager instead of the standard MCPSessionManager when working with
30+
StreamableHTTPConnectionParams to avoid CancelScope constraint violations.
31+
"""
32+
33+
from __future__ import annotations
34+
35+
import asyncio
36+
from datetime import timedelta
37+
import hashlib
38+
import json
39+
import logging
40+
import sys
41+
from typing import Any
42+
from typing import Dict
43+
from typing import Optional
44+
from typing import TextIO
45+
from typing import Union
46+
47+
from mcp import ClientSession
48+
from mcp import StdioServerParameters
49+
from mcp.client.sse import sse_client
50+
from mcp.client.stdio import stdio_client
51+
from mcp.client.streamable_http import streamablehttp_client
52+
53+
from ..mcp_session_manager import SseConnectionParams
54+
from ..mcp_session_manager import StdioConnectionParams
55+
from ..mcp_session_manager import StreamableHTTPConnectionParams
56+
from .session_lifecycle import SessionLifecycle
57+
58+
logger = logging.getLogger('google_adk.' + __name__)
59+
60+
61+
class MCPSessionManager:
62+
"""Lifecycle-based MCP session manager for AnyIO CancelScope compatibility.
63+
64+
This class provides the same functionality as the standard MCPSessionManager
65+
but ensures that MCP client lifecycle operations (creation and cleanup)
66+
occur within the same asyncio task. This is required because MCP clients use
67+
AnyIO internally, and AnyIO's TaskGroup/CancelScope requires that the start
68+
and end of a scope occur within the same task.
69+
70+
The session lifecycle is managed by SessionLifecycle, which spawns a
71+
dedicated background task for each session. This background task:
72+
1. Enters the MCP client's async context and initializes the session
73+
2. Signals readiness via an asyncio.Event
74+
3. Waits for a close signal
75+
4. Cleans up the client within the same task
76+
77+
This ensures CancelScope constraints are satisfied regardless of which
78+
task calls create_session() or close().
79+
"""
80+
81+
def __init__(
82+
self,
83+
connection_params: Union[
84+
StdioServerParameters,
85+
StdioConnectionParams,
86+
SseConnectionParams,
87+
StreamableHTTPConnectionParams,
88+
],
89+
errlog: TextIO = sys.stderr,
90+
):
91+
"""Initializes the lifecycle-based MCP session manager.
92+
93+
Args:
94+
connection_params: Parameters for the MCP connection (Stdio, SSE or
95+
Streamable HTTP). Stdio by default also has a 5s read timeout as other
96+
parameters but it's not configurable for now.
97+
errlog: (Optional) TextIO stream for error logging. Use only for
98+
initializing a local stdio MCP session.
99+
"""
100+
if isinstance(connection_params, StdioServerParameters):
101+
# So far timeout is not configurable. Given MCP is still evolving, we
102+
# would expect stdio_client to evolve to accept timeout parameter like
103+
# other client.
104+
logger.warning(
105+
'StdioServerParameters is not recommended. Please use'
106+
' StdioConnectionParams.'
107+
)
108+
self._connection_params = StdioConnectionParams(
109+
server_params=connection_params,
110+
timeout=5,
111+
)
112+
else:
113+
self._connection_params = connection_params
114+
self._errlog = errlog
115+
116+
# Session pool: maps session keys to lifecycle managers
117+
self._sessions: Dict[str, SessionLifecycle] = {}
118+
119+
# Lock to prevent race conditions in session creation
120+
self._session_lock = asyncio.Lock()
121+
122+
def _generate_session_key(
123+
self, merged_headers: Optional[Dict[str, str]] = None
124+
) -> str:
125+
"""Generates a session key based on connection params and merged headers.
126+
127+
For StdioConnectionParams, returns a constant key since headers are not
128+
supported. For SSE and StreamableHTTP connections, generates a key based
129+
on the provided merged headers.
130+
131+
Args:
132+
merged_headers: Already merged headers (base + additional).
133+
134+
Returns:
135+
A unique session key string.
136+
"""
137+
if isinstance(self._connection_params, StdioConnectionParams):
138+
return 'stdio_session'
139+
140+
if merged_headers:
141+
headers_json = json.dumps(merged_headers, sort_keys=True)
142+
headers_hash = hashlib.md5(headers_json.encode()).hexdigest()
143+
return f'session_{headers_hash}'
144+
else:
145+
return 'session_no_headers'
146+
147+
def _merge_headers(
148+
self, additional_headers: Optional[Dict[str, str]] = None
149+
) -> Optional[Dict[str, str]]:
150+
"""Merges base connection headers with additional headers.
151+
152+
Args:
153+
additional_headers: Optional headers to merge with connection headers.
154+
155+
Returns:
156+
Merged headers dictionary, or None if no headers are provided.
157+
"""
158+
if isinstance(self._connection_params, StdioConnectionParams) or isinstance(
159+
self._connection_params, StdioServerParameters
160+
):
161+
return None
162+
163+
base_headers = {}
164+
if (
165+
hasattr(self._connection_params, 'headers')
166+
and self._connection_params.headers
167+
):
168+
base_headers = self._connection_params.headers.copy()
169+
170+
if additional_headers:
171+
base_headers.update(additional_headers)
172+
173+
return base_headers
174+
175+
def _is_session_disconnected(self, session: ClientSession) -> bool:
176+
"""Checks if a session is disconnected or closed.
177+
178+
Args:
179+
session: The ClientSession to check.
180+
181+
Returns:
182+
True if the session is disconnected, False otherwise.
183+
"""
184+
return session._read_stream._closed or session._write_stream._closed
185+
186+
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
187+
"""Creates an MCP client based on the connection parameters.
188+
189+
Args:
190+
merged_headers: Optional headers to include in the connection.
191+
Only applicable for SSE and StreamableHTTP connections.
192+
193+
Returns:
194+
The appropriate MCP client instance.
195+
196+
Raises:
197+
ValueError: If the connection parameters are not supported.
198+
"""
199+
if isinstance(self._connection_params, StdioConnectionParams):
200+
client = stdio_client(
201+
server=self._connection_params.server_params,
202+
errlog=self._errlog,
203+
)
204+
elif isinstance(self._connection_params, SseConnectionParams):
205+
client = sse_client(
206+
url=self._connection_params.url,
207+
headers=merged_headers,
208+
timeout=self._connection_params.timeout,
209+
sse_read_timeout=self._connection_params.sse_read_timeout,
210+
)
211+
elif isinstance(self._connection_params, StreamableHTTPConnectionParams):
212+
client = streamablehttp_client(
213+
url=self._connection_params.url,
214+
headers=merged_headers,
215+
timeout=timedelta(seconds=self._connection_params.timeout),
216+
sse_read_timeout=timedelta(
217+
seconds=self._connection_params.sse_read_timeout
218+
),
219+
terminate_on_close=self._connection_params.terminate_on_close,
220+
)
221+
else:
222+
raise ValueError(
223+
'Unable to initialize connection. Connection should be'
224+
' StdioServerParameters or SseServerParams, but got'
225+
f' {self._connection_params}'
226+
)
227+
return client
228+
229+
async def create_session(
230+
self, headers: Optional[Dict[str, str]] = None
231+
) -> ClientSession:
232+
"""Creates and initializes an MCP client session.
233+
234+
This method will check if an existing session for the given headers
235+
is still connected. If it's disconnected, it will be cleaned up and
236+
a new session will be created.
237+
238+
The session lifecycle is managed by SessionLifecycle, which spawns a
239+
dedicated background task to handle the entire lifecycle (creation, usage,
240+
and cleanup) within a single task. This is required because MCP clients
241+
use AnyIO internally, and AnyIO's TaskGroup/CancelScope requires that the
242+
start and end of a scope occur within the same task.
243+
244+
Args:
245+
headers: Optional headers to include in the session. These will be
246+
merged with any existing connection headers. Only applicable
247+
for SSE and StreamableHTTP connections.
248+
249+
Returns:
250+
ClientSession: The initialized MCP client session.
251+
"""
252+
merged_headers = self._merge_headers(headers)
253+
session_key = self._generate_session_key(merged_headers)
254+
255+
async with self._session_lock:
256+
# Check if we have an existing session
257+
if session_key in self._sessions:
258+
lifecycle_manager = self._sessions[session_key]
259+
260+
if not self._is_session_disconnected(lifecycle_manager.session):
261+
return lifecycle_manager.session
262+
else:
263+
# Session is disconnected, clean it up
264+
logger.info('Cleaning up disconnected session: %s', session_key)
265+
try:
266+
await lifecycle_manager.close()
267+
except Exception as e:
268+
logger.warning('Error during disconnected session cleanup: %s', e)
269+
finally:
270+
del self._sessions[session_key]
271+
272+
# Create a new session
273+
timeout_in_seconds = (
274+
self._connection_params.timeout
275+
if hasattr(self._connection_params, 'timeout')
276+
else None
277+
)
278+
279+
is_stdio = isinstance(self._connection_params, StdioConnectionParams)
280+
281+
# Use SessionLifecycle to ensure client lifecycle operations
282+
# happen in the same task (required by AnyIO's CancelScope)
283+
client = self._create_client(merged_headers)
284+
lifecycle_manager = SessionLifecycle(
285+
client=client,
286+
timeout=timeout_in_seconds,
287+
is_stdio=is_stdio,
288+
)
289+
290+
try:
291+
session = await lifecycle_manager.start()
292+
self._sessions[session_key] = lifecycle_manager
293+
logger.debug('Created new session: %s', session_key)
294+
return session
295+
296+
except Exception as e:
297+
raise ConnectionError(f'Failed to create MCP session: {e}') from e
298+
299+
async def close(self):
300+
"""Closes all sessions and cleans up resources.
301+
302+
Each session's cleanup is performed by its SessionLifecycle,
303+
which ensures that the cleanup happens in the same task where the
304+
client was created (required by AnyIO's CancelScope).
305+
"""
306+
async with self._session_lock:
307+
for session_key in list(self._sessions.keys()):
308+
lifecycle_manager = self._sessions[session_key]
309+
try:
310+
await lifecycle_manager.close()
311+
except Exception as e:
312+
print(
313+
'Warning: Error during MCP session cleanup for'
314+
f' {session_key}: {e}',
315+
file=self._errlog,
316+
)
317+
finally:
318+
del self._sessions[session_key]
319+

0 commit comments

Comments
 (0)