diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 76841e27..bb62cd00 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -439,7 +439,7 @@ async def create_turn( raise Exception("Turn did not complete") return chunks[-1].event.payload.turn - async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: + async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam: assert len(tool_calls) == 1, "Only one tool call is supported" tool_call = tool_calls[0] @@ -464,20 +464,18 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: tool_name=tool_call.tool_name, kwargs=tool_call.arguments, ) - tool_response_message = ToolResponseMessage( + tool_response = ToolResponseParam( call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=tool_result.content, - role="tool", ) - return tool_response_message + return tool_response # cannot find tools - return ToolResponseMessage( + return ToolResponseParam( call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=f"Unknown tool `{tool_call.tool_name}` was called.", - role="tool", ) async def _create_turn_streaming( @@ -524,14 +522,14 @@ async def _create_turn_streaming( yield chunk # run the tools - tool_response_message = await self._run_tool(tool_calls) + tool_response = await self._run_tool(tool_calls) # pass it to next iteration turn_response = await self.client.agents.turn.resume( agent_id=self.agent_id, session_id=session_id or self.session_id[-1], turn_id=turn_id, - tool_responses=[tool_response_message], + tool_responses=[tool_response], stream=True, ) n_iter += 1