@@ -893,6 +893,10 @@ def _convert_to_content(
893893 return [TextContent (type = "text" , text = result )]
894894
895895
896+ # Primitive types allowed in elicitation schemas
897+ _ELICITATION_PRIMITIVE_TYPES = (str , int , float , bool )
898+
899+
896900def _validate_elicitation_schema (schema : type [BaseModel ]) -> None :
897901 """Validate that a Pydantic model only contains primitive field types."""
898902 for field_name , field_info in schema .model_fields .items ():
@@ -904,28 +908,24 @@ def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
904908 )
905909
906910
907- # Primitive types allowed in elicitation schemas
908- _ELICITATION_PRIMITIVE_TYPES = (str , int , float , bool )
909-
910-
911911def _is_primitive_field (field_info : FieldInfo ) -> bool :
912912 """Check if a field is a primitive type allowed in elicitation schemas."""
913913 annotation = field_info .annotation
914914
915915 # Handle None type
916- if annotation is type ( None ) :
916+ if annotation is types . NoneType :
917917 return True
918918
919919 # Handle basic primitive types
920920 if annotation in _ELICITATION_PRIMITIVE_TYPES :
921921 return True
922922
923- # Handle Union types (including Optional and Python 3.10+ union syntax)
923+ # Handle Union types
924924 origin = get_origin (annotation )
925- if origin is Union or ( hasattr ( types , 'UnionType' ) and isinstance ( annotation , types .UnionType )) :
925+ if origin is Union or origin is types .UnionType :
926926 args = get_args (annotation )
927927 # All args must be primitive types or None
928- return all (arg is type ( None ) or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args )
928+ return all (arg is types . NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args )
929929
930930 return False
931931
0 commit comments