From 3290e87c3b2f642550d03a857482b26d10adcbf6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 3 Feb 2025 16:00:20 -0800 Subject: [PATCH 01/12] react agent tmp --- src/llama_stack_client/lib/agents/agent.py | 61 ++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index fd8db879..45836949 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -12,11 +12,55 @@ from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk +import re +import json +from typing import Dict, Any + +from llama_stack_client.types.shared.tool_call import ToolCall + from .client_tool import ClientTool DEFAULT_MAX_ITER = 10 +def maybe_extract_action(text: str) -> Optional[Tuple[str, Dict[str, Any]]]: + """ + Extract action name and parameters from the text format: + + Thought: + + Action: + { + "action": , + "action_input": + } + + Args: + text (str): Input text containing the action block + + Returns: + Tuple[str, Dict[str, Any]]: Tuple of (action_name, action_parameters) + + Raises: + ValueError: If the action block cannot be parsed or is missing required fields + """ + try: + # Find the action block using regex + action_pattern = r'Action:\s*{\s*"action":\s*"([^"]+)",\s*"action_input":\s*({[^}]+})\s*}' + match = re.search(action_pattern, text, re.DOTALL) + + if not match: + raise ValueError("Could not find valid action block in text") + + action_name = match.group(1) + action_params = json.loads(match.group(2)) + + return action_name, action_params + except (ValueError, json.JSONDecodeError) as e: + print(f"Error parsing action: {e}") + return None + + class Agent: def __init__( self, @@ -54,6 +98,19 @@ def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: message = chunk.event.payload.turn.output_message if message.stop_reason == "out_of_tokens": return False + + # Has tool call if it is using the ReAct pattern + action = maybe_extract_action(message.content) + if action and action[0] in self.client_tools: + message.tool_calls = [ + ToolCall( + call_id="random-id", + tool_name=action[0], + arguments=action[1], + ) + ] + print(f"!!Action: {action}") + return len(message.tool_calls) > 0 def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: @@ -121,6 +178,10 @@ def _create_turn_streaming( elif not self._has_tool_call(chunk): yield chunk else: + from rich.pretty import pprint + + print("Running Tools...") + pprint(chunk) next_message = self._run_tool(chunk) yield next_message From a68dc2b09029b043fe4d5454d64e71a982305dc8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 11:34:51 -0800 Subject: [PATCH 02/12] output parser --- src/llama_stack_client/lib/agents/agent.py | 66 ++----------------- .../lib/agents/output_parser.py | 22 +++++++ 2 files changed, 29 insertions(+), 59 deletions(-) create mode 100644 src/llama_stack_client/lib/agents/output_parser.py diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 45836949..1d96cac9 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -12,55 +12,13 @@ from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk -import re -import json -from typing import Dict, Any - -from llama_stack_client.types.shared.tool_call import ToolCall from .client_tool import ClientTool +from .output_parser import OutputParser DEFAULT_MAX_ITER = 10 -def maybe_extract_action(text: str) -> Optional[Tuple[str, Dict[str, Any]]]: - """ - Extract action name and parameters from the text format: - - Thought: - - Action: - { - "action": , - "action_input": - } - - Args: - text (str): Input text containing the action block - - Returns: - Tuple[str, Dict[str, Any]]: Tuple of (action_name, action_parameters) - - Raises: - ValueError: If the action block cannot be parsed or is missing required fields - """ - try: - # Find the action block using regex - action_pattern = r'Action:\s*{\s*"action":\s*"([^"]+)",\s*"action_input":\s*({[^}]+})\s*}' - match = re.search(action_pattern, text, re.DOTALL) - - if not match: - raise ValueError("Could not find valid action block in text") - - action_name = match.group(1) - action_params = json.loads(match.group(2)) - - return action_name, action_params - except (ValueError, json.JSONDecodeError) as e: - print(f"Error parsing action: {e}") - return None - - class Agent: def __init__( self, @@ -68,6 +26,7 @@ def __init__( agent_config: AgentConfig, client_tools: Tuple[ClientTool] = (), memory_bank_id: Optional[str] = None, + output_parser: Optional[OutputParser] = None, ): self.client = client self.agent_config = agent_config @@ -75,6 +34,7 @@ def __init__( self.client_tools = {t.get_name(): t for t in client_tools} self.sessions = [] self.memory_bank_id = memory_bank_id + self.output_parser = output_parser def _create_agent(self, agent_config: AgentConfig) -> int: agentic_system_create_response = self.client.agents.create( @@ -99,17 +59,9 @@ def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: if message.stop_reason == "out_of_tokens": return False - # Has tool call if it is using the ReAct pattern - action = maybe_extract_action(message.content) - if action and action[0] in self.client_tools: - message.tool_calls = [ - ToolCall( - call_id="random-id", - tool_name=action[0], - arguments=action[1], - ) - ] - print(f"!!Action: {action}") + if self.output_parser: + parsed_message = self.output_parser.parse(message) + message = parsed_message return len(message.tool_calls) > 0 @@ -121,7 +73,7 @@ def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=f"Unknown tool `{tool_call.tool_name}` was called.", - role="ipython", + role="tool", ) tool = self.client_tools[tool_call.tool_name] result_messages = tool.run([message]) @@ -178,10 +130,6 @@ def _create_turn_streaming( elif not self._has_tool_call(chunk): yield chunk else: - from rich.pretty import pprint - - print("Running Tools...") - pprint(chunk) next_message = self._run_tool(chunk) yield next_message diff --git a/src/llama_stack_client/lib/agents/output_parser.py b/src/llama_stack_client/lib/agents/output_parser.py new file mode 100644 index 00000000..19bb298a --- /dev/null +++ b/src/llama_stack_client/lib/agents/output_parser.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import abstractmethod + +from llama_stack_client.types.agents.turn import CompletionMessage + + +class OutputParser: + """ + Developers can define their own response output parser to parse the response from the agent turn. + + Developers need to implement the `parse` method to parse the response from the agent turn. + The return value should be a CompletionMessage object with parsed result popoulated in content and tool_calls. + """ + + @abstractmethod + def parse(self, output_message: CompletionMessage) -> CompletionMessage: + raise NotImplementedError From 4307a65e36f160961bfa7e6593688b7e622bc421 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 13:14:26 -0800 Subject: [PATCH 03/12] doc --- .../lib/agents/output_parser.py | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/llama_stack_client/lib/agents/output_parser.py b/src/llama_stack_client/lib/agents/output_parser.py index 19bb298a..1097d6d5 100644 --- a/src/llama_stack_client/lib/agents/output_parser.py +++ b/src/llama_stack_client/lib/agents/output_parser.py @@ -11,10 +11,36 @@ class OutputParser: """ - Developers can define their own response output parser to parse the response from the agent turn. + Abstract base class for parsing agent responses. Implement this class to customize how + agent outputs are processed and transformed. - Developers need to implement the `parse` method to parse the response from the agent turn. - The return value should be a CompletionMessage object with parsed result popoulated in content and tool_calls. + This class allows developers to define custom parsing logic for agent responses, + which can be useful for: + - Extracting specific information from the response + - Formatting or structuring the output in a specific way + - Validating or sanitizing the agent's response + + To use this class: + 1. Create a subclass of OutputParser + 2. Implement the `parse` method + 3. Pass your parser instance to the Agent's constructor + + Example: + class MyCustomParser(OutputParser): + def parse(self, output_message: CompletionMessage) -> CompletionMessage: + # Add your custom parsing logic here + return processed_message + + Methods: + parse(output_message: CompletionMessage) -> CompletionMessage: + Abstract method that must be implemented by subclasses to process + the agent's response. + + Args: + output_message (CompletionMessage): The response message from agent turn + + Returns: + CompletionMessage: The processed/transformed response message """ @abstractmethod From 1d0e56c11b056915cc0506650c77e8166ccac650 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 4 Feb 2025 17:47:05 -0800 Subject: [PATCH 04/12] refactor --- src/llama_stack_client/lib/agents/agent.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 1d96cac9..cb1c19a0 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -52,17 +52,22 @@ def create_session(self, session_name: str) -> int: self.sessions.append(self.session_id) return self.session_id - def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: + def _process_chunk(self, chunk: AgentTurnResponseStreamChunk) -> None: if chunk.event.payload.event_type != "turn_complete": - return False + return message = chunk.event.payload.turn.output_message - if message.stop_reason == "out_of_tokens": - return False if self.output_parser: parsed_message = self.output_parser.parse(message) message = parsed_message + def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: + if chunk.event.payload.event_type != "turn_complete": + return False + message = chunk.event.payload.turn.output_message + if message.stop_reason == "out_of_tokens": + return False + return len(message.tool_calls) > 0 def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: @@ -124,6 +129,7 @@ def _create_turn_streaming( # by default, we stop after the first turn stop = True for chunk in response: + self._process_chunk(chunk) if hasattr(chunk, "error"): yield chunk return From e4dc8e21cf17801885c36d968a4c8d10ac516c32 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 10:37:49 -0800 Subject: [PATCH 05/12] react prompt and output parser --- .../lib/agents/react/__init__.py | 5 + .../lib/agents/react/output_parser.py | 34 ++++ .../lib/agents/react/prompts.py | 152 ++++++++++++++++++ 3 files changed, 191 insertions(+) create mode 100644 src/llama_stack_client/lib/agents/react/__init__.py create mode 100644 src/llama_stack_client/lib/agents/react/output_parser.py create mode 100644 src/llama_stack_client/lib/agents/react/prompts.py diff --git a/src/llama_stack_client/lib/agents/react/__init__.py b/src/llama_stack_client/lib/agents/react/__init__.py new file mode 100644 index 00000000..756f351d --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/output_parser.py new file mode 100644 index 00000000..7d764e42 --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/output_parser.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from ..output_parser import OutputParser +from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.shared.tool_call import ToolCall + +import json +import uuid + + +class ReActOutputParser(OutputParser): + def parse(self, output_message: CompletionMessage) -> CompletionMessage: + response_text = str(output_message.content) + try: + response_json = json.loads(response_text) + except json.JSONDecodeError as e: + print(f"Error parsing action: {e}") + return output_message + + if response_json.get("answer", None): + return output_message + + if response_json.get("action", None): + tool_name = response_json["action"].get("tool_name", None) + tool_params = response_json["action"].get("tool_params", None) + if tool_name and tool_params: + call_id = str(uuid.uuid4()) + output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] + + return output_message diff --git a/src/llama_stack_client/lib/agents/react/prompts.py b/src/llama_stack_client/lib/agents/react/prompts.py new file mode 100644 index 00000000..44bbae00 --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/prompts.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE = """ +You are an expert assistant who can solve any task using tool calls. You will be given a task to solve as best you can. +To do so, you have been given access to the following tools: <> + +You must always respond in the following JSON format: +{ + "thought": $THOUGHT_PROCESS, + "action": { + "tool_name": $TOOL_NAME, + "tool_params": $TOOL_PARAMS + }, + "answer": $ANSWER +} + +Specifically, this json should have a `thought` key, a `action` key and an `answer` key. + +The `action` key should specify the $TOOL_NAME the name of the tool to use and the `tool_params` key should specify the parameters key as input to the tool. + +Make sure to have the $TOOL_PARAMS as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values. + +You should always think about one action to take, and have the `thought` key contain your thought process about this action. +If the tool responds, the tool will return an observation containing result of the action. +... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The action key must only use a SINGLE tool at a time.) + +You can use the result of the previous action as input for the next action. +The observation will always be a string: it can represent a file, like "image_1.jpg". +Then you can use it as input for the next action. You can do it for instance as follows: + +Observation: "image_1.jpg" +{ + "thought": "I need to transform the image that I received in the previous observation to make it green.", + "action": { + "tool_name": "image_transformer", + "tool_params": {"image": "image_1.jpg"} + }, + "answer": null +} + + +To provide the final answer to the task, use the `answer` key. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: +Observation: "your observation" + +{ + "thought": "you thought process", + "action": null, + "answer": "insert your final answer here" +} + +Here are a few examples using notional tools: +--- +Task: "Generate an image of the oldest person in this document." + +{ + "thought": "I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.", + "action": { + "tool_name": "document_qa", + "tool_params": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} + }, + "answer": null +} +Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." + +{ + "thought": "I will now generate an image showcasing the oldest person.", + "action": { + "tool_name": "image_generator", + "tool_params": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} + }, + "answer": null +} +Observation: "image.png" + +{ + "thought": "I will now return the generated image.", + "action": null, + "answer": "image.png" +} + +--- +Task: "What is the result of the following operation: 5 + 3 + 1294.678?" + +{ + "thought": "I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool", + "action": { + "tool_name": "python_interpreter", + "tool_params": {"code": "5 + 3 + 1294.678"} + }, + "answer": null +} +Observation: 1302.678 + +{ + "thought": "Now that I know the result, I will now return it.", + "action": null, + "answer": 1302.678 +} + +--- +Task: "Which city has the highest population , Guangzhou or Shanghai?" + +{ + "thought": "I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.", + "action": { + "tool_name": "search", + "tool_params": {"query": "Population Guangzhou"} + }, + "answer": null +} +Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] + +{ + "thought": "Now let's get the population of Shanghai using the tool 'search'.", + "action": { + "tool_name": "search", + "tool_params": {"query": "Population Shanghai"} + }, + "answer": null +} +Observation: "26 million (2019)" + +{ + "thought": "Now I know that Shanghai has a larger population. Let's return the result.", + "action": null, + "answer": "Shanghai" +} + +Above example were using notional tools that might not exist for you. You only have access to these tools: +<> + +Here are the rules you should always follow to solve your task: +1. ALWAYS answer in the JSON format with keys "observation", "thought", "action", "answer", else you will fail. +2. Always use the right arguments for the tools. Never use variable names in the 'tool_params' field, use the value instead. +3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. +4. Never re-do a tool call that you previously did with the exact same parameters. +5. Observations will be provided to you, no need to generate them + +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. +""" + +from pydantic import BaseModel + + +class PromptTemplate(BaseModel): + system_prompt_template: str = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE + tool_names: str + tool_descriptions: str From 0c5a39afecbf56775dd4f2ce6c871b533029439a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 10:51:22 -0800 Subject: [PATCH 06/12] tweak prompt --- .../lib/agents/react/output_parser.py | 3 +++ .../lib/agents/react/prompts.py | 21 ++++++++++++------- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/output_parser.py index 7d764e42..b6722612 100644 --- a/src/llama_stack_client/lib/agents/react/output_parser.py +++ b/src/llama_stack_client/lib/agents/react/output_parser.py @@ -11,12 +11,15 @@ import json import uuid +from rich.pretty import pprint + class ReActOutputParser(OutputParser): def parse(self, output_message: CompletionMessage) -> CompletionMessage: response_text = str(output_message.content) try: response_json = json.loads(response_text) + pprint(response_json) except json.JSONDecodeError as e: print(f"Error parsing action: {e}") return output_message diff --git a/src/llama_stack_client/lib/agents/react/prompts.py b/src/llama_stack_client/lib/agents/react/prompts.py index 44bbae00..b146e605 100644 --- a/src/llama_stack_client/lib/agents/react/prompts.py +++ b/src/llama_stack_client/lib/agents/react/prompts.py @@ -29,7 +29,7 @@ ... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The action key must only use a SINGLE tool at a time.) You can use the result of the previous action as input for the next action. -The observation will always be a string: it can represent a file, like "image_1.jpg". +The observation will always be the response from calling the tool: it can represent a file, like "image_1.jpg". You do not need to generate them, it will be provided to you. Then you can use it as input for the next action. You can do it for instance as follows: Observation: "image_1.jpg" @@ -56,6 +56,7 @@ --- Task: "Generate an image of the oldest person in this document." +Your Response: { "thought": "I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.", "action": { @@ -64,8 +65,10 @@ }, "answer": null } -Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." +Your Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." + +Your Response: { "thought": "I will now generate an image showcasing the oldest person.", "action": { @@ -74,7 +77,7 @@ }, "answer": null } -Observation: "image.png" +Your Observation: "image.png" { "thought": "I will now return the generated image.", @@ -85,6 +88,7 @@ --- Task: "What is the result of the following operation: 5 + 3 + 1294.678?" +Your Response: { "thought": "I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool", "action": { @@ -93,7 +97,7 @@ }, "answer": null } -Observation: 1302.678 +Your Observation: 1302.678 { "thought": "Now that I know the result, I will now return it.", @@ -104,6 +108,7 @@ --- Task: "Which city has the highest population , Guangzhou or Shanghai?" +Your Response: { "thought": "I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.", "action": { @@ -112,8 +117,9 @@ }, "answer": null } -Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] +Your Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] +Your Response: { "thought": "Now let's get the population of Shanghai using the tool 'search'.", "action": { @@ -122,8 +128,9 @@ }, "answer": null } -Observation: "26 million (2019)" +Your Observation: "26 million (2019)" +Your Response: { "thought": "Now I know that Shanghai has a larger population. Let's return the result.", "action": null, @@ -134,7 +141,7 @@ <> Here are the rules you should always follow to solve your task: -1. ALWAYS answer in the JSON format with keys "observation", "thought", "action", "answer", else you will fail. +1. ALWAYS answer in the JSON format with keys "thought", "action", "answer", else you will fail. 2. Always use the right arguments for the tools. Never use variable names in the 'tool_params' field, use the value instead. 3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. 4. Never re-do a tool call that you previously did with the exact same parameters. From 080e34b5eb67975195b5834efa7ed24f750ee8b6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 10:55:18 -0800 Subject: [PATCH 07/12] tweak prompt --- src/llama_stack_client/lib/agents/react/output_parser.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/output_parser.py index b6722612..7d764e42 100644 --- a/src/llama_stack_client/lib/agents/react/output_parser.py +++ b/src/llama_stack_client/lib/agents/react/output_parser.py @@ -11,15 +11,12 @@ import json import uuid -from rich.pretty import pprint - class ReActOutputParser(OutputParser): def parse(self, output_message: CompletionMessage) -> CompletionMessage: response_text = str(output_message.content) try: response_json = json.loads(response_text) - pprint(response_json) except json.JSONDecodeError as e: print(f"Error parsing action: {e}") return output_message From c0fccd776579deeae463367bc2a3ecfd24ac2e84 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 12:04:49 -0800 Subject: [PATCH 08/12] wip builtin tool --- src/llama_stack_client/lib/agents/agent.py | 37 +++++++++++++++---- .../lib/agents/react/agent.py | 20 ++++++++++ .../lib/agents/react/output_parser.py | 13 +++++++ 3 files changed, 63 insertions(+), 7 deletions(-) create mode 100644 src/llama_stack_client/lib/agents/react/agent.py diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index cb1c19a0..6063ad42 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -36,6 +36,11 @@ def __init__( self.memory_bank_id = memory_bank_id self.output_parser = output_parser + self.builtin_tools = {} + for tg in agent_config["toolgroups"]: + for tool in self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool + def _create_agent(self, agent_config: AgentConfig) -> int: agentic_system_create_response = self.client.agents.create( agent_config=agent_config, @@ -73,17 +78,35 @@ def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: message = chunk.event.payload.turn.output_message tool_call = message.tool_calls[0] - if tool_call.tool_name not in self.client_tools: - return ToolResponseMessage( + + # custom client tools + if tool_call.tool_name in self.client_tools: + tool = self.client_tools[tool_call.tool_name] + result_messages = tool.run([message]) + next_message = result_messages[0] + return next_message + + # builtin tools executed by tool_runtime + if tool_call.tool_name in self.builtin_tools: + tool_result = self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs=tool_call.arguments, + ) + tool_response_message = ToolResponseMessage( call_id=tool_call.call_id, tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called.", + content=tool_result.content, role="tool", ) - tool = self.client_tools[tool_call.tool_name] - result_messages = tool.run([message]) - next_message = result_messages[0] - return next_message + return tool_response_message + + # cannot find tools + return ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called.", + role="tool", + ) def create_turn( self, diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py new file mode 100644 index 00000000..8db6c867 --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +# class ReActAgent(Agent): +# """ReAct agent. + +# Simple wrapper around Agent + ReActOutputParser + ReActPromptTemplate to create a ReAct agent. +# """ +# def __init__(self, model: str, output_parser: OutputParser, prompt_template: PromptTemplate): +# super().__init__(model, output_parser, prompt_template) +# self.output_parser = ReActOutputParser() +# self.prompt_template = ReActPromptTemplate() + +# @classmethod +# def create(cls, **kwargs) -> "ReActAgent": +# return cls(**kwargs) diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/output_parser.py index 7d764e42..7c8eb38e 100644 --- a/src/llama_stack_client/lib/agents/react/output_parser.py +++ b/src/llama_stack_client/lib/agents/react/output_parser.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pydantic import BaseModel +from typing import Dict, Any, Optional from ..output_parser import OutputParser from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall @@ -12,6 +14,17 @@ import uuid +class Action(BaseModel): + tool_name: str + tool_params: Dict[str, Any] + + +class ReActOutput(BaseModel): + thought: str + action: Optional[Action] = None + answer: Optional[str] = None + + class ReActOutputParser(OutputParser): def parse(self, output_message: CompletionMessage) -> CompletionMessage: response_text = str(output_message.content) From 0b8d7a986ddcf511944b3d2222d957c49f1d7963 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 13:00:52 -0800 Subject: [PATCH 09/12] react agent wrapper --- src/llama_stack_client/lib/agents/agent.py | 3 - .../lib/agents/react/agent.py | 99 ++++++++++++++++--- .../lib/agents/react/prompts.py | 8 -- 3 files changed, 88 insertions(+), 22 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 6063ad42..56d8907b 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -25,7 +25,6 @@ def __init__( client: LlamaStackClient, agent_config: AgentConfig, client_tools: Tuple[ClientTool] = (), - memory_bank_id: Optional[str] = None, output_parser: Optional[OutputParser] = None, ): self.client = client @@ -33,9 +32,7 @@ def __init__( self.agent_id = self._create_agent(agent_config) self.client_tools = {t.get_name(): t for t in client_tools} self.sessions = [] - self.memory_bank_id = memory_bank_id self.output_parser = output_parser - self.builtin_tools = {} for tg in agent_config["toolgroups"]: for tool in self.client.tools.list(toolgroup_id=tg): diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index 8db6c867..4036cbdc 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -3,18 +3,95 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pydantic import BaseModel +from typing import Dict, Any +from ..agent import Agent +from .output_parser import ReActOutputParser +from ..output_parser import OutputParser +from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE +from typing import Tuple, Optional +from llama_stack_client import LlamaStackClient +from ..client_tool import ClientTool +from llama_stack_client.types.agent_create_params import AgentConfig -# class ReActAgent(Agent): -# """ReAct agent. -# Simple wrapper around Agent + ReActOutputParser + ReActPromptTemplate to create a ReAct agent. -# """ -# def __init__(self, model: str, output_parser: OutputParser, prompt_template: PromptTemplate): -# super().__init__(model, output_parser, prompt_template) -# self.output_parser = ReActOutputParser() -# self.prompt_template = ReActPromptTemplate() +class Action(BaseModel): + tool_name: str + tool_params: Dict[str, Any] -# @classmethod -# def create(cls, **kwargs) -> "ReActAgent": -# return cls(**kwargs) + +class ReActOutput(BaseModel): + thought: str + action: Optional[Action] = None + answer: Optional[str] = None + + +class ReActAgent(Agent): + """ReAct agent. + + Simple wrapper around Agent to add prepare prompts for creating a ReAct agent from a list of tools. + """ + + def __init__( + self, + client: LlamaStackClient, + model: str, + builtin_toolgroups: Tuple[str] = (), + client_tools: Tuple[ClientTool] = (), + output_parser: OutputParser = ReActOutputParser(), + json_response_format: bool = False, + custom_agent_config: Optional[AgentConfig] = None, + ): + def get_tool_definition(tool): + return { + "name": tool.identifier, + "description": tool.description, + "parameters": tool.parameters, + } + + if custom_agent_config is None: + tool_names = "" + tool_descriptions = "" + for x in builtin_toolgroups: + tool_names += ", ".join([tool.identifier for tool in client.tools.list(toolgroup_id=x)]) + tool_descriptions += "\n".join( + [f"- {tool.identifier}: {get_tool_definition(tool)}" for tool in client.tools.list(toolgroup_id=x)] + ) + + tool_names += ", " + tool_descriptions += "\n" + tool_names += ", ".join([tool.get_name() for tool in client_tools]) + tool_descriptions += "\n".join( + [f"- {tool.get_name()}: {tool.get_tool_definition()}" for tool in client_tools] + ) + + instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( + "<>", tool_descriptions + ) + + # user default toolgroups + agent_config = AgentConfig( + model=model, + instructions=instruction, + toolgroups=builtin_toolgroups, + client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], + tool_choice="auto", + tool_prompt_format="json", + input_shields=[], + output_shields=[], + enable_session_persistence=False, + ) + + if json_response_format: + agent_config.response_format = { + "type": "json_schema", + "json_schema": ReActOutput.model_json_schema(), + } + + super().__init__( + client=client, + agent_config=agent_config, + client_tools=client_tools, + output_parser=output_parser, + ) diff --git a/src/llama_stack_client/lib/agents/react/prompts.py b/src/llama_stack_client/lib/agents/react/prompts.py index b146e605..4ce228f2 100644 --- a/src/llama_stack_client/lib/agents/react/prompts.py +++ b/src/llama_stack_client/lib/agents/react/prompts.py @@ -149,11 +149,3 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. """ - -from pydantic import BaseModel - - -class PromptTemplate(BaseModel): - system_prompt_template: str = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE - tool_names: str - tool_descriptions: str From 344fad15561cc48243428c084a16c6e2723d8ba4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 14:57:34 -0800 Subject: [PATCH 10/12] add todo --- src/llama_stack_client/lib/agents/react/agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index 4036cbdc..9cd3fa21 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -77,6 +77,7 @@ def get_tool_definition(tool): toolgroups=builtin_toolgroups, client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], tool_choice="auto", + # TODO: refactor this to use SystemMessageBehaviour.replace tool_prompt_format="json", input_shields=[], output_shields=[], From c40f4e7328431e223b886dabaed7c5b6138cb462 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 15:11:07 -0800 Subject: [PATCH 11/12] refactor, comments --- .../lib/agents/react/agent.py | 45 +++++++++++-------- .../lib/agents/react/output_parser.py | 15 +++---- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index 9cd3fa21..3d40a08b 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -43,29 +43,36 @@ def __init__( json_response_format: bool = False, custom_agent_config: Optional[AgentConfig] = None, ): - def get_tool_definition(tool): - return { - "name": tool.identifier, - "description": tool.description, - "parameters": tool.parameters, - } - - if custom_agent_config is None: - tool_names = "" - tool_descriptions = "" + def get_tool_defs(): + tool_defs = [] for x in builtin_toolgroups: - tool_names += ", ".join([tool.identifier for tool in client.tools.list(toolgroup_id=x)]) - tool_descriptions += "\n".join( - [f"- {tool.identifier}: {get_tool_definition(tool)}" for tool in client.tools.list(toolgroup_id=x)] + tool_defs.extend( + [ + { + "name": tool.identifier, + "description": tool.description, + "parameters": tool.parameters, + } + for tool in client.tools.list(toolgroup_id=x) + ] ) - - tool_names += ", " - tool_descriptions += "\n" - tool_names += ", ".join([tool.get_name() for tool in client_tools]) - tool_descriptions += "\n".join( - [f"- {tool.get_name()}: {tool.get_tool_definition()}" for tool in client_tools] + tool_defs.extend( + [ + { + "name": tool.get_name(), + "description": tool.get_description(), + "parameters": tool.get_params_definition(), + } + for tool in client_tools + ] ) + return tool_defs + if custom_agent_config is None: + tool_names, tool_descriptions = "", "" + tool_defs = get_tool_defs() + tool_names = ", ".join([x["name"] for x in tool_defs]) + tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( "<>", tool_descriptions ) diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/output_parser.py index 7c8eb38e..6e4861a9 100644 --- a/src/llama_stack_client/lib/agents/react/output_parser.py +++ b/src/llama_stack_client/lib/agents/react/output_parser.py @@ -4,13 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing import Dict, Any, Optional from ..output_parser import OutputParser from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall -import json import uuid @@ -29,17 +28,17 @@ class ReActOutputParser(OutputParser): def parse(self, output_message: CompletionMessage) -> CompletionMessage: response_text = str(output_message.content) try: - response_json = json.loads(response_text) - except json.JSONDecodeError as e: + react_output = ReActOutput.model_validate_json(response_text) + except ValidationError as e: print(f"Error parsing action: {e}") return output_message - if response_json.get("answer", None): + if react_output.answer: return output_message - if response_json.get("action", None): - tool_name = response_json["action"].get("tool_name", None) - tool_params = response_json["action"].get("tool_params", None) + if react_output.action: + tool_name = react_output.action.tool_name + tool_params = react_output.action.tool_params if tool_name and tool_params: call_id = str(uuid.uuid4()) output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] From a1265a93f67236c495e42e48c20e154370d6a710 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Feb 2025 15:29:45 -0800 Subject: [PATCH 12/12] address comments --- src/llama_stack_client/lib/agents/agent.py | 3 +-- src/llama_stack_client/lib/agents/react/output_parser.py | 8 +++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 56d8907b..315a0641 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -60,8 +60,7 @@ def _process_chunk(self, chunk: AgentTurnResponseStreamChunk) -> None: message = chunk.event.payload.turn.output_message if self.output_parser: - parsed_message = self.output_parser.parse(message) - message = parsed_message + self.output_parser.parse(message) def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: if chunk.event.payload.event_type != "turn_complete": diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/output_parser.py index 6e4861a9..884a2b0d 100644 --- a/src/llama_stack_client/lib/agents/react/output_parser.py +++ b/src/llama_stack_client/lib/agents/react/output_parser.py @@ -25,16 +25,16 @@ class ReActOutput(BaseModel): class ReActOutputParser(OutputParser): - def parse(self, output_message: CompletionMessage) -> CompletionMessage: + def parse(self, output_message: CompletionMessage) -> None: response_text = str(output_message.content) try: react_output = ReActOutput.model_validate_json(response_text) except ValidationError as e: print(f"Error parsing action: {e}") - return output_message + return if react_output.answer: - return output_message + return if react_output.action: tool_name = react_output.action.tool_name @@ -42,5 +42,3 @@ def parse(self, output_message: CompletionMessage) -> CompletionMessage: if tool_name and tool_params: call_id = str(uuid.uuid4()) output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] - - return output_message