@@ -48,6 +48,95 @@ async def __call__(
4848 ) -> None : ... # pragma: no branch
4949
5050
51+ # Experimental: Task handler protocols for server -> client requests
52+ class GetTaskHandlerFnT (Protocol ):
53+ """Handler for tasks/get requests from server.
54+
55+ WARNING: This is experimental and may change without notice.
56+ """
57+
58+ async def __call__ (
59+ self ,
60+ context : RequestContext ["ClientSession" , Any ],
61+ params : types .GetTaskRequestParams ,
62+ ) -> types .GetTaskResult | types .ErrorData : ... # pragma: no branch
63+
64+
65+ class GetTaskResultHandlerFnT (Protocol ):
66+ """Handler for tasks/result requests from server.
67+
68+ WARNING: This is experimental and may change without notice.
69+ """
70+
71+ async def __call__ (
72+ self ,
73+ context : RequestContext ["ClientSession" , Any ],
74+ params : types .GetTaskPayloadRequestParams ,
75+ ) -> types .GetTaskPayloadResult | types .ErrorData : ... # pragma: no branch
76+
77+
78+ class ListTasksHandlerFnT (Protocol ):
79+ """Handler for tasks/list requests from server.
80+
81+ WARNING: This is experimental and may change without notice.
82+ """
83+
84+ async def __call__ (
85+ self ,
86+ context : RequestContext ["ClientSession" , Any ],
87+ params : types .PaginatedRequestParams | None ,
88+ ) -> types .ListTasksResult | types .ErrorData : ... # pragma: no branch
89+
90+
91+ class CancelTaskHandlerFnT (Protocol ):
92+ """Handler for tasks/cancel requests from server.
93+
94+ WARNING: This is experimental and may change without notice.
95+ """
96+
97+ async def __call__ (
98+ self ,
99+ context : RequestContext ["ClientSession" , Any ],
100+ params : types .CancelTaskRequestParams ,
101+ ) -> types .CancelTaskResult | types .ErrorData : ... # pragma: no branch
102+
103+
104+ class TaskAugmentedSamplingFnT (Protocol ):
105+ """Handler for task-augmented sampling/createMessage requests from server.
106+
107+ When server sends a CreateMessageRequest with task field, this callback
108+ is invoked. The callback should create a task, spawn background work,
109+ and return CreateTaskResult immediately.
110+
111+ WARNING: This is experimental and may change without notice.
112+ """
113+
114+ async def __call__ (
115+ self ,
116+ context : RequestContext ["ClientSession" , Any ],
117+ params : types .CreateMessageRequestParams ,
118+ task_metadata : types .TaskMetadata ,
119+ ) -> types .CreateTaskResult | types .ErrorData : ... # pragma: no branch
120+
121+
122+ class TaskAugmentedElicitationFnT (Protocol ):
123+ """Handler for task-augmented elicitation/create requests from server.
124+
125+ When server sends an ElicitRequest with task field, this callback
126+ is invoked. The callback should create a task, spawn background work,
127+ and return CreateTaskResult immediately.
128+
129+ WARNING: This is experimental and may change without notice.
130+ """
131+
132+ async def __call__ (
133+ self ,
134+ context : RequestContext ["ClientSession" , Any ],
135+ params : types .ElicitRequestParams ,
136+ task_metadata : types .TaskMetadata ,
137+ ) -> types .CreateTaskResult | types .ErrorData : ... # pragma: no branch
138+
139+
51140class MessageHandlerFnT (Protocol ):
52141 async def __call__ (
53142 self ,
@@ -96,6 +185,69 @@ async def _default_logging_callback(
96185 pass
97186
98187
188+ # Default handlers for experimental task requests (return "not supported" errors)
189+ async def _default_get_task_handler (
190+ context : RequestContext ["ClientSession" , Any ],
191+ params : types .GetTaskRequestParams ,
192+ ) -> types .GetTaskResult | types .ErrorData :
193+ return types .ErrorData (
194+ code = types .METHOD_NOT_FOUND ,
195+ message = "tasks/get not supported" ,
196+ )
197+
198+
199+ async def _default_get_task_result_handler (
200+ context : RequestContext ["ClientSession" , Any ],
201+ params : types .GetTaskPayloadRequestParams ,
202+ ) -> types .GetTaskPayloadResult | types .ErrorData :
203+ return types .ErrorData (
204+ code = types .METHOD_NOT_FOUND ,
205+ message = "tasks/result not supported" ,
206+ )
207+
208+
209+ async def _default_list_tasks_handler (
210+ context : RequestContext ["ClientSession" , Any ],
211+ params : types .PaginatedRequestParams | None ,
212+ ) -> types .ListTasksResult | types .ErrorData :
213+ return types .ErrorData (
214+ code = types .METHOD_NOT_FOUND ,
215+ message = "tasks/list not supported" ,
216+ )
217+
218+
219+ async def _default_cancel_task_handler (
220+ context : RequestContext ["ClientSession" , Any ],
221+ params : types .CancelTaskRequestParams ,
222+ ) -> types .CancelTaskResult | types .ErrorData :
223+ return types .ErrorData (
224+ code = types .METHOD_NOT_FOUND ,
225+ message = "tasks/cancel not supported" ,
226+ )
227+
228+
229+ async def _default_task_augmented_sampling_callback (
230+ context : RequestContext ["ClientSession" , Any ],
231+ params : types .CreateMessageRequestParams ,
232+ task_metadata : types .TaskMetadata ,
233+ ) -> types .CreateTaskResult | types .ErrorData :
234+ return types .ErrorData (
235+ code = types .INVALID_REQUEST ,
236+ message = "Task-augmented sampling not supported" ,
237+ )
238+
239+
240+ async def _default_task_augmented_elicitation_callback (
241+ context : RequestContext ["ClientSession" , Any ],
242+ params : types .ElicitRequestParams ,
243+ task_metadata : types .TaskMetadata ,
244+ ) -> types .CreateTaskResult | types .ErrorData :
245+ return types .ErrorData (
246+ code = types .INVALID_REQUEST ,
247+ message = "Task-augmented elicitation not supported" ,
248+ )
249+
250+
99251ClientResponse : TypeAdapter [types .ClientResult | types .ErrorData ] = TypeAdapter (types .ClientResult | types .ErrorData )
100252
101253
@@ -119,6 +271,14 @@ def __init__(
119271 logging_callback : LoggingFnT | None = None ,
120272 message_handler : MessageHandlerFnT | None = None ,
121273 client_info : types .Implementation | None = None ,
274+ tasks_capability : types .ClientTasksCapability | None = None ,
275+ # Experimental: Task handlers for server -> client requests
276+ get_task_handler : GetTaskHandlerFnT | None = None ,
277+ get_task_result_handler : GetTaskResultHandlerFnT | None = None ,
278+ list_tasks_handler : ListTasksHandlerFnT | None = None ,
279+ cancel_task_handler : CancelTaskHandlerFnT | None = None ,
280+ task_augmented_sampling_callback : TaskAugmentedSamplingFnT | None = None ,
281+ task_augmented_elicitation_callback : TaskAugmentedElicitationFnT | None = None ,
122282 ) -> None :
123283 super ().__init__ (
124284 read_stream ,
@@ -133,9 +293,21 @@ def __init__(
133293 self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
134294 self ._logging_callback = logging_callback or _default_logging_callback
135295 self ._message_handler = message_handler or _default_message_handler
296+ self ._tasks_capability = tasks_capability
136297 self ._tool_output_schemas : dict [str , dict [str , Any ] | None ] = {}
137298 self ._server_capabilities : types .ServerCapabilities | None = None
138299 self ._experimental : ExperimentalClientFeatures | None = None
300+ # Experimental: Task handlers
301+ self ._get_task_handler = get_task_handler or _default_get_task_handler
302+ self ._get_task_result_handler = get_task_result_handler or _default_get_task_result_handler
303+ self ._list_tasks_handler = list_tasks_handler or _default_list_tasks_handler
304+ self ._cancel_task_handler = cancel_task_handler or _default_cancel_task_handler
305+ self ._task_augmented_sampling_callback = (
306+ task_augmented_sampling_callback or _default_task_augmented_sampling_callback
307+ )
308+ self ._task_augmented_elicitation_callback = (
309+ task_augmented_elicitation_callback or _default_task_augmented_elicitation_callback
310+ )
139311
140312 async def initialize (self ) -> types .InitializeResult :
141313 sampling = types .SamplingCapability () if self ._sampling_callback is not _default_sampling_callback else None
@@ -166,6 +338,7 @@ async def initialize(self) -> types.InitializeResult:
166338 elicitation = elicitation ,
167339 experimental = None ,
168340 roots = roots ,
341+ tasks = self ._tasks_capability ,
169342 ),
170343 clientInfo = self ._client_info ,
171344 ),
@@ -191,7 +364,7 @@ def get_server_capabilities(self) -> types.ServerCapabilities | None:
191364 return self ._server_capabilities
192365
193366 @property
194- def experimental (self ) -> " ExperimentalClientFeatures" :
367+ def experimental (self ) -> ExperimentalClientFeatures :
195368 """Experimental APIs for tasks and other features.
196369
197370 WARNING: These APIs are experimental and may change without notice.
@@ -540,13 +713,21 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
540713 match responder .request .root :
541714 case types .CreateMessageRequest (params = params ):
542715 with responder :
543- response = await self ._sampling_callback (ctx , params )
716+ # Check if this is a task-augmented request
717+ if params .task is not None :
718+ response = await self ._task_augmented_sampling_callback (ctx , params , params .task )
719+ else :
720+ response = await self ._sampling_callback (ctx , params )
544721 client_response = ClientResponse .validate_python (response )
545722 await responder .respond (client_response )
546723
547724 case types .ElicitRequest (params = params ):
548725 with responder :
549- response = await self ._elicitation_callback (ctx , params )
726+ # Check if this is a task-augmented request
727+ if params .task is not None :
728+ response = await self ._task_augmented_elicitation_callback (ctx , params , params .task )
729+ else :
730+ response = await self ._elicitation_callback (ctx , params )
550731 client_response = ClientResponse .validate_python (response )
551732 await responder .respond (client_response )
552733
@@ -559,7 +740,33 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
559740 case types .PingRequest (): # pragma: no cover
560741 with responder :
561742 return await responder .respond (types .ClientResult (root = types .EmptyResult ()))
562- case _:
743+
744+ # Experimental: Task management requests from server
745+ case types .GetTaskRequest (params = params ):
746+ with responder :
747+ response = await self ._get_task_handler (ctx , params )
748+ client_response = ClientResponse .validate_python (response )
749+ await responder .respond (client_response )
750+
751+ case types .GetTaskPayloadRequest (params = params ):
752+ with responder :
753+ response = await self ._get_task_result_handler (ctx , params )
754+ client_response = ClientResponse .validate_python (response )
755+ await responder .respond (client_response )
756+
757+ case types .ListTasksRequest (params = params ):
758+ with responder :
759+ response = await self ._list_tasks_handler (ctx , params )
760+ client_response = ClientResponse .validate_python (response )
761+ await responder .respond (client_response )
762+
763+ case types .CancelTaskRequest (params = params ):
764+ with responder :
765+ response = await self ._cancel_task_handler (ctx , params )
766+ client_response = ClientResponse .validate_python (response )
767+ await responder .respond (client_response )
768+
769+ case _: # pragma: no cover
563770 raise NotImplementedError ()
564771
565772 async def _handle_incoming (
0 commit comments