From 10636349b43a13d49d4477a318ebeaea168eb31d Mon Sep 17 00:00:00 2001 From: Greg Spencer Date: Thu, 26 Jun 2025 11:33:46 -0700 Subject: [PATCH 1/3] Adding support for roots changed notification and initialized notification. --- src/mcp/server/lowlevel/server.py | 62 +++++- ...notifications.py => test_notifications.py} | 191 +++++++++++++++++- 2 files changed, 242 insertions(+), 11 deletions(-) rename tests/shared/{test_progress_notifications.py => test_notifications.py} (63%) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index faad95aca6..8ae261117c 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -68,6 +68,7 @@ async def main(): from __future__ import annotations as _annotations import contextvars +import inspect import json import logging import warnings @@ -104,6 +105,9 @@ async def main(): # This will be properly typed in each Server instance's context request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") +# Context variable to hold the current ServerSession, accessible by notification handlers +current_session_ctx: contextvars.ContextVar[ServerSession] = contextvars.ContextVar("current_server_session") + class NotificationOptions: def __init__( @@ -520,6 +524,36 @@ async def handler(req: types.ProgressNotification): return decorator + def initialized_notification(self): + """Decorator to register a handler for InitializedNotification.""" + + def decorator( + func: ( + Callable[[types.InitializedNotification, ServerSession], Awaitable[None]] + | Callable[[types.InitializedNotification], Awaitable[None]] + ), + ): + logger.debug("Registering handler for InitializedNotification") + self.notification_handlers[types.InitializedNotification] = func + return func + + return decorator + + def roots_list_changed_notification(self): + """Decorator to register a handler for RootsListChangedNotification.""" + + def decorator( + func: ( + Callable[[types.RootsListChangedNotification, ServerSession], Awaitable[None]] + | Callable[[types.RootsListChangedNotification], Awaitable[None]] + ), + ): + logger.debug("Registering handler for RootsListChangedNotification") + self.notification_handlers[types.RootsListChangedNotification] = func + return func + + return decorator + def completion(self): """Provides completions for prompts and resource templates""" @@ -591,22 +625,26 @@ async def run( async def _handle_message( self, - message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message: (RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception), session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, ): - with warnings.catch_warnings(record=True) as w: - # TODO(Marcelo): We should be checking if message is Exception here. - match message: # type: ignore[reportMatchNotExhaustive] - case RequestResponder(request=types.ClientRequest(root=req)) as responder: - with responder: - await self._handle_request(message, req, session, lifespan_context, raise_exceptions) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) + session_token = current_session_ctx.set(session) + try: + with warnings.catch_warnings(record=True) as w: + # TODO(Marcelo): We should be checking if message is Exception here. + match message: # type: ignore[reportMatchNotExhaustive] + case RequestResponder(request=types.ClientRequest(root=req)) as responder: + with responder: + await self._handle_request(message, req, session, lifespan_context, raise_exceptions) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) for warning in w: logger.info("Warning: %s: %s", warning.category.__name__, warning.message) + finally: + current_session_ctx.reset(session_token) async def _handle_request( self, @@ -666,7 +704,11 @@ async def _handle_notification(self, notify: Any): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - await handler(notify) + sig = inspect.signature(handler) + if "session" in sig.parameters: + await handler(notify, current_session_ctx.get()) + else: + await handler(notify) except Exception: logger.exception("Uncaught exception in notification handler") diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_notifications.py similarity index 63% rename from tests/shared/test_progress_notifications.py rename to tests/shared/test_notifications.py index 08bcb26623..fe835cd9e0 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_notifications.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, cast import anyio @@ -10,11 +11,11 @@ from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage from mcp.shared.progress import progress from mcp.shared.session import ( BaseSession, RequestResponder, - SessionMessage, ) @@ -333,3 +334,191 @@ async def handle_client_message( assert server_progress_updates[3]["progress"] == 100 assert server_progress_updates[3]["total"] == 100 assert server_progress_updates[3]["message"] == "Processing results..." + + +@pytest.mark.anyio +async def test_initialized_notification(): + """Test that the server receives and handles InitializedNotification.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + initialized_received = asyncio.Event() + + @server.initialized_notification() + async def handle_initialized(notification: types.InitializedNotification): + initialized_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await initialized_received.wait() + tg.cancel_scope.cancel() + + assert initialized_received.is_set() + + +@pytest.mark.anyio +async def test_roots_list_changed_notification(): + """Test that the server receives and handles RootsListChangedNotification.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + roots_list_changed_received = asyncio.Event() + + @server.roots_list_changed_notification() + async def handle_roots_list_changed( + notification: types.RootsListChangedNotification, + ): + roots_list_changed_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await client_session.send_notification( + types.ClientNotification( + root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None) + ) + ) + await roots_list_changed_received.wait() + tg.cancel_scope.cancel() + + assert roots_list_changed_received.is_set() + + +@pytest.mark.anyio +async def test_initialized_notification_with_session(): + """Test that the server receives and handles InitializedNotification with a session.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + initialized_received = asyncio.Event() + received_session = None + + @server.initialized_notification() + async def handle_initialized(notification: types.InitializedNotification, session: ServerSession): + nonlocal received_session + received_session = session + initialized_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await initialized_received.wait() + tg.cancel_scope.cancel() + + assert initialized_received.is_set() + assert isinstance(received_session, ServerSession) + + +@pytest.mark.anyio +async def test_roots_list_changed_notification_with_session(): + """Test that the server receives and handles RootsListChangedNotification with a session.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + roots_list_changed_received = asyncio.Event() + received_session = None + + @server.roots_list_changed_notification() + async def handle_roots_list_changed(notification: types.RootsListChangedNotification, session: ServerSession): + nonlocal received_session + received_session = session + roots_list_changed_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await client_session.send_notification( + types.ClientNotification( + root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None) + ) + ) + await roots_list_changed_received.wait() + tg.cancel_scope.cancel() + + assert roots_list_changed_received.is_set() + assert isinstance(received_session, ServerSession) From 4c2bdb7e2394cef78139868af2437bbc2c3c20e2 Mon Sep 17 00:00:00 2001 From: Greg Spencer Date: Tue, 23 Sep 2025 10:06:02 -0700 Subject: [PATCH 2/3] refactor: simplify notification handling by passing session This commit refactors the notification handling logic to eliminate the global context variable and introspection. The `ServerSession` is now explicitly passed to notification handlers, simplifying the control flow and improving explicitness. This addresses the code review feedback to avoid introspection and global state in the low-level server code. --- src/mcp/server/lowlevel/server.py | 66 +++++++++-------- tests/shared/test_notifications.py | 109 ++++------------------------- 2 files changed, 45 insertions(+), 130 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 25fe2520d3..0fc36ba617 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -68,7 +68,6 @@ async def main(): from __future__ import annotations as _annotations import contextvars -import inspect import json import logging import warnings @@ -105,9 +104,6 @@ async def main(): # This will be properly typed in each Server instance's context request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") -# Context variable to hold the current ServerSession, accessible by notification handlers -current_session_ctx: contextvars.ContextVar[ServerSession] = contextvars.ContextVar("current_server_session") - class NotificationOptions: def __init__( @@ -512,16 +508,23 @@ async def handler(req: types.CallToolRequest): def progress_notification(self): def decorator( - func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], + func: Callable[ + [str | int, float, float | None, str | None, ServerSession | None], + Awaitable[None], + ], ): logger.debug("Registering handler for ProgressNotification") - async def handler(req: types.ProgressNotification): + async def handler( + req: types.ProgressNotification, + session: ServerSession | None = None, + ): await func( req.params.progressToken, req.params.progress, req.params.total, req.params.message, + session, ) self.notification_handlers[types.ProgressNotification] = handler @@ -533,10 +536,10 @@ def initialized_notification(self): """Decorator to register a handler for InitializedNotification.""" def decorator( - func: ( - Callable[[types.InitializedNotification, ServerSession], Awaitable[None]] - | Callable[[types.InitializedNotification], Awaitable[None]] - ), + func: Callable[ + [types.InitializedNotification, ServerSession | None], + Awaitable[None], + ], ): logger.debug("Registering handler for InitializedNotification") self.notification_handlers[types.InitializedNotification] = func @@ -548,10 +551,10 @@ def roots_list_changed_notification(self): """Decorator to register a handler for RootsListChangedNotification.""" def decorator( - func: ( - Callable[[types.RootsListChangedNotification, ServerSession], Awaitable[None]] - | Callable[[types.RootsListChangedNotification], Awaitable[None]] - ), + func: Callable[ + [types.RootsListChangedNotification, ServerSession | None], + Awaitable[None], + ], ): logger.debug("Registering handler for RootsListChangedNotification") self.notification_handlers[types.RootsListChangedNotification] = func @@ -635,21 +638,17 @@ async def _handle_message( lifespan_context: LifespanResultT, raise_exceptions: bool = False, ): - session_token = current_session_ctx.set(session) - try: - with warnings.catch_warnings(record=True) as w: - # TODO(Marcelo): We should be checking if message is Exception here. - match message: # type: ignore[reportMatchNotExhaustive] - case RequestResponder(request=types.ClientRequest(root=req)) as responder: - with responder: - await self._handle_request(message, req, session, lifespan_context, raise_exceptions) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) - - for warning in w: - logger.info("Warning: %s: %s", warning.category.__name__, warning.message) - finally: - current_session_ctx.reset(session_token) + with warnings.catch_warnings(record=True) as w: + # TODO(Marcelo): We should be checking if message is Exception here. + match message: # type: ignore[reportMatchNotExhaustive] + case RequestResponder(request=types.ClientRequest(root=req)) as responder: + with responder: + await self._handle_request(message, req, session, lifespan_context, raise_exceptions) + case types.ClientNotification(root=notify): + await self._handle_notification(notify, session) + + for warning in w: + logger.info("Warning: %s: %s", warning.category.__name__, warning.message) async def _handle_request( self, @@ -710,15 +709,14 @@ async def _handle_request( logger.debug("Response sent") - async def _handle_notification(self, notify: Any): + async def _handle_notification(self, notify: Any, session: ServerSession): if handler := self.notification_handlers.get(type(notify)): # type: ignore logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - sig = inspect.signature(handler) - if "session" in sig.parameters: - await handler(notify, current_session_ctx.get()) - else: + try: + await handler(notify, session) + except TypeError: await handler(notify) except Exception: logger.exception("Uncaught exception in notification handler") diff --git a/tests/shared/test_notifications.py b/tests/shared/test_notifications.py index dd0671c0cc..cc34cf5189 100644 --- a/tests/shared/test_notifications.py +++ b/tests/shared/test_notifications.py @@ -13,7 +13,7 @@ from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.progress import progress -from mcp.shared.session import BaseSession, RequestResponder, SessionMessage +from mcp.shared.session import BaseSession, RequestResponder @pytest.mark.anyio @@ -62,6 +62,7 @@ async def handle_progress( progress: float, total: float | None, message: str | None, + session: ServerSession | None, ): server_progress_updates.append( { @@ -228,6 +229,7 @@ async def handle_progress( progress: float, total: float | None, message: str | None, + session: ServerSession | None, ): server_progress_updates.append( {"token": progress_token, "progress": progress, "total": total, "message": message} @@ -332,9 +334,15 @@ async def test_initialized_notification(): server = Server("test") initialized_received = asyncio.Event() + received_session: ServerSession | None = None @server.initialized_notification() - async def handle_initialized(notification: types.InitializedNotification): + async def handle_initialized( + notification: types.InitializedNotification, + session: ServerSession | None = None, + ): + nonlocal received_session + received_session = session initialized_received.set() async def run_server(): @@ -364,6 +372,7 @@ async def message_handler( tg.cancel_scope.cancel() assert initialized_received.is_set() + assert isinstance(received_session, ServerSession) @pytest.mark.anyio @@ -374,105 +383,13 @@ async def test_roots_list_changed_notification(): server = Server("test") roots_list_changed_received = asyncio.Event() + received_session: ServerSession | None = None @server.roots_list_changed_notification() async def handle_roots_list_changed( notification: types.RootsListChangedNotification, + session: ServerSession | None = None, ): - roots_list_changed_received.set() - - async def run_server(): - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def message_handler( - message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), - ) -> None: - if isinstance(message, Exception): - raise message - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - await client_session.initialize() - await client_session.send_notification( - types.ClientNotification( - root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None) - ) - ) - await roots_list_changed_received.wait() - tg.cancel_scope.cancel() - - assert roots_list_changed_received.is_set() - - -@pytest.mark.anyio -async def test_initialized_notification_with_session(): - """Test that the server receives and handles InitializedNotification with a session.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - server = Server("test") - initialized_received = asyncio.Event() - received_session = None - - @server.initialized_notification() - async def handle_initialized(notification: types.InitializedNotification, session: ServerSession): - nonlocal received_session - received_session = session - initialized_received.set() - - async def run_server(): - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def message_handler( - message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), - ) -> None: - if isinstance(message, Exception): - raise message - - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - await client_session.initialize() - await initialized_received.wait() - tg.cancel_scope.cancel() - - assert initialized_received.is_set() - assert isinstance(received_session, ServerSession) - - -@pytest.mark.anyio -async def test_roots_list_changed_notification_with_session(): - """Test that the server receives and handles RootsListChangedNotification with a session.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - server = Server("test") - roots_list_changed_received = asyncio.Event() - received_session = None - - @server.roots_list_changed_notification() - async def handle_roots_list_changed(notification: types.RootsListChangedNotification, session: ServerSession): nonlocal received_session received_session = session roots_list_changed_received.set() From 235dbc2a9d32ffc8f0f37846783d70e5525b394d Mon Sep 17 00:00:00 2001 From: Greg Spencer Date: Tue, 23 Sep 2025 13:00:48 -0700 Subject: [PATCH 3/3] refactor(tests): remove global statement in notification tests Removes the use of a `global` statement in `tests/shared/test_notifications.py` to resolve a `PLW0603` linting error reported by Ruff. The `serv_sesh` global variable has been replaced with a mutable list (`server_session_ref`) to share the server session object between the `run_server` and `handle_call_tool` functions within the test. This maintains the test's functionality while adhering to better coding practices. --- tests/shared/test_notifications.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/shared/test_notifications.py b/tests/shared/test_notifications.py index cc34cf5189..a1ff968444 100644 --- a/tests/shared/test_notifications.py +++ b/tests/shared/test_notifications.py @@ -23,6 +23,8 @@ async def test_bidirectional_progress_notifications(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) + server_session_ref: list[ServerSession | None] = [None] + # Run a server session so we can send progress updates in tool async def run_server(): # Create a server session @@ -35,9 +37,7 @@ async def run_server(): capabilities=server.get_capabilities(NotificationOptions(), {}), ), ) as server_session: - global serv_sesh - - serv_sesh = server_session + server_session_ref[0] = server_session async for message in server_session.incoming_messages: try: await server._handle_message(message, server_session, {}) @@ -87,6 +87,10 @@ async def handle_list_tools() -> list[types.Tool]: # Register tool handler @server.call_tool() async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + serv_sesh = server_session_ref[0] + if not serv_sesh: + raise ValueError("Server session not available") + # Make sure we received a progress token if name == "test_tool": if arguments and "_meta" in arguments: