Skip to content

Commit 9dda45e

Browse files
authored
[RFC] Client Agent SDK OutputParser (#121)
# What does this PR do? - See llamastack/llama-stack#975 **Changes** ✅ Bugfix ToolResponseMessage role ✅ Add ReACT default prompt + default output parser ✅ Add ReACTAgent wrapper 🚧 Remove ClientTool and simplify it as a decorator (separate PR, including llama-stack-apps) ✅ Make agent able to return structured outputs - Note that some remote provider do not support response_format structured outputs, add it as an optional flag when calling `ReActAgent` wrapper. ## Test Plan see test in llamastack/llama-stack-apps#166 ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
1 parent af09b41 commit 9dda45e

File tree

6 files changed

+400
-10
lines changed

6 files changed

+400
-10
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
1313
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
1414

15+
1516
from .client_tool import ClientTool
17+
from .output_parser import OutputParser
1618

1719
DEFAULT_MAX_ITER = 10
1820

@@ -23,14 +25,18 @@ def __init__(
2325
client: LlamaStackClient,
2426
agent_config: AgentConfig,
2527
client_tools: Tuple[ClientTool] = (),
26-
memory_bank_id: Optional[str] = None,
28+
output_parser: Optional[OutputParser] = None,
2729
):
2830
self.client = client
2931
self.agent_config = agent_config
3032
self.agent_id = self._create_agent(agent_config)
3133
self.client_tools = {t.get_name(): t for t in client_tools}
3234
self.sessions = []
33-
self.memory_bank_id = memory_bank_id
35+
self.output_parser = output_parser
36+
self.builtin_tools = {}
37+
for tg in agent_config["toolgroups"]:
38+
for tool in self.client.tools.list(toolgroup_id=tg):
39+
self.builtin_tools[tool.identifier] = tool
3440

3541
def _create_agent(self, agent_config: AgentConfig) -> int:
3642
agentic_system_create_response = self.client.agents.create(
@@ -48,28 +54,56 @@ def create_session(self, session_name: str) -> int:
4854
self.sessions.append(self.session_id)
4955
return self.session_id
5056

57+
def _process_chunk(self, chunk: AgentTurnResponseStreamChunk) -> None:
58+
if chunk.event.payload.event_type != "turn_complete":
59+
return
60+
message = chunk.event.payload.turn.output_message
61+
62+
if self.output_parser:
63+
parsed_message = self.output_parser.parse(message)
64+
message = parsed_message
65+
5166
def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool:
5267
if chunk.event.payload.event_type != "turn_complete":
5368
return False
5469
message = chunk.event.payload.turn.output_message
5570
if message.stop_reason == "out_of_tokens":
5671
return False
72+
5773
return len(message.tool_calls) > 0
5874

5975
def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage:
6076
message = chunk.event.payload.turn.output_message
6177
tool_call = message.tool_calls[0]
62-
if tool_call.tool_name not in self.client_tools:
63-
return ToolResponseMessage(
78+
79+
# custom client tools
80+
if tool_call.tool_name in self.client_tools:
81+
tool = self.client_tools[tool_call.tool_name]
82+
result_messages = tool.run([message])
83+
next_message = result_messages[0]
84+
return next_message
85+
86+
# builtin tools executed by tool_runtime
87+
if tool_call.tool_name in self.builtin_tools:
88+
tool_result = self.client.tool_runtime.invoke_tool(
89+
tool_name=tool_call.tool_name,
90+
kwargs=tool_call.arguments,
91+
)
92+
tool_response_message = ToolResponseMessage(
6493
call_id=tool_call.call_id,
6594
tool_name=tool_call.tool_name,
66-
content=f"Unknown tool `{tool_call.tool_name}` was called.",
67-
role="ipython",
95+
content=tool_result.content,
96+
role="tool",
6897
)
69-
tool = self.client_tools[tool_call.tool_name]
70-
result_messages = tool.run([message])
71-
next_message = result_messages[0]
72-
return next_message
98+
return tool_response_message
99+
100+
# cannot find tools
101+
return ToolResponseMessage(
102+
call_id=tool_call.call_id,
103+
tool_name=tool_call.tool_name,
104+
content=f"Unknown tool `{tool_call.tool_name}` was called.",
105+
role="tool",
106+
)
73107

74108
def create_turn(
75109
self,
@@ -115,6 +149,7 @@ def _create_turn_streaming(
115149
# by default, we stop after the first turn
116150
stop = True
117151
for chunk in response:
152+
self._process_chunk(chunk)
118153
if hasattr(chunk, "error"):
119154
yield chunk
120155
return
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from abc import abstractmethod
8+
9+
from llama_stack_client.types.agents.turn import CompletionMessage
10+
11+
12+
class OutputParser:
13+
"""
14+
Abstract base class for parsing agent responses. Implement this class to customize how
15+
agent outputs are processed and transformed.
16+
17+
This class allows developers to define custom parsing logic for agent responses,
18+
which can be useful for:
19+
- Extracting specific information from the response
20+
- Formatting or structuring the output in a specific way
21+
- Validating or sanitizing the agent's response
22+
23+
To use this class:
24+
1. Create a subclass of OutputParser
25+
2. Implement the `parse` method
26+
3. Pass your parser instance to the Agent's constructor
27+
28+
Example:
29+
class MyCustomParser(OutputParser):
30+
def parse(self, output_message: CompletionMessage) -> CompletionMessage:
31+
# Add your custom parsing logic here
32+
return processed_message
33+
34+
Methods:
35+
parse(output_message: CompletionMessage) -> CompletionMessage:
36+
Abstract method that must be implemented by subclasses to process
37+
the agent's response.
38+
39+
Args:
40+
output_message (CompletionMessage): The response message from agent turn
41+
42+
Returns:
43+
CompletionMessage: The processed/transformed response message
44+
"""
45+
46+
@abstractmethod
47+
def parse(self, output_message: CompletionMessage) -> CompletionMessage:
48+
raise NotImplementedError
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
from pydantic import BaseModel
7+
from typing import Dict, Any
8+
from ..agent import Agent
9+
from .output_parser import ReActOutputParser
10+
from ..output_parser import OutputParser
11+
from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE
12+
13+
from typing import Tuple, Optional
14+
from llama_stack_client import LlamaStackClient
15+
from ..client_tool import ClientTool
16+
from llama_stack_client.types.agent_create_params import AgentConfig
17+
18+
19+
class Action(BaseModel):
20+
tool_name: str
21+
tool_params: Dict[str, Any]
22+
23+
24+
class ReActOutput(BaseModel):
25+
thought: str
26+
action: Optional[Action] = None
27+
answer: Optional[str] = None
28+
29+
30+
class ReActAgent(Agent):
31+
"""ReAct agent.
32+
33+
Simple wrapper around Agent to add prepare prompts for creating a ReAct agent from a list of tools.
34+
"""
35+
36+
def __init__(
37+
self,
38+
client: LlamaStackClient,
39+
model: str,
40+
builtin_toolgroups: Tuple[str] = (),
41+
client_tools: Tuple[ClientTool] = (),
42+
output_parser: OutputParser = ReActOutputParser(),
43+
json_response_format: bool = False,
44+
custom_agent_config: Optional[AgentConfig] = None,
45+
):
46+
def get_tool_defs():
47+
tool_defs = []
48+
for x in builtin_toolgroups:
49+
tool_defs.extend(
50+
[
51+
{
52+
"name": tool.identifier,
53+
"description": tool.description,
54+
"parameters": tool.parameters,
55+
}
56+
for tool in client.tools.list(toolgroup_id=x)
57+
]
58+
)
59+
tool_defs.extend(
60+
[
61+
{
62+
"name": tool.get_name(),
63+
"description": tool.get_description(),
64+
"parameters": tool.get_params_definition(),
65+
}
66+
for tool in client_tools
67+
]
68+
)
69+
return tool_defs
70+
71+
if custom_agent_config is None:
72+
tool_names, tool_descriptions = "", ""
73+
tool_defs = get_tool_defs()
74+
tool_names = ", ".join([x["name"] for x in tool_defs])
75+
tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs])
76+
instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<<tool_names>>", tool_names).replace(
77+
"<<tool_descriptions>>", tool_descriptions
78+
)
79+
80+
# user default toolgroups
81+
agent_config = AgentConfig(
82+
model=model,
83+
instructions=instruction,
84+
toolgroups=builtin_toolgroups,
85+
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
86+
tool_choice="auto",
87+
# TODO: refactor this to use SystemMessageBehaviour.replace
88+
tool_prompt_format="json",
89+
input_shields=[],
90+
output_shields=[],
91+
enable_session_persistence=False,
92+
)
93+
94+
if json_response_format:
95+
agent_config.response_format = {
96+
"type": "json_schema",
97+
"json_schema": ReActOutput.model_json_schema(),
98+
}
99+
100+
super().__init__(
101+
client=client,
102+
agent_config=agent_config,
103+
client_tools=client_tools,
104+
output_parser=output_parser,
105+
)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from pydantic import BaseModel, ValidationError
8+
from typing import Dict, Any, Optional
9+
from ..output_parser import OutputParser
10+
from llama_stack_client.types.shared.completion_message import CompletionMessage
11+
from llama_stack_client.types.shared.tool_call import ToolCall
12+
13+
import uuid
14+
15+
16+
class Action(BaseModel):
17+
tool_name: str
18+
tool_params: Dict[str, Any]
19+
20+
21+
class ReActOutput(BaseModel):
22+
thought: str
23+
action: Optional[Action] = None
24+
answer: Optional[str] = None
25+
26+
27+
class ReActOutputParser(OutputParser):
28+
def parse(self, output_message: CompletionMessage) -> CompletionMessage:
29+
response_text = str(output_message.content)
30+
try:
31+
react_output = ReActOutput.model_validate_json(response_text)
32+
except ValidationError as e:
33+
print(f"Error parsing action: {e}")
34+
return output_message
35+
36+
if react_output.answer:
37+
return output_message
38+
39+
if react_output.action:
40+
tool_name = react_output.action.tool_name
41+
tool_params = react_output.action.tool_params
42+
if tool_name and tool_params:
43+
call_id = str(uuid.uuid4())
44+
output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)]
45+
46+
return output_message

0 commit comments

Comments
 (0)