Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 14 additions & 34 deletions src/llama_stack_client/lib/agents/react/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional

from llama_stack_client import LlamaStackClient
from llama_stack_client.types.agent_create_params import AgentConfig
from pydantic import BaseModel

from ..agent import Agent
from ..client_tool import ClientTool
from ..tool_parser import ToolParser
from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE

Expand All @@ -37,16 +36,15 @@ class ReActAgent(Agent):
def __init__(
self,
client: LlamaStackClient,
model: str,
builtin_toolgroups: Tuple[str] = (),
client_tools: Tuple[ClientTool] = (),
agent_config: AgentConfig,
tool_parser: ToolParser = ReActToolParser(),
json_response_format: bool = False,
custom_agent_config: Optional[AgentConfig] = None,
):
self.agent_config = agent_config

def get_tool_defs():
tool_defs = []
for x in builtin_toolgroups:
for x in agent_config["toolgroups"]:
tool_defs.extend(
[
{
Expand All @@ -64,37 +62,20 @@ def get_tool_defs():
"description": tool.get_description(),
"parameters": tool.get_params_definition(),
}
for tool in client_tools
for tool in agent_config["client_tools"]
]
)
return tool_defs

if custom_agent_config is None:
tool_names, tool_descriptions = "", ""
tool_defs = get_tool_defs()
tool_names = ", ".join([x["name"] for x in tool_defs])
tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs])
instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<<tool_names>>", tool_names).replace(
"<<tool_descriptions>>", tool_descriptions
)
tool_names, tool_descriptions = "", ""
tool_defs = get_tool_defs()
tool_names = ", ".join([x["name"] for x in tool_defs])
tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs])
instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<<tool_names>>", tool_names).replace(
"<<tool_descriptions>>", tool_descriptions
)

# user default toolgroups
agent_config = AgentConfig(
model=model,
instructions=instruction,
toolgroups=builtin_toolgroups,
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
tool_config={
"tool_choice": "auto",
"tool_prompt_format": "json" if "3.1" in model else "python_list",
"system_message_behavior": "replace",
},
input_shields=[],
output_shields=[],
enable_session_persistence=False,
)
else:
agent_config = custom_agent_config
agent_config["instructions"] = instruction

if json_response_format:
agent_config.response_format = {
Expand All @@ -105,6 +86,5 @@ def get_tool_defs():
super().__init__(
client=client,
agent_config=agent_config,
client_tools=client_tools,
tool_parser=tool_parser,
)