|
4 | 4 |
|
5 | 5 | import inspect |
6 | 6 | import re |
7 | | -import types |
8 | 7 | from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence |
9 | 8 | from contextlib import ( |
10 | 9 | AbstractAsyncContextManager, |
11 | 10 | asynccontextmanager, |
12 | 11 | ) |
13 | 12 | from itertools import chain |
14 | | -from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin |
| 13 | +from typing import Any, Generic, Literal |
15 | 14 |
|
16 | 15 | import anyio |
17 | 16 | import pydantic_core |
18 | | -from pydantic import BaseModel, Field, ValidationError |
19 | | -from pydantic.fields import FieldInfo |
| 17 | +from pydantic import BaseModel, Field |
20 | 18 | from pydantic.networks import AnyUrl |
21 | 19 | from pydantic_settings import BaseSettings, SettingsConfigDict |
22 | 20 | from starlette.applications import Starlette |
|
36 | 34 | from mcp.server.auth.settings import ( |
37 | 35 | AuthSettings, |
38 | 36 | ) |
| 37 | +from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation |
39 | 38 | from mcp.server.fastmcp.exceptions import ResourceError |
40 | 39 | from mcp.server.fastmcp.prompts import Prompt, PromptManager |
41 | 40 | from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager |
|
67 | 66 |
|
68 | 67 | logger = get_logger(__name__) |
69 | 68 |
|
70 | | -ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) |
71 | | - |
72 | | - |
73 | | -class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]): |
74 | | - """Result of an elicitation request.""" |
75 | | - |
76 | | - action: Literal["accept", "decline", "cancel"] |
77 | | - """The user's action in response to the elicitation.""" |
78 | | - |
79 | | - data: ElicitSchemaModelT | None = None |
80 | | - """The validated data if action is 'accept', None otherwise.""" |
81 | | - |
82 | | - validation_error: str | None = None |
83 | | - """Validation error message if data failed to validate.""" |
84 | | - |
85 | 69 |
|
86 | 70 | class Settings(BaseSettings, Generic[LifespanResultT]): |
87 | 71 | """FastMCP server settings. |
@@ -893,43 +877,6 @@ def _convert_to_content( |
893 | 877 | return [TextContent(type="text", text=result)] |
894 | 878 |
|
895 | 879 |
|
896 | | -# Primitive types allowed in elicitation schemas |
897 | | -_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) |
898 | | - |
899 | | - |
900 | | -def _validate_elicitation_schema(schema: type[BaseModel]) -> None: |
901 | | - """Validate that a Pydantic model only contains primitive field types.""" |
902 | | - for field_name, field_info in schema.model_fields.items(): |
903 | | - if not _is_primitive_field(field_info): |
904 | | - raise TypeError( |
905 | | - f"Elicitation schema field '{field_name}' must be a primitive type " |
906 | | - f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " |
907 | | - f"Complex types like lists, dicts, or nested models are not allowed." |
908 | | - ) |
909 | | - |
910 | | - |
911 | | -def _is_primitive_field(field_info: FieldInfo) -> bool: |
912 | | - """Check if a field is a primitive type allowed in elicitation schemas.""" |
913 | | - annotation = field_info.annotation |
914 | | - |
915 | | - # Handle None type |
916 | | - if annotation is types.NoneType: |
917 | | - return True |
918 | | - |
919 | | - # Handle basic primitive types |
920 | | - if annotation in _ELICITATION_PRIMITIVE_TYPES: |
921 | | - return True |
922 | | - |
923 | | - # Handle Union types |
924 | | - origin = get_origin(annotation) |
925 | | - if origin is Union or origin is types.UnionType: |
926 | | - args = get_args(annotation) |
927 | | - # All args must be primitive types or None |
928 | | - return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) |
929 | | - |
930 | | - return False |
931 | | - |
932 | | - |
933 | 880 | class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]): |
934 | 881 | """Context object providing access to MCP capabilities. |
935 | 882 |
|
@@ -1053,27 +1000,10 @@ async def elicit( |
1053 | 1000 | The result.data will only be populated if action is "accept" and validation succeeded. |
1054 | 1001 | """ |
1055 | 1002 |
|
1056 | | - # Validate that schema only contains primitive types and fail loudly if not |
1057 | | - _validate_elicitation_schema(schema) |
1058 | | - |
1059 | | - json_schema = schema.model_json_schema() |
1060 | | - |
1061 | | - result = await self.request_context.session.elicit( |
1062 | | - message=message, |
1063 | | - requestedSchema=json_schema, |
1064 | | - related_request_id=self.request_id, |
| 1003 | + return await elicit_with_validation( |
| 1004 | + session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id |
1065 | 1005 | ) |
1066 | 1006 |
|
1067 | | - if result.action == "accept" and result.content: |
1068 | | - # Validate and parse the content using the schema |
1069 | | - try: |
1070 | | - validated_data = schema.model_validate(result.content) |
1071 | | - return ElicitationResult(action="accept", data=validated_data) |
1072 | | - except ValidationError as e: |
1073 | | - return ElicitationResult(action="accept", validation_error=str(e)) |
1074 | | - else: |
1075 | | - return ElicitationResult(action=result.action) |
1076 | | - |
1077 | 1007 | async def log( |
1078 | 1008 | self, |
1079 | 1009 | level: Literal["debug", "info", "warning", "error"], |
|
0 commit comments