Skip to content

Commit 9065a68

Browse files
committed
added tests
1 parent e895e52 commit 9065a68

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

tests/client/test_sampling_callback.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import anyio
12
import pytest
23

34
from mcp.client.session import ClientSession
@@ -71,3 +72,104 @@ async def test_sampling_tool(message: str):
7172
result.content[0].text
7273
== "Error executing tool test_sampling: Sampling not supported"
7374
)
75+
76+
77+
@pytest.mark.anyio
78+
async def test_concurrent_sampling_callback():
79+
"""Test multiple concurrent sampling calls using time-sort verification."""
80+
from mcp.server.fastmcp import FastMCP
81+
82+
server = FastMCP("test")
83+
84+
# Track completion order using time-sort approach
85+
completion_order = []
86+
87+
async def sampling_callback(
88+
context: RequestContext[ClientSession, None],
89+
params: CreateMessageRequestParams,
90+
) -> CreateMessageResult:
91+
# Extract delay from the message content (e.g., "delay_0.3")
92+
message_text = params.messages[0].content.text
93+
if message_text.startswith("delay_"):
94+
delay = float(message_text.split("_")[1])
95+
# Simulate different LLM response times
96+
await anyio.sleep(delay)
97+
completion_order.append(delay)
98+
return CreateMessageResult(
99+
role="assistant",
100+
content=TextContent(type="text", text=f"Response after {delay}s"),
101+
model="test-model",
102+
stopReason="endTurn",
103+
)
104+
105+
return CreateMessageResult(
106+
role="assistant",
107+
content=TextContent(type="text", text="Default response"),
108+
model="test-model",
109+
stopReason="endTurn",
110+
)
111+
112+
@server.tool("concurrent_sampling_tool")
113+
async def concurrent_sampling_tool():
114+
"""Tool that makes multiple concurrent sampling calls."""
115+
# Use TaskGroup to make multiple concurrent sampling calls
116+
# Using out-of-order durations: 0.6s, 0.2s, 0.4s
117+
# If concurrent, should complete in order: 0.2s, 0.4s, 0.6s
118+
async with anyio.create_task_group() as tg:
119+
results = {}
120+
121+
async def make_sampling_call(call_id: str, delay: float):
122+
result = await server.get_context().session.create_message(
123+
messages=[
124+
SamplingMessage(
125+
role="user",
126+
content=TextContent(type="text", text=f"delay_{delay}"),
127+
)
128+
],
129+
max_tokens=100,
130+
)
131+
results[call_id] = result
132+
133+
# Start operations with out-of-order timing
134+
tg.start_soon(make_sampling_call, "slow_call", 0.6) # Should finish last
135+
tg.start_soon(make_sampling_call, "fast_call", 0.2) # Should finish first
136+
tg.start_soon(
137+
make_sampling_call, "medium_call", 0.4
138+
) # Should finish middle
139+
140+
# Combine results to show all completed
141+
combined_response = " | ".join(
142+
[
143+
results["slow_call"].content.text,
144+
results["fast_call"].content.text,
145+
results["medium_call"].content.text,
146+
]
147+
)
148+
149+
return combined_response
150+
151+
# Test concurrent sampling calls with time-sort verification
152+
async with create_session(
153+
server._mcp_server, sampling_callback=sampling_callback
154+
) as client_session:
155+
# Make a request that triggers multiple concurrent sampling calls
156+
result = await client_session.call_tool("concurrent_sampling_tool", {})
157+
158+
assert result.isError is False
159+
assert isinstance(result.content[0], TextContent)
160+
161+
# Verify all sampling calls completed with expected responses
162+
expected_result = (
163+
"Response after 0.6s | Response after 0.2s | Response after 0.4s"
164+
)
165+
assert result.content[0].text == expected_result
166+
167+
# Key test: verify concurrent execution using time-sort
168+
# Started in order: 0.6s, 0.2s, 0.4s
169+
# Should complete in order: 0.2s, 0.4s, 0.6s (fastest first)
170+
assert len(completion_order) == 3
171+
assert completion_order == [
172+
0.2,
173+
0.4,
174+
0.6,
175+
], f"Expected [0.2, 0.4, 0.6] but got {completion_order}"

tests/shared/test_session.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ClientNotification,
1818
ClientRequest,
1919
EmptyResult,
20+
TextContent,
2021
)
2122

2223

@@ -181,3 +182,77 @@ async def mock_server():
181182
await ev_closed.wait()
182183
with anyio.fail_after(1):
183184
await ev_response.wait()
185+
186+
187+
@pytest.mark.anyio
188+
async def test_async_request_handling_with_taskgroup():
189+
"""Test that multiple sampling requests are handled asynchronously."""
190+
# Track completion order
191+
completion_order = []
192+
193+
def make_server() -> Server:
194+
server = Server(name="AsyncTestServer")
195+
196+
@server.call_tool()
197+
async def handle_call_tool(name: str, arguments: dict | None) -> list:
198+
nonlocal completion_order
199+
200+
if name.startswith("timed_tool"):
201+
# Extract wait time from tool name (e.g., "timed_tool_0.2")
202+
wait_time = float(name.split("_")[-1])
203+
204+
# Wait for the specified time
205+
await anyio.sleep(wait_time)
206+
207+
# Record completion
208+
completion_order.append(wait_time)
209+
210+
return [TextContent(type="text", text=f"Waited {wait_time}s")]
211+
212+
raise ValueError(f"Unknown tool: {name}")
213+
214+
@server.list_tools()
215+
async def handle_list_tools() -> list[types.Tool]:
216+
return [
217+
types.Tool(
218+
name="timed_tool_0.1",
219+
description="Tool that waits 0.1s",
220+
inputSchema={},
221+
),
222+
types.Tool(
223+
name="timed_tool_0.2",
224+
description="Tool that waits 0.2s",
225+
inputSchema={},
226+
),
227+
types.Tool(
228+
name="timed_tool_0.05",
229+
description="Tool that waits 0.05s",
230+
inputSchema={},
231+
),
232+
]
233+
234+
return server
235+
236+
async with create_connected_server_and_client_session(
237+
make_server()
238+
) as client_session:
239+
# Test basic async handling with a single request
240+
result = await client_session.send_request(
241+
ClientRequest(
242+
types.CallToolRequest(
243+
method="tools/call",
244+
params=types.CallToolRequestParams(
245+
name="timed_tool_0.1", arguments={}
246+
),
247+
)
248+
),
249+
types.CallToolResult,
250+
)
251+
252+
# Verify the request completed successfully
253+
assert result.content[0].text == "Waited 0.1s"
254+
assert len(completion_order) == 1
255+
assert completion_order[0] == 0.1
256+
257+
# Verify no pending requests remain
258+
assert len(client_session._in_flight) == 0

0 commit comments

Comments
 (0)