Skip to content

Commit 40470d6

Browse files
ihrprdsp-ant
authored andcommitted
add elicitation test using create_client_server_memory_streams
1 parent 6597f30 commit 40470d6

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

src/mcp/shared/memory.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1212

1313
import mcp.types as types
14-
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
14+
from mcp.client.session import (
15+
ClientSession,
16+
ElicitationFnT,
17+
ListRootsFnT,
18+
LoggingFnT,
19+
MessageHandlerFnT,
20+
SamplingFnT,
21+
)
1522
from mcp.server import Server
1623
from mcp.shared.message import SessionMessage
1724

@@ -53,6 +60,7 @@ async def create_connected_server_and_client_session(
5360
message_handler: MessageHandlerFnT | None = None,
5461
client_info: types.Implementation | None = None,
5562
raise_exceptions: bool = False,
63+
elicitation_callback: ElicitationFnT | None = None,
5664
) -> AsyncGenerator[ClientSession, None]:
5765
"""Creates a ClientSession that is connected to a running MCP server."""
5866
async with create_client_server_memory_streams() as (
@@ -83,6 +91,7 @@ async def create_connected_server_and_client_session(
8391
logging_callback=logging_callback,
8492
message_handler=message_handler,
8593
client_info=client_info,
94+
elicitation_callback=elicitation_callback,
8695
) as client_session:
8796
await client_session.initialize()
8897
yield client_session
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
Test the elicitation feature using stdio transport.
3+
"""
4+
5+
import pytest
6+
7+
from mcp.server.fastmcp import Context, FastMCP
8+
from mcp.shared.memory import create_connected_server_and_client_session
9+
from mcp.types import ElicitResult, TextContent
10+
11+
12+
@pytest.mark.anyio
13+
async def test_stdio_elicitation():
14+
"""Test the elicitation feature using stdio transport."""
15+
16+
# Create a FastMCP server with a tool that uses elicitation
17+
mcp = FastMCP(name="StdioElicitationServer")
18+
19+
@mcp.tool(description="A tool that uses elicitation")
20+
async def ask_user(prompt: str, ctx: Context) -> str:
21+
schema = {
22+
"type": "object",
23+
"properties": {
24+
"answer": {"type": "string"},
25+
},
26+
"required": ["answer"],
27+
}
28+
29+
response = await ctx.elicit(
30+
message=f"Tool wants to ask: {prompt}",
31+
requestedSchema=schema,
32+
)
33+
return f"User answered: {response['answer']}"
34+
35+
# Create a custom handler for elicitation requests
36+
async def elicitation_callback(context, params):
37+
# Verify the elicitation parameters
38+
if params.message == "Tool wants to ask: What is your name?":
39+
return ElicitResult(response={"answer": "Test User"})
40+
else:
41+
raise ValueError(f"Unexpected elicitation message: {params.message}")
42+
43+
# Use memory-based session to test with stdio transport
44+
async with create_connected_server_and_client_session(
45+
mcp._mcp_server, elicitation_callback=elicitation_callback
46+
) as client_session:
47+
# First initialize the session
48+
result = await client_session.initialize()
49+
assert result.serverInfo.name == "StdioElicitationServer"
50+
51+
# Call the tool that uses elicitation
52+
tool_result = await client_session.call_tool(
53+
"ask_user", {"prompt": "What is your name?"}
54+
)
55+
56+
# Verify the result
57+
assert len(tool_result.content) == 1
58+
assert isinstance(tool_result.content[0], TextContent)
59+
assert tool_result.content[0].text == "User answered: Test User"

0 commit comments

Comments
 (0)