Skip to content

Commit 78f270c

Browse files
committed
fix: Support string enum for elicitation
1 parent 06748eb commit 78f270c

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/mcp/server/elicitation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import types
66
from collections.abc import Sequence
7+
from enum import Enum, StrEnum
78
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
89

910
from pydantic import BaseModel
@@ -46,7 +47,7 @@ class AcceptedUrlElicitation(BaseModel):
4647

4748

4849
# Primitive types allowed in elicitation schemas
49-
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
50+
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool, StrEnum)
5051

5152

5253
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
@@ -99,6 +100,10 @@ def _is_primitive_field(annotation: type) -> bool:
99100
arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES or _is_string_sequence(arg) for arg in args
100101
)
101102

103+
# Handle Enum types
104+
if issubclass(annotation, str) and issubclass(annotation, Enum):
105+
return True
106+
102107
return False
103108

104109

tests/server/fastmcp/test_elicitation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Test the elicitation feature using stdio transport.
33
"""
44

5+
from enum import StrEnum
56
from typing import Any
67

78
import pytest
@@ -147,6 +148,39 @@ async def elicitation_callback(
147148
assert "Validation failed as expected" in result.content[0].text
148149
assert field_name in result.content[0].text
149150

151+
# Test valid Enum types (should not fail validation)
152+
class Status(StrEnum):
153+
ACTIVE = "active"
154+
INACTIVE = "inactive"
155+
156+
class ValidStrEnumSchema(BaseModel):
157+
status: Status = Field(description="Status using StrEnum")
158+
159+
def create_valid_validation_tool(name: str, schema_class: type[BaseModel]):
160+
@mcp.tool(name=name, description=f"Tool testing {name}")
161+
async def tool(ctx: Context[ServerSession, None]) -> str:
162+
# This should succeed without validation error
163+
result = await ctx.elicit(message="Test valid schema", schema=schema_class)
164+
return f"Success: {result.action}"
165+
166+
return tool
167+
168+
create_valid_validation_tool("valid_strenum", ValidStrEnumSchema)
169+
170+
async def enum_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams):
171+
# Return the required status field
172+
return ElicitResult(action="accept", content={"status": "active"})
173+
174+
async with create_connected_server_and_client_session(
175+
mcp._mcp_server, elicitation_callback=enum_callback
176+
) as client_session:
177+
await client_session.initialize()
178+
179+
result = await client_session.call_tool("valid_strenum", {})
180+
assert len(result.content) == 1
181+
assert isinstance(result.content[0], TextContent)
182+
assert "Success: accept" == result.content[0].text
183+
150184

151185
@pytest.mark.anyio
152186
async def test_elicitation_with_optional_fields():

0 commit comments

Comments
 (0)