Skip to content

Commit aabcecf

Browse files
Tapan Chughfelixweinberger
authored andcommitted
SEP: Elicitation Enum Schema Improvements and Standards Compliance
1 parent c51936f commit aabcecf

File tree

3 files changed

+134
-8
lines changed

3 files changed

+134
-8
lines changed

src/mcp/server/elicitation.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import types
6+
from collections.abc import Sequence
67
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
78

89
from pydantic import BaseModel
@@ -46,11 +47,22 @@ def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
4647
if not _is_primitive_field(field_info):
4748
raise TypeError(
4849
f"Elicitation schema field '{field_name}' must be a primitive type "
49-
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
50-
f"Complex types like lists, dicts, or nested models are not allowed."
50+
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
51+
f"or Optional of these types. Nested models and complex types are not allowed."
5152
)
5253

5354

55+
def _is_string_sequence(annotation: type) -> bool:
56+
"""Check if annotation is a sequence of strings (list[str], Sequence[str], etc)."""
57+
origin = get_origin(annotation)
58+
# Check if it's a sequence-like type with str elements
59+
if origin and issubclass(origin, Sequence):
60+
args = get_args(annotation)
61+
# Should have single str type arg
62+
return len(args) == 1 and args[0] is str
63+
return False
64+
65+
5466
def _is_primitive_field(field_info: FieldInfo) -> bool:
5567
"""Check if a field is a primitive type allowed in elicitation schemas."""
5668
annotation = field_info.annotation
@@ -63,12 +75,21 @@ def _is_primitive_field(field_info: FieldInfo) -> bool:
6375
if annotation in _ELICITATION_PRIMITIVE_TYPES:
6476
return True
6577

78+
# Handle string sequences for multi-select enums
79+
if annotation is not None and _is_string_sequence(annotation):
80+
return True
81+
6682
# Handle Union types
6783
origin = get_origin(annotation)
6884
if origin is Union or origin is types.UnionType:
6985
args = get_args(annotation)
70-
# All args must be primitive types or None
71-
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
86+
# All args must be primitive types, None, or string sequences
87+
return all(
88+
arg is types.NoneType
89+
or arg in _ELICITATION_PRIMITIVE_TYPES
90+
or (arg is not None and _is_string_sequence(arg))
91+
for arg in args
92+
)
7293

7394
return False
7495

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ class ElicitResult(Result):
13041304
- "cancel": User dismissed without making an explicit choice
13051305
"""
13061306

1307-
content: dict[str, str | int | float | bool | None] | None = None
1307+
content: dict[str, str | int | float | bool | list[str] | None] | None = None
13081308
"""
13091309
The submitted form data, only present when action is "accept".
13101310
Contains values matching the requested schema.

tests/server/fastmcp/test_elicitation.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover
116116

117117
# Test cases for invalid schemas
118118
class InvalidListSchema(BaseModel):
119-
names: list[str] = Field(description="List of names")
119+
numbers: list[int] = Field(description="List of numbers")
120120

121121
class NestedModel(BaseModel):
122122
value: str
@@ -139,7 +139,7 @@ async def elicitation_callback(
139139
await client_session.initialize()
140140

141141
# Test both invalid schemas
142-
for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]:
142+
for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]:
143143
result = await client_session.call_tool(tool_name, {})
144144
assert len(result.content) == 1
145145
assert isinstance(result.content[0], TextContent)
@@ -197,7 +197,7 @@ async def callback(context: RequestContext[ClientSession, None], params: ElicitR
197197
# Test invalid optional field
198198
class InvalidOptionalSchema(BaseModel):
199199
name: str = Field(description="Name")
200-
optional_list: list[str] | None = Field(default=None, description="Invalid optional list")
200+
optional_list: list[int] | None = Field(default=None, description="Invalid optional list")
201201

202202
@mcp.tool(description="Tool with invalid optional field")
203203
async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover
@@ -220,6 +220,25 @@ async def elicitation_callback(
220220
text_contains=["Validation failed:", "optional_list"],
221221
)
222222

223+
# Test valid list[str] for multi-select enum
224+
class ValidMultiSelectSchema(BaseModel):
225+
name: str = Field(description="Name")
226+
tags: list[str] = Field(description="Tags")
227+
228+
@mcp.tool(description="Tool with valid list[str] field")
229+
async def valid_multiselect_tool(ctx: Context[ServerSession, None]) -> str:
230+
result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema)
231+
if result.action == "accept" and result.data:
232+
return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}"
233+
return f"User {result.action}"
234+
235+
async def multiselect_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams):
236+
if "Please provide tags" in params.message:
237+
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
238+
return ElicitResult(action="decline")
239+
240+
await call_tool_and_assert(mcp, multiselect_callback, "valid_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2")
241+
223242

224243
@pytest.mark.anyio
225244
async def test_elicitation_with_default_values():
@@ -274,3 +293,89 @@ async def callback_override(context: RequestContext[ClientSession, None], params
274293
await call_tool_and_assert(
275294
mcp, callback_override, "defaults_tool", {}, "Name: John, Age: 25, Subscribe: False, Email: john@example.com"
276295
)
296+
297+
298+
@pytest.mark.anyio
299+
async def test_elicitation_with_enum_titles():
300+
"""Test elicitation with enum schemas using oneOf/anyOf for titles."""
301+
mcp = FastMCP(name="ColorPreferencesApp")
302+
303+
# Test single-select with titles using oneOf
304+
class FavoriteColorSchema(BaseModel):
305+
user_name: str = Field(description="Your name")
306+
favorite_color: str = Field(
307+
description="Select your favorite color",
308+
json_schema_extra={
309+
"oneOf": [
310+
{"const": "red", "title": "Red"},
311+
{"const": "green", "title": "Green"},
312+
{"const": "blue", "title": "Blue"},
313+
{"const": "yellow", "title": "Yellow"},
314+
]
315+
},
316+
)
317+
318+
@mcp.tool(description="Single color selection")
319+
async def select_favorite_color(ctx: Context) -> str:
320+
result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema)
321+
if result.action == "accept" and result.data:
322+
return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}"
323+
return f"User {result.action}"
324+
325+
# Test multi-select with titles using anyOf
326+
class FavoriteColorsSchema(BaseModel):
327+
user_name: str = Field(description="Your name")
328+
favorite_colors: list[str] = Field(
329+
description="Select your favorite colors",
330+
json_schema_extra={
331+
"items": {
332+
"anyOf": [
333+
{"const": "red", "title": "Red"},
334+
{"const": "green", "title": "Green"},
335+
{"const": "blue", "title": "Blue"},
336+
{"const": "yellow", "title": "Yellow"},
337+
]
338+
}
339+
},
340+
)
341+
342+
@mcp.tool(description="Multiple color selection")
343+
async def select_favorite_colors(ctx: Context) -> str:
344+
result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema)
345+
if result.action == "accept" and result.data:
346+
return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}"
347+
return f"User {result.action}"
348+
349+
# Test deprecated enumNames format
350+
class DeprecatedColorSchema(BaseModel):
351+
user_name: str = Field(description="Your name")
352+
color: str = Field(
353+
description="Select a color",
354+
json_schema_extra={"enum": ["red", "green", "blue"], "enumNames": ["Red", "Green", "Blue"]},
355+
)
356+
357+
@mcp.tool(description="Deprecated enum format")
358+
async def select_color_deprecated(ctx: Context) -> str:
359+
result = await ctx.elicit(message="Select a color (deprecated format)", schema=DeprecatedColorSchema)
360+
if result.action == "accept" and result.data:
361+
return f"User: {result.data.user_name}, Color: {result.data.color}"
362+
return f"User {result.action}"
363+
364+
async def enum_callback(context, params):
365+
if "colors" in params.message and "deprecated" not in params.message:
366+
return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]})
367+
elif "color" in params.message:
368+
if "deprecated" in params.message:
369+
return ElicitResult(action="accept", content={"user_name": "Charlie", "color": "green"})
370+
else:
371+
return ElicitResult(action="accept", content={"user_name": "Alice", "favorite_color": "blue"})
372+
return ElicitResult(action="decline")
373+
374+
# Test single-select with titles
375+
await call_tool_and_assert(mcp, enum_callback, "select_favorite_color", {}, "User: Alice, Favorite: blue")
376+
377+
# Test multi-select with titles
378+
await call_tool_and_assert(mcp, enum_callback, "select_favorite_colors", {}, "User: Bob, Colors: red, green")
379+
380+
# Test deprecated enumNames format
381+
await call_tool_and_assert(mcp, enum_callback, "select_color_deprecated", {}, "User: Charlie, Color: green")

0 commit comments

Comments
 (0)