Skip to content

Commit 4c0385f

Browse files
committed
Refactor client task handlers into ExperimentalTaskHandlers dataclass
- Replace 6 individual task handler parameters with single `experimental_task_handlers: ExperimentalTaskHandlers` (keyword-only) - ExperimentalTaskHandlers dataclass groups all handlers and provides: - `build_capability()` - auto-builds ClientTasksCapability from handlers - `handles_request()` - checks if request is task-related - `handle_request()` - dispatches to appropriate handler - Simplify ClientSession._received_request by delegating task requests - Update tests to use new ExperimentalTaskHandlers API
1 parent b709d6f commit 4c0385f

File tree

4 files changed

+184
-233
lines changed

4 files changed

+184
-233
lines changed

src/mcp/client/experimental/task_handlers.py

Lines changed: 117 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
- Server polls client's task status via tasks/get, tasks/result, etc.
1313
"""
1414

15+
from dataclasses import dataclass, field
1516
from typing import TYPE_CHECKING, Any, Protocol
1617

1718
import mcp.types as types
1819
from mcp.shared.context import RequestContext
20+
from mcp.shared.session import RequestResponder
1921

2022
if TYPE_CHECKING:
2123
from mcp.client.session import ClientSession
@@ -109,7 +111,11 @@ async def __call__(
109111
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
110112

111113

112-
# Default handlers for experimental task requests (return "not supported" errors)
114+
# =============================================================================
115+
# Default Handlers (return "not supported" errors)
116+
# =============================================================================
117+
118+
113119
async def default_get_task_handler(
114120
context: RequestContext["ClientSession", Any],
115121
params: types.GetTaskRequestParams,
@@ -150,7 +156,7 @@ async def default_cancel_task_handler(
150156
)
151157

152158

153-
async def default_task_augmented_sampling_callback(
159+
async def default_task_augmented_sampling(
154160
context: RequestContext["ClientSession", Any],
155161
params: types.CreateMessageRequestParams,
156162
task_metadata: types.TaskMetadata,
@@ -161,7 +167,7 @@ async def default_task_augmented_sampling_callback(
161167
)
162168

163169

164-
async def default_task_augmented_elicitation_callback(
170+
async def default_task_augmented_elicitation(
165171
context: RequestContext["ClientSession", Any],
166172
params: types.ElicitRequestParams,
167173
task_metadata: types.TaskMetadata,
@@ -172,58 +178,118 @@ async def default_task_augmented_elicitation_callback(
172178
)
173179

174180

175-
def build_client_tasks_capability(
176-
*,
177-
list_tasks_handler: ListTasksHandlerFnT | None = None,
178-
cancel_task_handler: CancelTaskHandlerFnT | None = None,
179-
task_augmented_sampling_callback: TaskAugmentedSamplingFnT | None = None,
180-
task_augmented_elicitation_callback: TaskAugmentedElicitationFnT | None = None,
181-
) -> types.ClientTasksCapability | None:
182-
"""Build ClientTasksCapability from the provided handlers.
183-
184-
This helper builds the appropriate capability object based on which
185-
handlers are provided (non-None and not the default handlers).
181+
@dataclass
182+
class ExperimentalTaskHandlers:
183+
"""Container for experimental task handlers.
186184
187-
WARNING: This is experimental and may change without notice.
185+
Groups all task-related handlers that handle server -> client requests.
186+
This includes both pure task requests (get, list, cancel, result) and
187+
task-augmented request handlers (sampling, elicitation with task field).
188188
189-
Args:
190-
list_tasks_handler: Handler for tasks/list requests
191-
cancel_task_handler: Handler for tasks/cancel requests
192-
task_augmented_sampling_callback: Handler for task-augmented sampling
193-
task_augmented_elicitation_callback: Handler for task-augmented elicitation
189+
WARNING: These APIs are experimental and may change without notice.
194190
195-
Returns:
196-
ClientTasksCapability if any handlers are provided, None otherwise
191+
Example:
192+
handlers = ExperimentalTaskHandlers(
193+
get_task=my_get_task_handler,
194+
list_tasks=my_list_tasks_handler,
195+
)
196+
session = ClientSession(..., experimental_task_handlers=handlers)
197197
"""
198-
has_list = list_tasks_handler is not None and list_tasks_handler is not default_list_tasks_handler
199-
has_cancel = cancel_task_handler is not None and cancel_task_handler is not default_cancel_task_handler
200-
has_sampling = (
201-
task_augmented_sampling_callback is not None
202-
and task_augmented_sampling_callback is not default_task_augmented_sampling_callback
203-
)
204-
has_elicitation = (
205-
task_augmented_elicitation_callback is not None
206-
and task_augmented_elicitation_callback is not default_task_augmented_elicitation_callback
207-
)
208198

209-
# If no handlers are provided, return None
210-
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
211-
return None
212-
213-
# Build requests capability if any request handlers are provided
214-
requests_capability: types.ClientTasksRequestsCapability | None = None
215-
if has_sampling or has_elicitation:
216-
requests_capability = types.ClientTasksRequestsCapability(
217-
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
218-
if has_sampling
219-
else None,
220-
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
221-
if has_elicitation
222-
else None,
199+
# Pure task request handlers
200+
get_task: GetTaskHandlerFnT = field(default=default_get_task_handler)
201+
get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler)
202+
list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler)
203+
cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler)
204+
205+
# Task-augmented request handlers
206+
augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling)
207+
augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation)
208+
209+
def build_capability(self) -> types.ClientTasksCapability | None:
210+
"""Build ClientTasksCapability from the configured handlers.
211+
212+
Returns a capability object that reflects which handlers are configured
213+
(i.e., not using the default "not supported" handlers).
214+
215+
Returns:
216+
ClientTasksCapability if any handlers are provided, None otherwise
217+
"""
218+
has_list = self.list_tasks is not default_list_tasks_handler
219+
has_cancel = self.cancel_task is not default_cancel_task_handler
220+
has_sampling = self.augmented_sampling is not default_task_augmented_sampling
221+
has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation
222+
223+
# If no handlers are provided, return None
224+
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
225+
return None
226+
227+
# Build requests capability if any request handlers are provided
228+
requests_capability: types.ClientTasksRequestsCapability | None = None
229+
if has_sampling or has_elicitation:
230+
requests_capability = types.ClientTasksRequestsCapability(
231+
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
232+
if has_sampling
233+
else None,
234+
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
235+
if has_elicitation
236+
else None,
237+
)
238+
239+
return types.ClientTasksCapability(
240+
list=types.TasksListCapability() if has_list else None,
241+
cancel=types.TasksCancelCapability() if has_cancel else None,
242+
requests=requests_capability,
223243
)
224244

225-
return types.ClientTasksCapability(
226-
list=types.TasksListCapability() if has_list else None,
227-
cancel=types.TasksCancelCapability() if has_cancel else None,
228-
requests=requests_capability,
229-
)
245+
@staticmethod
246+
def handles_request(request: types.ServerRequest) -> bool:
247+
"""Check if this handler handles the given request type."""
248+
return isinstance(
249+
request.root,
250+
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
251+
)
252+
253+
async def handle_request(
254+
self,
255+
ctx: RequestContext["ClientSession", Any],
256+
responder: RequestResponder[types.ServerRequest, types.ClientResult],
257+
) -> None:
258+
"""Handle a task-related request from the server.
259+
260+
Call handles_request() first to check if this handler can handle the request.
261+
"""
262+
from pydantic import TypeAdapter
263+
264+
client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
265+
types.ClientResult | types.ErrorData
266+
)
267+
268+
match responder.request.root:
269+
case types.GetTaskRequest(params=params):
270+
response = await self.get_task(ctx, params)
271+
client_response = client_response_type.validate_python(response)
272+
await responder.respond(client_response)
273+
274+
case types.GetTaskPayloadRequest(params=params):
275+
response = await self.get_task_result(ctx, params)
276+
client_response = client_response_type.validate_python(response)
277+
await responder.respond(client_response)
278+
279+
case types.ListTasksRequest(params=params):
280+
response = await self.list_tasks(ctx, params)
281+
client_response = client_response_type.validate_python(response)
282+
await responder.respond(client_response)
283+
284+
case types.CancelTaskRequest(params=params):
285+
response = await self.cancel_task(ctx, params)
286+
client_response = client_response_type.validate_python(response)
287+
await responder.respond(client_response)
288+
289+
case _: # pragma: no cover
290+
raise ValueError(f"Unhandled request type: {type(responder.request.root)}")
291+
292+
293+
# Backwards compatibility aliases
294+
default_task_augmented_sampling_callback = default_task_augmented_sampling
295+
default_task_augmented_elicitation_callback = default_task_augmented_elicitation

src/mcp/client/session.py

Lines changed: 21 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,7 @@
99

1010
import mcp.types as types
1111
from mcp.client.experimental import ExperimentalClientFeatures
12-
from mcp.client.experimental.task_handlers import (
13-
CancelTaskHandlerFnT,
14-
GetTaskHandlerFnT,
15-
GetTaskResultHandlerFnT,
16-
ListTasksHandlerFnT,
17-
TaskAugmentedElicitationFnT,
18-
TaskAugmentedSamplingFnT,
19-
build_client_tasks_capability,
20-
default_cancel_task_handler,
21-
default_get_task_handler,
22-
default_get_task_result_handler,
23-
default_list_tasks_handler,
24-
default_task_augmented_elicitation_callback,
25-
default_task_augmented_sampling_callback,
26-
)
12+
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
2713
from mcp.shared.context import RequestContext
2814
from mcp.shared.message import SessionMessage
2915
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
@@ -134,14 +120,8 @@ def __init__(
134120
logging_callback: LoggingFnT | None = None,
135121
message_handler: MessageHandlerFnT | None = None,
136122
client_info: types.Implementation | None = None,
137-
tasks_capability: types.ClientTasksCapability | None = None,
138-
# Experimental: Task handlers for server -> client requests
139-
get_task_handler: GetTaskHandlerFnT | None = None,
140-
get_task_result_handler: GetTaskResultHandlerFnT | None = None,
141-
list_tasks_handler: ListTasksHandlerFnT | None = None,
142-
cancel_task_handler: CancelTaskHandlerFnT | None = None,
143-
task_augmented_sampling_callback: TaskAugmentedSamplingFnT | None = None,
144-
task_augmented_elicitation_callback: TaskAugmentedElicitationFnT | None = None,
123+
*,
124+
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
145125
) -> None:
146126
super().__init__(
147127
read_stream,
@@ -158,25 +138,10 @@ def __init__(
158138
self._message_handler = message_handler or _default_message_handler
159139
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
160140
self._server_capabilities: types.ServerCapabilities | None = None
161-
self._experimental: ExperimentalClientFeatures | None = None
162-
# Experimental: Task handlers
163-
self._get_task_handler = get_task_handler or default_get_task_handler
164-
self._get_task_result_handler = get_task_result_handler or default_get_task_result_handler
165-
self._list_tasks_handler = list_tasks_handler or default_list_tasks_handler
166-
self._cancel_task_handler = cancel_task_handler or default_cancel_task_handler
167-
self._task_augmented_sampling_callback = (
168-
task_augmented_sampling_callback or default_task_augmented_sampling_callback
169-
)
170-
self._task_augmented_elicitation_callback = (
171-
task_augmented_elicitation_callback or default_task_augmented_elicitation_callback
172-
)
173-
# Build tasks capability from handlers if not explicitly provided
174-
self._tasks_capability = tasks_capability or build_client_tasks_capability(
175-
list_tasks_handler=list_tasks_handler,
176-
cancel_task_handler=cancel_task_handler,
177-
task_augmented_sampling_callback=task_augmented_sampling_callback,
178-
task_augmented_elicitation_callback=task_augmented_elicitation_callback,
179-
)
141+
self._experimental_features: ExperimentalClientFeatures | None = None
142+
143+
# Experimental: Task handlers (use defaults if not provided)
144+
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
180145

181146
async def initialize(self) -> types.InitializeResult:
182147
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -207,7 +172,7 @@ async def initialize(self) -> types.InitializeResult:
207172
elicitation=elicitation,
208173
experimental=None,
209174
roots=roots,
210-
tasks=self._tasks_capability,
175+
tasks=self._task_handlers.build_capability(),
211176
),
212177
clientInfo=self._client_info,
213178
),
@@ -242,9 +207,9 @@ def experimental(self) -> ExperimentalClientFeatures:
242207
status = await session.experimental.get_task(task_id)
243208
result = await session.experimental.get_task_result(task_id, CallToolResult)
244209
"""
245-
if self._experimental is None:
246-
self._experimental = ExperimentalClientFeatures(self)
247-
return self._experimental
210+
if self._experimental_features is None:
211+
self._experimental_features = ExperimentalClientFeatures(self)
212+
return self._experimental_features
248213

249214
async def send_ping(self) -> types.EmptyResult:
250215
"""Send a ping request."""
@@ -579,12 +544,19 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
579544
lifespan_context=None,
580545
)
581546

547+
# Delegate to experimental task handler if applicable
548+
if self._task_handlers.handles_request(responder.request):
549+
with responder:
550+
await self._task_handlers.handle_request(ctx, responder)
551+
return None
552+
553+
# Core request handling
582554
match responder.request.root:
583555
case types.CreateMessageRequest(params=params):
584556
with responder:
585557
# Check if this is a task-augmented request
586558
if params.task is not None:
587-
response = await self._task_augmented_sampling_callback(ctx, params, params.task)
559+
response = await self._task_handlers.augmented_sampling(ctx, params, params.task)
588560
else:
589561
response = await self._sampling_callback(ctx, params)
590562
client_response = ClientResponse.validate_python(response)
@@ -594,7 +566,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
594566
with responder:
595567
# Check if this is a task-augmented request
596568
if params.task is not None:
597-
response = await self._task_augmented_elicitation_callback(ctx, params, params.task)
569+
response = await self._task_handlers.augmented_elicitation(ctx, params, params.task)
598570
else:
599571
response = await self._elicitation_callback(ctx, params)
600572
client_response = ClientResponse.validate_python(response)
@@ -610,33 +582,9 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
610582
with responder:
611583
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
612584

613-
# Experimental: Task management requests from server
614-
case types.GetTaskRequest(params=params):
615-
with responder:
616-
response = await self._get_task_handler(ctx, params)
617-
client_response = ClientResponse.validate_python(response)
618-
await responder.respond(client_response)
619-
620-
case types.GetTaskPayloadRequest(params=params):
621-
with responder:
622-
response = await self._get_task_result_handler(ctx, params)
623-
client_response = ClientResponse.validate_python(response)
624-
await responder.respond(client_response)
625-
626-
case types.ListTasksRequest(params=params):
627-
with responder:
628-
response = await self._list_tasks_handler(ctx, params)
629-
client_response = ClientResponse.validate_python(response)
630-
await responder.respond(client_response)
631-
632-
case types.CancelTaskRequest(params=params):
633-
with responder:
634-
response = await self._cancel_task_handler(ctx, params)
635-
client_response = ClientResponse.validate_python(response)
636-
await responder.respond(client_response)
637-
638585
case _: # pragma: no cover
639586
raise NotImplementedError()
587+
return None
640588

641589
async def _handle_incoming(
642590
self,

0 commit comments

Comments
 (0)