Skip to content

Commit 4165200

Browse files
committed
add methods to enable call tool requests to be started and joined at a later state and cancelled
1 parent bd84329 commit 4165200

File tree

5 files changed

+788
-80
lines changed

5 files changed

+788
-80
lines changed

src/mcp/client/session.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import mcp.types as types
1111
from mcp.shared.context import RequestContext
1212
from mcp.shared.message import SessionMessage
13-
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
13+
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, RequestStateManager
1414
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1515

1616
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
@@ -118,13 +118,15 @@ def __init__(
118118
logging_callback: LoggingFnT | None = None,
119119
message_handler: MessageHandlerFnT | None = None,
120120
client_info: types.Implementation | None = None,
121+
request_state_manager: RequestStateManager[types.ClientRequest, types.ClientResult] | None = None,
121122
) -> None:
122123
super().__init__(
123124
read_stream,
124125
write_stream,
125126
types.ServerRequest,
126127
types.ServerNotification,
127128
read_timeout_seconds=read_timeout_seconds,
129+
request_state_manager=request_state_manager,
128130
)
129131
self._client_info = client_info or DEFAULT_CLIENT_INFO
130132
self._sampling_callback = sampling_callback or _default_sampling_callback
@@ -281,6 +283,46 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
281283
types.EmptyResult,
282284
)
283285

286+
async def request_call_tool(
287+
self,
288+
name: str,
289+
arguments: dict[str, Any] | None = None,
290+
progress_callback: ProgressFnT | None = None,
291+
) -> types.RequestId:
292+
return await self.start_request(
293+
types.ClientRequest(
294+
types.CallToolRequest(
295+
method="tools/call",
296+
params=types.CallToolRequestParams(
297+
name=name,
298+
arguments=arguments,
299+
),
300+
)
301+
),
302+
progress_callback=progress_callback,
303+
)
304+
305+
async def join_call_tool(
306+
self,
307+
request_id: types.RequestId,
308+
progress_callback: ProgressFnT | None = None,
309+
request_read_timeout_seconds: timedelta | None = None,
310+
fail_on_timeout: bool = True,
311+
) -> types.CallToolResult | None:
312+
return await self.join_request(
313+
request_id,
314+
types.CallToolResult,
315+
request_read_timeout_seconds=request_read_timeout_seconds,
316+
progress_callback=progress_callback,
317+
fail_on_timeout=fail_on_timeout,
318+
)
319+
320+
async def cancel_call_tool(
321+
self,
322+
request_id: types.RequestId,
323+
) -> bool:
324+
return await self.cancel_request(request_id)
325+
284326
async def call_tool(
285327
self,
286328
name: str,

0 commit comments

Comments
 (0)