@@ -300,66 +300,65 @@ async def test_create_message_tool_result_validation():
300300 server_to_client_send ,
301301 server_to_client_receive ,
302302 ):
303- session = ServerSession (
303+ async with ServerSession (
304304 client_to_server_receive ,
305305 server_to_client_send ,
306306 InitializationOptions (
307307 server_name = "test" ,
308308 server_version = "0.1.0" ,
309309 capabilities = ServerCapabilities (),
310310 ),
311- )
312-
313- tool = types .Tool (name = "test_tool" , inputSchema = {"type" : "object" })
314- text = types .TextContent (type = "text" , text = "hello" )
315- tool_use = types .ToolUseContent (type = "tool_use" , id = "call_1" , name = "test_tool" , input = {})
316- tool_result = types .ToolResultContent (type = "tool_result" , toolUseId = "call_1" , content = [])
317-
318- # Case 1: tool_result mixed with other content
319- with pytest .raises (ValueError , match = "only tool_result content" ):
320- await session .create_message (
321- messages = [
322- types .SamplingMessage (role = "user" , content = text ),
323- types .SamplingMessage (role = "assistant" , content = tool_use ),
324- types .SamplingMessage (role = "user" , content = [tool_result , text ]), # mixed!
325- ],
326- max_tokens = 100 ,
327- tools = [tool ],
328- )
311+ ) as session :
312+ tool = types .Tool (name = "test_tool" , inputSchema = {"type" : "object" })
313+ text = types .TextContent (type = "text" , text = "hello" )
314+ tool_use = types .ToolUseContent (type = "tool_use" , id = "call_1" , name = "test_tool" , input = {})
315+ tool_result = types .ToolResultContent (type = "tool_result" , toolUseId = "call_1" , content = [])
316+
317+ # Case 1: tool_result mixed with other content
318+ with pytest .raises (ValueError , match = "only tool_result content" ):
319+ await session .create_message (
320+ messages = [
321+ types .SamplingMessage (role = "user" , content = text ),
322+ types .SamplingMessage (role = "assistant" , content = tool_use ),
323+ types .SamplingMessage (role = "user" , content = [tool_result , text ]), # mixed!
324+ ],
325+ max_tokens = 100 ,
326+ tools = [tool ],
327+ )
329328
330- # Case 2: tool_result without previous message
331- with pytest .raises (ValueError , match = "no previous message " ):
332- await session .create_message (
333- messages = [types .SamplingMessage (role = "user" , content = tool_result )],
334- max_tokens = 100 ,
335- tools = [tool ],
336- )
329+ # Case 2: tool_result without previous message
330+ with pytest .raises (ValueError , match = "not matching any tool_use " ):
331+ await session .create_message (
332+ messages = [types .SamplingMessage (role = "user" , content = tool_result )],
333+ max_tokens = 100 ,
334+ tools = [tool ],
335+ )
337336
338- # Case 3: tool_result without previous tool_use
339- with pytest .raises (ValueError , match = "previous message must contain tool_use" ):
340- await session .create_message (
341- messages = [
342- types .SamplingMessage (role = "user" , content = text ),
343- types .SamplingMessage (role = "user" , content = tool_result ),
344- ],
345- max_tokens = 100 ,
346- tools = [tool ],
347- )
337+ # Case 3: tool_result without previous tool_use
338+ with pytest .raises (ValueError , match = "not matching any tool_use" ):
339+ await session .create_message (
340+ messages = [
341+ types .SamplingMessage (role = "user" , content = text ),
342+ types .SamplingMessage (role = "user" , content = tool_result ),
343+ ],
344+ max_tokens = 100 ,
345+ tools = [tool ],
346+ )
348347
349- # Case 4: mismatched tool IDs
350- with pytest .raises (ValueError , match = "must correspond to all tool_use" ):
351- await session .create_message (
352- messages = [
353- types .SamplingMessage (role = "user" , content = text ),
354- types .SamplingMessage (role = "assistant" , content = tool_use ),
355- types .SamplingMessage (
356- role = "user" ,
357- content = types .ToolResultContent (type = "tool_result" , toolUseId = "wrong_id" , content = []),
358- ),
359- ],
360- max_tokens = 100 ,
361- tools = [tool ],
362- )
348+ # Case 4: mismatched tool IDs
349+ with pytest .raises (ValueError , match = "ids of tool_result blocks and tool_use blocks " ):
350+ await session .create_message (
351+ messages = [
352+ types .SamplingMessage (role = "user" , content = text ),
353+ types .SamplingMessage (role = "assistant" , content = tool_use ),
354+ types .SamplingMessage (
355+ role = "user" ,
356+ content = types .ToolResultContent (type = "tool_result" , toolUseId = "wrong_id" , content = []),
357+ ),
358+ ],
359+ max_tokens = 100 ,
360+ tools = [tool ],
361+ )
363362
364363
365364@pytest .mark .anyio
0 commit comments