|
45 | 45 | from mcp.server.streamable_http_manager import StreamableHTTPSessionManager |
46 | 46 | from mcp.server.transport_security import TransportSecuritySettings |
47 | 47 | from mcp.shared.context import LifespanContextT, RequestContext, RequestT |
48 | | -from mcp.types import NEXT_PROTOCOL_VERSION, AnyFunction, ContentBlock, GetPromptResult, ToolAnnotations |
| 48 | +from mcp.types import ( |
| 49 | + NEXT_PROTOCOL_VERSION, |
| 50 | + AnyFunction, |
| 51 | + ContentBlock, |
| 52 | + GetOperationPayloadResult, |
| 53 | + GetOperationStatusResult, |
| 54 | + GetPromptResult, |
| 55 | + ToolAnnotations, |
| 56 | +) |
49 | 57 | from mcp.types import Prompt as MCPPrompt |
50 | 58 | from mcp.types import PromptArgument as MCPPromptArgument |
51 | 59 | from mcp.types import Resource as MCPResource |
@@ -170,17 +178,19 @@ def __init__( |
170 | 178 | transport_security=transport_security, |
171 | 179 | ) |
172 | 180 |
|
| 181 | + self._async_operations = async_operations or AsyncOperationManager() |
| 182 | + |
173 | 183 | self._mcp_server = MCPServer( |
174 | 184 | name=name or "FastMCP", |
175 | 185 | instructions=instructions, |
| 186 | + async_operations=self._async_operations, |
176 | 187 | # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server. |
177 | 188 | # We need to create a Lifespan type that is a generic on the server type, like Starlette does. |
178 | 189 | lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore |
179 | 190 | ) |
180 | 191 | self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) |
181 | 192 | self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) |
182 | 193 | self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) |
183 | | - self.async_operations = async_operations or AsyncOperationManager() |
184 | 194 | # Validate auth configuration |
185 | 195 | if self.settings.auth is not None: |
186 | 196 | if auth_server_provider and token_verifier: |
@@ -270,6 +280,45 @@ def _setup_handlers(self) -> None: |
270 | 280 | self._mcp_server.get_prompt()(self.get_prompt) |
271 | 281 | self._mcp_server.list_resource_templates()(self.list_resource_templates) |
272 | 282 |
|
| 283 | + # Register async operation handlers |
| 284 | + logger.info(f"Async operations manager: {self._async_operations}") |
| 285 | + logger.info("Registering async operation handlers") |
| 286 | + self._mcp_server.get_operation_status()(self.get_operation_status) |
| 287 | + self._mcp_server.get_operation_result()(self.get_operation_result) |
| 288 | + |
| 289 | + async def get_operation_status(self, token: str) -> GetOperationStatusResult: |
| 290 | + """Get the status of an async operation.""" |
| 291 | + try: |
| 292 | + operation = self._async_operations.get_operation(token) |
| 293 | + if not operation: |
| 294 | + raise ValueError(f"Operation not found: {token}") |
| 295 | + |
| 296 | + return GetOperationStatusResult( |
| 297 | + status=operation.status, |
| 298 | + error=operation.error if operation.status == "failed" else None, |
| 299 | + ) |
| 300 | + except Exception: |
| 301 | + logger.exception(f"Error getting operation status for token {token}") |
| 302 | + raise |
| 303 | + |
| 304 | + async def get_operation_result(self, token: str) -> GetOperationPayloadResult: |
| 305 | + """Get the result of a completed async operation.""" |
| 306 | + try: |
| 307 | + operation = self._async_operations.get_operation(token) |
| 308 | + if not operation: |
| 309 | + raise ValueError(f"Operation not found: {token}") |
| 310 | + |
| 311 | + if operation.status != "completed": |
| 312 | + raise ValueError(f"Operation not completed: {operation.status}") |
| 313 | + |
| 314 | + if not operation.result: |
| 315 | + raise ValueError("Operation completed but no result available") |
| 316 | + |
| 317 | + return GetOperationPayloadResult(result=operation.result) |
| 318 | + except Exception: |
| 319 | + logger.exception(f"Error getting operation result for token {token}") |
| 320 | + raise |
| 321 | + |
273 | 322 | def _client_supports_async(self) -> bool: |
274 | 323 | """Check if the current client supports async tools based on protocol version.""" |
275 | 324 | try: |
|
0 commit comments