Skip to content

Commit c87874f

Browse files
Tapan Chughfelixweinberger
authored andcommitted
cleanup changes a bit
1 parent aabcecf commit c87874f

File tree

1 file changed

+12
-20
lines changed

1 file changed

+12
-20
lines changed

src/mcp/server/elicitation.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
88

99
from pydantic import BaseModel
10-
from pydantic.fields import FieldInfo
1110

1211
from mcp.server.session import ServerSession
1312
from mcp.types import RequestId
@@ -44,7 +43,15 @@ class CancelledElicitation(BaseModel):
4443
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
4544
"""Validate that a Pydantic model only contains primitive field types."""
4645
for field_name, field_info in schema.model_fields.items():
47-
if not _is_primitive_field(field_info):
46+
annotation = field_info.annotation
47+
48+
if annotation is None or annotation is types.NoneType:
49+
continue
50+
elif _is_primitive_field(annotation):
51+
continue
52+
elif _is_string_sequence(annotation):
53+
continue
54+
else:
4855
raise TypeError(
4956
f"Elicitation schema field '{field_name}' must be a primitive type "
5057
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
@@ -63,33 +70,18 @@ def _is_string_sequence(annotation: type) -> bool:
6370
return False
6471

6572

66-
def _is_primitive_field(field_info: FieldInfo) -> bool:
73+
def _is_primitive_field(annotation: type) -> bool:
6774
"""Check if a field is a primitive type allowed in elicitation schemas."""
68-
annotation = field_info.annotation
69-
70-
# Handle None type
71-
if annotation is types.NoneType: # pragma: no cover
72-
return True
73-
7475
# Handle basic primitive types
7576
if annotation in _ELICITATION_PRIMITIVE_TYPES:
7677
return True
7778

78-
# Handle string sequences for multi-select enums
79-
if annotation is not None and _is_string_sequence(annotation):
80-
return True
81-
8279
# Handle Union types
8380
origin = get_origin(annotation)
8481
if origin is Union or origin is types.UnionType:
8582
args = get_args(annotation)
86-
# All args must be primitive types, None, or string sequences
87-
return all(
88-
arg is types.NoneType
89-
or arg in _ELICITATION_PRIMITIVE_TYPES
90-
or (arg is not None and _is_string_sequence(arg))
91-
for arg in args
92-
)
83+
# All args must be primitive types or None
84+
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
9385

9486
return False
9587

0 commit comments

Comments
 (0)