Skip to content

Commit caeb69b

Browse files
committed
Adds support for ElicitCompleteNotification callbacks
1 parent eb82efb commit caeb69b

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

src/mcp/client/session.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ async def __call__(
3333
params: types.ElicitRequestParams,
3434
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
3535

36+
class ElicitCompleteFnT(Protocol):
37+
async def __call__(
38+
self, params: types.ElicitCompleteNotificationParams,
39+
) -> None: ... #pragma: no branch
3640

3741
class ListRootsFnT(Protocol):
3842
async def __call__(
@@ -111,6 +115,11 @@ async def _default_elicitation_callback(
111115
message="Elicitation not supported",
112116
)
113117

118+
async def _default_elicit_complete_callback(
119+
params: types.ElicitCompleteNotificationParams
120+
) -> None:
121+
pass
122+
114123

115124
async def _default_list_roots_callback(
116125
context: RequestContext["ClientSession", Any],
@@ -172,6 +181,7 @@ def __init__(
172181
read_timeout_seconds: timedelta | None = None,
173182
sampling_callback: SamplingFnT | None = None,
174183
elicitation_callback: ElicitationFnT | None = None,
184+
elicit_complete_callback: ElicitCompleteFnT | None = None,
175185
list_roots_callback: ListRootsFnT | None = None,
176186
logging_callback: LoggingFnT | None = None,
177187
progress_notification_callback: ProgressNotificationFnT | None = None,
@@ -192,6 +202,7 @@ def __init__(
192202
self._client_info = client_info or DEFAULT_CLIENT_INFO
193203
self._sampling_callback = sampling_callback or _default_sampling_callback
194204
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
205+
self._elicit_complete_callback = elicit_complete_callback or _default_elicit_complete_callback
195206
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
196207
self._logging_callback = logging_callback or _default_logging_callback
197208
self._progress_notification_callback = progress_notification_callback or _default_progress_callback
@@ -638,7 +649,7 @@ async def _received_notification(self, notification: types.ServerNotification) -
638649
# Handle elicitation completion notification
639650
# Clients MAY use this to retry requests or update UI
640651
# The notification contains the elicitationId of the completed elicitation
641-
pass
652+
await self._elicit_complete_callback(params)
642653
case _: # pragma: no cover
643654
# CancelledNotification is handled separately in shared/session.py
644655
# and should never reach this point. This case is defensive.

src/mcp/shared/memory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import mcp.types as types
1616
from mcp.client.session import (
1717
ClientSession,
18+
ElicitCompleteFnT,
1819
ElicitationFnT,
1920
ListRootsFnT,
2021
LoggingFnT,
@@ -74,6 +75,7 @@ async def create_connected_server_and_client_session(
7475
client_info: types.Implementation | None = None,
7576
raise_exceptions: bool = False,
7677
elicitation_callback: ElicitationFnT | None = None,
78+
elicit_complete_callback: ElicitCompleteFnT | None = None,
7779
) -> AsyncGenerator[ClientSession, None]:
7880
"""Creates a ClientSession that is connected to a running MCP server."""
7981

@@ -113,6 +115,7 @@ async def create_connected_server_and_client_session(
113115
message_handler=message_handler,
114116
client_info=client_info,
115117
elicitation_callback=elicitation_callback,
118+
elicit_complete_callback=elicit_complete_callback,
116119
) as client_session:
117120
await client_session.initialize()
118121
yield client_session

tests/client/test_notification_callbacks.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ async def __call__(self) -> None:
8282
self.notification_count += 1
8383

8484

85+
class ElicitCompleteCollector:
86+
"""Collector for ElicitCompleteNotification events."""
87+
88+
def __init__(self) -> None:
89+
"""Initialize the collector."""
90+
self.notifications: list[types.ElicitCompleteNotificationParams] = []
91+
92+
async def __call__(self, params: types.ElicitCompleteNotificationParams) -> None:
93+
"""Collect an elicit complete notification."""
94+
self.notifications.append(params)
95+
96+
8597
@pytest.fixture
8698
def progress_collector() -> ProgressNotificationCollector:
8799
"""Create a progress notification collector."""
@@ -112,6 +124,12 @@ def prompt_list_changed_collector() -> PromptListChangedCollector:
112124
return PromptListChangedCollector()
113125

114126

127+
@pytest.fixture
128+
def elicit_complete_collector() -> ElicitCompleteCollector:
129+
"""Create an elicit complete collector."""
130+
return ElicitCompleteCollector()
131+
132+
115133
@pytest.mark.anyio
116134
async def test_progress_notification_callback(progress_collector: ProgressNotificationCollector) -> None:
117135
"""Test that progress notifications are correctly received by the callback."""
@@ -298,6 +316,41 @@ async def message_handler(
298316
assert prompt_list_changed_collector.notification_count == 1
299317

300318

319+
@pytest.mark.anyio
320+
async def test_elicit_complete_callback(elicit_complete_collector: ElicitCompleteCollector) -> None:
321+
"""Test that elicit complete notifications are correctly received by the callback."""
322+
from mcp.server.fastmcp import FastMCP
323+
324+
server = FastMCP("test")
325+
326+
@server.tool("send_elicit_complete")
327+
async def send_elicit_complete_tool(elicitation_id: str) -> bool:
328+
"""Send an elicit complete notification to the client."""
329+
await server.get_context().session.send_elicit_complete(elicitation_id)
330+
return True
331+
332+
async def message_handler(
333+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
334+
) -> None:
335+
"""Handle exceptions from the session."""
336+
if isinstance(message, Exception): # pragma: no cover
337+
raise message
338+
339+
async with create_session(
340+
server._mcp_server,
341+
elicit_complete_callback=elicit_complete_collector,
342+
message_handler=message_handler,
343+
) as client_session:
344+
# Trigger elicit complete notification
345+
result = await client_session.call_tool("send_elicit_complete", {"elicitation_id": "test-elicit-123"})
346+
assert result.isError is False
347+
348+
# Verify the notification was received
349+
assert len(elicit_complete_collector.notifications) == 1
350+
notification = elicit_complete_collector.notifications[0]
351+
assert notification.elicitationId == "test-elicit-123"
352+
353+
301354
@pytest.mark.anyio
302355
@pytest.mark.parametrize(
303356
"notification_type,callback_param,collector_fixture,tool_name,tool_args,verification",
@@ -350,6 +403,17 @@ async def message_handler(
350403
{},
351404
lambda c: c.notification_count == 1, # type: ignore[attr-defined]
352405
),
406+
(
407+
"elicit_complete",
408+
"elicit_complete_callback",
409+
"elicit_complete_collector",
410+
"send_elicit_complete",
411+
{"elicitation_id": "param-test-elicit-456"},
412+
lambda c: ( # type: ignore[misc]
413+
len(c.notifications) == 1 # type: ignore[attr-defined]
414+
and c.notifications[0].elicitationId == "param-test-elicit-456" # type: ignore[attr-defined]
415+
),
416+
),
353417
],
354418
)
355419
async def test_notification_callback_parametrized(
@@ -407,6 +471,12 @@ async def change_prompt_list_tool() -> bool:
407471
await server.get_context().session.send_prompt_list_changed()
408472
return True
409473

474+
@server.tool("send_elicit_complete")
475+
async def send_elicit_complete_tool(elicitation_id: str) -> bool:
476+
"""Send an elicit complete notification to the client."""
477+
await server.get_context().session.send_elicit_complete(elicitation_id)
478+
return True
479+
410480
async def message_handler(
411481
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
412482
) -> None:
@@ -478,6 +548,12 @@ async def send_prompt_list_changed_tool() -> bool:
478548
await server.get_context().session.send_prompt_list_changed()
479549
return True
480550

551+
@server.tool("send_elicit_complete")
552+
async def send_elicit_complete_tool(elicitation_id: str) -> bool:
553+
"""Send an elicit complete notification."""
554+
await server.get_context().session.send_elicit_complete(elicitation_id)
555+
return True
556+
481557
# Create session WITHOUT custom callbacks - all will use defaults
482558
async with create_session(server._mcp_server) as client_session:
483559
# Test progress notification with default callback
@@ -507,6 +583,10 @@ async def send_prompt_list_changed_tool() -> bool:
507583
result5 = await client_session.call_tool("send_prompt_list_changed", {})
508584
assert result5.isError is False
509585

586+
# Test elicit complete with default callback
587+
result6 = await client_session.call_tool("send_elicit_complete", {"elicitation_id": "test-123"})
588+
assert result6.isError is False
589+
510590

511591
@pytest.mark.anyio
512592
async def test_progress_tool_without_progress_token() -> None:

0 commit comments

Comments
 (0)