|
10 | 10 | import mcp.types as types |
11 | 11 | from mcp.shared.context import RequestContext |
12 | 12 | from mcp.shared.message import SessionMessage |
13 | | -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder |
| 13 | +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, RequestStateManager |
14 | 14 | from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS |
15 | 15 |
|
16 | 16 | DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") |
@@ -118,13 +118,15 @@ def __init__( |
118 | 118 | logging_callback: LoggingFnT | None = None, |
119 | 119 | message_handler: MessageHandlerFnT | None = None, |
120 | 120 | client_info: types.Implementation | None = None, |
| 121 | + request_state_manager: RequestStateManager[types.ClientRequest, types.ClientResult] | None = None, |
121 | 122 | ) -> None: |
122 | 123 | super().__init__( |
123 | 124 | read_stream, |
124 | 125 | write_stream, |
125 | 126 | types.ServerRequest, |
126 | 127 | types.ServerNotification, |
127 | 128 | read_timeout_seconds=read_timeout_seconds, |
| 129 | + request_state_manager=request_state_manager, |
128 | 130 | ) |
129 | 131 | self._client_info = client_info or DEFAULT_CLIENT_INFO |
130 | 132 | self._sampling_callback = sampling_callback or _default_sampling_callback |
@@ -281,6 +283,46 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: |
281 | 283 | types.EmptyResult, |
282 | 284 | ) |
283 | 285 |
|
| 286 | + async def request_call_tool( |
| 287 | + self, |
| 288 | + name: str, |
| 289 | + arguments: dict[str, Any] | None = None, |
| 290 | + progress_callback: ProgressFnT | None = None, |
| 291 | + ) -> types.RequestId: |
| 292 | + return await self.start_request( |
| 293 | + types.ClientRequest( |
| 294 | + types.CallToolRequest( |
| 295 | + method="tools/call", |
| 296 | + params=types.CallToolRequestParams( |
| 297 | + name=name, |
| 298 | + arguments=arguments, |
| 299 | + ), |
| 300 | + ) |
| 301 | + ), |
| 302 | + progress_callback=progress_callback, |
| 303 | + ) |
| 304 | + |
| 305 | + async def join_call_tool( |
| 306 | + self, |
| 307 | + request_id: types.RequestId, |
| 308 | + progress_callback: ProgressFnT | None = None, |
| 309 | + request_read_timeout_seconds: timedelta | None = None, |
| 310 | + fail_on_timeout: bool = True, |
| 311 | + ) -> types.CallToolResult | None: |
| 312 | + return await self.join_request( |
| 313 | + request_id, |
| 314 | + types.CallToolResult, |
| 315 | + request_read_timeout_seconds=request_read_timeout_seconds, |
| 316 | + progress_callback=progress_callback, |
| 317 | + fail_on_timeout=fail_on_timeout, |
| 318 | + ) |
| 319 | + |
| 320 | + async def cancel_call_tool( |
| 321 | + self, |
| 322 | + request_id: types.RequestId, |
| 323 | + ) -> bool: |
| 324 | + return await self.cancel_request(request_id) |
| 325 | + |
284 | 326 | async def call_tool( |
285 | 327 | self, |
286 | 328 | name: str, |
|
0 commit comments