From 9bfacdf42289ea8603711944ed3858988b4d4341 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 16:43:54 -0800 Subject: [PATCH 1/5] tool refactor --- src/llama_stack_client/lib/agents/agent.py | 38 +++++++------- .../lib/agents/output_parser.py | 48 ------------------ .../lib/agents/react/agent.py | 8 +-- .../{output_parser.py => tool_parser.py} | 17 ++++--- .../lib/agents/tool_parser.py | 50 +++++++++++++++++++ 5 files changed, 81 insertions(+), 80 deletions(-) delete mode 100644 src/llama_stack_client/lib/agents/output_parser.py rename src/llama_stack_client/lib/agents/react/{output_parser.py => tool_parser.py} (72%) create mode 100644 src/llama_stack_client/lib/agents/tool_parser.py diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index dcf38426..0d2ddc05 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -11,10 +11,10 @@ from llama_stack_client.types.agents.turn import Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk - +from llama_stack_client.types.shared.tool_call import ToolCall from .client_tool import ClientTool -from .output_parser import OutputParser +from .tool_parser import ToolParser DEFAULT_MAX_ITER = 10 @@ -25,14 +25,14 @@ def __init__( client: LlamaStackClient, agent_config: AgentConfig, client_tools: Tuple[ClientTool] = (), - output_parser: Optional[OutputParser] = None, + tool_parser: Optional[ToolParser] = None, ): self.client = client self.agent_config = agent_config self.agent_id = self._create_agent(agent_config) self.client_tools = {t.get_name(): t for t in client_tools} self.sessions = [] - self.output_parser = output_parser + self.tool_parser = tool_parser self.builtin_tools = {} for tg in agent_config["toolgroups"]: for tool in self.client.tools.list(toolgroup_id=tg): @@ -54,25 +54,23 @@ def create_session(self, session_name: str) -> int: self.sessions.append(self.session_id) return self.session_id - def _process_chunk(self, chunk: AgentTurnResponseStreamChunk) -> None: - if chunk.event.payload.event_type != "turn_complete": - return - message = chunk.event.payload.turn.output_message - - if self.output_parser: - self.output_parser.parse(message) - - def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: + def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: if chunk.event.payload.event_type != "turn_complete": - return False + return None + message = chunk.event.payload.turn.output_message if message.stop_reason == "out_of_tokens": - return False + return None + + if self.tool_parser: + return self.tool_parser.get_tool_calls(message) + + return message.tool_calls - return len(message.tool_calls) > 0 - def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: + def _run_tool(self, chunk: AgentTurnResponseStreamChunk, tool_calls: List[ToolCall]) -> ToolResponseMessage: message = chunk.event.payload.turn.output_message + message.tool_calls = tool_calls tool_call = message.tool_calls[0] # custom client tools @@ -149,14 +147,14 @@ def _create_turn_streaming( # by default, we stop after the first turn stop = True for chunk in response: - self._process_chunk(chunk) + tool_calls = self._get_tool_calls(chunk) if hasattr(chunk, "error"): yield chunk return - elif not self._has_tool_call(chunk): + elif not tool_calls: yield chunk else: - next_message = self._run_tool(chunk) + next_message = self._run_tool(chunk, tool_calls) yield next_message # continue the turn when there's a tool call diff --git a/src/llama_stack_client/lib/agents/output_parser.py b/src/llama_stack_client/lib/agents/output_parser.py deleted file mode 100644 index 20c8468e..00000000 --- a/src/llama_stack_client/lib/agents/output_parser.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from abc import abstractmethod - -from llama_stack_client.types.agents.turn import CompletionMessage - - -class OutputParser: - """ - Abstract base class for parsing agent responses. Implement this class to customize how - agent outputs are processed and transformed. - - This class allows developers to define custom parsing logic for agent responses, - which can be useful for: - - Extracting specific information from the response - - Formatting or structuring the output in a specific way - - Validating or sanitizing the agent's response - - To use this class: - 1. Create a subclass of OutputParser - 2. Implement the `parse` method - 3. Pass your parser instance to the Agent's constructor - - Example: - class MyCustomParser(OutputParser): - def parse(self, output_message: CompletionMessage) -> CompletionMessage: - # Add your custom parsing logic here - return processed_message - - Methods: - parse(output_message: CompletionMessage) -> CompletionMessage: - Abstract method that must be implemented by subclasses to process - the agent's response. - - Args: - output_message (CompletionMessage): The response message from agent turn - - Returns: None - Modifies the output_message in place - """ - - @abstractmethod - def parse(self, output_message: CompletionMessage) -> None: - raise NotImplementedError diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index 3d40a08b..622d4420 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -6,8 +6,8 @@ from pydantic import BaseModel from typing import Dict, Any from ..agent import Agent -from .output_parser import ReActOutputParser -from ..output_parser import OutputParser +from .tool_parser import ReActToolParser +from ..tool_parser import ToolParser from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE from typing import Tuple, Optional @@ -39,7 +39,7 @@ def __init__( model: str, builtin_toolgroups: Tuple[str] = (), client_tools: Tuple[ClientTool] = (), - output_parser: OutputParser = ReActOutputParser(), + tool_parser: ToolParser = ReActToolParser(), json_response_format: bool = False, custom_agent_config: Optional[AgentConfig] = None, ): @@ -101,5 +101,5 @@ def get_tool_defs(): client=client, agent_config=agent_config, client_tools=client_tools, - output_parser=output_parser, + tool_parser=tool_parser, ) diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/tool_parser.py similarity index 72% rename from src/llama_stack_client/lib/agents/react/output_parser.py rename to src/llama_stack_client/lib/agents/react/tool_parser.py index 71177a6f..e668d28d 100644 --- a/src/llama_stack_client/lib/agents/react/output_parser.py +++ b/src/llama_stack_client/lib/agents/react/tool_parser.py @@ -5,8 +5,8 @@ # the root directory of this source tree. from pydantic import BaseModel, ValidationError -from typing import Dict, Any, Optional -from ..output_parser import OutputParser +from typing import Dict, Any, Optional, List +from ..tool_parser import ToolParser from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall @@ -24,23 +24,24 @@ class ReActOutput(BaseModel): answer: Optional[str] = None -class ReActOutputParser(OutputParser): - def parse(self, output_message: CompletionMessage) -> None: +class ReActToolParser(ToolParser): + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + tool_calls = [] response_text = str(output_message.content) try: react_output = ReActOutput.model_validate_json(response_text) except ValidationError as e: print(f"Error parsing action: {e}") - return + return tool_calls if react_output.answer: - return + return tool_calls if react_output.action: tool_name = react_output.action.tool_name tool_params = react_output.action.tool_params if tool_name and tool_params: call_id = str(uuid.uuid4()) - output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] + tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] - return + return tool_calls diff --git a/src/llama_stack_client/lib/agents/tool_parser.py b/src/llama_stack_client/lib/agents/tool_parser.py new file mode 100644 index 00000000..091a1728 --- /dev/null +++ b/src/llama_stack_client/lib/agents/tool_parser.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import abstractmethod +from typing import List, Optional + +from llama_stack_client.types.agents.turn import CompletionMessage +from llama_stack_client.types.shared.tool_call import ToolCall + + +class ToolParser: + """ + Abstract base class for parsing agent responses into tool calls. Implement this class to customize how + agent outputs are processed and transformed into executable tool calls. + + This class allows developers to define custom parsing logic for agent responses, + which can be useful for: + - Extracting tool calls from the response + - Validating tool parameters and arguments + - Transforming raw output into structured tool calls + + To use this class: + 1. Create a subclass of ToolParser + 2. Implement the `get_tool_calls` method + 3. Pass your parser instance to the Agent's constructor + + Example: + class MyCustomParser(ToolParser): + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + # Add your custom parsing logic here + return extracted_tool_calls + + Methods: + get_tool_calls(output_message: CompletionMessage) -> List[ToolCall]: + Abstract method that must be implemented by subclasses to process + the agent's response and extract tool calls. + + Args: + output_message (CompletionMessage): The response message from agent turn + + Returns: + Optional[List[ToolCall]]: A list of parsed tool calls, or None if no tools should be called + """ + + @abstractmethod + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + raise NotImplementedError From e13be0f65cf1d3bb6a85227c374a8e58ca9e278e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 16:45:24 -0800 Subject: [PATCH 2/5] precommit --- src/llama_stack_client/lib/agents/agent.py | 7 +++---- src/llama_stack_client/lib/agents/tool_parser.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 0d2ddc05..2ba1ac9a 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -57,16 +57,15 @@ def create_session(self, session_name: str) -> int: def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: if chunk.event.payload.event_type != "turn_complete": return None - + message = chunk.event.payload.turn.output_message if message.stop_reason == "out_of_tokens": return None - + if self.tool_parser: return self.tool_parser.get_tool_calls(message) - - return message.tool_calls + return message.tool_calls def _run_tool(self, chunk: AgentTurnResponseStreamChunk, tool_calls: List[ToolCall]) -> ToolResponseMessage: message = chunk.event.payload.turn.output_message diff --git a/src/llama_stack_client/lib/agents/tool_parser.py b/src/llama_stack_client/lib/agents/tool_parser.py index 091a1728..cb476663 100644 --- a/src/llama_stack_client/lib/agents/tool_parser.py +++ b/src/llama_stack_client/lib/agents/tool_parser.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from abc import abstractmethod -from typing import List, Optional +from typing import List from llama_stack_client.types.agents.turn import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall From a5b8eb3ac13b0b6976de08764f664ae9555b5356 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 17:10:59 -0800 Subject: [PATCH 3/5] udpate docs & types sig --- src/llama_stack_client/lib/agents/agent.py | 4 ++-- src/llama_stack_client/lib/agents/tool_parser.py | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 2ba1ac9a..3aa767a5 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -56,11 +56,11 @@ def create_session(self, session_name: str) -> int: def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: if chunk.event.payload.event_type != "turn_complete": - return None + return [] message = chunk.event.payload.turn.output_message if message.stop_reason == "out_of_tokens": - return None + return [] if self.tool_parser: return self.tool_parser.get_tool_calls(message) diff --git a/src/llama_stack_client/lib/agents/tool_parser.py b/src/llama_stack_client/lib/agents/tool_parser.py index cb476663..dc0c5ba4 100644 --- a/src/llama_stack_client/lib/agents/tool_parser.py +++ b/src/llama_stack_client/lib/agents/tool_parser.py @@ -16,12 +16,6 @@ class ToolParser: Abstract base class for parsing agent responses into tool calls. Implement this class to customize how agent outputs are processed and transformed into executable tool calls. - This class allows developers to define custom parsing logic for agent responses, - which can be useful for: - - Extracting tool calls from the response - - Validating tool parameters and arguments - - Transforming raw output into structured tool calls - To use this class: 1. Create a subclass of ToolParser 2. Implement the `get_tool_calls` method From bc6cc56fd3acc8d13cde78ecb0912c7b364334db Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 17:17:40 -0800 Subject: [PATCH 4/5] address comments --- src/llama_stack_client/lib/agents/agent.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 3aa767a5..b513e03d 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -12,7 +12,7 @@ from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk from llama_stack_client.types.shared.tool_call import ToolCall - +from llama_stack_client.types.agents.turn import CompletionMessage from .client_tool import ClientTool from .tool_parser import ToolParser @@ -67,17 +67,21 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall] return message.tool_calls - def _run_tool(self, chunk: AgentTurnResponseStreamChunk, tool_calls: List[ToolCall]) -> ToolResponseMessage: - message = chunk.event.payload.turn.output_message - message.tool_calls = tool_calls - tool_call = message.tool_calls[0] + def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: + assert len(tool_calls) == 1, "Only one tool call is supported" + tool_call = tool_calls[0] # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] # NOTE: tool.run() expects a list of messages, we only pass in last message here # but we could pass in the entire message history - result_message = tool.run([message]) + result_message = tool.run([CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + )]) return result_message # builtin tools executed by tool_runtime @@ -153,7 +157,7 @@ def _create_turn_streaming( elif not tool_calls: yield chunk else: - next_message = self._run_tool(chunk, tool_calls) + next_message = self._run_tool(tool_calls) yield next_message # continue the turn when there's a tool call From c6934dd102570acd62cadb7b0c95468bd3387510 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 17:22:43 -0800 Subject: [PATCH 5/5] precommit --- src/llama_stack_client/lib/agents/agent.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index b513e03d..c40ef4c8 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -76,12 +76,16 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: tool = self.client_tools[tool_call.tool_name] # NOTE: tool.run() expects a list of messages, we only pass in last message here # but we could pass in the entire message history - result_message = tool.run([CompletionMessage( - role="assistant", - content=tool_call.tool_name, - tool_calls=[tool_call], - stop_reason="end_of_turn", - )]) + result_message = tool.run( + [ + CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + ] + ) return result_message # builtin tools executed by tool_runtime