Skip to content

Commit 98782fc

Browse files
committed
Add poll_task() method and update examples to use spec-compliant polling
The MCP Tasks spec requires clients to poll tasks/get watching for status changes, then call tasks/result when status becomes input_required to receive elicitation/sampling requests. - Add poll_task() async iterator to ExperimentalClientFeatures that yields status on each poll and respects the server's pollInterval hint - Update simple-task-client to use poll_task() instead of manual loop - Update simple-task-interactive-client to poll first, then call tasks/result on input_required per the spec pattern
1 parent b529e26 commit 98782fc

File tree

4 files changed

+195
-15
lines changed
  • examples/clients
    • simple-task-client/mcp_simple_task_client
    • simple-task-interactive-client/mcp_simple_task_interactive_client
  • src/mcp/client/experimental
  • tests/experimental/tasks/client

4 files changed

+195
-15
lines changed

examples/clients/simple-task-client/mcp_simple_task_client/main.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,14 @@ async def run(url: str) -> None:
2828
task_id = result.task.taskId
2929
print(f"Task created: {task_id}")
3030

31-
# Poll until done
32-
while True:
33-
status = await session.experimental.get_task(task_id)
31+
# Poll until done (respects server's pollInterval hint)
32+
async for status in session.experimental.poll_task(task_id):
3433
print(f" Status: {status.status} - {status.statusMessage or ''}")
3534

36-
if status.status == "completed":
37-
break
38-
elif status.status in ("failed", "cancelled"):
39-
print(f"Task ended with status: {status.status}")
40-
return
41-
42-
await asyncio.sleep(0.5)
35+
# Check final status
36+
if status.status != "completed":
37+
print(f"Task ended with status: {status.status}")
38+
return
4339

4440
# Get the result
4541
task_result = await session.experimental.get_task_result(task_id, CallToolResult)

examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
"""Simple interactive task client demonstrating elicitation and sampling responses."""
1+
"""Simple interactive task client demonstrating elicitation and sampling responses.
2+
3+
This example demonstrates the spec-compliant polling pattern:
4+
1. Poll tasks/get watching for status changes
5+
2. On input_required, call tasks/result to receive elicitation/sampling requests
6+
3. Continue until terminal status, then retrieve final result
7+
"""
28

39
import asyncio
410
from typing import Any
@@ -88,8 +94,17 @@ async def run(url: str) -> None:
8894
task_id = result.task.taskId
8995
print(f"Task created: {task_id}")
9096

91-
# get_task_result() delivers elicitation requests and blocks until complete
92-
final = await session.experimental.get_task_result(task_id, CallToolResult)
97+
# Poll until terminal, calling tasks/result on input_required
98+
async for status in session.experimental.poll_task(task_id):
99+
print(f"[Poll] Status: {status.status}")
100+
if status.status == "input_required":
101+
# Server needs input - tasks/result delivers the elicitation request
102+
final = await session.experimental.get_task_result(task_id, CallToolResult)
103+
break
104+
else:
105+
# poll_task exited due to terminal status
106+
final = await session.experimental.get_task_result(task_id, CallToolResult)
107+
93108
print(f"Result: {get_text(final)}")
94109

95110
# Demo 2: Sampling (write_haiku)
@@ -100,8 +115,15 @@ async def run(url: str) -> None:
100115
task_id = result.task.taskId
101116
print(f"Task created: {task_id}")
102117

103-
# get_task_result() delivers sampling requests and blocks until complete
104-
final = await session.experimental.get_task_result(task_id, CallToolResult)
118+
# Poll until terminal, calling tasks/result on input_required
119+
async for status in session.experimental.poll_task(task_id):
120+
print(f"[Poll] Status: {status.status}")
121+
if status.status == "input_required":
122+
final = await session.experimental.get_task_result(task_id, CallToolResult)
123+
break
124+
else:
125+
final = await session.experimental.get_task_result(task_id, CallToolResult)
126+
105127
print(f"Result:\n{get_text(final)}")
106128

107129

src/mcp/client/experimental/tasks.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,13 @@
2424
await session.experimental.cancel_task(task_id)
2525
"""
2626

27+
from collections.abc import AsyncIterator
2728
from typing import TYPE_CHECKING, Any, TypeVar
2829

30+
import anyio
31+
2932
import mcp.types as types
33+
from mcp.shared.experimental.tasks.helpers import is_terminal
3034

3135
if TYPE_CHECKING:
3236
from mcp.client.session import ClientSession
@@ -191,3 +195,40 @@ async def cancel_task(self, task_id: str) -> types.CancelTaskResult:
191195
),
192196
types.CancelTaskResult,
193197
)
198+
199+
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
200+
"""
201+
Poll a task until it reaches a terminal status.
202+
203+
Yields GetTaskResult for each poll, allowing the caller to react to
204+
status changes (e.g., handle input_required). Exits when task reaches
205+
a terminal status (completed, failed, cancelled).
206+
207+
Respects the pollInterval hint from the server.
208+
209+
Args:
210+
task_id: The task identifier
211+
212+
Yields:
213+
GetTaskResult for each poll
214+
215+
Example:
216+
async for status in session.experimental.poll_task(task_id):
217+
print(f"Status: {status.status}")
218+
if status.status == "input_required":
219+
# Handle elicitation request via tasks/result
220+
pass
221+
222+
# Task is now terminal, get the result
223+
result = await session.experimental.get_task_result(task_id, CallToolResult)
224+
"""
225+
while True:
226+
status = await self.get_task(task_id)
227+
yield status
228+
229+
if is_terminal(status.status):
230+
break
231+
232+
# Respect server's pollInterval hint, default to 500ms if not specified
233+
interval_ms = status.pollInterval if status.pollInterval is not None else 500
234+
await anyio.sleep(interval_ms / 1000)
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""Tests for poll_task async iterator."""
2+
3+
from collections.abc import Callable, Coroutine
4+
from datetime import datetime, timezone
5+
from typing import Any
6+
from unittest.mock import AsyncMock
7+
8+
import pytest
9+
10+
from mcp.client.experimental.tasks import ExperimentalClientFeatures
11+
from mcp.types import GetTaskResult, TaskStatus
12+
13+
14+
def make_task_result(
15+
status: TaskStatus = "working",
16+
poll_interval: int = 0,
17+
task_id: str = "test-task",
18+
status_message: str | None = None,
19+
) -> GetTaskResult:
20+
"""Create GetTaskResult with sensible defaults."""
21+
now = datetime.now(timezone.utc)
22+
return GetTaskResult(
23+
taskId=task_id,
24+
status=status,
25+
statusMessage=status_message,
26+
createdAt=now,
27+
lastUpdatedAt=now,
28+
ttl=60000,
29+
pollInterval=poll_interval,
30+
)
31+
32+
33+
def make_status_sequence(
34+
*statuses: TaskStatus,
35+
task_id: str = "test-task",
36+
) -> Callable[[str], Coroutine[Any, Any, GetTaskResult]]:
37+
"""Create mock get_task that returns statuses in sequence."""
38+
status_iter = iter(statuses)
39+
40+
async def mock_get_task(tid: str) -> GetTaskResult:
41+
return make_task_result(status=next(status_iter), task_id=tid)
42+
43+
return mock_get_task
44+
45+
46+
@pytest.fixture
47+
def mock_session() -> AsyncMock:
48+
return AsyncMock()
49+
50+
51+
@pytest.fixture
52+
def features(mock_session: AsyncMock) -> ExperimentalClientFeatures:
53+
return ExperimentalClientFeatures(mock_session)
54+
55+
56+
@pytest.mark.anyio
57+
async def test_poll_task_yields_until_completed(features: ExperimentalClientFeatures) -> None:
58+
"""poll_task yields each status until terminal."""
59+
features.get_task = make_status_sequence("working", "working", "completed") # type: ignore[method-assign]
60+
61+
statuses = [s.status async for s in features.poll_task("test-task")]
62+
63+
assert statuses == ["working", "working", "completed"]
64+
65+
66+
@pytest.mark.anyio
67+
@pytest.mark.parametrize("terminal_status", ["completed", "failed", "cancelled"])
68+
async def test_poll_task_exits_on_terminal(features: ExperimentalClientFeatures, terminal_status: TaskStatus) -> None:
69+
"""poll_task exits immediately when task is already terminal."""
70+
features.get_task = make_status_sequence(terminal_status) # type: ignore[method-assign]
71+
72+
statuses = [s.status async for s in features.poll_task("test-task")]
73+
74+
assert statuses == [terminal_status]
75+
76+
77+
@pytest.mark.anyio
78+
async def test_poll_task_continues_through_input_required(features: ExperimentalClientFeatures) -> None:
79+
"""poll_task yields input_required and continues (non-terminal)."""
80+
features.get_task = make_status_sequence("working", "input_required", "working", "completed") # type: ignore[method-assign]
81+
82+
statuses = [s.status async for s in features.poll_task("test-task")]
83+
84+
assert statuses == ["working", "input_required", "working", "completed"]
85+
86+
87+
@pytest.mark.anyio
88+
async def test_poll_task_passes_task_id(features: ExperimentalClientFeatures) -> None:
89+
"""poll_task passes correct task_id to get_task."""
90+
received_ids: list[str] = []
91+
92+
async def mock_get_task(task_id: str) -> GetTaskResult:
93+
received_ids.append(task_id)
94+
return make_task_result(status="completed", task_id=task_id)
95+
96+
features.get_task = mock_get_task # type: ignore[method-assign]
97+
98+
_ = [s async for s in features.poll_task("my-task-123")]
99+
100+
assert received_ids == ["my-task-123"]
101+
102+
103+
@pytest.mark.anyio
104+
async def test_poll_task_yields_full_result(features: ExperimentalClientFeatures) -> None:
105+
"""poll_task yields complete GetTaskResult objects."""
106+
107+
async def mock_get_task(task_id: str) -> GetTaskResult:
108+
return make_task_result(
109+
status="completed",
110+
task_id=task_id,
111+
status_message="All done!",
112+
)
113+
114+
features.get_task = mock_get_task # type: ignore[method-assign]
115+
116+
results = [r async for r in features.poll_task("test-task")]
117+
118+
assert len(results) == 1
119+
assert results[0].status == "completed"
120+
assert results[0].statusMessage == "All done!"
121+
assert results[0].taskId == "test-task"

0 commit comments

Comments
 (0)