From 6597f303e12bba50962f07c3a7bd06ed1a6866c6 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 17 Jun 2025 11:34:06 +0100 Subject: [PATCH 01/10] Elicitation --- src/mcp/client/session.py | 28 ++++++++++++++ src/mcp/server/fastmcp/server.py | 33 ++++++++++++++++ src/mcp/server/session.py | 33 ++++++++++++++++ src/mcp/types.py | 39 ++++++++++++++++++- tests/server/fastmcp/test_integration.py | 49 +++++++++++++++++++++++- 5 files changed, 179 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8b819ad6d1..047e37a96d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -22,6 +22,14 @@ async def __call__( ) -> types.CreateMessageResult | types.ErrorData: ... +class ElicitationFnT(Protocol): + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, + ) -> types.ElicitResult | types.ErrorData: ... + + class ListRootsFnT(Protocol): async def __call__( self, context: RequestContext["ClientSession", Any] @@ -58,6 +66,16 @@ async def _default_sampling_callback( ) +async def _default_elicitation_callback( + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, +) -> types.ElicitResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Elicitation not supported", + ) + + async def _default_list_roots_callback( context: RequestContext["ClientSession", Any], ) -> types.ListRootsResult | types.ErrorData: @@ -91,6 +109,7 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, @@ -105,12 +124,14 @@ def __init__( ) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback + self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None + elicitation = types.ElicitationCapability() roots = ( # TODO: Should this be based on whether we # _will_ send notifications, or only whether @@ -128,6 +149,7 @@ async def initialize(self) -> types.InitializeResult: protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=types.ClientCapabilities( sampling=sampling, + elicitation=elicitation, experimental=None, roots=roots, ), @@ -362,6 +384,12 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) + case types.ElicitRequest(params=params): + with responder: + response = await self._elicitation_callback(ctx, params) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + case types.ListRootsRequest(): with responder: response = await self._list_roots_callback(ctx) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index a62974bc9c..21e82f2021 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -972,6 +972,39 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent assert self._fastmcp is not None, "Context is not available outside of a request" return await self._fastmcp.read_resource(uri) + async def elicit( + self, + message: str, + requestedSchema: dict[str, Any], + ) -> dict[str, Any]: + """Elicit information from the client/user. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. + The client might display the message to the user and collect a response + according to the provided schema. Or in case a client is an agent, it might + decide how to handle the elicitation -- either by asking the user or + automatically generating a response. + + Args: + message: The message to present to the user + requestedSchema: JSON Schema defining the expected response structure + + Returns: + The user's response as a dict matching the request schema structure + + Raises: + ValueError: If elicitation is not supported by the client or fails + """ + + result = await self.request_context.session.elicit( + message=message, + requestedSchema=requestedSchema, + related_request_id=self.request_id, + ) + + return result.response + async def log( self, level: Literal["debug", "info", "warning", "error"], diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index e6611b0d41..4376cb150b 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -121,6 +121,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if client_caps.sampling is None: return False + if capability.elicitation is not None: + if client_caps.elicitation is None: + return False + if capability.experimental is not None: if client_caps.experimental is None: return False @@ -251,6 +255,35 @@ async def list_roots(self) -> types.ListRootsResult: types.ListRootsResult, ) + async def elicit( + self, + message: str, + requestedSchema: dict[str, Any], + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send an elicitation/create request. + + Args: + message: The message to present to the user + requestedSchema: JSON Schema defining the expected response structure + + Returns: + The client's response + """ + return await self.send_request( + types.ServerRequest( + types.ElicitRequest( + method="elicitation/create", + params=types.ElicitRequestParams( + message=message, + requestedSchema=requestedSchema, + ), + ) + ), + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=related_request_id), + ) + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( diff --git a/src/mcp/types.py b/src/mcp/types.py index be8d7326f1..a9cf4c4899 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -216,6 +216,12 @@ class SamplingCapability(BaseModel): model_config = ConfigDict(extra="allow") +class ElicitationCapability(BaseModel): + """Capability for elicitation operations.""" + + model_config = ConfigDict(extra="allow") + + class ClientCapabilities(BaseModel): """Capabilities a client may support.""" @@ -223,6 +229,8 @@ class ClientCapabilities(BaseModel): """Experimental, non-standard capabilities that the client supports.""" sampling: SamplingCapability | None = None """Present if the client supports sampling from an LLM.""" + elicitation: ElicitationCapability | None = None + """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" model_config = ConfigDict(extra="allow") @@ -1186,11 +1194,38 @@ class ClientNotification( pass -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]): +class ElicitRequestParams(RequestParams): + """Parameters for elicitation requests.""" + + message: str + """The message to present to the user.""" + + requestedSchema: dict[str, Any] + """ + A JSON Schema object defining the expected structure of the response. + """ + model_config = ConfigDict(extra="allow") + + +class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]): + """A request from the server to elicit information from the client.""" + + method: Literal["elicitation/create"] + params: ElicitRequestParams + + +class ElicitResult(Result): + """The client's response to an elicitation/create request from the server.""" + + response: dict[str, Any] + """The response from the client, matching the structure of requestedSchema.""" + + +class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]): +class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]): pass diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 3caf994e01..570ed72c01 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -21,7 +21,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.resources import FunctionResource from mcp.shared.context import RequestContext from mcp.types import ( @@ -30,6 +30,7 @@ CompletionContext, CreateMessageRequestParams, CreateMessageResult, + ElicitResult, GetPromptResult, InitializeResult, LoggingMessageNotification, @@ -98,6 +99,23 @@ def make_fastmcp_app(): def echo(message: str) -> str: return f"Echo: {message}" + # Add a tool that uses elicitation + @mcp.tool(description="A tool that uses elicitation") + async def ask_user(prompt: str, ctx: Context) -> str: + schema = { + "type": "object", + "properties": { + "answer": {"type": "string"}, + }, + "required": ["answer"], + } + + response = await ctx.elicit( + message=f"Tool wants to ask: {prompt}", + requestedSchema=schema, + ) + return f"User answered: {response['answer']}" + # Create the SSE app app = mcp.sse_app() @@ -937,3 +955,32 @@ async def message_handler(message): ) as session: # Run the common test suite with HTTP-specific test suffix await call_all_mcp_features(session, collector) + + +@pytest.mark.anyio +async def test_elicitation_feature(server: None, server_url: str) -> None: + """Test the elicitation feature.""" + + # Create a custom handler for elicitation requests + async def elicitation_callback(context, params): + # Verify the elicitation parameters + if params.message == "Tool wants to ask: What is your name?": + return ElicitResult(response={"answer": "Test User"}) + else: + raise ValueError("Unexpected elicitation message") + + # Connect to the server with our custom elicitation handler + async with sse_client(server_url + "/sse") as streams: + async with ClientSession(*streams, elicitation_callback=elicitation_callback) as session: + # First initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "NoAuthServer" + + # Call the tool that uses elicitation + tool_result = await session.call_tool("ask_user", {"prompt": "What is your name?"}) + # Verify the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + # # The test should only succeed with the successful elicitation response + assert tool_result.content[0].text == "User answered: Test User" From 40470d688b716020a6729ae465020266d3a3470f Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 4 May 2025 13:50:05 +0100 Subject: [PATCH 02/10] add elicitation test using create_client_server_memory_streams --- src/mcp/shared/memory.py | 11 +++- .../server/fastmcp/test_stdio_elicitation.py | 59 +++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/server/fastmcp/test_stdio_elicitation.py diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index f088d3f8b9..c94e5e6ac1 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -11,7 +11,14 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types -from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ( + ClientSession, + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + SamplingFnT, +) from mcp.server import Server from mcp.shared.message import SessionMessage @@ -53,6 +60,7 @@ async def create_connected_server_and_client_session( message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, raise_exceptions: bool = False, + elicitation_callback: ElicitationFnT | None = None, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" async with create_client_server_memory_streams() as ( @@ -83,6 +91,7 @@ async def create_connected_server_and_client_session( logging_callback=logging_callback, message_handler=message_handler, client_info=client_info, + elicitation_callback=elicitation_callback, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/server/fastmcp/test_stdio_elicitation.py b/tests/server/fastmcp/test_stdio_elicitation.py new file mode 100644 index 0000000000..e41f704f6e --- /dev/null +++ b/tests/server/fastmcp/test_stdio_elicitation.py @@ -0,0 +1,59 @@ +""" +Test the elicitation feature using stdio transport. +""" + +import pytest + +from mcp.server.fastmcp import Context, FastMCP +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ElicitResult, TextContent + + +@pytest.mark.anyio +async def test_stdio_elicitation(): + """Test the elicitation feature using stdio transport.""" + + # Create a FastMCP server with a tool that uses elicitation + mcp = FastMCP(name="StdioElicitationServer") + + @mcp.tool(description="A tool that uses elicitation") + async def ask_user(prompt: str, ctx: Context) -> str: + schema = { + "type": "object", + "properties": { + "answer": {"type": "string"}, + }, + "required": ["answer"], + } + + response = await ctx.elicit( + message=f"Tool wants to ask: {prompt}", + requestedSchema=schema, + ) + return f"User answered: {response['answer']}" + + # Create a custom handler for elicitation requests + async def elicitation_callback(context, params): + # Verify the elicitation parameters + if params.message == "Tool wants to ask: What is your name?": + return ElicitResult(response={"answer": "Test User"}) + else: + raise ValueError(f"Unexpected elicitation message: {params.message}") + + # Use memory-based session to test with stdio transport + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + # First initialize the session + result = await client_session.initialize() + assert result.serverInfo.name == "StdioElicitationServer" + + # Call the tool that uses elicitation + tool_result = await client_session.call_tool( + "ask_user", {"prompt": "What is your name?"} + ) + + # Verify the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "User answered: Test User" From 427a634ed3443dc726d1f9fe978787dd1f6afcfc Mon Sep 17 00:00:00 2001 From: ihrpr Date: Tue, 6 May 2025 11:39:16 +0100 Subject: [PATCH 03/10] field rename --- src/mcp/server/fastmcp/server.py | 2 +- src/mcp/types.py | 2 +- tests/server/fastmcp/test_integration.py | 2 +- tests/server/fastmcp/test_stdio_elicitation.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 21e82f2021..0edbbcf5a6 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -1003,7 +1003,7 @@ async def elicit( related_request_id=self.request_id, ) - return result.response + return result.content async def log( self, diff --git a/src/mcp/types.py b/src/mcp/types.py index a9cf4c4899..5708027dfb 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1217,7 +1217,7 @@ class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]) class ElicitResult(Result): """The client's response to an elicitation/create request from the server.""" - response: dict[str, Any] + content: dict[str, Any] """The response from the client, matching the structure of requestedSchema.""" diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 570ed72c01..5aa719a6e6 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -965,7 +965,7 @@ async def test_elicitation_feature(server: None, server_url: str) -> None: async def elicitation_callback(context, params): # Verify the elicitation parameters if params.message == "Tool wants to ask: What is your name?": - return ElicitResult(response={"answer": "Test User"}) + return ElicitResult(content={"answer": "Test User"}) else: raise ValueError("Unexpected elicitation message") diff --git a/tests/server/fastmcp/test_stdio_elicitation.py b/tests/server/fastmcp/test_stdio_elicitation.py index e41f704f6e..52555a287d 100644 --- a/tests/server/fastmcp/test_stdio_elicitation.py +++ b/tests/server/fastmcp/test_stdio_elicitation.py @@ -36,7 +36,7 @@ async def ask_user(prompt: str, ctx: Context) -> str: async def elicitation_callback(context, params): # Verify the elicitation parameters if params.message == "Tool wants to ask: What is your name?": - return ElicitResult(response={"answer": "Test User"}) + return ElicitResult(content={"answer": "Test User"}) else: raise ValueError(f"Unexpected elicitation message: {params.message}") From 3a2d915eace097ae12b94574448f08dc89dae855 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 12 Jun 2025 00:09:42 +0100 Subject: [PATCH 04/10] adjust types after the spec revision --- src/mcp/client/session.py | 4 +- src/mcp/server/fastmcp/server.py | 45 +++++--- src/mcp/server/session.py | 4 +- src/mcp/types.py | 29 +++-- tests/server/fastmcp/test_integration.py | 100 +++++++++++++++--- .../server/fastmcp/test_stdio_elicitation.py | 35 +++--- 6 files changed, 157 insertions(+), 60 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 047e37a96d..9488171406 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -131,7 +131,9 @@ def __init__( async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None - elicitation = types.ElicitationCapability() + elicitation = ( + types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None + ) roots = ( # TODO: Should this be based on whether we # _will_ send notifications, or only whether diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 0edbbcf5a6..200900ee2a 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -10,11 +10,11 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Generic, Literal +from typing import Any, Generic, Literal, TypeVar import anyio import pydantic_core -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette @@ -65,6 +65,8 @@ logger = get_logger(__name__) +ElicitedModelT = TypeVar("ElicitedModelT", bound=BaseModel) + class Settings(BaseSettings, Generic[LifespanResultT]): """FastMCP server settings. @@ -975,35 +977,48 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent async def elicit( self, message: str, - requestedSchema: dict[str, Any], - ) -> dict[str, Any]: + schema: type[ElicitedModelT], + ) -> ElicitedModelT: """Elicit information from the client/user. This method can be used to interactively ask for additional information from the - client within a tool's execution. - The client might display the message to the user and collect a response - according to the provided schema. Or in case a client is an agent, it might - decide how to handle the elicitation -- either by asking the user or - automatically generating a response. + client within a tool's execution. The client might display the message to the + user and collect a response according to the provided schema. Or in case a + client + is an agent, it might decide how to handle the elicitation -- either by asking + the user or automatically generating a response. Args: - message: The message to present to the user - requestedSchema: JSON Schema defining the expected response structure + schema: A Pydantic model class defining the expected response structure + message: Optional message to present to the user. If not provided, will use + a default message based on the schema Returns: - The user's response as a dict matching the request schema structure + An instance of the schema type with the user's response Raises: - ValueError: If elicitation is not supported by the client or fails + Exception: If the user declines or cancels the elicitation + ValidationError: If the response doesn't match the schema """ + json_schema = schema.model_json_schema() + result = await self.request_context.session.elicit( message=message, - requestedSchema=requestedSchema, + requestedSchema=json_schema, related_request_id=self.request_id, ) - return result.content + if result.action == "accept" and result.content: + # Validate and parse the content using the schema + try: + return schema.model_validate(result.content) + except ValidationError as e: + raise ValueError(f"Invalid response: {e}") + elif result.action == "decline": + raise Exception("User declined to provide information") + else: + raise Exception("User cancelled the request") async def log( self, diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 4376cb150b..5c696b136a 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -258,14 +258,14 @@ async def list_roots(self) -> types.ListRootsResult: async def elicit( self, message: str, - requestedSchema: dict[str, Any], + requestedSchema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, ) -> types.ElicitResult: """Send an elicitation/create request. Args: message: The message to present to the user - requestedSchema: JSON Schema defining the expected response structure + requestedSchema: Schema defining the expected response structure Returns: The client's response diff --git a/src/mcp/types.py b/src/mcp/types.py index 5708027dfb..a678f101d4 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1194,16 +1194,16 @@ class ClientNotification( pass +# Type for elicitation schema - a JSON Schema dict +ElicitRequestedSchema: TypeAlias = dict[str, Any] +"""Schema for elicitation requests.""" + + class ElicitRequestParams(RequestParams): """Parameters for elicitation requests.""" message: str - """The message to present to the user.""" - - requestedSchema: dict[str, Any] - """ - A JSON Schema object defining the expected structure of the response. - """ + requestedSchema: ElicitRequestedSchema model_config = ConfigDict(extra="allow") @@ -1215,10 +1215,21 @@ class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]) class ElicitResult(Result): - """The client's response to an elicitation/create request from the server.""" + """The client's response to an elicitation request.""" - content: dict[str, Any] - """The response from the client, matching the structure of requestedSchema.""" + action: Literal["accept", "decline", "cancel"] + """ + The user action in response to the elicitation. + - "accept": User submitted the form/confirmed the action + - "decline": User explicitly declined the action + - "cancel": User dismissed without making an explicit choice + """ + + content: dict[str, str | int | float | bool | None] | None = None + """ + The submitted form data, only present when action is "accept". + Contains values matching the requested schema. + """ class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 5aa719a6e6..cb9be6e1ab 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -14,7 +14,7 @@ import pytest import uvicorn -from pydantic import AnyUrl +from pydantic import AnyUrl, BaseModel, Field from starlette.applications import Starlette from starlette.requests import Request @@ -102,19 +102,15 @@ def echo(message: str) -> str: # Add a tool that uses elicitation @mcp.tool(description="A tool that uses elicitation") async def ask_user(prompt: str, ctx: Context) -> str: - schema = { - "type": "object", - "properties": { - "answer": {"type": "string"}, - }, - "required": ["answer"], - } + class AnswerSchema(BaseModel): + answer: str = Field(description="The user's answer to the question") - response = await ctx.elicit( - message=f"Tool wants to ask: {prompt}", - requestedSchema=schema, - ) - return f"User answered: {response['answer']}" + try: + result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) + return f"User answered: {result.answer}" + except Exception as e: + # Handle cancellation or decline + return f"User cancelled or declined: {str(e)}" # Create the SSE app app = mcp.sse_app() @@ -279,6 +275,47 @@ def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str context_data["path"] = request.url.path return json.dumps(context_data) + # Restaurant booking tool with elicitation + @mcp.tool(description="Book a table at a restaurant with elicitation") + async def book_restaurant( + date: str, + time: str, + party_size: int, + ctx: Context, + ) -> str: + """Book a table - uses elicitation if requested date is unavailable.""" + + class AlternativeDateSchema(BaseModel): + checkAlternative: bool = Field(description="Would you like to try another date?") + alternativeDate: str = Field( + default="2024-12-26", + description="What date would you prefer? (YYYY-MM-DD)", + ) + + # For testing: assume dates starting with "2024-12-25" are unavailable + if date.startswith("2024-12-25"): + # Use elicitation to ask about alternatives + try: + result = await ctx.elicit( + message=( + f"No tables available for {party_size} people on {date} " + f"at {time}. Would you like to check another date?" + ), + schema=AlternativeDateSchema, + ) + + if result.checkAlternative: + alt_date = result.alternativeDate + return f"✅ Booked table for {party_size} on {alt_date} at {time}" + else: + return "❌ No booking made" + except Exception: + # User declined or cancelled + return "❌ Booking cancelled" + else: + # Available - book directly + return f"✅ Booked table for {party_size} on {date} at {time}" + return mcp @@ -670,6 +707,22 @@ async def handle_generic_notification(self, message) -> None: await self.handle_tool_list_changed(message.root.params) +async def create_test_elicitation_callback(context, params): + """Shared elicitation callback for tests. + + Handles elicitation requests for restaurant booking tests. + """ + # For restaurant booking test + if "No tables available" in params.message: + return ElicitResult( + action="accept", + content={"checkAlternative": True, "alternativeDate": "2024-12-26"}, + ) + else: + # Default response + return ElicitResult(action="decline") + + async def call_all_mcp_features(session: ClientSession, collector: NotificationCollector) -> None: """ Test all MCP features using the provided session. @@ -765,6 +818,21 @@ async def progress_callback(progress: float, total: float | None, message: str | assert "info" in log_levels assert "warning" in log_levels + # 5. Test elicitation tool + # Test restaurant booking with unavailable date (triggers elicitation) + booking_result = await session.call_tool( + "book_restaurant", + { + "date": "2024-12-25", # Unavailable date to trigger elicitation + "time": "19:00", + "party_size": 4, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + # Should have booked the alternative date from elicitation callback + assert "✅ Booked table for 4 on 2024-12-26" in booking_result.content[0].text + # Test resources # 1. Static resource resources = await session.list_resources() @@ -905,8 +973,6 @@ async def test_fastmcp_all_features_sse(everything_server: None, everything_serv # Create notification collector collector = NotificationCollector() - # Create a sampling callback that simulates an LLM - # Connect to the server with callbacks async with sse_client(everything_server_url + "/sse") as streams: # Set up message handler to capture notifications @@ -919,6 +985,7 @@ async def message_handler(message): async with ClientSession( *streams, sampling_callback=sampling_callback, + elicitation_callback=create_test_elicitation_callback, message_handler=message_handler, ) as session: # Run the common test suite @@ -951,6 +1018,7 @@ async def message_handler(message): read_stream, write_stream, sampling_callback=sampling_callback, + elicitation_callback=create_test_elicitation_callback, message_handler=message_handler, ) as session: # Run the common test suite with HTTP-specific test suffix @@ -965,7 +1033,7 @@ async def test_elicitation_feature(server: None, server_url: str) -> None: async def elicitation_callback(context, params): # Verify the elicitation parameters if params.message == "Tool wants to ask: What is your name?": - return ElicitResult(content={"answer": "Test User"}) + return ElicitResult(content={"answer": "Test User"}, action="accept") else: raise ValueError("Unexpected elicitation message") diff --git a/tests/server/fastmcp/test_stdio_elicitation.py b/tests/server/fastmcp/test_stdio_elicitation.py index 52555a287d..b6a6a6edb8 100644 --- a/tests/server/fastmcp/test_stdio_elicitation.py +++ b/tests/server/fastmcp/test_stdio_elicitation.py @@ -3,6 +3,7 @@ """ import pytest +from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP from mcp.shared.memory import create_connected_server_and_client_session @@ -18,25 +19,27 @@ async def test_stdio_elicitation(): @mcp.tool(description="A tool that uses elicitation") async def ask_user(prompt: str, ctx: Context) -> str: - schema = { - "type": "object", - "properties": { - "answer": {"type": "string"}, - }, - "required": ["answer"], - } - - response = await ctx.elicit( - message=f"Tool wants to ask: {prompt}", - requestedSchema=schema, - ) - return f"User answered: {response['answer']}" + class AnswerSchema(BaseModel): + answer: str = Field(description="The user's answer to the question") + + try: + result = await ctx.elicit( + message=f"Tool wants to ask: {prompt}", + schema=AnswerSchema, + ) + return f"User answered: {result.answer}" + except Exception as e: + # Handle cancellation or decline + if "declined" in str(e): + return "User declined to answer" + else: + return "User cancelled" # Create a custom handler for elicitation requests async def elicitation_callback(context, params): # Verify the elicitation parameters if params.message == "Tool wants to ask: What is your name?": - return ElicitResult(content={"answer": "Test User"}) + return ElicitResult(action="accept", content={"answer": "Test User"}) else: raise ValueError(f"Unexpected elicitation message: {params.message}") @@ -49,9 +52,7 @@ async def elicitation_callback(context, params): assert result.serverInfo.name == "StdioElicitationServer" # Call the tool that uses elicitation - tool_result = await client_session.call_tool( - "ask_user", {"prompt": "What is your name?"} - ) + tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"}) # Verify the result assert len(tool_result.content) == 1 From a75afd46b59e9cdd0560f7afcb83c304febabd18 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 12 Jun 2025 09:51:12 +0100 Subject: [PATCH 05/10] add ElicitationResult to fastMCP --- src/mcp/server/fastmcp/server.py | 42 +++++--- tests/server/fastmcp/test_elicitation.py | 101 ++++++++++++++++++ tests/server/fastmcp/test_integration.py | 37 ++++--- .../server/fastmcp/test_stdio_elicitation.py | 60 ----------- 4 files changed, 148 insertions(+), 92 deletions(-) create mode 100644 tests/server/fastmcp/test_elicitation.py delete mode 100644 tests/server/fastmcp/test_stdio_elicitation.py diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 200900ee2a..78a73a957f 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -65,7 +65,20 @@ logger = get_logger(__name__) -ElicitedModelT = TypeVar("ElicitedModelT", bound=BaseModel) +ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) + + +class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]): + """Result of an elicitation request.""" + + action: Literal["accept", "decline", "cancel"] + """The user's action in response to the elicitation.""" + + data: ElicitSchemaModelT | None = None + """The validated data if action is 'accept', None otherwise.""" + + validation_error: str | None = None + """Validation error message if data failed to validate.""" class Settings(BaseSettings, Generic[LifespanResultT]): @@ -977,28 +990,28 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent async def elicit( self, message: str, - schema: type[ElicitedModelT], - ) -> ElicitedModelT: + schema: type[ElicitSchemaModelT], + ) -> ElicitationResult[ElicitSchemaModelT]: """Elicit information from the client/user. This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the user and collect a response according to the provided schema. Or in case a - client - is an agent, it might decide how to handle the elicitation -- either by asking + client is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. Args: - schema: A Pydantic model class defining the expected response structure + schema: A Pydantic model class defining the expected response structure, according to the specification, + only primive types are allowed. message: Optional message to present to the user. If not provided, will use a default message based on the schema Returns: - An instance of the schema type with the user's response + An ElicitationResult containing the action taken and the data if accepted - Raises: - Exception: If the user declines or cancels the elicitation - ValidationError: If the response doesn't match the schema + Note: + Check the result.action to determine if the user accepted, declined, or cancelled. + The result.data will only be populated if action is "accept" and validation succeeded. """ json_schema = schema.model_json_schema() @@ -1012,13 +1025,12 @@ async def elicit( if result.action == "accept" and result.content: # Validate and parse the content using the schema try: - return schema.model_validate(result.content) + validated_data = schema.model_validate(result.content) + return ElicitationResult(action="accept", data=validated_data) except ValidationError as e: - raise ValueError(f"Invalid response: {e}") - elif result.action == "decline": - raise Exception("User declined to provide information") + return ElicitationResult(action="accept", validation_error=str(e)) else: - raise Exception("User cancelled the request") + return ElicitationResult(action=result.action) async def log( self, diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py new file mode 100644 index 0000000000..930a6f44c7 --- /dev/null +++ b/tests/server/fastmcp/test_elicitation.py @@ -0,0 +1,101 @@ +""" +Test the elicitation feature using stdio transport. +""" + +import pytest +from pydantic import BaseModel, Field + +from mcp.server.fastmcp import Context, FastMCP +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ElicitResult, TextContent + + +@pytest.mark.anyio +async def test_stdio_elicitation(): + """Test the elicitation feature using stdio transport.""" + + # Create a FastMCP server with a tool that uses elicitation + mcp = FastMCP(name="StdioElicitationServer") + + @mcp.tool(description="A tool that uses elicitation") + async def ask_user(prompt: str, ctx: Context) -> str: + class AnswerSchema(BaseModel): + answer: str = Field(description="The user's answer to the question") + + result = await ctx.elicit( + message=f"Tool wants to ask: {prompt}", + schema=AnswerSchema, + ) + + if result.action == "accept" and result.data: + return f"User answered: {result.data.answer}" + elif result.action == "decline": + return "User declined to answer" + else: + return "User cancelled" + + # Create a custom handler for elicitation requests + async def elicitation_callback(context, params): + # Verify the elicitation parameters + if params.message == "Tool wants to ask: What is your name?": + return ElicitResult(action="accept", content={"answer": "Test User"}) + else: + raise ValueError(f"Unexpected elicitation message: {params.message}") + + # Use memory-based session to test with stdio transport + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + # First initialize the session + result = await client_session.initialize() + assert result.serverInfo.name == "StdioElicitationServer" + + # Call the tool that uses elicitation + tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"}) + + # Verify the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "User answered: Test User" + + +@pytest.mark.anyio +async def test_stdio_elicitation_decline(): + """Test elicitation with user declining.""" + + mcp = FastMCP(name="StdioElicitationDeclineServer") + + @mcp.tool(description="A tool that uses elicitation") + async def ask_user(prompt: str, ctx: Context) -> str: + class AnswerSchema(BaseModel): + answer: str = Field(description="The user's answer to the question") + + result = await ctx.elicit( + message=f"Tool wants to ask: {prompt}", + schema=AnswerSchema, + ) + + if result.action == "accept" and result.data: + return f"User answered: {result.data.answer}" + elif result.action == "decline": + return "User declined to answer" + else: + return "User cancelled" + + # Create a custom handler that declines + async def elicitation_callback(context, params): + return ElicitResult(action="decline") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + # Initialize + await client_session.initialize() + + # Call the tool + tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"}) + + # Verify the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "User declined to answer" diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index cb9be6e1ab..4d385b8d28 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -105,12 +105,13 @@ async def ask_user(prompt: str, ctx: Context) -> str: class AnswerSchema(BaseModel): answer: str = Field(description="The user's answer to the question") - try: - result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) - return f"User answered: {result.answer}" - except Exception as e: + result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) + + if result.action == "accept" and result.data: + return f"User answered: {result.data.answer}" + else: # Handle cancellation or decline - return f"User cancelled or declined: {str(e)}" + return f"User cancelled or declined: {result.action}" # Create the SSE app app = mcp.sse_app() @@ -295,23 +296,25 @@ class AlternativeDateSchema(BaseModel): # For testing: assume dates starting with "2024-12-25" are unavailable if date.startswith("2024-12-25"): # Use elicitation to ask about alternatives - try: - result = await ctx.elicit( - message=( - f"No tables available for {party_size} people on {date} " - f"at {time}. Would you like to check another date?" - ), - schema=AlternativeDateSchema, - ) + result = await ctx.elicit( + message=( + f"No tables available for {party_size} people on {date} " + f"at {time}. Would you like to check another date?" + ), + schema=AlternativeDateSchema, + ) - if result.checkAlternative: - alt_date = result.alternativeDate + if result.action == "accept" and result.data: + if result.data.checkAlternative: + alt_date = result.data.alternativeDate return f"✅ Booked table for {party_size} on {alt_date} at {time}" else: return "❌ No booking made" - except Exception: - # User declined or cancelled + elif result.action in ("decline", "cancel"): return "❌ Booking cancelled" + else: + # Validation error + return f"❌ Invalid input: {result.validation_error}" else: # Available - book directly return f"✅ Booked table for {party_size} on {date} at {time}" diff --git a/tests/server/fastmcp/test_stdio_elicitation.py b/tests/server/fastmcp/test_stdio_elicitation.py deleted file mode 100644 index b6a6a6edb8..0000000000 --- a/tests/server/fastmcp/test_stdio_elicitation.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Test the elicitation feature using stdio transport. -""" - -import pytest -from pydantic import BaseModel, Field - -from mcp.server.fastmcp import Context, FastMCP -from mcp.shared.memory import create_connected_server_and_client_session -from mcp.types import ElicitResult, TextContent - - -@pytest.mark.anyio -async def test_stdio_elicitation(): - """Test the elicitation feature using stdio transport.""" - - # Create a FastMCP server with a tool that uses elicitation - mcp = FastMCP(name="StdioElicitationServer") - - @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context) -> str: - class AnswerSchema(BaseModel): - answer: str = Field(description="The user's answer to the question") - - try: - result = await ctx.elicit( - message=f"Tool wants to ask: {prompt}", - schema=AnswerSchema, - ) - return f"User answered: {result.answer}" - except Exception as e: - # Handle cancellation or decline - if "declined" in str(e): - return "User declined to answer" - else: - return "User cancelled" - - # Create a custom handler for elicitation requests - async def elicitation_callback(context, params): - # Verify the elicitation parameters - if params.message == "Tool wants to ask: What is your name?": - return ElicitResult(action="accept", content={"answer": "Test User"}) - else: - raise ValueError(f"Unexpected elicitation message: {params.message}") - - # Use memory-based session to test with stdio transport - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - # First initialize the session - result = await client_session.initialize() - assert result.serverInfo.name == "StdioElicitationServer" - - # Call the tool that uses elicitation - tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"}) - - # Verify the result - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "User answered: Test User" From 4603e89e794af92aab6775c1d973c5e2e491782e Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 12 Jun 2025 09:59:55 +0100 Subject: [PATCH 06/10] add readme --- README.md | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 38154c878d..a4eba199b5 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,8 @@ - [Images](#images) - [Context](#context) - [Completions](#completions) + - [Elicitation](#elicitation) + - [Authentication](#authentication) - [Running Your Server](#running-your-server) - [Development Mode](#development-mode) - [Claude Desktop Integration](#claude-desktop-integration) @@ -74,7 +76,7 @@ The Model Context Protocol allows applications to provide context for LLMs in a ### Adding MCP to your python project -We recommend using [uv](https://docs.astral.sh/uv/) to manage your Python projects. +We recommend using [uv](https://docs.astral.sh/uv/) to manage your Python projects. If you haven't created a uv-managed project yet, create one: @@ -372,6 +374,43 @@ async def handle_completion( return Completion(values=filtered) return None ``` +### Elicitation + +Request additional information from users during tool execution: + +```python +from mcp.server.fastmcp import FastMCP, Context +from pydantic import BaseModel, Field + +mcp = FastMCP("Booking System") + + +@mcp.tool() +async def book_table(date: str, party_size: int, ctx: Context) -> str: + """Book a table with confirmation""" + + class ConfirmBooking(BaseModel): + confirm: bool = Field(description="Confirm booking?") + notes: str = Field(default="", description="Special requests") + + result = await ctx.elicit( + message=f"Confirm booking for {party_size} on {date}?", + schema=ConfirmBooking + ) + + if result.action == "accept" and result.data: + if result.data.confirm: + return f"Booked! Notes: {result.data.notes or 'None'}" + return "Booking cancelled" + + # User declined or cancelled + return f"Booking {result.action}" +``` + +The `elicit()` method returns an `ElicitationResult` with: +- `action`: "accept", "decline", or "cancel" +- `data`: The validated response (only when accepted) +- `validation_error`: Any validation error message ### Authentication From 653c057759a8787a48ccb4393c3b10d3c11bfa9c Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 12 Jun 2025 10:54:35 +0100 Subject: [PATCH 07/10] add validation for primitive types --- README.md | 4 +- src/mcp/server/fastmcp/server.py | 50 +++++- tests/server/fastmcp/test_elicitation.py | 219 +++++++++++++++++------ 3 files changed, 212 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index a4eba199b5..c470596ff2 100644 --- a/README.md +++ b/README.md @@ -389,13 +389,13 @@ mcp = FastMCP("Booking System") async def book_table(date: str, party_size: int, ctx: Context) -> str: """Book a table with confirmation""" + # Schema must only contain primitive types (str, int, float, bool) class ConfirmBooking(BaseModel): confirm: bool = Field(description="Confirm booking?") notes: str = Field(default="", description="Special requests") result = await ctx.elicit( - message=f"Confirm booking for {party_size} on {date}?", - schema=ConfirmBooking + message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking ) if result.action == "accept" and result.data: diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 78a73a957f..f95907458b 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -4,17 +4,19 @@ import inspect import re +import types from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import ( AbstractAsyncContextManager, asynccontextmanager, ) from itertools import chain -from typing import Any, Generic, Literal, TypeVar +from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin import anyio import pydantic_core from pydantic import BaseModel, Field, ValidationError +from pydantic.fields import FieldInfo from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette @@ -70,13 +72,13 @@ class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]): """Result of an elicitation request.""" - + action: Literal["accept", "decline", "cancel"] """The user's action in response to the elicitation.""" - + data: ElicitSchemaModelT | None = None """The validated data if action is 'accept', None otherwise.""" - + validation_error: str | None = None """Validation error message if data failed to validate.""" @@ -891,6 +893,43 @@ def _convert_to_content( return [TextContent(type="text", text=result)] +def _validate_elicitation_schema(schema: type[BaseModel]) -> None: + """Validate that a Pydantic model only contains primitive field types.""" + for field_name, field_info in schema.model_fields.items(): + if not _is_primitive_field(field_info): + raise TypeError( + f"Elicitation schema field '{field_name}' must be a primitive type " + f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " + f"Complex types like lists, dicts, or nested models are not allowed." + ) + + +# Primitive types allowed in elicitation schemas +_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) + + +def _is_primitive_field(field_info: FieldInfo) -> bool: + """Check if a field is a primitive type allowed in elicitation schemas.""" + annotation = field_info.annotation + + # Handle None type + if annotation is type(None): + return True + + # Handle basic primitive types + if annotation in _ELICITATION_PRIMITIVE_TYPES: + return True + + # Handle Union types (including Optional and Python 3.10+ union syntax) + origin = get_origin(annotation) + if origin is Union or (hasattr(types, 'UnionType') and isinstance(annotation, types.UnionType)): + args = get_args(annotation) + # All args must be primitive types or None + return all(arg is type(None) or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) + + return False + + class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]): """Context object providing access to MCP capabilities. @@ -1014,6 +1053,9 @@ async def elicit( The result.data will only be populated if action is "accept" and validation succeeded. """ + # Validate that schema only contains primitive types and fail loudly if not + _validate_elicitation_schema(schema) + json_schema = schema.model_json_schema() result = await self.request_context.session.elicit( diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 930a6f44c7..20937d91dc 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -10,23 +10,21 @@ from mcp.types import ElicitResult, TextContent -@pytest.mark.anyio -async def test_stdio_elicitation(): - """Test the elicitation feature using stdio transport.""" +# Shared schema for basic tests +class AnswerSchema(BaseModel): + answer: str = Field(description="The user's answer to the question") - # Create a FastMCP server with a tool that uses elicitation - mcp = FastMCP(name="StdioElicitationServer") + +def create_ask_user_tool(mcp: FastMCP): + """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") async def ask_user(prompt: str, ctx: Context) -> str: - class AnswerSchema(BaseModel): - answer: str = Field(description="The user's answer to the question") - result = await ctx.elicit( message=f"Tool wants to ask: {prompt}", schema=AnswerSchema, ) - + if result.action == "accept" and result.data: return f"User answered: {result.data.answer}" elif result.action == "decline": @@ -34,68 +32,179 @@ class AnswerSchema(BaseModel): else: return "User cancelled" + return ask_user + + +async def call_tool_and_assert( + mcp: FastMCP, + elicitation_callback, + tool_name: str, + args: dict, + expected_text: str | None = None, + text_contains: list[str] | None = None, +): + """Helper to create session, call tool, and assert result.""" + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool(tool_name, args) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + + if expected_text is not None: + assert result.content[0].text == expected_text + elif text_contains is not None: + for substring in text_contains: + assert substring in result.content[0].text + + return result + + +@pytest.mark.anyio +async def test_stdio_elicitation(): + """Test the elicitation feature using stdio transport.""" + mcp = FastMCP(name="StdioElicitationServer") + create_ask_user_tool(mcp) + # Create a custom handler for elicitation requests async def elicitation_callback(context, params): - # Verify the elicitation parameters if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) else: raise ValueError(f"Unexpected elicitation message: {params.message}") - # Use memory-based session to test with stdio transport - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - # First initialize the session - result = await client_session.initialize() - assert result.serverInfo.name == "StdioElicitationServer" - - # Call the tool that uses elicitation - tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"}) - - # Verify the result - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "User answered: Test User" + await call_tool_and_assert( + mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User answered: Test User" + ) @pytest.mark.anyio async def test_stdio_elicitation_decline(): """Test elicitation with user declining.""" - mcp = FastMCP(name="StdioElicitationDeclineServer") - - @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context) -> str: - class AnswerSchema(BaseModel): - answer: str = Field(description="The user's answer to the question") - - result = await ctx.elicit( - message=f"Tool wants to ask: {prompt}", - schema=AnswerSchema, - ) - - if result.action == "accept" and result.data: - return f"User answered: {result.data.answer}" - elif result.action == "decline": - return "User declined to answer" - else: - return "User cancelled" - - # Create a custom handler that declines + create_ask_user_tool(mcp) + async def elicitation_callback(context, params): return ElicitResult(action="decline") - + + await call_tool_and_assert( + mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User declined to answer" + ) + + +@pytest.mark.anyio +async def test_elicitation_schema_validation(): + """Test that elicitation schemas must only contain primitive types.""" + mcp = FastMCP(name="ValidationTestServer") + + def create_validation_tool(name: str, schema_class: type[BaseModel]): + @mcp.tool(name=name, description=f"Tool testing {name}") + async def tool(ctx: Context) -> str: + try: + await ctx.elicit(message="This should fail validation", schema=schema_class) + return "Should not reach here" + except TypeError as e: + return f"Validation failed as expected: {str(e)}" + + return tool + + # Test cases for invalid schemas + class InvalidListSchema(BaseModel): + names: list[str] = Field(description="List of names") + + class NestedModel(BaseModel): + value: str + + class InvalidNestedSchema(BaseModel): + nested: NestedModel = Field(description="Nested model") + + create_validation_tool("invalid_list", InvalidListSchema) + create_validation_tool("nested_model", InvalidNestedSchema) + + # Dummy callback (won't be called due to validation failure) + async def elicitation_callback(context, params): + return ElicitResult(action="accept", content={}) + async with create_connected_server_and_client_session( mcp._mcp_server, elicitation_callback=elicitation_callback ) as client_session: - # Initialize await client_session.initialize() - - # Call the tool - tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"}) - - # Verify the result - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "User declined to answer" + + # Test both invalid schemas + for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]: + result = await client_session.call_tool(tool_name, {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert "Validation failed as expected" in result.content[0].text + assert field_name in result.content[0].text + + +@pytest.mark.anyio +async def test_elicitation_with_optional_fields(): + """Test that Optional fields work correctly in elicitation schemas.""" + mcp = FastMCP(name="OptionalFieldServer") + + class OptionalSchema(BaseModel): + required_name: str = Field(description="Your name (required)") + optional_age: int | None = Field(default=None, description="Your age (optional)") + optional_email: str | None = Field(default=None, description="Your email (optional)") + subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") + + @mcp.tool(description="Tool with optional fields") + async def optional_tool(ctx: Context) -> str: + result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) + + if result.action == "accept" and result.data: + info = [f"Name: {result.data.required_name}"] + if result.data.optional_age is not None: + info.append(f"Age: {result.data.optional_age}") + if result.data.optional_email is not None: + info.append(f"Email: {result.data.optional_email}") + info.append(f"Subscribe: {result.data.subscribe}") + return ", ".join(info) + else: + return f"User {result.action}" + + # Test cases with different field combinations + test_cases = [ + ( + # All fields provided + {"required_name": "John Doe", "optional_age": 30, "optional_email": "john@example.com", "subscribe": True}, + "Name: John Doe, Age: 30, Email: john@example.com, Subscribe: True", + ), + ( + # Only required fields + {"required_name": "Jane Smith"}, + "Name: Jane Smith, Subscribe: False", + ), + ] + + for content, expected in test_cases: + + async def callback(context, params): + return ElicitResult(action="accept", content=content) + + await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) + + # Test invalid optional field + class InvalidOptionalSchema(BaseModel): + name: str = Field(description="Name") + optional_list: list[str] | None = Field(default=None, description="Invalid optional list") + + @mcp.tool(description="Tool with invalid optional field") + async def invalid_optional_tool(ctx: Context) -> str: + try: + await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) + return "Should not reach here" + except TypeError as e: + return f"Validation failed: {str(e)}" + + await call_tool_and_assert( + mcp, + lambda c, p: ElicitResult(action="accept", content={}), + "invalid_optional_tool", + {}, + text_contains=["Validation failed:", "optional_list"], + ) From d4ae036379b862f02a2de8895c6e0293a9ad54d1 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 12 Jun 2025 11:00:25 +0100 Subject: [PATCH 08/10] cleanup --- src/mcp/server/fastmcp/server.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index f95907458b..da1645795d 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -893,6 +893,10 @@ def _convert_to_content( return [TextContent(type="text", text=result)] +# Primitive types allowed in elicitation schemas +_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) + + def _validate_elicitation_schema(schema: type[BaseModel]) -> None: """Validate that a Pydantic model only contains primitive field types.""" for field_name, field_info in schema.model_fields.items(): @@ -904,28 +908,24 @@ def _validate_elicitation_schema(schema: type[BaseModel]) -> None: ) -# Primitive types allowed in elicitation schemas -_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) - - def _is_primitive_field(field_info: FieldInfo) -> bool: """Check if a field is a primitive type allowed in elicitation schemas.""" annotation = field_info.annotation # Handle None type - if annotation is type(None): + if annotation is types.NoneType: return True # Handle basic primitive types if annotation in _ELICITATION_PRIMITIVE_TYPES: return True - # Handle Union types (including Optional and Python 3.10+ union syntax) + # Handle Union types origin = get_origin(annotation) - if origin is Union or (hasattr(types, 'UnionType') and isinstance(annotation, types.UnionType)): + if origin is Union or origin is types.UnionType: args = get_args(annotation) # All args must be primitive types or None - return all(arg is type(None) or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) + return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) return False From 67bfd9acd136d0647ba4b12a6656a242e107dcdf Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 12 Jun 2025 11:15:49 +0100 Subject: [PATCH 09/10] format --- tests/server/fastmcp/test_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 4d385b8d28..b696e1d909 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -106,7 +106,7 @@ class AnswerSchema(BaseModel): answer: str = Field(description="The user's answer to the question") result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) - + if result.action == "accept" and result.data: return f"User answered: {result.data.answer}" else: From 51b5ee86fa9d21d632cf25335d5626dbda10631b Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 16 Jun 2025 22:38:38 +0100 Subject: [PATCH 10/10] Update --- README.md | 21 +++-- src/mcp/server/elicitation.py | 111 +++++++++++++++++++++++ src/mcp/server/fastmcp/server.py | 80 +--------------- tests/server/fastmcp/test_integration.py | 4 +- 4 files changed, 132 insertions(+), 84 deletions(-) create mode 100644 src/mcp/server/elicitation.py diff --git a/README.md b/README.md index c470596ff2..1bdae1d203 100644 --- a/README.md +++ b/README.md @@ -380,6 +380,11 @@ Request additional information from users during tool execution: ```python from mcp.server.fastmcp import FastMCP, Context +from mcp.server.elicitation import ( + AcceptedElicitation, + DeclinedElicitation, + CancelledElicitation, +) from pydantic import BaseModel, Field mcp = FastMCP("Booking System") @@ -398,13 +403,15 @@ async def book_table(date: str, party_size: int, ctx: Context) -> str: message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking ) - if result.action == "accept" and result.data: - if result.data.confirm: - return f"Booked! Notes: {result.data.notes or 'None'}" - return "Booking cancelled" - - # User declined or cancelled - return f"Booking {result.action}" + match result: + case AcceptedElicitation(data=data): + if data.confirm: + return f"Booked! Notes: {data.notes or 'None'}" + return "Booking cancelled" + case DeclinedElicitation(): + return "Booking declined" + case CancelledElicitation(): + return "Booking cancelled" ``` The `elicit()` method returns an `ElicitationResult` with: diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py new file mode 100644 index 0000000000..1e48738c84 --- /dev/null +++ b/src/mcp/server/elicitation.py @@ -0,0 +1,111 @@ +"""Elicitation utilities for MCP servers.""" + +from __future__ import annotations + +import types +from typing import Generic, Literal, TypeVar, Union, get_args, get_origin + +from pydantic import BaseModel +from pydantic.fields import FieldInfo + +from mcp.server.session import ServerSession +from mcp.types import RequestId + +ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) + + +class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]): + """Result when user accepts the elicitation.""" + + action: Literal["accept"] = "accept" + data: ElicitSchemaModelT + + +class DeclinedElicitation(BaseModel): + """Result when user declines the elicitation.""" + + action: Literal["decline"] = "decline" + + +class CancelledElicitation(BaseModel): + """Result when user cancels the elicitation.""" + + action: Literal["cancel"] = "cancel" + + +ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation + + +# Primitive types allowed in elicitation schemas +_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) + + +def _validate_elicitation_schema(schema: type[BaseModel]) -> None: + """Validate that a Pydantic model only contains primitive field types.""" + for field_name, field_info in schema.model_fields.items(): + if not _is_primitive_field(field_info): + raise TypeError( + f"Elicitation schema field '{field_name}' must be a primitive type " + f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " + f"Complex types like lists, dicts, or nested models are not allowed." + ) + + +def _is_primitive_field(field_info: FieldInfo) -> bool: + """Check if a field is a primitive type allowed in elicitation schemas.""" + annotation = field_info.annotation + + # Handle None type + if annotation is types.NoneType: + return True + + # Handle basic primitive types + if annotation in _ELICITATION_PRIMITIVE_TYPES: + return True + + # Handle Union types + origin = get_origin(annotation) + if origin is Union or origin is types.UnionType: + args = get_args(annotation) + # All args must be primitive types or None + return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) + + return False + + +async def elicit_with_validation( + session: ServerSession, + message: str, + schema: type[ElicitSchemaModelT], + related_request_id: RequestId | None = None, +) -> ElicitationResult[ElicitSchemaModelT]: + """Elicit information from the client/user with schema validation. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. The client might display the message to the + user and collect a response according to the provided schema. Or in case a + client is an agent, it might decide how to handle the elicitation -- either by asking + the user or automatically generating a response. + """ + # Validate that schema only contains primitive types and fail loudly if not + _validate_elicitation_schema(schema) + + json_schema = schema.model_json_schema() + + result = await session.elicit( + message=message, + requestedSchema=json_schema, + related_request_id=related_request_id, + ) + + if result.action == "accept" and result.content: + # Validate and parse the content using the schema + validated_data = schema.model_validate(result.content) + return AcceptedElicitation(data=validated_data) + elif result.action == "decline": + return DeclinedElicitation() + elif result.action == "cancel": + return CancelledElicitation() + else: + # This should never happen, but handle it just in case + raise ValueError(f"Unexpected elicitation action: {result.action}") diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index da1645795d..a85c7117c9 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -4,19 +4,17 @@ import inspect import re -import types from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import ( AbstractAsyncContextManager, asynccontextmanager, ) from itertools import chain -from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin +from typing import Any, Generic, Literal import anyio import pydantic_core -from pydantic import BaseModel, Field, ValidationError -from pydantic.fields import FieldInfo +from pydantic import BaseModel, Field from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette @@ -36,6 +34,7 @@ from mcp.server.auth.settings import ( AuthSettings, ) +from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -67,21 +66,6 @@ logger = get_logger(__name__) -ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) - - -class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]): - """Result of an elicitation request.""" - - action: Literal["accept", "decline", "cancel"] - """The user's action in response to the elicitation.""" - - data: ElicitSchemaModelT | None = None - """The validated data if action is 'accept', None otherwise.""" - - validation_error: str | None = None - """Validation error message if data failed to validate.""" - class Settings(BaseSettings, Generic[LifespanResultT]): """FastMCP server settings. @@ -893,43 +877,6 @@ def _convert_to_content( return [TextContent(type="text", text=result)] -# Primitive types allowed in elicitation schemas -_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) - - -def _validate_elicitation_schema(schema: type[BaseModel]) -> None: - """Validate that a Pydantic model only contains primitive field types.""" - for field_name, field_info in schema.model_fields.items(): - if not _is_primitive_field(field_info): - raise TypeError( - f"Elicitation schema field '{field_name}' must be a primitive type " - f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " - f"Complex types like lists, dicts, or nested models are not allowed." - ) - - -def _is_primitive_field(field_info: FieldInfo) -> bool: - """Check if a field is a primitive type allowed in elicitation schemas.""" - annotation = field_info.annotation - - # Handle None type - if annotation is types.NoneType: - return True - - # Handle basic primitive types - if annotation in _ELICITATION_PRIMITIVE_TYPES: - return True - - # Handle Union types - origin = get_origin(annotation) - if origin is Union or origin is types.UnionType: - args = get_args(annotation) - # All args must be primitive types or None - return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) - - return False - - class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]): """Context object providing access to MCP capabilities. @@ -1053,27 +1000,10 @@ async def elicit( The result.data will only be populated if action is "accept" and validation succeeded. """ - # Validate that schema only contains primitive types and fail loudly if not - _validate_elicitation_schema(schema) - - json_schema = schema.model_json_schema() - - result = await self.request_context.session.elicit( - message=message, - requestedSchema=json_schema, - related_request_id=self.request_id, + return await elicit_with_validation( + session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id ) - if result.action == "accept" and result.content: - # Validate and parse the content using the schema - try: - validated_data = schema.model_validate(result.content) - return ElicitationResult(action="accept", data=validated_data) - except ValidationError as e: - return ElicitationResult(action="accept", validation_error=str(e)) - else: - return ElicitationResult(action=result.action) - async def log( self, level: Literal["debug", "info", "warning", "error"], diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index b696e1d909..3eb0139467 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -313,8 +313,8 @@ class AlternativeDateSchema(BaseModel): elif result.action in ("decline", "cancel"): return "❌ Booking cancelled" else: - # Validation error - return f"❌ Invalid input: {result.validation_error}" + # Handle case where action is "accept" but data is None + return "❌ No booking data received" else: # Available - book directly return f"✅ Booked table for {party_size} on {date} at {time}"