|
| 1 | +import asyncio |
1 | 2 | import logging |
2 | 3 | from collections.abc import Callable |
3 | 4 | from contextlib import AsyncExitStack |
@@ -350,76 +351,90 @@ async def _receive_loop(self) -> None: |
350 | 351 | self._read_stream, |
351 | 352 | self._write_stream, |
352 | 353 | ): |
353 | | - async for message in self._read_stream: |
354 | | - if isinstance(message, Exception): |
355 | | - await self._handle_incoming(message) |
356 | | - elif isinstance(message.message.root, JSONRPCRequest): |
357 | | - validated_request = self._receive_request_type.model_validate( |
358 | | - message.message.root.model_dump( |
359 | | - by_alias=True, mode="json", exclude_none=True |
| 354 | + async with asyncio.TaskGroup() as tg: |
| 355 | + async for message in self._read_stream: |
| 356 | + if isinstance(message, Exception): |
| 357 | + await self._handle_incoming(message) |
| 358 | + elif isinstance(message.message.root, JSONRPCRequest): |
| 359 | + validated_request = self._receive_request_type.model_validate( |
| 360 | + message.message.root.model_dump( |
| 361 | + by_alias=True, mode="json", exclude_none=True |
| 362 | + ) |
| 363 | + ) |
| 364 | + responder = RequestResponder( |
| 365 | + request_id=message.message.root.id, |
| 366 | + request_meta=validated_request.root.params.meta |
| 367 | + if validated_request.root.params |
| 368 | + else None, |
| 369 | + request=validated_request, |
| 370 | + session=self, |
| 371 | + on_complete=lambda r: self._in_flight.pop( |
| 372 | + r.request_id, None |
| 373 | + ), |
| 374 | + message_metadata=message.metadata, |
360 | 375 | ) |
361 | | - ) |
362 | | - responder = RequestResponder( |
363 | | - request_id=message.message.root.id, |
364 | | - request_meta=validated_request.root.params.meta |
365 | | - if validated_request.root.params |
366 | | - else None, |
367 | | - request=validated_request, |
368 | | - session=self, |
369 | | - on_complete=lambda r: self._in_flight.pop(r.request_id, None), |
370 | | - message_metadata=message.metadata, |
371 | | - ) |
372 | 376 |
|
373 | | - self._in_flight[responder.request_id] = responder |
374 | | - await self._received_request(responder) |
| 377 | + self._in_flight[responder.request_id] = responder |
| 378 | + task = tg.create_task(self._received_request(responder)) |
375 | 379 |
|
376 | | - if not responder._completed: # type: ignore[reportPrivateUsage] |
377 | | - await self._handle_incoming(responder) |
| 380 | + def _callback(task: asyncio.Task[None]) -> None: |
| 381 | + if not responder._completed: # type: ignore[reportPrivateUsage] |
| 382 | + tg.create_task(self._handle_incoming(responder)) |
378 | 383 |
|
379 | | - elif isinstance(message.message.root, JSONRPCNotification): |
380 | | - try: |
381 | | - notification = self._receive_notification_type.model_validate( |
382 | | - message.message.root.model_dump( |
383 | | - by_alias=True, mode="json", exclude_none=True |
| 384 | + task.add_done_callback(_callback) |
| 385 | + |
| 386 | + elif isinstance(message.message.root, JSONRPCNotification): |
| 387 | + try: |
| 388 | + notification = ( |
| 389 | + self._receive_notification_type.model_validate( |
| 390 | + message.message.root.model_dump( |
| 391 | + by_alias=True, mode="json", exclude_none=True |
| 392 | + ) |
| 393 | + ) |
384 | 394 | ) |
385 | | - ) |
386 | | - # Handle cancellation notifications |
387 | | - if isinstance(notification.root, CancelledNotification): |
388 | | - cancelled_id = notification.root.params.requestId |
389 | | - if cancelled_id in self._in_flight: |
390 | | - await self._in_flight[cancelled_id].cancel() |
391 | | - else: |
392 | | - # Handle progress notifications callback |
393 | | - if isinstance(notification.root, ProgressNotification): |
394 | | - progress_token = notification.root.params.progressToken |
395 | | - # If there is a progress callback for this token, |
396 | | - # call it with the progress information |
397 | | - if progress_token in self._progress_callbacks: |
398 | | - callback = self._progress_callbacks[progress_token] |
399 | | - await callback( |
400 | | - notification.root.params.progress, |
401 | | - notification.root.params.total, |
402 | | - notification.root.params.message, |
| 395 | + # Handle cancellation notifications |
| 396 | + if isinstance(notification.root, CancelledNotification): |
| 397 | + cancelled_id = notification.root.params.requestId |
| 398 | + if cancelled_id in self._in_flight: |
| 399 | + await self._in_flight[cancelled_id].cancel() |
| 400 | + else: |
| 401 | + # Handle progress notifications callback |
| 402 | + if isinstance(notification.root, ProgressNotification): |
| 403 | + progress_token = ( |
| 404 | + notification.root.params.progressToken |
403 | 405 | ) |
404 | | - await self._received_notification(notification) |
405 | | - await self._handle_incoming(notification) |
406 | | - except Exception as e: |
407 | | - # For other validation errors, log and continue |
408 | | - logging.warning( |
409 | | - f"Failed to validate notification: {e}. " |
410 | | - f"Message was: {message.message.root}" |
411 | | - ) |
412 | | - else: # Response or error |
413 | | - stream = self._response_streams.pop(message.message.root.id, None) |
414 | | - if stream: |
415 | | - await stream.send(message.message.root) |
416 | | - else: |
417 | | - await self._handle_incoming( |
418 | | - RuntimeError( |
419 | | - "Received response with an unknown " |
420 | | - f"request ID: {message}" |
| 406 | + # If there is a progress callback for this token, |
| 407 | + # call it with the progress information |
| 408 | + if progress_token in self._progress_callbacks: |
| 409 | + callback = self._progress_callbacks[ |
| 410 | + progress_token |
| 411 | + ] |
| 412 | + await callback( |
| 413 | + notification.root.params.progress, |
| 414 | + notification.root.params.total, |
| 415 | + notification.root.params.message, |
| 416 | + ) |
| 417 | + await self._received_notification(notification) |
| 418 | + await self._handle_incoming(notification) |
| 419 | + except Exception as e: |
| 420 | + # For other validation errors, log and continue |
| 421 | + logging.warning( |
| 422 | + f"Failed to validate notification: {e}. " |
| 423 | + f"Message was: {message.message.root}" |
421 | 424 | ) |
| 425 | + else: # Response or error |
| 426 | + stream = self._response_streams.pop( |
| 427 | + message.message.root.id, None |
422 | 428 | ) |
| 429 | + if stream: |
| 430 | + await stream.send(message.message.root) |
| 431 | + else: |
| 432 | + await self._handle_incoming( |
| 433 | + RuntimeError( |
| 434 | + "Received response with an unknown " |
| 435 | + f"request ID: {message}" |
| 436 | + ) |
| 437 | + ) |
423 | 438 |
|
424 | 439 | # after the read stream is closed, we need to send errors |
425 | 440 | # to any pending requests |
|
0 commit comments