Skip to content

Commit 8b31b61

Browse files
committed
Clean up task code: fix error codes, imports, type annotations, and add test fixtures
Source code: - Use INVALID_REQUEST constant instead of hardcoded -32600 - Move inline imports to module level in task_context.py - Use cast() instead of type: ignore in resolver.py Test improvements: - Add ClientTestStreams dataclass and client_streams fixture - Add store fixture for test_store.py with automatic cleanup - Remove unused ClientTaskContext dataclass - Add STREAM_BUFFER_SIZE constant for magic number
1 parent e200fbd commit 8b31b61

File tree

4 files changed

+87
-149
lines changed

4 files changed

+87
-149
lines changed

src/mcp/server/experimental/task_context.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from mcp.server.session import ServerSession
1616
from mcp.shared.exceptions import McpError
1717
from mcp.shared.experimental.tasks.context import TaskContext
18+
from mcp.shared.experimental.tasks.helpers import create_task_state
1819
from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue
1920
from mcp.shared.experimental.tasks.resolver import Resolver
2021
from mcp.shared.experimental.tasks.store import TaskStore
2122
from mcp.types import (
23+
INVALID_REQUEST,
2224
TASK_STATUS_INPUT_REQUIRED,
2325
TASK_STATUS_WORKING,
2426
ClientCapabilities,
@@ -35,6 +37,7 @@
3537
SamplingMessage,
3638
ServerNotification,
3739
Task,
40+
TaskMetadata,
3841
TaskStatusNotification,
3942
TaskStatusNotificationParams,
4043
)
@@ -90,14 +93,9 @@ def __init__(
9093
if task is not None and task_id is not None:
9194
raise ValueError("Provide either task or task_id, not both")
9295

93-
# If task_id provided, we need to get the task from the store synchronously
94-
# This is a limitation - for async task lookup, use task= parameter
96+
# If task_id provided, create a minimal task object
97+
# This is for backwards compatibility with tests that pass task_id
9598
if task is None:
96-
# Create a minimal task object - the real task state comes from the store
97-
# This is for backwards compatibility with tests that pass task_id
98-
from mcp.shared.experimental.tasks.helpers import create_task_state
99-
from mcp.types import TaskMetadata
100-
10199
task = create_task_state(TaskMetadata(ttl=None), task_id=task_id)
102100

103101
self._ctx = TaskContext(task=task, store=store)
@@ -191,7 +189,7 @@ def _check_elicitation_capability(self) -> None:
191189
if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())):
192190
raise McpError(
193191
ErrorData(
194-
code=-32600, # INVALID_REQUEST
192+
code=INVALID_REQUEST,
195193
message="Client does not support elicitation capability",
196194
)
197195
)
@@ -201,7 +199,7 @@ def _check_sampling_capability(self) -> None:
201199
if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())):
202200
raise McpError(
203201
ErrorData(
204-
code=-32600, # INVALID_REQUEST
202+
code=INVALID_REQUEST,
205203
message="Client does not support sampling capability",
206204
)
207205
)

src/mcp/shared/experimental/tasks/resolver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
to another without depending on asyncio.Future.
66
"""
77

8-
from typing import Generic, TypeVar
8+
from typing import Generic, TypeVar, cast
99

1010
import anyio
1111

@@ -52,7 +52,8 @@ async def wait(self) -> T:
5252
await self._event.wait()
5353
if self._exception is not None:
5454
raise self._exception
55-
return self._value # type: ignore[return-value]
55+
# If we reach here, set_result() was called, so _value is set
56+
return cast(T, self._value)
5657

5758
def done(self) -> bool:
5859
"""Return True if the resolver has been completed."""

tests/experimental/tasks/client/test_handlers.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
client -> server task requests.
1111
"""
1212

13-
from dataclasses import dataclass, field
13+
from collections.abc import AsyncIterator
14+
from dataclasses import dataclass
1415

1516
import anyio
1617
import pytest
1718
from anyio import Event
1819
from anyio.abc import TaskGroup
20+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1921

2022
import mcp.types as types
2123
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
@@ -43,14 +45,47 @@
4345
TextContent,
4446
)
4547

48+
# Buffer size for test streams
49+
STREAM_BUFFER_SIZE = 10
50+
4651

4752
@dataclass
48-
class ClientTaskContext:
49-
"""Context for managing client-side tasks during tests."""
53+
class ClientTestStreams:
54+
"""Bidirectional message streams for client/server communication in tests."""
55+
56+
server_send: MemoryObjectSendStream[SessionMessage]
57+
server_receive: MemoryObjectReceiveStream[SessionMessage]
58+
client_send: MemoryObjectSendStream[SessionMessage]
59+
client_receive: MemoryObjectReceiveStream[SessionMessage]
60+
5061

51-
task_group: TaskGroup
52-
store: InMemoryTaskStore
53-
task_done_events: dict[str, Event] = field(default_factory=lambda: {})
62+
@pytest.fixture
63+
async def client_streams() -> AsyncIterator[ClientTestStreams]:
64+
"""Create bidirectional message streams for client tests.
65+
66+
Automatically closes all streams after the test completes.
67+
"""
68+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](
69+
STREAM_BUFFER_SIZE
70+
)
71+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](
72+
STREAM_BUFFER_SIZE
73+
)
74+
75+
streams = ClientTestStreams(
76+
server_send=server_to_client_send,
77+
server_receive=client_to_server_receive,
78+
client_send=client_to_server_send,
79+
client_receive=server_to_client_receive,
80+
)
81+
82+
yield streams
83+
84+
# Cleanup
85+
await server_to_client_send.aclose()
86+
await server_to_client_receive.aclose()
87+
await client_to_server_send.aclose()
88+
await client_to_server_receive.aclose()
5489

5590

5691
@pytest.mark.anyio

0 commit comments

Comments
 (0)