-
Notifications
You must be signed in to change notification settings - Fork 97
[RFC] Client Agent SDK OutputParser #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3290e87
a68dc2b
4307a65
5268530
1d0e56c
e4dc8e2
0c5a39a
080e34b
c0fccd7
0b8d7a9
344fad1
c40f4e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,9 @@ | |
| from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup | ||
| from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk | ||
|
|
||
|
|
||
| from .client_tool import ClientTool | ||
| from .output_parser import OutputParser | ||
|
|
||
| DEFAULT_MAX_ITER = 10 | ||
|
|
||
|
|
@@ -23,14 +25,18 @@ def __init__( | |
| client: LlamaStackClient, | ||
| agent_config: AgentConfig, | ||
| client_tools: Tuple[ClientTool] = (), | ||
| memory_bank_id: Optional[str] = None, | ||
| output_parser: Optional[OutputParser] = None, | ||
| ): | ||
| self.client = client | ||
| self.agent_config = agent_config | ||
| self.agent_id = self._create_agent(agent_config) | ||
| self.client_tools = {t.get_name(): t for t in client_tools} | ||
| self.sessions = [] | ||
| self.memory_bank_id = memory_bank_id | ||
| self.output_parser = output_parser | ||
| self.builtin_tools = {} | ||
| for tg in agent_config["toolgroups"]: | ||
| for tool in self.client.tools.list(toolgroup_id=tg): | ||
| self.builtin_tools[tool.identifier] = tool | ||
|
|
||
| def _create_agent(self, agent_config: AgentConfig) -> int: | ||
| agentic_system_create_response = self.client.agents.create( | ||
|
|
@@ -48,28 +54,56 @@ def create_session(self, session_name: str) -> int: | |
| self.sessions.append(self.session_id) | ||
| return self.session_id | ||
|
|
||
| def _process_chunk(self, chunk: AgentTurnResponseStreamChunk) -> None: | ||
| if chunk.event.payload.event_type != "turn_complete": | ||
| return | ||
| message = chunk.event.payload.turn.output_message | ||
|
|
||
| if self.output_parser: | ||
| parsed_message = self.output_parser.parse(message) | ||
| message = parsed_message | ||
yanxi0830 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: | ||
| if chunk.event.payload.event_type != "turn_complete": | ||
| return False | ||
| message = chunk.event.payload.turn.output_message | ||
| if message.stop_reason == "out_of_tokens": | ||
| return False | ||
|
|
||
| return len(message.tool_calls) > 0 | ||
|
|
||
| def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: | ||
| message = chunk.event.payload.turn.output_message | ||
| tool_call = message.tool_calls[0] | ||
| if tool_call.tool_name not in self.client_tools: | ||
| return ToolResponseMessage( | ||
|
|
||
| # custom client tools | ||
| if tool_call.tool_name in self.client_tools: | ||
| tool = self.client_tools[tool_call.tool_name] | ||
| result_messages = tool.run([message]) | ||
| next_message = result_messages[0] | ||
| return next_message | ||
|
|
||
| # builtin tools executed by tool_runtime | ||
| if tool_call.tool_name in self.builtin_tools: | ||
| tool_result = self.client.tool_runtime.invoke_tool( | ||
| tool_name=tool_call.tool_name, | ||
| kwargs=tool_call.arguments, | ||
| ) | ||
| tool_response_message = ToolResponseMessage( | ||
| call_id=tool_call.call_id, | ||
| tool_name=tool_call.tool_name, | ||
| content=f"Unknown tool `{tool_call.tool_name}` was called.", | ||
| role="ipython", | ||
| content=tool_result.content, | ||
| role="tool", | ||
yanxi0830 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| tool = self.client_tools[tool_call.tool_name] | ||
| result_messages = tool.run([message]) | ||
| next_message = result_messages[0] | ||
| return next_message | ||
| return tool_response_message | ||
|
|
||
| # cannot find tools | ||
| return ToolResponseMessage( | ||
| call_id=tool_call.call_id, | ||
| tool_name=tool_call.tool_name, | ||
| content=f"Unknown tool `{tool_call.tool_name}` was called.", | ||
| role="tool", | ||
| ) | ||
|
|
||
| def create_turn( | ||
| self, | ||
|
|
@@ -115,6 +149,7 @@ def _create_turn_streaming( | |
| # by default, we stop after the first turn | ||
| stop = True | ||
| for chunk in response: | ||
| self._process_chunk(chunk) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it's cleaner to only have the override as
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ehhuang +100 especially the (3) bonus
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some update in: #130 However, we still need to overwrite chunk with parsed tool calls, as ClientTool.run takes in a message history and expect the ToolCall detail in the last message. |
||
| if hasattr(chunk, "error"): | ||
| yield chunk | ||
| return | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the terms described in the LICENSE file in | ||
| # the root directory of this source tree. | ||
|
|
||
| from abc import abstractmethod | ||
|
|
||
| from llama_stack_client.types.agents.turn import CompletionMessage | ||
|
|
||
|
|
||
| class OutputParser: | ||
yanxi0830 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Abstract base class for parsing agent responses. Implement this class to customize how | ||
| agent outputs are processed and transformed. | ||
| This class allows developers to define custom parsing logic for agent responses, | ||
| which can be useful for: | ||
| - Extracting specific information from the response | ||
| - Formatting or structuring the output in a specific way | ||
| - Validating or sanitizing the agent's response | ||
| To use this class: | ||
| 1. Create a subclass of OutputParser | ||
| 2. Implement the `parse` method | ||
| 3. Pass your parser instance to the Agent's constructor | ||
| Example: | ||
| class MyCustomParser(OutputParser): | ||
| def parse(self, output_message: CompletionMessage) -> CompletionMessage: | ||
| # Add your custom parsing logic here | ||
| return processed_message | ||
| Methods: | ||
| parse(output_message: CompletionMessage) -> CompletionMessage: | ||
| Abstract method that must be implemented by subclasses to process | ||
| the agent's response. | ||
| Args: | ||
| output_message (CompletionMessage): The response message from agent turn | ||
| Returns: | ||
| CompletionMessage: The processed/transformed response message | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def parse(self, output_message: CompletionMessage) -> CompletionMessage: | ||
yanxi0830 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| raise NotImplementedError | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the terms described in the LICENSE file in | ||
| # the root directory of this source tree. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the terms described in the LICENSE file in | ||
| # the root directory of this source tree. | ||
| from pydantic import BaseModel | ||
| from typing import Dict, Any | ||
| from ..agent import Agent | ||
| from .output_parser import ReActOutputParser | ||
| from ..output_parser import OutputParser | ||
| from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE | ||
|
|
||
| from typing import Tuple, Optional | ||
| from llama_stack_client import LlamaStackClient | ||
| from ..client_tool import ClientTool | ||
| from llama_stack_client.types.agent_create_params import AgentConfig | ||
|
|
||
|
|
||
| class Action(BaseModel): | ||
| tool_name: str | ||
| tool_params: Dict[str, Any] | ||
|
|
||
|
|
||
| class ReActOutput(BaseModel): | ||
| thought: str | ||
| action: Optional[Action] = None | ||
| answer: Optional[str] = None | ||
|
|
||
|
|
||
| class ReActAgent(Agent): | ||
| """ReAct agent. | ||
|
|
||
| Simple wrapper around Agent to add prepare prompts for creating a ReAct agent from a list of tools. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| client: LlamaStackClient, | ||
| model: str, | ||
| builtin_toolgroups: Tuple[str] = (), | ||
| client_tools: Tuple[ClientTool] = (), | ||
| output_parser: OutputParser = ReActOutputParser(), | ||
| json_response_format: bool = False, | ||
| custom_agent_config: Optional[AgentConfig] = None, | ||
| ): | ||
| def get_tool_defs(): | ||
| tool_defs = [] | ||
| for x in builtin_toolgroups: | ||
| tool_defs.extend( | ||
| [ | ||
| { | ||
| "name": tool.identifier, | ||
| "description": tool.description, | ||
| "parameters": tool.parameters, | ||
| } | ||
| for tool in client.tools.list(toolgroup_id=x) | ||
| ] | ||
| ) | ||
| tool_defs.extend( | ||
| [ | ||
| { | ||
| "name": tool.get_name(), | ||
| "description": tool.get_description(), | ||
| "parameters": tool.get_params_definition(), | ||
| } | ||
| for tool in client_tools | ||
| ] | ||
| ) | ||
| return tool_defs | ||
|
|
||
| if custom_agent_config is None: | ||
| tool_names, tool_descriptions = "", "" | ||
| tool_defs = get_tool_defs() | ||
| tool_names = ", ".join([x["name"] for x in tool_defs]) | ||
| tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) | ||
| instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<<tool_names>>", tool_names).replace( | ||
| "<<tool_descriptions>>", tool_descriptions | ||
| ) | ||
|
|
||
| # user default toolgroups | ||
| agent_config = AgentConfig( | ||
| model=model, | ||
| instructions=instruction, | ||
| toolgroups=builtin_toolgroups, | ||
| client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], | ||
| tool_choice="auto", | ||
| # TODO: refactor this to use SystemMessageBehaviour.replace | ||
| tool_prompt_format="json", | ||
yanxi0830 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| input_shields=[], | ||
| output_shields=[], | ||
| enable_session_persistence=False, | ||
| ) | ||
|
|
||
| if json_response_format: | ||
| agent_config.response_format = { | ||
| "type": "json_schema", | ||
| "json_schema": ReActOutput.model_json_schema(), | ||
| } | ||
|
|
||
| super().__init__( | ||
| client=client, | ||
| agent_config=agent_config, | ||
| client_tools=client_tools, | ||
| output_parser=output_parser, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the terms described in the LICENSE file in | ||
| # the root directory of this source tree. | ||
|
|
||
| from pydantic import BaseModel, ValidationError | ||
| from typing import Dict, Any, Optional | ||
| from ..output_parser import OutputParser | ||
| from llama_stack_client.types.shared.completion_message import CompletionMessage | ||
| from llama_stack_client.types.shared.tool_call import ToolCall | ||
|
|
||
| import uuid | ||
|
|
||
|
|
||
| class Action(BaseModel): | ||
| tool_name: str | ||
| tool_params: Dict[str, Any] | ||
|
|
||
|
|
||
| class ReActOutput(BaseModel): | ||
| thought: str | ||
| action: Optional[Action] = None | ||
| answer: Optional[str] = None | ||
|
|
||
|
|
||
| class ReActOutputParser(OutputParser): | ||
| def parse(self, output_message: CompletionMessage) -> CompletionMessage: | ||
| response_text = str(output_message.content) | ||
| try: | ||
| react_output = ReActOutput.model_validate_json(response_text) | ||
| except ValidationError as e: | ||
| print(f"Error parsing action: {e}") | ||
| return output_message | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the turn just terminate after this point?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the turn will terminate after this point as there's no tool calls in |
||
|
|
||
| if react_output.answer: | ||
| return output_message | ||
|
|
||
| if react_output.action: | ||
| tool_name = react_output.action.tool_name | ||
| tool_params = react_output.action.tool_params | ||
| if tool_name and tool_params: | ||
| call_id = str(uuid.uuid4()) | ||
| output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] | ||
|
|
||
| return output_message | ||
Uh oh!
There was an error while loading. Please reload this page.