Skip to content

Commit e895e52

Browse files
committed
added change and formatted with ruff
1 parent 05b7156 commit e895e52

File tree

1 file changed

+77
-62
lines changed

1 file changed

+77
-62
lines changed

src/mcp/shared/session.py

Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from collections.abc import Callable
34
from contextlib import AsyncExitStack
@@ -350,76 +351,90 @@ async def _receive_loop(self) -> None:
350351
self._read_stream,
351352
self._write_stream,
352353
):
353-
async for message in self._read_stream:
354-
if isinstance(message, Exception):
355-
await self._handle_incoming(message)
356-
elif isinstance(message.message.root, JSONRPCRequest):
357-
validated_request = self._receive_request_type.model_validate(
358-
message.message.root.model_dump(
359-
by_alias=True, mode="json", exclude_none=True
354+
async with asyncio.TaskGroup() as tg:
355+
async for message in self._read_stream:
356+
if isinstance(message, Exception):
357+
await self._handle_incoming(message)
358+
elif isinstance(message.message.root, JSONRPCRequest):
359+
validated_request = self._receive_request_type.model_validate(
360+
message.message.root.model_dump(
361+
by_alias=True, mode="json", exclude_none=True
362+
)
363+
)
364+
responder = RequestResponder(
365+
request_id=message.message.root.id,
366+
request_meta=validated_request.root.params.meta
367+
if validated_request.root.params
368+
else None,
369+
request=validated_request,
370+
session=self,
371+
on_complete=lambda r: self._in_flight.pop(
372+
r.request_id, None
373+
),
374+
message_metadata=message.metadata,
360375
)
361-
)
362-
responder = RequestResponder(
363-
request_id=message.message.root.id,
364-
request_meta=validated_request.root.params.meta
365-
if validated_request.root.params
366-
else None,
367-
request=validated_request,
368-
session=self,
369-
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
370-
message_metadata=message.metadata,
371-
)
372376

373-
self._in_flight[responder.request_id] = responder
374-
await self._received_request(responder)
377+
self._in_flight[responder.request_id] = responder
378+
task = tg.create_task(self._received_request(responder))
375379

376-
if not responder._completed: # type: ignore[reportPrivateUsage]
377-
await self._handle_incoming(responder)
380+
def _callback(task: asyncio.Task[None]) -> None:
381+
if not responder._completed: # type: ignore[reportPrivateUsage]
382+
tg.create_task(self._handle_incoming(responder))
378383

379-
elif isinstance(message.message.root, JSONRPCNotification):
380-
try:
381-
notification = self._receive_notification_type.model_validate(
382-
message.message.root.model_dump(
383-
by_alias=True, mode="json", exclude_none=True
384+
task.add_done_callback(_callback)
385+
386+
elif isinstance(message.message.root, JSONRPCNotification):
387+
try:
388+
notification = (
389+
self._receive_notification_type.model_validate(
390+
message.message.root.model_dump(
391+
by_alias=True, mode="json", exclude_none=True
392+
)
393+
)
384394
)
385-
)
386-
# Handle cancellation notifications
387-
if isinstance(notification.root, CancelledNotification):
388-
cancelled_id = notification.root.params.requestId
389-
if cancelled_id in self._in_flight:
390-
await self._in_flight[cancelled_id].cancel()
391-
else:
392-
# Handle progress notifications callback
393-
if isinstance(notification.root, ProgressNotification):
394-
progress_token = notification.root.params.progressToken
395-
# If there is a progress callback for this token,
396-
# call it with the progress information
397-
if progress_token in self._progress_callbacks:
398-
callback = self._progress_callbacks[progress_token]
399-
await callback(
400-
notification.root.params.progress,
401-
notification.root.params.total,
402-
notification.root.params.message,
395+
# Handle cancellation notifications
396+
if isinstance(notification.root, CancelledNotification):
397+
cancelled_id = notification.root.params.requestId
398+
if cancelled_id in self._in_flight:
399+
await self._in_flight[cancelled_id].cancel()
400+
else:
401+
# Handle progress notifications callback
402+
if isinstance(notification.root, ProgressNotification):
403+
progress_token = (
404+
notification.root.params.progressToken
403405
)
404-
await self._received_notification(notification)
405-
await self._handle_incoming(notification)
406-
except Exception as e:
407-
# For other validation errors, log and continue
408-
logging.warning(
409-
f"Failed to validate notification: {e}. "
410-
f"Message was: {message.message.root}"
411-
)
412-
else: # Response or error
413-
stream = self._response_streams.pop(message.message.root.id, None)
414-
if stream:
415-
await stream.send(message.message.root)
416-
else:
417-
await self._handle_incoming(
418-
RuntimeError(
419-
"Received response with an unknown "
420-
f"request ID: {message}"
406+
# If there is a progress callback for this token,
407+
# call it with the progress information
408+
if progress_token in self._progress_callbacks:
409+
callback = self._progress_callbacks[
410+
progress_token
411+
]
412+
await callback(
413+
notification.root.params.progress,
414+
notification.root.params.total,
415+
notification.root.params.message,
416+
)
417+
await self._received_notification(notification)
418+
await self._handle_incoming(notification)
419+
except Exception as e:
420+
# For other validation errors, log and continue
421+
logging.warning(
422+
f"Failed to validate notification: {e}. "
423+
f"Message was: {message.message.root}"
421424
)
425+
else: # Response or error
426+
stream = self._response_streams.pop(
427+
message.message.root.id, None
422428
)
429+
if stream:
430+
await stream.send(message.message.root)
431+
else:
432+
await self._handle_incoming(
433+
RuntimeError(
434+
"Received response with an unknown "
435+
f"request ID: {message}"
436+
)
437+
)
423438

424439
# after the read stream is closed, we need to send errors
425440
# to any pending requests

0 commit comments

Comments
 (0)