@@ -288,6 +288,80 @@ async def mock_client():
288288 assert ping_response_id == 42
289289
290290
291+ @pytest .mark .anyio
292+ async def test_create_message_tool_result_validation ():
293+ """Test tool_use/tool_result validation in create_message."""
294+ server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
295+ client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
296+
297+ async with (
298+ client_to_server_send ,
299+ client_to_server_receive ,
300+ server_to_client_send ,
301+ server_to_client_receive ,
302+ ):
303+ session = ServerSession (
304+ client_to_server_receive ,
305+ server_to_client_send ,
306+ InitializationOptions (
307+ server_name = "test" ,
308+ server_version = "0.1.0" ,
309+ capabilities = ServerCapabilities (),
310+ ),
311+ )
312+
313+ tool = types .Tool (name = "test_tool" , inputSchema = {"type" : "object" })
314+ text = types .TextContent (type = "text" , text = "hello" )
315+ tool_use = types .ToolUseContent (type = "tool_use" , id = "call_1" , name = "test_tool" , input = {})
316+ tool_result = types .ToolResultContent (type = "tool_result" , toolUseId = "call_1" , content = [])
317+
318+ # Case 1: tool_result mixed with other content
319+ with pytest .raises (ValueError , match = "only tool_result content" ):
320+ await session .create_message (
321+ messages = [
322+ types .SamplingMessage (role = "user" , content = text ),
323+ types .SamplingMessage (role = "assistant" , content = tool_use ),
324+ types .SamplingMessage (role = "user" , content = [tool_result , text ]), # mixed!
325+ ],
326+ max_tokens = 100 ,
327+ tools = [tool ],
328+ )
329+
330+ # Case 2: tool_result without previous message
331+ with pytest .raises (ValueError , match = "no previous message" ):
332+ await session .create_message (
333+ messages = [types .SamplingMessage (role = "user" , content = tool_result )],
334+ max_tokens = 100 ,
335+ tools = [tool ],
336+ )
337+
338+ # Case 3: tool_result without previous tool_use
339+ with pytest .raises (ValueError , match = "previous message must contain tool_use" ):
340+ await session .create_message (
341+ messages = [
342+ types .SamplingMessage (role = "user" , content = text ),
343+ types .SamplingMessage (role = "user" , content = tool_result ),
344+ ],
345+ max_tokens = 100 ,
346+ tools = [tool ],
347+ )
348+
349+ # Case 4: mismatched tool IDs
350+ with pytest .raises (ValueError , match = "must correspond to all tool_use" ):
351+ await session .create_message (
352+ messages = [
353+ types .SamplingMessage (role = "user" , content = text ),
354+ types .SamplingMessage (role = "assistant" , content = tool_use ),
355+ types .SamplingMessage (
356+ role = "user" ,
357+ content = types .ToolResultContent (type = "tool_result" , toolUseId = "wrong_id" , content = []),
358+ ),
359+ ],
360+ max_tokens = 100 ,
361+ tools = [tool ],
362+ )
363+
364+
291365@pytest .mark .anyio
292366async def test_other_requests_blocked_before_initialization ():
293367 """Test that non-ping requests are still blocked before initialization."""
0 commit comments