|
37 | 37 | import psutil |
38 | 38 | import zmq |
39 | 39 | import zmq_anyio |
40 | | -from anyio import TASK_STATUS_IGNORED, create_task_group, sleep, to_thread |
| 40 | +from anyio import ( |
| 41 | + TASK_STATUS_IGNORED, |
| 42 | + create_memory_object_stream, |
| 43 | + create_task_group, |
| 44 | + sleep, |
| 45 | + to_thread, |
| 46 | +) |
41 | 47 | from anyio.abc import TaskStatus |
42 | 48 | from IPython.core.error import StdinNotImplementedError |
43 | 49 | from jupyter_client.session import Session |
@@ -418,27 +424,39 @@ async def shell_main(self, subshell_id: str | None): |
418 | 424 | assert subshell_id is None |
419 | 425 | assert threading.current_thread() == threading.main_thread() |
420 | 426 | socket = None |
421 | | - |
| 427 | + send_stream, receive_stream = create_memory_object_stream() |
422 | 428 | async with create_task_group() as tg: |
423 | 429 | if not socket.started.is_set(): |
424 | 430 | await tg.start(socket.start) |
425 | | - tg.start_soon(self.process_shell, socket) |
| 431 | + tg.start_soon(self.process_shell, socket, send_stream) |
| 432 | + tg.start_soon(self._execute_request_handler, receive_stream) |
426 | 433 | if subshell_id is None: |
427 | 434 | # Main subshell. |
428 | 435 | await to_thread.run_sync(self.shell_stop.wait) |
429 | 436 | tg.cancel_scope.cancel() |
430 | 437 |
|
431 | | - async def process_shell(self, socket=None): |
| 438 | + async def _execute_request_handler(self, receive_stream): |
| 439 | + async with receive_stream: |
| 440 | + async for handler, (socket, idents, msg) in receive_stream: |
| 441 | + try: |
| 442 | + result = handler(socket, idents, msg) |
| 443 | + self.set_parent(idents, msg, channel="shell") |
| 444 | + if inspect.isawaitable(result): |
| 445 | + await result |
| 446 | + except Exception as e: |
| 447 | + self.log.exception("Execute request", exc_info=e) |
| 448 | + |
| 449 | + async def process_shell(self, socket, send_stream): |
432 | 450 | # socket=None is valid if kernel subshells are not supported. |
433 | 451 | try: |
434 | 452 | while True: |
435 | | - await self.process_shell_message(socket=socket) |
| 453 | + await self.process_shell_message(socket=socket, send_stream=send_stream) |
436 | 454 | except BaseException: |
437 | 455 | if self.shell_stop.is_set(): |
438 | 456 | return |
439 | 457 | raise |
440 | 458 |
|
441 | | - async def process_shell_message(self, msg=None, socket=None): |
| 459 | + async def process_shell_message(self, msg=None, socket=None, send_stream=None): |
442 | 460 | # If socket is None kernel subshells are not supported so use socket=shell_socket. |
443 | 461 | # If msg is set, process that message. |
444 | 462 | # If msg is None, await the next message to arrive on the socket. |
@@ -507,9 +525,12 @@ async def process_shell_message(self, msg=None, socket=None): |
507 | 525 | except Exception: |
508 | 526 | self.log.debug("Unable to signal in pre_handler_hook:", exc_info=True) |
509 | 527 | try: |
510 | | - result = handler(socket, idents, msg) |
511 | | - if inspect.isawaitable(result): |
512 | | - await result |
| 528 | + if msg_type == "execute_request" and send_stream: |
| 529 | + await send_stream.send((handler, (socket, idents, msg))) |
| 530 | + else: |
| 531 | + result = handler(socket, idents, msg) |
| 532 | + if inspect.isawaitable(result): |
| 533 | + await result |
513 | 534 | except Exception: |
514 | 535 | self.log.error("Exception in message handler:", exc_info=True) # noqa: G201 |
515 | 536 | except KeyboardInterrupt: |
|
0 commit comments