Skip to content

Commit 3a2d915

Browse files
ihrprdsp-ant
authored andcommitted
adjust types after the spec revision
1 parent 427a634 commit 3a2d915

File tree

6 files changed

+157
-60
lines changed

6 files changed

+157
-60
lines changed

src/mcp/client/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def __init__(
131131

132132
async def initialize(self) -> types.InitializeResult:
133133
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
134-
elicitation = types.ElicitationCapability()
134+
elicitation = (
135+
types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None
136+
)
135137
roots = (
136138
# TODO: Should this be based on whether we
137139
# _will_ send notifications, or only whether

src/mcp/server/fastmcp/server.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
asynccontextmanager,
1111
)
1212
from itertools import chain
13-
from typing import Any, Generic, Literal
13+
from typing import Any, Generic, Literal, TypeVar
1414

1515
import anyio
1616
import pydantic_core
17-
from pydantic import BaseModel, Field
17+
from pydantic import BaseModel, Field, ValidationError
1818
from pydantic.networks import AnyUrl
1919
from pydantic_settings import BaseSettings, SettingsConfigDict
2020
from starlette.applications import Starlette
@@ -65,6 +65,8 @@
6565

6666
logger = get_logger(__name__)
6767

68+
ElicitedModelT = TypeVar("ElicitedModelT", bound=BaseModel)
69+
6870

6971
class Settings(BaseSettings, Generic[LifespanResultT]):
7072
"""FastMCP server settings.
@@ -975,35 +977,48 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
975977
async def elicit(
976978
self,
977979
message: str,
978-
requestedSchema: dict[str, Any],
979-
) -> dict[str, Any]:
980+
schema: type[ElicitedModelT],
981+
) -> ElicitedModelT:
980982
"""Elicit information from the client/user.
981983
982984
This method can be used to interactively ask for additional information from the
983-
client within a tool's execution.
984-
The client might display the message to the user and collect a response
985-
according to the provided schema. Or in case a client is an agent, it might
986-
decide how to handle the elicitation -- either by asking the user or
987-
automatically generating a response.
985+
client within a tool's execution. The client might display the message to the
986+
user and collect a response according to the provided schema. Or in case a
987+
client
988+
is an agent, it might decide how to handle the elicitation -- either by asking
989+
the user or automatically generating a response.
988990
989991
Args:
990-
message: The message to present to the user
991-
requestedSchema: JSON Schema defining the expected response structure
992+
schema: A Pydantic model class defining the expected response structure
993+
message: Optional message to present to the user. If not provided, will use
994+
a default message based on the schema
992995
993996
Returns:
994-
The user's response as a dict matching the request schema structure
997+
An instance of the schema type with the user's response
995998
996999
Raises:
997-
ValueError: If elicitation is not supported by the client or fails
1000+
Exception: If the user declines or cancels the elicitation
1001+
ValidationError: If the response doesn't match the schema
9981002
"""
9991003

1004+
json_schema = schema.model_json_schema()
1005+
10001006
result = await self.request_context.session.elicit(
10011007
message=message,
1002-
requestedSchema=requestedSchema,
1008+
requestedSchema=json_schema,
10031009
related_request_id=self.request_id,
10041010
)
10051011

1006-
return result.content
1012+
if result.action == "accept" and result.content:
1013+
# Validate and parse the content using the schema
1014+
try:
1015+
return schema.model_validate(result.content)
1016+
except ValidationError as e:
1017+
raise ValueError(f"Invalid response: {e}")
1018+
elif result.action == "decline":
1019+
raise Exception("User declined to provide information")
1020+
else:
1021+
raise Exception("User cancelled the request")
10071022

10081023
async def log(
10091024
self,

src/mcp/server/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,14 @@ async def list_roots(self) -> types.ListRootsResult:
258258
async def elicit(
259259
self,
260260
message: str,
261-
requestedSchema: dict[str, Any],
261+
requestedSchema: types.ElicitRequestedSchema,
262262
related_request_id: types.RequestId | None = None,
263263
) -> types.ElicitResult:
264264
"""Send an elicitation/create request.
265265
266266
Args:
267267
message: The message to present to the user
268-
requestedSchema: JSON Schema defining the expected response structure
268+
requestedSchema: Schema defining the expected response structure
269269
270270
Returns:
271271
The client's response

src/mcp/types.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,16 +1194,16 @@ class ClientNotification(
11941194
pass
11951195

11961196

1197+
# Type for elicitation schema - a JSON Schema dict
1198+
ElicitRequestedSchema: TypeAlias = dict[str, Any]
1199+
"""Schema for elicitation requests."""
1200+
1201+
11971202
class ElicitRequestParams(RequestParams):
11981203
"""Parameters for elicitation requests."""
11991204

12001205
message: str
1201-
"""The message to present to the user."""
1202-
1203-
requestedSchema: dict[str, Any]
1204-
"""
1205-
A JSON Schema object defining the expected structure of the response.
1206-
"""
1206+
requestedSchema: ElicitRequestedSchema
12071207
model_config = ConfigDict(extra="allow")
12081208

12091209

@@ -1215,10 +1215,21 @@ class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]])
12151215

12161216

12171217
class ElicitResult(Result):
1218-
"""The client's response to an elicitation/create request from the server."""
1218+
"""The client's response to an elicitation request."""
12191219

1220-
content: dict[str, Any]
1221-
"""The response from the client, matching the structure of requestedSchema."""
1220+
action: Literal["accept", "decline", "cancel"]
1221+
"""
1222+
The user action in response to the elicitation.
1223+
- "accept": User submitted the form/confirmed the action
1224+
- "decline": User explicitly declined the action
1225+
- "cancel": User dismissed without making an explicit choice
1226+
"""
1227+
1228+
content: dict[str, str | int | float | bool | None] | None = None
1229+
"""
1230+
The submitted form data, only present when action is "accept".
1231+
Contains values matching the requested schema.
1232+
"""
12221233

12231234

12241235
class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]):

tests/server/fastmcp/test_integration.py

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pytest
1616
import uvicorn
17-
from pydantic import AnyUrl
17+
from pydantic import AnyUrl, BaseModel, Field
1818
from starlette.applications import Starlette
1919
from starlette.requests import Request
2020

@@ -102,19 +102,15 @@ def echo(message: str) -> str:
102102
# Add a tool that uses elicitation
103103
@mcp.tool(description="A tool that uses elicitation")
104104
async def ask_user(prompt: str, ctx: Context) -> str:
105-
schema = {
106-
"type": "object",
107-
"properties": {
108-
"answer": {"type": "string"},
109-
},
110-
"required": ["answer"],
111-
}
105+
class AnswerSchema(BaseModel):
106+
answer: str = Field(description="The user's answer to the question")
112107

113-
response = await ctx.elicit(
114-
message=f"Tool wants to ask: {prompt}",
115-
requestedSchema=schema,
116-
)
117-
return f"User answered: {response['answer']}"
108+
try:
109+
result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema)
110+
return f"User answered: {result.answer}"
111+
except Exception as e:
112+
# Handle cancellation or decline
113+
return f"User cancelled or declined: {str(e)}"
118114

119115
# Create the SSE app
120116
app = mcp.sse_app()
@@ -279,6 +275,47 @@ def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str
279275
context_data["path"] = request.url.path
280276
return json.dumps(context_data)
281277

278+
# Restaurant booking tool with elicitation
279+
@mcp.tool(description="Book a table at a restaurant with elicitation")
280+
async def book_restaurant(
281+
date: str,
282+
time: str,
283+
party_size: int,
284+
ctx: Context,
285+
) -> str:
286+
"""Book a table - uses elicitation if requested date is unavailable."""
287+
288+
class AlternativeDateSchema(BaseModel):
289+
checkAlternative: bool = Field(description="Would you like to try another date?")
290+
alternativeDate: str = Field(
291+
default="2024-12-26",
292+
description="What date would you prefer? (YYYY-MM-DD)",
293+
)
294+
295+
# For testing: assume dates starting with "2024-12-25" are unavailable
296+
if date.startswith("2024-12-25"):
297+
# Use elicitation to ask about alternatives
298+
try:
299+
result = await ctx.elicit(
300+
message=(
301+
f"No tables available for {party_size} people on {date} "
302+
f"at {time}. Would you like to check another date?"
303+
),
304+
schema=AlternativeDateSchema,
305+
)
306+
307+
if result.checkAlternative:
308+
alt_date = result.alternativeDate
309+
return f"✅ Booked table for {party_size} on {alt_date} at {time}"
310+
else:
311+
return "❌ No booking made"
312+
except Exception:
313+
# User declined or cancelled
314+
return "❌ Booking cancelled"
315+
else:
316+
# Available - book directly
317+
return f"✅ Booked table for {party_size} on {date} at {time}"
318+
282319
return mcp
283320

284321

@@ -670,6 +707,22 @@ async def handle_generic_notification(self, message) -> None:
670707
await self.handle_tool_list_changed(message.root.params)
671708

672709

710+
async def create_test_elicitation_callback(context, params):
711+
"""Shared elicitation callback for tests.
712+
713+
Handles elicitation requests for restaurant booking tests.
714+
"""
715+
# For restaurant booking test
716+
if "No tables available" in params.message:
717+
return ElicitResult(
718+
action="accept",
719+
content={"checkAlternative": True, "alternativeDate": "2024-12-26"},
720+
)
721+
else:
722+
# Default response
723+
return ElicitResult(action="decline")
724+
725+
673726
async def call_all_mcp_features(session: ClientSession, collector: NotificationCollector) -> None:
674727
"""
675728
Test all MCP features using the provided session.
@@ -765,6 +818,21 @@ async def progress_callback(progress: float, total: float | None, message: str |
765818
assert "info" in log_levels
766819
assert "warning" in log_levels
767820

821+
# 5. Test elicitation tool
822+
# Test restaurant booking with unavailable date (triggers elicitation)
823+
booking_result = await session.call_tool(
824+
"book_restaurant",
825+
{
826+
"date": "2024-12-25", # Unavailable date to trigger elicitation
827+
"time": "19:00",
828+
"party_size": 4,
829+
},
830+
)
831+
assert len(booking_result.content) == 1
832+
assert isinstance(booking_result.content[0], TextContent)
833+
# Should have booked the alternative date from elicitation callback
834+
assert "✅ Booked table for 4 on 2024-12-26" in booking_result.content[0].text
835+
768836
# Test resources
769837
# 1. Static resource
770838
resources = await session.list_resources()
@@ -905,8 +973,6 @@ async def test_fastmcp_all_features_sse(everything_server: None, everything_serv
905973
# Create notification collector
906974
collector = NotificationCollector()
907975

908-
# Create a sampling callback that simulates an LLM
909-
910976
# Connect to the server with callbacks
911977
async with sse_client(everything_server_url + "/sse") as streams:
912978
# Set up message handler to capture notifications
@@ -919,6 +985,7 @@ async def message_handler(message):
919985
async with ClientSession(
920986
*streams,
921987
sampling_callback=sampling_callback,
988+
elicitation_callback=create_test_elicitation_callback,
922989
message_handler=message_handler,
923990
) as session:
924991
# Run the common test suite
@@ -951,6 +1018,7 @@ async def message_handler(message):
9511018
read_stream,
9521019
write_stream,
9531020
sampling_callback=sampling_callback,
1021+
elicitation_callback=create_test_elicitation_callback,
9541022
message_handler=message_handler,
9551023
) as session:
9561024
# Run the common test suite with HTTP-specific test suffix
@@ -965,7 +1033,7 @@ async def test_elicitation_feature(server: None, server_url: str) -> None:
9651033
async def elicitation_callback(context, params):
9661034
# Verify the elicitation parameters
9671035
if params.message == "Tool wants to ask: What is your name?":
968-
return ElicitResult(content={"answer": "Test User"})
1036+
return ElicitResult(content={"answer": "Test User"}, action="accept")
9691037
else:
9701038
raise ValueError("Unexpected elicitation message")
9711039

tests/server/fastmcp/test_stdio_elicitation.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import pytest
6+
from pydantic import BaseModel, Field
67

78
from mcp.server.fastmcp import Context, FastMCP
89
from mcp.shared.memory import create_connected_server_and_client_session
@@ -18,25 +19,27 @@ async def test_stdio_elicitation():
1819

1920
@mcp.tool(description="A tool that uses elicitation")
2021
async def ask_user(prompt: str, ctx: Context) -> str:
21-
schema = {
22-
"type": "object",
23-
"properties": {
24-
"answer": {"type": "string"},
25-
},
26-
"required": ["answer"],
27-
}
28-
29-
response = await ctx.elicit(
30-
message=f"Tool wants to ask: {prompt}",
31-
requestedSchema=schema,
32-
)
33-
return f"User answered: {response['answer']}"
22+
class AnswerSchema(BaseModel):
23+
answer: str = Field(description="The user's answer to the question")
24+
25+
try:
26+
result = await ctx.elicit(
27+
message=f"Tool wants to ask: {prompt}",
28+
schema=AnswerSchema,
29+
)
30+
return f"User answered: {result.answer}"
31+
except Exception as e:
32+
# Handle cancellation or decline
33+
if "declined" in str(e):
34+
return "User declined to answer"
35+
else:
36+
return "User cancelled"
3437

3538
# Create a custom handler for elicitation requests
3639
async def elicitation_callback(context, params):
3740
# Verify the elicitation parameters
3841
if params.message == "Tool wants to ask: What is your name?":
39-
return ElicitResult(content={"answer": "Test User"})
42+
return ElicitResult(action="accept", content={"answer": "Test User"})
4043
else:
4144
raise ValueError(f"Unexpected elicitation message: {params.message}")
4245

@@ -49,9 +52,7 @@ async def elicitation_callback(context, params):
4952
assert result.serverInfo.name == "StdioElicitationServer"
5053

5154
# Call the tool that uses elicitation
52-
tool_result = await client_session.call_tool(
53-
"ask_user", {"prompt": "What is your name?"}
54-
)
55+
tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"})
5556

5657
# Verify the result
5758
assert len(tool_result.content) == 1

0 commit comments

Comments
 (0)