Skip to content

Commit 6597f30

Browse files
committed
Elicitation
1 parent a3bcabd commit 6597f30

File tree

5 files changed

+179
-3
lines changed

5 files changed

+179
-3
lines changed

src/mcp/client/session.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ async def __call__(
2222
) -> types.CreateMessageResult | types.ErrorData: ...
2323

2424

25+
class ElicitationFnT(Protocol):
26+
async def __call__(
27+
self,
28+
context: RequestContext["ClientSession", Any],
29+
params: types.ElicitRequestParams,
30+
) -> types.ElicitResult | types.ErrorData: ...
31+
32+
2533
class ListRootsFnT(Protocol):
2634
async def __call__(
2735
self, context: RequestContext["ClientSession", Any]
@@ -58,6 +66,16 @@ async def _default_sampling_callback(
5866
)
5967

6068

69+
async def _default_elicitation_callback(
70+
context: RequestContext["ClientSession", Any],
71+
params: types.ElicitRequestParams,
72+
) -> types.ElicitResult | types.ErrorData:
73+
return types.ErrorData(
74+
code=types.INVALID_REQUEST,
75+
message="Elicitation not supported",
76+
)
77+
78+
6179
async def _default_list_roots_callback(
6280
context: RequestContext["ClientSession", Any],
6381
) -> types.ListRootsResult | types.ErrorData:
@@ -91,6 +109,7 @@ def __init__(
91109
write_stream: MemoryObjectSendStream[SessionMessage],
92110
read_timeout_seconds: timedelta | None = None,
93111
sampling_callback: SamplingFnT | None = None,
112+
elicitation_callback: ElicitationFnT | None = None,
94113
list_roots_callback: ListRootsFnT | None = None,
95114
logging_callback: LoggingFnT | None = None,
96115
message_handler: MessageHandlerFnT | None = None,
@@ -105,12 +124,14 @@ def __init__(
105124
)
106125
self._client_info = client_info or DEFAULT_CLIENT_INFO
107126
self._sampling_callback = sampling_callback or _default_sampling_callback
127+
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
108128
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
109129
self._logging_callback = logging_callback or _default_logging_callback
110130
self._message_handler = message_handler or _default_message_handler
111131

112132
async def initialize(self) -> types.InitializeResult:
113133
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
134+
elicitation = types.ElicitationCapability()
114135
roots = (
115136
# TODO: Should this be based on whether we
116137
# _will_ send notifications, or only whether
@@ -128,6 +149,7 @@ async def initialize(self) -> types.InitializeResult:
128149
protocolVersion=types.LATEST_PROTOCOL_VERSION,
129150
capabilities=types.ClientCapabilities(
130151
sampling=sampling,
152+
elicitation=elicitation,
131153
experimental=None,
132154
roots=roots,
133155
),
@@ -362,6 +384,12 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
362384
client_response = ClientResponse.validate_python(response)
363385
await responder.respond(client_response)
364386

387+
case types.ElicitRequest(params=params):
388+
with responder:
389+
response = await self._elicitation_callback(ctx, params)
390+
client_response = ClientResponse.validate_python(response)
391+
await responder.respond(client_response)
392+
365393
case types.ListRootsRequest():
366394
with responder:
367395
response = await self._list_roots_callback(ctx)

src/mcp/server/fastmcp/server.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,39 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
972972
assert self._fastmcp is not None, "Context is not available outside of a request"
973973
return await self._fastmcp.read_resource(uri)
974974

975+
async def elicit(
976+
self,
977+
message: str,
978+
requestedSchema: dict[str, Any],
979+
) -> dict[str, Any]:
980+
"""Elicit information from the client/user.
981+
982+
This method can be used to interactively ask for additional information from the
983+
client within a tool's execution.
984+
The client might display the message to the user and collect a response
985+
according to the provided schema. Or in case a client is an agent, it might
986+
decide how to handle the elicitation -- either by asking the user or
987+
automatically generating a response.
988+
989+
Args:
990+
message: The message to present to the user
991+
requestedSchema: JSON Schema defining the expected response structure
992+
993+
Returns:
994+
The user's response as a dict matching the request schema structure
995+
996+
Raises:
997+
ValueError: If elicitation is not supported by the client or fails
998+
"""
999+
1000+
result = await self.request_context.session.elicit(
1001+
message=message,
1002+
requestedSchema=requestedSchema,
1003+
related_request_id=self.request_id,
1004+
)
1005+
1006+
return result.response
1007+
9751008
async def log(
9761009
self,
9771010
level: Literal["debug", "info", "warning", "error"],

src/mcp/server/session.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
121121
if client_caps.sampling is None:
122122
return False
123123

124+
if capability.elicitation is not None:
125+
if client_caps.elicitation is None:
126+
return False
127+
124128
if capability.experimental is not None:
125129
if client_caps.experimental is None:
126130
return False
@@ -251,6 +255,35 @@ async def list_roots(self) -> types.ListRootsResult:
251255
types.ListRootsResult,
252256
)
253257

258+
async def elicit(
259+
self,
260+
message: str,
261+
requestedSchema: dict[str, Any],
262+
related_request_id: types.RequestId | None = None,
263+
) -> types.ElicitResult:
264+
"""Send an elicitation/create request.
265+
266+
Args:
267+
message: The message to present to the user
268+
requestedSchema: JSON Schema defining the expected response structure
269+
270+
Returns:
271+
The client's response
272+
"""
273+
return await self.send_request(
274+
types.ServerRequest(
275+
types.ElicitRequest(
276+
method="elicitation/create",
277+
params=types.ElicitRequestParams(
278+
message=message,
279+
requestedSchema=requestedSchema,
280+
),
281+
)
282+
),
283+
types.ElicitResult,
284+
metadata=ServerMessageMetadata(related_request_id=related_request_id),
285+
)
286+
254287
async def send_ping(self) -> types.EmptyResult:
255288
"""Send a ping request."""
256289
return await self.send_request(

src/mcp/types.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,21 @@ class SamplingCapability(BaseModel):
216216
model_config = ConfigDict(extra="allow")
217217

218218

219+
class ElicitationCapability(BaseModel):
220+
"""Capability for elicitation operations."""
221+
222+
model_config = ConfigDict(extra="allow")
223+
224+
219225
class ClientCapabilities(BaseModel):
220226
"""Capabilities a client may support."""
221227

222228
experimental: dict[str, dict[str, Any]] | None = None
223229
"""Experimental, non-standard capabilities that the client supports."""
224230
sampling: SamplingCapability | None = None
225231
"""Present if the client supports sampling from an LLM."""
232+
elicitation: ElicitationCapability | None = None
233+
"""Present if the client supports elicitation from the user."""
226234
roots: RootsCapability | None = None
227235
"""Present if the client supports listing roots."""
228236
model_config = ConfigDict(extra="allow")
@@ -1186,11 +1194,38 @@ class ClientNotification(
11861194
pass
11871195

11881196

1189-
class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]):
1197+
class ElicitRequestParams(RequestParams):
1198+
"""Parameters for elicitation requests."""
1199+
1200+
message: str
1201+
"""The message to present to the user."""
1202+
1203+
requestedSchema: dict[str, Any]
1204+
"""
1205+
A JSON Schema object defining the expected structure of the response.
1206+
"""
1207+
model_config = ConfigDict(extra="allow")
1208+
1209+
1210+
class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]):
1211+
"""A request from the server to elicit information from the client."""
1212+
1213+
method: Literal["elicitation/create"]
1214+
params: ElicitRequestParams
1215+
1216+
1217+
class ElicitResult(Result):
1218+
"""The client's response to an elicitation/create request from the server."""
1219+
1220+
response: dict[str, Any]
1221+
"""The response from the client, matching the structure of requestedSchema."""
1222+
1223+
1224+
class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]):
11901225
pass
11911226

11921227

1193-
class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]):
1228+
class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]):
11941229
pass
11951230

11961231

tests/server/fastmcp/test_integration.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from mcp.client.session import ClientSession
2222
from mcp.client.sse import sse_client
2323
from mcp.client.streamable_http import streamablehttp_client
24-
from mcp.server.fastmcp import FastMCP
24+
from mcp.server.fastmcp import Context, FastMCP
2525
from mcp.server.fastmcp.resources import FunctionResource
2626
from mcp.shared.context import RequestContext
2727
from mcp.types import (
@@ -30,6 +30,7 @@
3030
CompletionContext,
3131
CreateMessageRequestParams,
3232
CreateMessageResult,
33+
ElicitResult,
3334
GetPromptResult,
3435
InitializeResult,
3536
LoggingMessageNotification,
@@ -98,6 +99,23 @@ def make_fastmcp_app():
9899
def echo(message: str) -> str:
99100
return f"Echo: {message}"
100101

102+
# Add a tool that uses elicitation
103+
@mcp.tool(description="A tool that uses elicitation")
104+
async def ask_user(prompt: str, ctx: Context) -> str:
105+
schema = {
106+
"type": "object",
107+
"properties": {
108+
"answer": {"type": "string"},
109+
},
110+
"required": ["answer"],
111+
}
112+
113+
response = await ctx.elicit(
114+
message=f"Tool wants to ask: {prompt}",
115+
requestedSchema=schema,
116+
)
117+
return f"User answered: {response['answer']}"
118+
101119
# Create the SSE app
102120
app = mcp.sse_app()
103121

@@ -937,3 +955,32 @@ async def message_handler(message):
937955
) as session:
938956
# Run the common test suite with HTTP-specific test suffix
939957
await call_all_mcp_features(session, collector)
958+
959+
960+
@pytest.mark.anyio
961+
async def test_elicitation_feature(server: None, server_url: str) -> None:
962+
"""Test the elicitation feature."""
963+
964+
# Create a custom handler for elicitation requests
965+
async def elicitation_callback(context, params):
966+
# Verify the elicitation parameters
967+
if params.message == "Tool wants to ask: What is your name?":
968+
return ElicitResult(response={"answer": "Test User"})
969+
else:
970+
raise ValueError("Unexpected elicitation message")
971+
972+
# Connect to the server with our custom elicitation handler
973+
async with sse_client(server_url + "/sse") as streams:
974+
async with ClientSession(*streams, elicitation_callback=elicitation_callback) as session:
975+
# First initialize the session
976+
result = await session.initialize()
977+
assert isinstance(result, InitializeResult)
978+
assert result.serverInfo.name == "NoAuthServer"
979+
980+
# Call the tool that uses elicitation
981+
tool_result = await session.call_tool("ask_user", {"prompt": "What is your name?"})
982+
# Verify the result
983+
assert len(tool_result.content) == 1
984+
assert isinstance(tool_result.content[0], TextContent)
985+
# # The test should only succeed with the successful elicitation response
986+
assert tool_result.content[0].text == "User answered: Test User"

0 commit comments

Comments
 (0)