Skip to content

Commit 1efe8b0

Browse files
committed
Add server→client task-augmented elicitation and sampling support
This implements the bidirectional task-augmented request pattern where the server can send task-augmented elicitation/sampling requests to the client, and the client can defer processing by returning CreateTaskResult. Key changes: - Add ExperimentalServerSessionFeatures with get_task(), get_task_result(), poll_task(), elicit_as_task(), and create_message_as_task() methods for server→client task operations - Add shared polling utility (poll_until_terminal) used by both client and server to avoid code duplication - Add elicit_as_task() and create_message_as_task() to ServerTaskContext for use inside task-augmented tool calls - Add capability checks for task-augmented elicitation/sampling in ServerSession.check_client_capability() - Add comprehensive tests for all four elicitation scenarios: 1. Normal tool call + normal elicitation 2. Normal tool call + task-augmented elicitation 3. Task-augmented tool call + normal elicitation 4. Task-augmented tool call + task-augmented elicitation The implementation correctly handles the complex bidirectional flow where the server polls the client while the client's tasks/result call is still blocking, waiting for the tool task to complete.
1 parent a28a650 commit 1efe8b0

File tree

7 files changed

+1188
-50
lines changed

7 files changed

+1188
-50
lines changed

src/mcp/client/experimental/tasks.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727
from collections.abc import AsyncIterator
2828
from typing import TYPE_CHECKING, Any, TypeVar
2929

30-
import anyio
31-
3230
import mcp.types as types
33-
from mcp.shared.experimental.tasks.helpers import is_terminal
31+
from mcp.shared.experimental.tasks.polling import poll_until_terminal
3432

3533
if TYPE_CHECKING:
3634
from mcp.client.session import ClientSession
@@ -222,13 +220,5 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
222220
# Task is now terminal, get the result
223221
result = await session.experimental.get_task_result(task_id, CallToolResult)
224222
"""
225-
while True:
226-
status = await self.get_task(task_id)
223+
async for status in poll_until_terminal(self.get_task, task_id):
227224
yield status
228-
229-
if is_terminal(status.status):
230-
break
231-
232-
# Respect server's pollInterval hint, default to 500ms if not specified
233-
interval_ms = status.pollInterval if status.pollInterval is not None else 500
234-
await anyio.sleep(interval_ms / 1000)
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""
2+
Experimental server session features for server→client task operations.
3+
4+
This module provides the server-side equivalent of ExperimentalClientFeatures,
5+
allowing the server to send task-augmented requests to the client and poll for results.
6+
7+
WARNING: These APIs are experimental and may change without notice.
8+
"""
9+
10+
from collections.abc import AsyncIterator
11+
from typing import TYPE_CHECKING, Any, TypeVar
12+
13+
import mcp.types as types
14+
from mcp.shared.experimental.tasks.polling import poll_until_terminal
15+
16+
if TYPE_CHECKING:
17+
from mcp.server.session import ServerSession
18+
19+
ResultT = TypeVar("ResultT", bound=types.Result)
20+
21+
22+
class ExperimentalServerSessionFeatures:
23+
"""
24+
Experimental server session features for server→client task operations.
25+
26+
This provides the server-side equivalent of ExperimentalClientFeatures,
27+
allowing the server to send task-augmented requests to the client and
28+
poll for results.
29+
30+
WARNING: These APIs are experimental and may change without notice.
31+
32+
Access via session.experimental:
33+
result = await session.experimental.elicit_as_task(...)
34+
"""
35+
36+
def __init__(self, session: "ServerSession") -> None:
37+
self._session = session
38+
39+
async def get_task(self, task_id: str) -> types.GetTaskResult:
40+
"""
41+
Send tasks/get to the client to get task status.
42+
43+
Args:
44+
task_id: The task identifier
45+
46+
Returns:
47+
GetTaskResult containing the task status
48+
"""
49+
return await self._session.send_request(
50+
types.ServerRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))),
51+
types.GetTaskResult,
52+
)
53+
54+
async def get_task_result(
55+
self,
56+
task_id: str,
57+
result_type: type[ResultT],
58+
) -> ResultT:
59+
"""
60+
Send tasks/result to the client to retrieve the final result.
61+
62+
Args:
63+
task_id: The task identifier
64+
result_type: The expected result type
65+
66+
Returns:
67+
The task result, validated against result_type
68+
"""
69+
return await self._session.send_request(
70+
types.ServerRequest(types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))),
71+
result_type,
72+
)
73+
74+
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
75+
"""
76+
Poll a client task until it reaches terminal status.
77+
78+
Yields GetTaskResult for each poll, allowing the caller to react to
79+
status changes. Exits when task reaches a terminal status.
80+
81+
Respects the pollInterval hint from the client.
82+
83+
Args:
84+
task_id: The task identifier
85+
86+
Yields:
87+
GetTaskResult for each poll
88+
"""
89+
async for status in poll_until_terminal(self.get_task, task_id):
90+
yield status
91+
92+
async def elicit_as_task(
93+
self,
94+
message: str,
95+
requestedSchema: types.ElicitRequestedSchema,
96+
*,
97+
ttl: int = 60000,
98+
) -> types.ElicitResult:
99+
"""
100+
Send a task-augmented elicitation to the client and poll until complete.
101+
102+
The client will create a local task, process the elicitation asynchronously,
103+
and return the result when ready. This method handles the full flow:
104+
1. Send elicitation with task field
105+
2. Receive CreateTaskResult from client
106+
3. Poll client's task until terminal
107+
4. Retrieve and return the final ElicitResult
108+
109+
Args:
110+
message: The message to present to the user
111+
requestedSchema: Schema defining the expected response
112+
ttl: Task time-to-live in milliseconds
113+
114+
Returns:
115+
The client's elicitation response
116+
"""
117+
create_result = await self._session.send_request(
118+
types.ServerRequest(
119+
types.ElicitRequest(
120+
params=types.ElicitRequestFormParams(
121+
message=message,
122+
requestedSchema=requestedSchema,
123+
task=types.TaskMetadata(ttl=ttl),
124+
)
125+
)
126+
),
127+
types.CreateTaskResult,
128+
)
129+
130+
task_id = create_result.task.taskId
131+
132+
async for _ in self.poll_task(task_id):
133+
pass
134+
135+
return await self.get_task_result(task_id, types.ElicitResult)
136+
137+
async def create_message_as_task(
138+
self,
139+
messages: list[types.SamplingMessage],
140+
*,
141+
max_tokens: int,
142+
ttl: int = 60000,
143+
system_prompt: str | None = None,
144+
include_context: types.IncludeContext | None = None,
145+
temperature: float | None = None,
146+
stop_sequences: list[str] | None = None,
147+
metadata: dict[str, Any] | None = None,
148+
model_preferences: types.ModelPreferences | None = None,
149+
) -> types.CreateMessageResult:
150+
"""
151+
Send a task-augmented sampling request and poll until complete.
152+
153+
The client will create a local task, process the sampling request
154+
asynchronously, and return the result when ready.
155+
156+
Args:
157+
messages: The conversation messages for sampling
158+
max_tokens: Maximum tokens in the response
159+
ttl: Task time-to-live in milliseconds
160+
system_prompt: Optional system prompt
161+
include_context: Context inclusion strategy
162+
temperature: Sampling temperature
163+
stop_sequences: Stop sequences
164+
metadata: Additional metadata
165+
model_preferences: Model selection preferences
166+
167+
Returns:
168+
The sampling result from the client
169+
"""
170+
create_result = await self._session.send_request(
171+
types.ServerRequest(
172+
types.CreateMessageRequest(
173+
params=types.CreateMessageRequestParams(
174+
messages=messages,
175+
maxTokens=max_tokens,
176+
systemPrompt=system_prompt,
177+
includeContext=include_context,
178+
temperature=temperature,
179+
stopSequences=stop_sequences,
180+
metadata=metadata,
181+
modelPreferences=model_preferences,
182+
task=types.TaskMetadata(ttl=ttl),
183+
)
184+
)
185+
),
186+
types.CreateTaskResult,
187+
)
188+
189+
task_id = create_result.task.taskId
190+
191+
async for _ in self.poll_task(task_id):
192+
pass
193+
194+
return await self.get_task_result(task_id, types.CreateMessageResult)

0 commit comments

Comments
 (0)