Skip to content

Commit 9e33007

Browse files
committed
Validate tool_use / tool_result blocks in create_message
1 parent eeab4fe commit 9e33007

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

src/mcp/server/session.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,41 @@ async def create_message(
246246
247247
Returns:
248248
The sampling result from the client.
249+
250+
Raises:
251+
McpError: If tool_use or tool_result blocks are misused
249252
"""
253+
254+
if messages and tools:
255+
last_msg_content = messages[-1].content
256+
last_content = last_msg_content if isinstance(last_msg_content, list) else [last_msg_content]
257+
has_tool_results = any(c.type == "tool_result" for c in last_content)
258+
259+
previous_content: list[types.SamplingMessageContentBlock] | None = None
260+
if len(messages) >= 2:
261+
prev_msg_content = messages[-2].content
262+
previous_content = prev_msg_content if isinstance(prev_msg_content, list) else [prev_msg_content]
263+
has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content)
264+
265+
if has_tool_results:
266+
if any(c.type != "tool_result" for c in last_content):
267+
raise ValueError("The last message must contain only tool_result content if any is present")
268+
if len(messages) == 1:
269+
raise ValueError(
270+
"The last message cannot contain tool_result content if there is no previous message"
271+
)
272+
if not has_previous_tool_use:
273+
raise ValueError(
274+
"The previous message must contain tool_use content if the last message contains tool_result content"
275+
)
276+
if has_previous_tool_use and previous_content:
277+
tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
278+
tool_result_ids = {c.toolUseId for c in last_content if c.type == "tool_result"}
279+
if tool_use_ids != tool_result_ids:
280+
raise ValueError(
281+
"The tool_result content in the last message must correspond to all tool_use content in the previous message"
282+
)
283+
250284
return await self.send_request(
251285
request=types.ServerRequest(
252286
types.CreateMessageRequest(

tests/server/test_session.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,80 @@ async def mock_client():
288288
assert ping_response_id == 42
289289

290290

291+
@pytest.mark.anyio
292+
async def test_create_message_tool_result_validation():
293+
"""Test tool_use/tool_result validation in create_message."""
294+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
295+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
296+
297+
async with (
298+
client_to_server_send,
299+
client_to_server_receive,
300+
server_to_client_send,
301+
server_to_client_receive,
302+
):
303+
session = ServerSession(
304+
client_to_server_receive,
305+
server_to_client_send,
306+
InitializationOptions(
307+
server_name="test",
308+
server_version="0.1.0",
309+
capabilities=ServerCapabilities(),
310+
),
311+
)
312+
313+
tool = types.Tool(name="test_tool", inputSchema={"type": "object"})
314+
text = types.TextContent(type="text", text="hello")
315+
tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={})
316+
tool_result = types.ToolResultContent(type="tool_result", toolUseId="call_1", content=[])
317+
318+
# Case 1: tool_result mixed with other content
319+
with pytest.raises(ValueError, match="only tool_result content"):
320+
await session.create_message(
321+
messages=[
322+
types.SamplingMessage(role="user", content=text),
323+
types.SamplingMessage(role="assistant", content=tool_use),
324+
types.SamplingMessage(role="user", content=[tool_result, text]), # mixed!
325+
],
326+
max_tokens=100,
327+
tools=[tool],
328+
)
329+
330+
# Case 2: tool_result without previous message
331+
with pytest.raises(ValueError, match="no previous message"):
332+
await session.create_message(
333+
messages=[types.SamplingMessage(role="user", content=tool_result)],
334+
max_tokens=100,
335+
tools=[tool],
336+
)
337+
338+
# Case 3: tool_result without previous tool_use
339+
with pytest.raises(ValueError, match="previous message must contain tool_use"):
340+
await session.create_message(
341+
messages=[
342+
types.SamplingMessage(role="user", content=text),
343+
types.SamplingMessage(role="user", content=tool_result),
344+
],
345+
max_tokens=100,
346+
tools=[tool],
347+
)
348+
349+
# Case 4: mismatched tool IDs
350+
with pytest.raises(ValueError, match="must correspond to all tool_use"):
351+
await session.create_message(
352+
messages=[
353+
types.SamplingMessage(role="user", content=text),
354+
types.SamplingMessage(role="assistant", content=tool_use),
355+
types.SamplingMessage(
356+
role="user",
357+
content=types.ToolResultContent(type="tool_result", toolUseId="wrong_id", content=[]),
358+
),
359+
],
360+
max_tokens=100,
361+
tools=[tool],
362+
)
363+
364+
291365
@pytest.mark.anyio
292366
async def test_other_requests_blocked_before_initialization():
293367
"""Test that non-ping requests are still blocked before initialization."""

0 commit comments

Comments
 (0)