Skip to content

Commit 04354dd

Browse files
committed
Drop RootModel from JSONRPCMessage
1 parent dcc9b4f commit 04354dd

File tree

18 files changed

+252
-403
lines changed

18 files changed

+252
-403
lines changed

src/mcp/client/streamable_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ async def _send_session_terminated_error(self, read_stream_writer: StreamWriter,
416416
id=request_id,
417417
error=ErrorData(code=32600, message="Session terminated"),
418418
)
419-
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
419+
session_message = SessionMessage(jsonrpc_error)
420420
await read_stream_writer.send(session_message)
421421

422422
async def post_writer(

src/mcp/server/experimental/task_result_handler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
ErrorData,
2727
GetTaskPayloadRequest,
2828
GetTaskPayloadResult,
29-
JSONRPCMessage,
3029
RelatedTaskMetadata,
3130
RequestId,
3231
)
@@ -107,12 +106,7 @@ async def handle(
107106
while True:
108107
task = await self._store.get_task(task_id)
109108
if task is None:
110-
raise McpError(
111-
ErrorData(
112-
code=INVALID_PARAMS,
113-
message=f"Task not found: {task_id}",
114-
)
115-
)
109+
raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Task not found: {task_id}"))
116110

117111
await self._deliver_queued_messages(task_id, session, request_id)
118112

@@ -161,7 +155,7 @@ async def _deliver_queued_messages(
161155

162156
# Send the message with relatedRequestId for routing
163157
session_message = SessionMessage(
164-
message=JSONRPCMessage(message.message),
158+
message=message.message,
165159
metadata=ServerMessageMetadata(related_request_id=request_id),
166160
)
167161
await self.send_message(session, session_message)

src/mcp/shared/session.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
ClientResult,
2525
ErrorData,
2626
JSONRPCError,
27-
JSONRPCMessage,
2827
JSONRPCNotification,
2928
JSONRPCRequest,
3029
JSONRPCResponse,
@@ -271,7 +270,7 @@ async def send_request(
271270
**request_data,
272271
)
273272

274-
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
273+
await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata))
275274

276275
# request read timeout takes precedence over session read timeout
277276
timeout = None
@@ -321,23 +320,23 @@ async def send_notification(
321320
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
322321
)
323322
session_message = SessionMessage( # pragma: no cover
324-
message=JSONRPCMessage(jsonrpc_notification),
323+
message=jsonrpc_notification,
325324
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
326325
)
327326
await self._write_stream.send(session_message)
328327

329328
async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
330329
if isinstance(response, ErrorData):
331330
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
332-
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
331+
session_message = SessionMessage(message=jsonrpc_error)
333332
await self._write_stream.send(session_message)
334333
else:
335334
jsonrpc_response = JSONRPCResponse(
336335
jsonrpc="2.0",
337336
id=request_id,
338337
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
339338
)
340-
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
339+
session_message = SessionMessage(message=jsonrpc_response)
341340
await self._write_stream.send(session_message)
342341

343342
async def _receive_loop(self) -> None:
@@ -349,14 +348,14 @@ async def _receive_loop(self) -> None:
349348
async for message in self._read_stream:
350349
if isinstance(message, Exception): # pragma: no cover
351350
await self._handle_incoming(message)
352-
elif isinstance(message.message.root, JSONRPCRequest):
351+
elif isinstance(message.message, JSONRPCRequest):
353352
try:
354353
validated_request = self._receive_request_type.model_validate(
355-
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True),
354+
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
356355
by_name=False,
357356
)
358357
responder = RequestResponder(
359-
request_id=message.message.root.id,
358+
request_id=message.message.id,
360359
request_meta=validated_request.root.params.meta
361360
if validated_request.root.params
362361
else None,
@@ -374,23 +373,23 @@ async def _receive_loop(self) -> None:
374373
# For request validation errors, send a proper JSON-RPC error
375374
# response instead of crashing the server
376375
logging.warning(f"Failed to validate request: {e}")
377-
logging.debug(f"Message that failed validation: {message.message.root}")
376+
logging.debug(f"Message that failed validation: {message.message}")
378377
error_response = JSONRPCError(
379378
jsonrpc="2.0",
380-
id=message.message.root.id,
379+
id=message.message.id,
381380
error=ErrorData(
382381
code=INVALID_PARAMS,
383382
message="Invalid request parameters",
384383
data="",
385384
),
386385
)
387-
session_message = SessionMessage(message=JSONRPCMessage(error_response))
386+
session_message = SessionMessage(message=error_response)
388387
await self._write_stream.send(session_message)
389388

390-
elif isinstance(message.message.root, JSONRPCNotification):
389+
elif isinstance(message.message, JSONRPCNotification):
391390
try:
392391
notification = self._receive_notification_type.model_validate(
393-
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True),
392+
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
394393
by_name=False,
395394
)
396395
# Handle cancellation notifications
@@ -419,10 +418,11 @@ async def _receive_loop(self) -> None:
419418
)
420419
await self._received_notification(notification)
421420
await self._handle_incoming(notification)
422-
except Exception as e: # pragma: no cover
421+
except Exception: # pragma: no cover
423422
# For other validation errors, log and continue
424423
logging.warning(
425-
f"Failed to validate notification: {e}. Message was: {message.message.root}"
424+
f"Failed to validate notification:. Message was: {message.message}",
425+
exc_info=True,
426426
)
427427
else: # Response or error
428428
await self._handle_response(message)
@@ -475,35 +475,33 @@ async def _handle_response(self, message: SessionMessage) -> None:
475475
Checks response routers first (e.g., for task-related responses),
476476
then falls back to the normal response stream mechanism.
477477
"""
478-
root = message.message.root
479-
480478
# This check is always true at runtime: the caller (_receive_loop) only invokes
481479
# this method in the else branch after checking for JSONRPCRequest and
482480
# JSONRPCNotification. However, the type checker can't infer this from the
483481
# method signature, so we need this guard for type narrowing.
484-
if not isinstance(root, JSONRPCResponse | JSONRPCError):
482+
if not isinstance(message.message, JSONRPCResponse | JSONRPCError):
485483
return # pragma: no cover
486484

487485
# Normalize response ID to handle type mismatches (e.g., "0" vs 0)
488-
response_id = self._normalize_request_id(root.id)
486+
response_id = self._normalize_request_id(message.message.id)
489487

490488
# First, check response routers (e.g., TaskResultHandler)
491-
if isinstance(root, JSONRPCError):
489+
if isinstance(message.message, JSONRPCError):
492490
# Route error to routers
493491
for router in self._response_routers:
494-
if router.route_error(response_id, root.error):
492+
if router.route_error(response_id, message.message.error):
495493
return # Handled
496494
else:
497495
# Route success response to routers
498-
response_data: dict[str, Any] = root.result or {}
496+
response_data: dict[str, Any] = message.message.result or {}
499497
for router in self._response_routers:
500498
if router.route_response(response_id, response_data):
501499
return # Handled
502500

503501
# Fall back to normal response streams
504502
stream = self._response_streams.pop(response_id, None)
505503
if stream: # pragma: no cover
506-
await stream.send(root)
504+
await stream.send(message.message)
507505
else: # pragma: no cover
508506
await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
509507

src/mcp/types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,7 @@ class JSONRPCError(MCPModel):
197197
error: ErrorData
198198

199199

200-
class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]):
201-
pass
200+
JSONRPCMessage = JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError
202201

203202

204203
class EmptyResult(Result):

0 commit comments

Comments
 (0)