From a1f20bddb294730e909bc25c123d70783297eca3 Mon Sep 17 00:00:00 2001 From: Michael Clifford Date: Thu, 27 Feb 2025 12:54:14 -0500 Subject: [PATCH] fix: ReACT Agent should be consistent with Agent Signed-off-by: Michael Clifford --- .../lib/agents/react/agent.py | 48 ++++++------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index bad7e46e..da833cea 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -3,14 +3,13 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional from llama_stack_client import LlamaStackClient from llama_stack_client.types.agent_create_params import AgentConfig from pydantic import BaseModel from ..agent import Agent -from ..client_tool import ClientTool from ..tool_parser import ToolParser from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE @@ -37,16 +36,15 @@ class ReActAgent(Agent): def __init__( self, client: LlamaStackClient, - model: str, - builtin_toolgroups: Tuple[str] = (), - client_tools: Tuple[ClientTool] = (), + agent_config: AgentConfig, tool_parser: ToolParser = ReActToolParser(), json_response_format: bool = False, - custom_agent_config: Optional[AgentConfig] = None, ): + self.agent_config = agent_config + def get_tool_defs(): tool_defs = [] - for x in builtin_toolgroups: + for x in agent_config["toolgroups"]: tool_defs.extend( [ { @@ -64,37 +62,20 @@ def get_tool_defs(): "description": tool.get_description(), "parameters": tool.get_params_definition(), } - for tool in client_tools + for tool in agent_config["client_tools"] ] ) return tool_defs - if custom_agent_config is None: - tool_names, tool_descriptions = "", "" - tool_defs = get_tool_defs() - tool_names = ", ".join([x["name"] for x in tool_defs]) - tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) - instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( - "<>", tool_descriptions - ) + tool_names, tool_descriptions = "", "" + tool_defs = get_tool_defs() + tool_names = ", ".join([x["name"] for x in tool_defs]) + tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) + instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( + "<>", tool_descriptions + ) - # user default toolgroups - agent_config = AgentConfig( - model=model, - instructions=instruction, - toolgroups=builtin_toolgroups, - client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], - tool_config={ - "tool_choice": "auto", - "tool_prompt_format": "json" if "3.1" in model else "python_list", - "system_message_behavior": "replace", - }, - input_shields=[], - output_shields=[], - enable_session_persistence=False, - ) - else: - agent_config = custom_agent_config + agent_config["instructions"] = instruction if json_response_format: agent_config.response_format = { @@ -105,6 +86,5 @@ def get_tool_defs(): super().__init__( client=client, agent_config=agent_config, - client_tools=client_tools, tool_parser=tool_parser, )