@@ -22,6 +22,14 @@ async def __call__(
2222 ) -> types .CreateMessageResult | types .ErrorData : ...
2323
2424
25+ class ElicitationFnT (Protocol ):
26+ async def __call__ (
27+ self ,
28+ context : RequestContext ["ClientSession" , Any ],
29+ params : types .ElicitRequestParams ,
30+ ) -> types .ElicitResult | types .ErrorData : ...
31+
32+
2533class ListRootsFnT (Protocol ):
2634 async def __call__ (
2735 self , context : RequestContext ["ClientSession" , Any ]
@@ -58,6 +66,16 @@ async def _default_sampling_callback(
5866 )
5967
6068
69+ async def _default_elicitation_callback (
70+ context : RequestContext ["ClientSession" , Any ],
71+ params : types .ElicitRequestParams ,
72+ ) -> types .ElicitResult | types .ErrorData :
73+ return types .ErrorData (
74+ code = types .INVALID_REQUEST ,
75+ message = "Elicitation not supported" ,
76+ )
77+
78+
6179async def _default_list_roots_callback (
6280 context : RequestContext ["ClientSession" , Any ],
6381) -> types .ListRootsResult | types .ErrorData :
@@ -91,6 +109,7 @@ def __init__(
91109 write_stream : MemoryObjectSendStream [SessionMessage ],
92110 read_timeout_seconds : timedelta | None = None ,
93111 sampling_callback : SamplingFnT | None = None ,
112+ elicitation_callback : ElicitationFnT | None = None ,
94113 list_roots_callback : ListRootsFnT | None = None ,
95114 logging_callback : LoggingFnT | None = None ,
96115 message_handler : MessageHandlerFnT | None = None ,
@@ -105,12 +124,14 @@ def __init__(
105124 )
106125 self ._client_info = client_info or DEFAULT_CLIENT_INFO
107126 self ._sampling_callback = sampling_callback or _default_sampling_callback
127+ self ._elicitation_callback = elicitation_callback or _default_elicitation_callback
108128 self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
109129 self ._logging_callback = logging_callback or _default_logging_callback
110130 self ._message_handler = message_handler or _default_message_handler
111131
112132 async def initialize (self ) -> types .InitializeResult :
113133 sampling = types .SamplingCapability () if self ._sampling_callback is not _default_sampling_callback else None
134+ elicitation = types .ElicitationCapability ()
114135 roots = (
115136 # TODO: Should this be based on whether we
116137 # _will_ send notifications, or only whether
@@ -128,6 +149,7 @@ async def initialize(self) -> types.InitializeResult:
128149 protocolVersion = types .LATEST_PROTOCOL_VERSION ,
129150 capabilities = types .ClientCapabilities (
130151 sampling = sampling ,
152+ elicitation = elicitation ,
131153 experimental = None ,
132154 roots = roots ,
133155 ),
@@ -362,6 +384,12 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
362384 client_response = ClientResponse .validate_python (response )
363385 await responder .respond (client_response )
364386
387+ case types .ElicitRequest (params = params ):
388+ with responder :
389+ response = await self ._elicitation_callback (ctx , params )
390+ client_response = ClientResponse .validate_python (response )
391+ await responder .respond (client_response )
392+
365393 case types .ListRootsRequest ():
366394 with responder :
367395 response = await self ._list_roots_callback (ctx )
0 commit comments