diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index fd8db879..56d8907b 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -12,7 +12,9 @@ from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk + from .client_tool import ClientTool +from .output_parser import OutputParser DEFAULT_MAX_ITER = 10 @@ -23,14 +25,18 @@ def __init__( client: LlamaStackClient, agent_config: AgentConfig, client_tools: Tuple[ClientTool] = (), - memory_bank_id: Optional[str] = None, + output_parser: Optional[OutputParser] = None, ): self.client = client self.agent_config = agent_config self.agent_id = self._create_agent(agent_config) self.client_tools = {t.get_name(): t for t in client_tools} self.sessions = [] - self.memory_bank_id = memory_bank_id + self.output_parser = output_parser + self.builtin_tools = {} + for tg in agent_config["toolgroups"]: + for tool in self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool def _create_agent(self, agent_config: AgentConfig) -> int: agentic_system_create_response = self.client.agents.create( @@ -48,28 +54,56 @@ def create_session(self, session_name: str) -> int: self.sessions.append(self.session_id) return self.session_id + def _process_chunk(self, chunk: AgentTurnResponseStreamChunk) -> None: + if chunk.event.payload.event_type != "turn_complete": + return + message = chunk.event.payload.turn.output_message + + if self.output_parser: + parsed_message = self.output_parser.parse(message) + message = parsed_message + def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: if chunk.event.payload.event_type != "turn_complete": return False message = chunk.event.payload.turn.output_message if message.stop_reason == "out_of_tokens": return False + return len(message.tool_calls) > 0 def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: message = chunk.event.payload.turn.output_message tool_call = message.tool_calls[0] - if tool_call.tool_name not in self.client_tools: - return ToolResponseMessage( + + # custom client tools + if tool_call.tool_name in self.client_tools: + tool = self.client_tools[tool_call.tool_name] + result_messages = tool.run([message]) + next_message = result_messages[0] + return next_message + + # builtin tools executed by tool_runtime + if tool_call.tool_name in self.builtin_tools: + tool_result = self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs=tool_call.arguments, + ) + tool_response_message = ToolResponseMessage( call_id=tool_call.call_id, tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called.", - role="ipython", + content=tool_result.content, + role="tool", ) - tool = self.client_tools[tool_call.tool_name] - result_messages = tool.run([message]) - next_message = result_messages[0] - return next_message + return tool_response_message + + # cannot find tools + return ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called.", + role="tool", + ) def create_turn( self, @@ -115,6 +149,7 @@ def _create_turn_streaming( # by default, we stop after the first turn stop = True for chunk in response: + self._process_chunk(chunk) if hasattr(chunk, "error"): yield chunk return diff --git a/src/llama_stack_client/lib/agents/output_parser.py b/src/llama_stack_client/lib/agents/output_parser.py new file mode 100644 index 00000000..1097d6d5 --- /dev/null +++ b/src/llama_stack_client/lib/agents/output_parser.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import abstractmethod + +from llama_stack_client.types.agents.turn import CompletionMessage + + +class OutputParser: + """ + Abstract base class for parsing agent responses. Implement this class to customize how + agent outputs are processed and transformed. + + This class allows developers to define custom parsing logic for agent responses, + which can be useful for: + - Extracting specific information from the response + - Formatting or structuring the output in a specific way + - Validating or sanitizing the agent's response + + To use this class: + 1. Create a subclass of OutputParser + 2. Implement the `parse` method + 3. Pass your parser instance to the Agent's constructor + + Example: + class MyCustomParser(OutputParser): + def parse(self, output_message: CompletionMessage) -> CompletionMessage: + # Add your custom parsing logic here + return processed_message + + Methods: + parse(output_message: CompletionMessage) -> CompletionMessage: + Abstract method that must be implemented by subclasses to process + the agent's response. + + Args: + output_message (CompletionMessage): The response message from agent turn + + Returns: + CompletionMessage: The processed/transformed response message + """ + + @abstractmethod + def parse(self, output_message: CompletionMessage) -> CompletionMessage: + raise NotImplementedError diff --git a/src/llama_stack_client/lib/agents/react/__init__.py b/src/llama_stack_client/lib/agents/react/__init__.py new file mode 100644 index 00000000..756f351d --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py new file mode 100644 index 00000000..3d40a08b --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from pydantic import BaseModel +from typing import Dict, Any +from ..agent import Agent +from .output_parser import ReActOutputParser +from ..output_parser import OutputParser +from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE + +from typing import Tuple, Optional +from llama_stack_client import LlamaStackClient +from ..client_tool import ClientTool +from llama_stack_client.types.agent_create_params import AgentConfig + + +class Action(BaseModel): + tool_name: str + tool_params: Dict[str, Any] + + +class ReActOutput(BaseModel): + thought: str + action: Optional[Action] = None + answer: Optional[str] = None + + +class ReActAgent(Agent): + """ReAct agent. + + Simple wrapper around Agent to add prepare prompts for creating a ReAct agent from a list of tools. + """ + + def __init__( + self, + client: LlamaStackClient, + model: str, + builtin_toolgroups: Tuple[str] = (), + client_tools: Tuple[ClientTool] = (), + output_parser: OutputParser = ReActOutputParser(), + json_response_format: bool = False, + custom_agent_config: Optional[AgentConfig] = None, + ): + def get_tool_defs(): + tool_defs = [] + for x in builtin_toolgroups: + tool_defs.extend( + [ + { + "name": tool.identifier, + "description": tool.description, + "parameters": tool.parameters, + } + for tool in client.tools.list(toolgroup_id=x) + ] + ) + tool_defs.extend( + [ + { + "name": tool.get_name(), + "description": tool.get_description(), + "parameters": tool.get_params_definition(), + } + for tool in client_tools + ] + ) + return tool_defs + + if custom_agent_config is None: + tool_names, tool_descriptions = "", "" + tool_defs = get_tool_defs() + tool_names = ", ".join([x["name"] for x in tool_defs]) + tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) + instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( + "<>", tool_descriptions + ) + + # user default toolgroups + agent_config = AgentConfig( + model=model, + instructions=instruction, + toolgroups=builtin_toolgroups, + client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], + tool_choice="auto", + # TODO: refactor this to use SystemMessageBehaviour.replace + tool_prompt_format="json", + input_shields=[], + output_shields=[], + enable_session_persistence=False, + ) + + if json_response_format: + agent_config.response_format = { + "type": "json_schema", + "json_schema": ReActOutput.model_json_schema(), + } + + super().__init__( + client=client, + agent_config=agent_config, + client_tools=client_tools, + output_parser=output_parser, + ) diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/output_parser.py new file mode 100644 index 00000000..6e4861a9 --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/output_parser.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, ValidationError +from typing import Dict, Any, Optional +from ..output_parser import OutputParser +from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.shared.tool_call import ToolCall + +import uuid + + +class Action(BaseModel): + tool_name: str + tool_params: Dict[str, Any] + + +class ReActOutput(BaseModel): + thought: str + action: Optional[Action] = None + answer: Optional[str] = None + + +class ReActOutputParser(OutputParser): + def parse(self, output_message: CompletionMessage) -> CompletionMessage: + response_text = str(output_message.content) + try: + react_output = ReActOutput.model_validate_json(response_text) + except ValidationError as e: + print(f"Error parsing action: {e}") + return output_message + + if react_output.answer: + return output_message + + if react_output.action: + tool_name = react_output.action.tool_name + tool_params = react_output.action.tool_params + if tool_name and tool_params: + call_id = str(uuid.uuid4()) + output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] + + return output_message diff --git a/src/llama_stack_client/lib/agents/react/prompts.py b/src/llama_stack_client/lib/agents/react/prompts.py new file mode 100644 index 00000000..4ce228f2 --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/prompts.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE = """ +You are an expert assistant who can solve any task using tool calls. You will be given a task to solve as best you can. +To do so, you have been given access to the following tools: <> + +You must always respond in the following JSON format: +{ + "thought": $THOUGHT_PROCESS, + "action": { + "tool_name": $TOOL_NAME, + "tool_params": $TOOL_PARAMS + }, + "answer": $ANSWER +} + +Specifically, this json should have a `thought` key, a `action` key and an `answer` key. + +The `action` key should specify the $TOOL_NAME the name of the tool to use and the `tool_params` key should specify the parameters key as input to the tool. + +Make sure to have the $TOOL_PARAMS as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values. + +You should always think about one action to take, and have the `thought` key contain your thought process about this action. +If the tool responds, the tool will return an observation containing result of the action. +... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The action key must only use a SINGLE tool at a time.) + +You can use the result of the previous action as input for the next action. +The observation will always be the response from calling the tool: it can represent a file, like "image_1.jpg". You do not need to generate them, it will be provided to you. +Then you can use it as input for the next action. You can do it for instance as follows: + +Observation: "image_1.jpg" +{ + "thought": "I need to transform the image that I received in the previous observation to make it green.", + "action": { + "tool_name": "image_transformer", + "tool_params": {"image": "image_1.jpg"} + }, + "answer": null +} + + +To provide the final answer to the task, use the `answer` key. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: +Observation: "your observation" + +{ + "thought": "you thought process", + "action": null, + "answer": "insert your final answer here" +} + +Here are a few examples using notional tools: +--- +Task: "Generate an image of the oldest person in this document." + +Your Response: +{ + "thought": "I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.", + "action": { + "tool_name": "document_qa", + "tool_params": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} + }, + "answer": null +} + +Your Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." + +Your Response: +{ + "thought": "I will now generate an image showcasing the oldest person.", + "action": { + "tool_name": "image_generator", + "tool_params": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} + }, + "answer": null +} +Your Observation: "image.png" + +{ + "thought": "I will now return the generated image.", + "action": null, + "answer": "image.png" +} + +--- +Task: "What is the result of the following operation: 5 + 3 + 1294.678?" + +Your Response: +{ + "thought": "I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool", + "action": { + "tool_name": "python_interpreter", + "tool_params": {"code": "5 + 3 + 1294.678"} + }, + "answer": null +} +Your Observation: 1302.678 + +{ + "thought": "Now that I know the result, I will now return it.", + "action": null, + "answer": 1302.678 +} + +--- +Task: "Which city has the highest population , Guangzhou or Shanghai?" + +Your Response: +{ + "thought": "I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.", + "action": { + "tool_name": "search", + "tool_params": {"query": "Population Guangzhou"} + }, + "answer": null +} +Your Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] + +Your Response: +{ + "thought": "Now let's get the population of Shanghai using the tool 'search'.", + "action": { + "tool_name": "search", + "tool_params": {"query": "Population Shanghai"} + }, + "answer": null +} +Your Observation: "26 million (2019)" + +Your Response: +{ + "thought": "Now I know that Shanghai has a larger population. Let's return the result.", + "action": null, + "answer": "Shanghai" +} + +Above example were using notional tools that might not exist for you. You only have access to these tools: +<> + +Here are the rules you should always follow to solve your task: +1. ALWAYS answer in the JSON format with keys "thought", "action", "answer", else you will fail. +2. Always use the right arguments for the tools. Never use variable names in the 'tool_params' field, use the value instead. +3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. +4. Never re-do a tool call that you previously did with the exact same parameters. +5. Observations will be provided to you, no need to generate them + +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. +"""