|
| 1 | +import traceback |
| 2 | +from datetime import datetime |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +from pydantic import BaseModel, ValidationError |
| 6 | + |
| 7 | +from mcp.server.minimcp.types import Message |
| 8 | +from mcp.types import ( |
| 9 | + ErrorData, |
| 10 | + JSONRPCError, |
| 11 | + JSONRPCMessage, |
| 12 | + JSONRPCNotification, |
| 13 | + JSONRPCResponse, |
| 14 | + ServerNotification, |
| 15 | + ServerResult, |
| 16 | +) |
| 17 | + |
| 18 | +# TODO: Remove once https://github.com/modelcontextprotocol/python-sdk/pull/1310 is merged |
| 19 | +JSON_RPC_VERSION = "2.0" |
| 20 | + |
| 21 | + |
| 22 | +def to_dict(model: BaseModel) -> dict[str, Any]: |
| 23 | + """ |
| 24 | + Convert a JSON-RPC Pydantic model to a dictionary. |
| 25 | +
|
| 26 | + Args: |
| 27 | + model: The Pydantic model to convert. |
| 28 | +
|
| 29 | + Returns: |
| 30 | + A dictionary representation of the model. |
| 31 | + """ |
| 32 | + return model.model_dump(by_alias=True, exclude_none=True) |
| 33 | + |
| 34 | + |
| 35 | +def _to_message(model: BaseModel) -> Message: |
| 36 | + return model.model_dump_json(by_alias=True, exclude_none=True) |
| 37 | + |
| 38 | + |
| 39 | +# --- Build JSON-RPC messages --- |
| 40 | + |
| 41 | + |
| 42 | +def build_response_message(request_id: str | int, response: ServerResult) -> Message: |
| 43 | + """ |
| 44 | + Build a JSON-RPC response message with the given message ID and response. |
| 45 | +
|
| 46 | + Args: |
| 47 | + request_id: The message ID to use. |
| 48 | + response: The response object to build the response message from. |
| 49 | +
|
| 50 | + Returns: |
| 51 | + A JSON-RPC response message string. |
| 52 | + """ |
| 53 | + json_rpc_response = JSONRPCResponse(jsonrpc=JSON_RPC_VERSION, id=request_id, result=to_dict(response)) |
| 54 | + return _to_message(JSONRPCMessage(json_rpc_response)) |
| 55 | + |
| 56 | + |
| 57 | +def build_notification_message(notification: ServerNotification) -> Message: |
| 58 | + """ |
| 59 | + Build a JSON-RPC notification message with the given notification. |
| 60 | +
|
| 61 | + Args: |
| 62 | + notification: The notification object to build the notification message from. |
| 63 | +
|
| 64 | + Returns: |
| 65 | + A JSON-RPC notification message string. |
| 66 | + """ |
| 67 | + json_rpc_notification = JSONRPCNotification(jsonrpc=JSON_RPC_VERSION, **to_dict(notification)) |
| 68 | + return _to_message(JSONRPCMessage(json_rpc_notification)) |
| 69 | + |
| 70 | + |
| 71 | +def build_error_message( |
| 72 | + error: BaseException, |
| 73 | + request_message: str, |
| 74 | + error_code: int, |
| 75 | + data: dict[str, Any] | None = None, |
| 76 | + include_stack_trace: bool = False, |
| 77 | +) -> tuple[Message, str]: |
| 78 | + """ |
| 79 | + Build a JSON-RPC error message with the given error code, message ID, and error. |
| 80 | +
|
| 81 | + Args: |
| 82 | + error: The error object to build the error message from. |
| 83 | + request_message: The request message that resulted in the error. |
| 84 | + error_code: The JSON-RPC error code to use. See mcp.types for available codes. |
| 85 | + data: Additional data to include in the error message. |
| 86 | + include_stack_trace: Whether to include the stack trace in the error message. |
| 87 | +
|
| 88 | + Returns: |
| 89 | + A tuple containing the error formatted as a JSON-RPC message and a human-readable string. |
| 90 | + """ |
| 91 | + |
| 92 | + request_id = get_request_id(request_message) |
| 93 | + error_type = error.__class__.__name__ |
| 94 | + error_message = f"{error_type}: {error} (Request ID {request_id})" |
| 95 | + |
| 96 | + # Build error data |
| 97 | + error_metadata: dict[str, Any] = { |
| 98 | + "errorType": error_type, |
| 99 | + "errorModule": error.__class__.__module__, |
| 100 | + "isoTimestamp": datetime.now().isoformat(), |
| 101 | + } |
| 102 | + |
| 103 | + if include_stack_trace: |
| 104 | + stack_trace = traceback.format_exception(type(error), error, error.__traceback__) |
| 105 | + error_metadata["stackTrace"] = "".join(stack_trace) |
| 106 | + |
| 107 | + error_data = ErrorData(code=error_code, message=error_message, data={**error_metadata, **(data or {})}) |
| 108 | + |
| 109 | + json_rpc_error = JSONRPCError(jsonrpc=JSON_RPC_VERSION, id=request_id, error=error_data) |
| 110 | + return _to_message(JSONRPCMessage(json_rpc_error)), error_message |
| 111 | + |
| 112 | + |
| 113 | +# --- Utility functions to extract basic details of out of JSON-RPC message --- |
| 114 | + |
| 115 | + |
| 116 | +# Using a custom model to extract basic details of out of JSON-RPC message |
| 117 | +# as pydantic model_validate_json is better than json.loads. |
| 118 | +# This could be further optimized using something like ijson, but would be an unnecessary dependency. |
| 119 | +class JSONRPCEnvelope(BaseModel): |
| 120 | + id: int | str | None = None |
| 121 | + method: str | None = None |
| 122 | + jsonrpc: str | None = None |
| 123 | + |
| 124 | + |
| 125 | +def get_request_id(request_message: str) -> str | int: |
| 126 | + """ |
| 127 | + Get the request ID from a JSON-RPC request message string. |
| 128 | + """ |
| 129 | + request_id = None |
| 130 | + try: |
| 131 | + request_id = JSONRPCEnvelope.model_validate_json(request_message).id |
| 132 | + except ValidationError: |
| 133 | + pass |
| 134 | + |
| 135 | + return "no-id" if request_id is None else request_id |
| 136 | + |
| 137 | + |
| 138 | +def is_initialize_request(request_message: str) -> bool: |
| 139 | + """ |
| 140 | + Check if the request message is an initialize request. |
| 141 | + """ |
| 142 | + try: |
| 143 | + if "initialize" in request_message: |
| 144 | + return JSONRPCEnvelope.model_validate_json(request_message).method == "initialize" |
| 145 | + except ValidationError: |
| 146 | + pass |
| 147 | + |
| 148 | + return False |
| 149 | + |
| 150 | + |
| 151 | +def check_jsonrpc_version(request_message: str) -> bool: |
| 152 | + """ |
| 153 | + Check if the JSON-RPC version is valid. |
| 154 | + """ |
| 155 | + try: |
| 156 | + return JSONRPCEnvelope.model_validate_json(request_message).jsonrpc == JSON_RPC_VERSION |
| 157 | + except ValidationError: |
| 158 | + return False |
0 commit comments