Skip to content

Commit 8e567ec

Browse files
committed
Drop RootModel from JSONRPCMessage
1 parent 04354dd commit 8e567ec

File tree

12 files changed

+62
-73
lines changed

12 files changed

+62
-73
lines changed

src/mcp/client/sse.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ async def sse_client(
7373
event_source.response.raise_for_status()
7474
logger.debug("SSE connection established")
7575

76-
async def sse_reader(
77-
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
78-
):
76+
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
7977
try:
8078
async for sse in event_source.aiter_sse(): # pragma: no branch
8179
logger.debug(f"Received SSE event: {sse.event}")
@@ -108,7 +106,7 @@ async def sse_reader(
108106
if not sse.data:
109107
continue
110108
try:
111-
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
109+
message = types.jsonrpc_message_adapter.validate_json(
112110
sse.data, by_name=False
113111
)
114112
logger.debug(f"Received server message: {message}")

src/mcp/client/stdio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ async def stdout_reader():
150150

151151
for line in lines:
152152
try:
153-
message = types.JSONRPCMessage.model_validate_json(line, by_name=False)
153+
message = types.jsonrpc_message_adapter.validate_json(line, by_name=False)
154154
except Exception as exc: # pragma: no cover
155155
logger.exception("Failed to parse JSONRPC message from server")
156156
await read_stream_writer.send(exc)

src/mcp/client/streamable_http.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
JSONRPCRequest,
2626
JSONRPCResponse,
2727
RequestId,
28+
jsonrpc_message_adapter,
2829
)
2930

3031
logger = logging.getLogger(__name__)
@@ -95,11 +96,11 @@ def _prepare_headers(self) -> dict[str, str]:
9596

9697
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
9798
"""Check if the message is an initialization request."""
98-
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
99+
return isinstance(message, JSONRPCRequest) and message.method == "initialize"
99100

100101
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
101102
"""Check if the message is an initialized notification."""
102-
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
103+
return isinstance(message, JSONRPCNotification) and message.method == "notifications/initialized"
103104

104105
def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None:
105106
"""Extract and store session ID from response headers."""
@@ -110,15 +111,15 @@ def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> N
110111

111112
def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None:
112113
"""Extract protocol version from initialization response message."""
113-
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
114+
if isinstance(message, JSONRPCResponse) and message.result: # pragma: no branch
114115
try:
115116
# Parse the result as InitializeResult for type safety
116-
init_result = InitializeResult.model_validate(message.root.result, by_name=False)
117+
init_result = InitializeResult.model_validate(message.result, by_name=False)
117118
self.protocol_version = str(init_result.protocol_version)
118119
logger.info(f"Negotiated protocol version: {self.protocol_version}")
119120
except Exception: # pragma: no cover
120121
logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True)
121-
logger.warning(f"Raw result: {message.root.result}")
122+
logger.warning(f"Raw result: {message.result}")
122123

123124
async def _handle_sse_event(
124125
self,
@@ -137,16 +138,16 @@ async def _handle_sse_event(
137138
await resumption_callback(sse.id)
138139
return False
139140
try:
140-
message = JSONRPCMessage.model_validate_json(sse.data, by_name=False)
141+
message = jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
141142
logger.debug(f"SSE message: {message}")
142143

143144
# Extract protocol version from initialization response
144145
if is_initialization:
145146
self._maybe_extract_protocol_version_from_message(message)
146147

147148
# If this is a response and we have original_request_id, replace it
148-
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
149-
message.root.id = original_request_id
149+
if original_request_id is not None and isinstance(message, JSONRPCResponse | JSONRPCError):
150+
message.id = original_request_id
150151

151152
session_message = SessionMessage(message)
152153
await read_stream_writer.send(session_message)
@@ -157,7 +158,7 @@ async def _handle_sse_event(
157158

158159
# If this is a response or error return True indicating completion
159160
# Otherwise, return False to continue listening
160-
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
161+
return isinstance(message, JSONRPCResponse | JSONRPCError)
161162

162163
except Exception as exc: # pragma: no cover
163164
logger.exception("Error parsing SSE message")
@@ -222,8 +223,8 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
222223

223224
# Extract original request ID to map responses
224225
original_request_id = None
225-
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
226-
original_request_id = ctx.session_message.message.root.id
226+
if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch
227+
original_request_id = ctx.session_message.message.id
227228

228229
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
229230
event_source.response.raise_for_status()
@@ -257,20 +258,17 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
257258
return
258259

259260
if response.status_code == 404: # pragma: no branch
260-
if isinstance(message.root, JSONRPCRequest):
261-
await self._send_session_terminated_error( # pragma: no cover
262-
ctx.read_stream_writer, # pragma: no cover
263-
message.root.id, # pragma: no cover
264-
) # pragma: no cover
265-
return # pragma: no cover
261+
if isinstance(message, JSONRPCRequest): # pragma: no branch
262+
await self._send_session_terminated_error(ctx.read_stream_writer, message.id)
263+
return
266264

267265
response.raise_for_status()
268266
if is_initialization:
269267
self._maybe_extract_session_id_from_response(response)
270268

271269
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
272270
# The server MUST NOT send a response to notifications.
273-
if isinstance(message.root, JSONRPCRequest):
271+
if isinstance(message, JSONRPCRequest):
274272
content_type = response.headers.get("content-type", "").lower()
275273
if content_type.startswith("application/json"):
276274
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
@@ -291,7 +289,7 @@ async def _handle_json_response(
291289
"""Handle JSON response from the server."""
292290
try:
293291
content = await response.aread()
294-
message = JSONRPCMessage.model_validate_json(content, by_name=False)
292+
message = jsonrpc_message_adapter.validate_json(content, by_name=False)
295293

296294
# Extract protocol version from initialization response
297295
if is_initialization:
@@ -365,8 +363,8 @@ async def _handle_reconnection(
365363

366364
# Extract original request ID to map responses
367365
original_request_id = None
368-
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
369-
original_request_id = ctx.session_message.message.root.id
366+
if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch
367+
original_request_id = ctx.session_message.message.id
370368

371369
try:
372370
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
@@ -463,7 +461,7 @@ async def handle_request_async():
463461
await self._handle_post_request(ctx)
464462

465463
# If this is a request, start a new task to handle it
466-
if isinstance(message.root, JSONRPCRequest):
464+
if isinstance(message, JSONRPCRequest):
467465
tg.start_soon(handle_request_async)
468466
else:
469467
await handle_request_async()

src/mcp/client/websocket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def ws_reader():
5151
async with read_stream_writer:
5252
async for raw_text in ws:
5353
try:
54-
message = types.JSONRPCMessage.model_validate_json(raw_text, by_name=False)
54+
message = types.jsonrpc_message_adapter.validate_json(raw_text, by_name=False)
5555
session_message = SessionMessage(message)
5656
await read_stream_writer.send(session_message)
5757
except ValidationError as exc: # pragma: no cover

src/mcp/server/sse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
227227
logger.debug(f"Received JSON: {body}")
228228

229229
try:
230-
message = types.JSONRPCMessage.model_validate_json(body, by_name=False)
230+
message = types.jsonrpc_message_adapter.validate_json(body, by_name=False)
231231
logger.debug(f"Validated client message: {message}")
232232
except ValidationError as err:
233233
logger.exception("Failed to parse message")

src/mcp/server/stdio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def stdin_reader():
6060
async with read_stream_writer:
6161
async for line in stdin:
6262
try:
63-
message = types.JSONRPCMessage.model_validate_json(line, by_name=False)
63+
message = types.jsonrpc_message_adapter.validate_json(line, by_name=False)
6464
except Exception as exc: # pragma: no cover
6565
await read_stream_writer.send(exc)
6666
continue

src/mcp/server/streamable_http.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
JSONRPCRequest,
4343
JSONRPCResponse,
4444
RequestId,
45+
jsonrpc_message_adapter,
4546
)
4647

4748
logger = logging.getLogger(__name__)
@@ -301,10 +302,7 @@ def _create_error_response(
301302
error_response = JSONRPCError(
302303
jsonrpc="2.0",
303304
id="server-error", # We don't have a request ID for general errors
304-
error=ErrorData(
305-
code=error_code,
306-
message=error_message,
307-
),
305+
error=ErrorData(code=error_code, message=error_message),
308306
)
309307

310308
return Response(
@@ -455,14 +453,15 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
455453
body = await request.body()
456454

457455
try:
456+
# TODO(Marcelo): Replace `json.loads` with `pydantic_core.from_json`.
458457
raw_message = json.loads(body)
459458
except json.JSONDecodeError as e:
460459
response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR)
461460
await response(scope, receive, send)
462461
return
463462

464463
try: # pragma: no cover
465-
message = JSONRPCMessage.model_validate(raw_message, by_name=False)
464+
message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False)
466465
except ValidationError as e: # pragma: no cover
467466
response = self._create_error_response(
468467
f"Validation error: {str(e)}",
@@ -473,9 +472,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
473472
return
474473

475474
# Check if this is an initialization request
476-
is_initialization_request = (
477-
isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
478-
) # pragma: no cover
475+
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"
479476

480477
if is_initialization_request: # pragma: no cover
481478
# Check if the server already has an established session
@@ -495,7 +492,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
495492
return
496493

497494
# For notifications and responses only, return 202 Accepted
498-
if not isinstance(message.root, JSONRPCRequest): # pragma: no cover
495+
if not isinstance(message, JSONRPCRequest): # pragma: no cover
499496
# Create response object and send it
500497
response = self._create_json_response(
501498
None,
@@ -514,13 +511,13 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
514511
# For initialize requests, get from request params.
515512
# For other requests, get from header (already validated).
516513
protocol_version = (
517-
str(message.root.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION))
518-
if is_initialization_request and message.root.params
514+
str(message.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION))
515+
if is_initialization_request and message.params
519516
else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
520517
)
521518

522519
# Extract the request ID outside the try block for proper scope
523-
request_id = str(message.root.id) # pragma: no cover
520+
request_id = str(message.id) # pragma: no cover
524521
# Register this stream for the request ID
525522
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) # pragma: no cover
526523
request_stream_reader = self._request_streams[request_id][1] # pragma: no cover
@@ -538,12 +535,12 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
538535
# Use similar approach to SSE writer for consistency
539536
async for event_message in request_stream_reader:
540537
# If it's a response, this is what we're waiting for
541-
if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError):
538+
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
542539
response_message = event_message.message
543540
break
544541
# For notifications and request, keep waiting
545542
else:
546-
logger.debug(f"received: {event_message.message.root.method}")
543+
logger.debug(f"received: {event_message.message.method}")
547544

548545
# At this point we should have a response
549546
if response_message:
@@ -589,10 +586,7 @@ async def sse_writer():
589586
await sse_stream_writer.send(event_data)
590587

591588
# If response, remove from pending streams and close
592-
if isinstance(
593-
event_message.message.root,
594-
JSONRPCResponse | JSONRPCError,
595-
):
589+
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
596590
break
597591
except anyio.ClosedResourceError:
598592
# Expected when close_sse_stream() is called
@@ -984,8 +978,8 @@ async def message_router(): # pragma: no cover
984978
message = session_message.message
985979
target_request_id = None
986980
# Check if this is a response
987-
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
988-
response_id = str(message.root.id)
981+
if isinstance(message, JSONRPCResponse | JSONRPCError):
982+
response_id = str(message.id)
989983
# If this response is for an existing request stream,
990984
# send it there
991985
target_request_id = response_id
@@ -1022,7 +1016,7 @@ async def message_router(): # pragma: no cover
10221016
self._request_streams.pop(request_stream_id, None)
10231017
else:
10241018
logger.debug(
1025-
f"""Request stream {request_stream_id} not found
1019+
f"""Request stream {request_stream_id} not found
10261020
for message. Still processing message as the client
10271021
might reconnect and replay."""
10281022
)

src/mcp/server/websocket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def ws_reader():
3636
async with read_stream_writer:
3737
async for msg in websocket.iter_text():
3838
try:
39-
client_message = types.JSONRPCMessage.model_validate_json(msg, by_name=False)
39+
client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False)
4040
except ValidationError as exc:
4141
await read_stream_writer.send(exc)
4242
continue

src/mcp/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datetime import datetime
55
from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar
66

7-
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
7+
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel, TypeAdapter
88
from pydantic.alias_generators import to_camel
99

1010
LATEST_PROTOCOL_VERSION = "2025-11-25"
@@ -198,6 +198,7 @@ class JSONRPCError(MCPModel):
198198

199199

200200
JSONRPCMessage = JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError
201+
jsonrpc_message_adapter = TypeAdapter[JSONRPCMessage](JSONRPCMessage)
201202

202203

203204
class EmptyResult(Result):

0 commit comments

Comments
 (0)