Skip to content

Commit c40f4e7

Browse files
committed
refactor, comments
1 parent 344fad1 commit c40f4e7

File tree

2 files changed

+33
-27
lines changed

2 files changed

+33
-27
lines changed

src/llama_stack_client/lib/agents/react/agent.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,36 @@ def __init__(
4343
json_response_format: bool = False,
4444
custom_agent_config: Optional[AgentConfig] = None,
4545
):
46-
def get_tool_definition(tool):
47-
return {
48-
"name": tool.identifier,
49-
"description": tool.description,
50-
"parameters": tool.parameters,
51-
}
52-
53-
if custom_agent_config is None:
54-
tool_names = ""
55-
tool_descriptions = ""
46+
def get_tool_defs():
47+
tool_defs = []
5648
for x in builtin_toolgroups:
57-
tool_names += ", ".join([tool.identifier for tool in client.tools.list(toolgroup_id=x)])
58-
tool_descriptions += "\n".join(
59-
[f"- {tool.identifier}: {get_tool_definition(tool)}" for tool in client.tools.list(toolgroup_id=x)]
49+
tool_defs.extend(
50+
[
51+
{
52+
"name": tool.identifier,
53+
"description": tool.description,
54+
"parameters": tool.parameters,
55+
}
56+
for tool in client.tools.list(toolgroup_id=x)
57+
]
6058
)
61-
62-
tool_names += ", "
63-
tool_descriptions += "\n"
64-
tool_names += ", ".join([tool.get_name() for tool in client_tools])
65-
tool_descriptions += "\n".join(
66-
[f"- {tool.get_name()}: {tool.get_tool_definition()}" for tool in client_tools]
59+
tool_defs.extend(
60+
[
61+
{
62+
"name": tool.get_name(),
63+
"description": tool.get_description(),
64+
"parameters": tool.get_params_definition(),
65+
}
66+
for tool in client_tools
67+
]
6768
)
69+
return tool_defs
6870

71+
if custom_agent_config is None:
72+
tool_names, tool_descriptions = "", ""
73+
tool_defs = get_tool_defs()
74+
tool_names = ", ".join([x["name"] for x in tool_defs])
75+
tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs])
6976
instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<<tool_names>>", tool_names).replace(
7077
"<<tool_descriptions>>", tool_descriptions
7178
)

src/llama_stack_client/lib/agents/react/output_parser.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from pydantic import BaseModel
7+
from pydantic import BaseModel, ValidationError
88
from typing import Dict, Any, Optional
99
from ..output_parser import OutputParser
1010
from llama_stack_client.types.shared.completion_message import CompletionMessage
1111
from llama_stack_client.types.shared.tool_call import ToolCall
1212

13-
import json
1413
import uuid
1514

1615

@@ -29,17 +28,17 @@ class ReActOutputParser(OutputParser):
2928
def parse(self, output_message: CompletionMessage) -> CompletionMessage:
3029
response_text = str(output_message.content)
3130
try:
32-
response_json = json.loads(response_text)
33-
except json.JSONDecodeError as e:
31+
react_output = ReActOutput.model_validate_json(response_text)
32+
except ValidationError as e:
3433
print(f"Error parsing action: {e}")
3534
return output_message
3635

37-
if response_json.get("answer", None):
36+
if react_output.answer:
3837
return output_message
3938

40-
if response_json.get("action", None):
41-
tool_name = response_json["action"].get("tool_name", None)
42-
tool_params = response_json["action"].get("tool_params", None)
39+
if react_output.action:
40+
tool_name = react_output.action.tool_name
41+
tool_params = react_output.action.tool_params
4342
if tool_name and tool_params:
4443
call_id = str(uuid.uuid4())
4544
output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)]

0 commit comments

Comments
 (0)