77from typing import Generic , Literal , TypeVar , Union , get_args , get_origin
88
99from pydantic import BaseModel
10- from pydantic .fields import FieldInfo
1110
1211from mcp .server .session import ServerSession
1312from mcp .types import RequestId
@@ -44,7 +43,15 @@ class CancelledElicitation(BaseModel):
4443def _validate_elicitation_schema (schema : type [BaseModel ]) -> None :
4544 """Validate that a Pydantic model only contains primitive field types."""
4645 for field_name , field_info in schema .model_fields .items ():
47- if not _is_primitive_field (field_info ):
46+ annotation = field_info .annotation
47+
48+ if annotation is None or annotation is types .NoneType :
49+ continue
50+ elif _is_primitive_field (annotation ):
51+ continue
52+ elif _is_string_sequence (annotation ):
53+ continue
54+ else :
4855 raise TypeError (
4956 f"Elicitation schema field '{ field_name } ' must be a primitive type "
5057 f"{ _ELICITATION_PRIMITIVE_TYPES } , a sequence of strings (list[str], etc.), "
@@ -63,33 +70,18 @@ def _is_string_sequence(annotation: type) -> bool:
6370 return False
6471
6572
66- def _is_primitive_field (field_info : FieldInfo ) -> bool :
73+ def _is_primitive_field (annotation : type ) -> bool :
6774 """Check if a field is a primitive type allowed in elicitation schemas."""
68- annotation = field_info .annotation
69-
70- # Handle None type
71- if annotation is types .NoneType : # pragma: no cover
72- return True
73-
7475 # Handle basic primitive types
7576 if annotation in _ELICITATION_PRIMITIVE_TYPES :
7677 return True
7778
78- # Handle string sequences for multi-select enums
79- if annotation is not None and _is_string_sequence (annotation ):
80- return True
81-
8279 # Handle Union types
8380 origin = get_origin (annotation )
8481 if origin is Union or origin is types .UnionType :
8582 args = get_args (annotation )
86- # All args must be primitive types, None, or string sequences
87- return all (
88- arg is types .NoneType
89- or arg in _ELICITATION_PRIMITIVE_TYPES
90- or (arg is not None and _is_string_sequence (arg ))
91- for arg in args
92- )
83+ # All args must be primitive types or None
84+ return all (arg is types .NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args )
9385
9486 return False
9587
0 commit comments