@@ -200,10 +200,13 @@ def create_session(self, session_name: str) -> str:
200200 self .sessions .append (self .session_id )
201201 return self .session_id
202202
203- def _run_tool (self , tool_calls : List [ToolCall ]) -> ToolResponseParam :
204- assert len (tool_calls ) == 1 , "Only one tool call is supported"
205- tool_call = tool_calls [0 ]
203+ def _run_tool (self , tool_calls : List [ToolCall ]) -> List [ToolResponseParam ]:
204+ responses = []
205+ for tool_call in tool_calls :
206+ responses .append (self ._run_single_tool (tool_call ))
207+ return responses
206208
209+ def _run_single_tool (self , tool_call : ToolCall ) -> ToolResponseParam :
207210 # custom client tools
208211 if tool_call .tool_name in self .client_tools :
209212 tool = self .client_tools [tool_call .tool_name ]
@@ -227,12 +230,11 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
227230 tool_name = tool_call .tool_name ,
228231 kwargs = tool_call .arguments ,
229232 )
230- tool_response = ToolResponseParam (
233+ return ToolResponseParam (
231234 call_id = tool_call .call_id ,
232235 tool_name = tool_call .tool_name ,
233236 content = tool_result .content ,
234237 )
235- return tool_response
236238
237239 # cannot find tools
238240 return ToolResponseParam (
@@ -302,14 +304,14 @@ def _create_turn_streaming(
302304 yield chunk
303305
304306 # run the tools
305- tool_response = self ._run_tool (tool_calls )
307+ tool_responses = self ._run_tool (tool_calls )
306308
307309 # pass it to next iteration
308310 turn_response = self .client .agents .turn .resume (
309311 agent_id = self .agent_id ,
310312 session_id = session_id or self .session_id [- 1 ],
311313 turn_id = turn_id ,
312- tool_responses = [ tool_response ] ,
314+ tool_responses = tool_responses ,
313315 stream = True ,
314316 )
315317 n_iter += 1
@@ -439,10 +441,13 @@ async def create_turn(
439441 raise Exception ("Turn did not complete" )
440442 return chunks [- 1 ].event .payload .turn
441443
442- async def _run_tool (self , tool_calls : List [ToolCall ]) -> ToolResponseMessage :
443- assert len (tool_calls ) == 1 , "Only one tool call is supported"
444- tool_call = tool_calls [0 ]
444+ async def _run_tool (self , tool_calls : List [ToolCall ]) -> List [ToolResponseMessage ]:
445+ responses = []
446+ for tool_call in tool_calls :
447+ responses .append (await self ._run_single_tool (tool_call ))
448+ return responses
445449
450+ async def _run_single_tool (self , tool_call : ToolCall ) -> ToolResponseMessage :
446451 # custom client tools
447452 if tool_call .tool_name in self .client_tools :
448453 tool = self .client_tools [tool_call .tool_name ]
@@ -464,13 +469,12 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
464469 tool_name = tool_call .tool_name ,
465470 kwargs = tool_call .arguments ,
466471 )
467- tool_response_message = ToolResponseMessage (
472+ return ToolResponseMessage (
468473 call_id = tool_call .call_id ,
469474 tool_name = tool_call .tool_name ,
470475 content = tool_result .content ,
471476 role = "tool" ,
472477 )
473- return tool_response_message
474478
475479 # cannot find tools
476480 return ToolResponseMessage (
@@ -524,14 +528,14 @@ async def _create_turn_streaming(
524528 yield chunk
525529
526530 # run the tools
527- tool_response_message = await self ._run_tool (tool_calls )
531+ tool_response_messages = await self ._run_tool (tool_calls )
528532
529533 # pass it to next iteration
530534 turn_response = await self .client .agents .turn .resume (
531535 agent_id = self .agent_id ,
532536 session_id = session_id or self .session_id [- 1 ],
533537 turn_id = turn_id ,
534- tool_responses = [ tool_response_message ] ,
538+ tool_responses = tool_response_messages ,
535539 stream = True ,
536540 )
537541 n_iter += 1
0 commit comments