From 605af7cae7c6f62ff4698529f04e9f1db17a37bf Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 19 Feb 2025 16:45:56 -0800 Subject: [PATCH 1/2] fix: react agent should be able to work with provided custom config --- src/llama_stack_client/lib/agents/agent.py | 30 ++++++++++------ .../lib/agents/client_tool.py | 36 +++++++++++++++---- .../lib/agents/react/agent.py | 26 ++++++++------ 3 files changed, 63 insertions(+), 29 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 0a8ab226..0a9ec004 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,12 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import uuid +from datetime import datetime from typing import Iterator, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient from llama_stack_client.types import ToolResponseMessage, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig -from llama_stack_client.types.agents.turn import Turn +from llama_stack_client.types.agents.turn import CompletionMessage, Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup from llama_stack_client.types.agents.turn_create_response import ( AgentTurnResponseStreamChunk, @@ -18,14 +20,12 @@ AgentTurnResponseStepCompletePayload, ) from llama_stack_client.types.shared.tool_call import ToolCall -from llama_stack_client.types.agents.turn import CompletionMessage -from .client_tool import ClientTool -from .tool_parser import ToolParser -from datetime import datetime -import uuid from llama_stack_client.types.tool_execution_step import ToolExecutionStep from llama_stack_client.types.tool_response import ToolResponse +from .client_tool import ClientTool +from .tool_parser import ToolParser + DEFAULT_MAX_ITER = 10 @@ -55,7 +55,7 @@ def _create_agent(self, agent_config: AgentConfig) -> int: self.agent_id = agentic_system_create_response.agent_id return self.agent_id - def create_session(self, session_name: str) -> int: + def create_session(self, session_name: str) -> str: agentic_system_create_session_response = self.client.agents.session.create( agent_id=self.agent_id, session_name=session_name, @@ -129,10 +129,14 @@ def create_turn( stream: bool = True, ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: if stream: - return self._create_turn_streaming(messages, session_id, toolgroups, documents) + return self._create_turn_streaming( + messages, session_id, toolgroups, documents + ) else: chunks = [] - for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents): + for chunk in self._create_turn_streaming( + messages, session_id, toolgroups, documents + ): if chunk.event.payload.event_type == "turn_complete": chunks.append(chunk) pass @@ -144,12 +148,16 @@ def create_turn( input_messages=chunks[0].event.payload.turn.input_messages, output_message=chunks[-1].event.payload.turn.output_message, session_id=chunks[0].event.payload.turn.session_id, - steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps], + steps=[ + step for chunk in chunks for step in chunk.event.payload.turn.steps + ], turn_id=chunks[0].event.payload.turn.turn_id, started_at=chunks[0].event.payload.turn.started_at, completed_at=chunks[-1].event.payload.turn.completed_at, output_attachments=[ - attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments + attachment + for chunk in chunks + for attachment in chunk.event.payload.turn.output_attachments ], ) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 9ffec4a0..e860d960 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -4,10 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import inspect import json from abc import abstractmethod -from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List -import inspect +from typing import ( + Callable, + Dict, + get_args, + get_origin, + get_type_hints, + List, + TypeVar, + Union, +) from llama_stack_client.types import Message, ToolResponseMessage from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -47,7 +56,10 @@ def parameters_for_system_prompt(self) -> str: { "name": self.get_name(), "description": self.get_description(), - "parameters": {name: definition for name, definition in self.get_params_definition().items()}, + "parameters": { + name: definition + for name, definition in self.get_params_definition().items() + }, } ) @@ -146,16 +158,26 @@ def get_params_definition(self) -> Dict[str, Parameter]: break if param_doc == "": - raise ValueError(f"No parameter description found for parameter {name}") + raise ValueError( + f"No parameter description found for parameter {name}" + ) param = sig.parameters[name] - is_optional_type = get_origin(type_hint) is Union and type(None) in get_args(type_hint) - is_required = param.default == inspect.Parameter.empty and not is_optional_type + is_optional_type = get_origin(type_hint) is Union and type( + None + ) in get_args(type_hint) + is_required = ( + param.default == inspect.Parameter.empty and not is_optional_type + ) params[name] = Parameter( name=name, description=param_doc or f"Parameter {name}", parameter_type=type_hint.__name__, - default=param.default if param.default != inspect.Parameter.empty else None, + default=( + param.default + if param.default != inspect.Parameter.empty + else None + ), required=is_required, ) return params diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index fafca9dc..78c7696c 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -3,17 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, Dict, Optional, Tuple + +from llama_stack_client import LlamaStackClient +from llama_stack_client.types.agent_create_params import AgentConfig from pydantic import BaseModel -from typing import Dict, Any + from ..agent import Agent -from .tool_parser import ReActToolParser +from ..client_tool import ClientTool from ..tool_parser import ToolParser from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE - -from typing import Tuple, Optional -from llama_stack_client import LlamaStackClient -from ..client_tool import ClientTool -from llama_stack_client.types.agent_create_params import AgentConfig +from .tool_parser import ReActToolParser class Action(BaseModel): @@ -73,16 +73,18 @@ def get_tool_defs(): tool_defs = get_tool_defs() tool_names = ", ".join([x["name"] for x in tool_defs]) tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) - instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( - "<>", tool_descriptions - ) + instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace( + "<>", tool_names + ).replace("<>", tool_descriptions) # user default toolgroups agent_config = AgentConfig( model=model, instructions=instruction, toolgroups=builtin_toolgroups, - client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], + client_tools=[ + client_tool.get_tool_definition() for client_tool in client_tools + ], tool_config={ "tool_choice": "auto", "tool_prompt_format": "json" if "3.1" in model else "python_list", @@ -92,6 +94,8 @@ def get_tool_defs(): output_shields=[], enable_session_persistence=False, ) + else: + agent_config = custom_agent_config if json_response_format: agent_config.response_format = { From 1366d2fe65e172f9d91a7578544e34a1c8e19e8f Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 19 Feb 2025 16:49:19 -0800 Subject: [PATCH 2/2] lint fix --- src/llama_stack_client/lib/agents/agent.py | 17 ++++--------- .../lib/agents/client_tool.py | 24 +++++-------------- .../lib/agents/react/agent.py | 11 ++++----- 3 files changed, 16 insertions(+), 36 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 0a9ec004..3b7bcc7f 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -8,6 +8,7 @@ from typing import Iterator, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient + from llama_stack_client.types import ToolResponseMessage, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agents.turn import CompletionMessage, Turn @@ -129,14 +130,10 @@ def create_turn( stream: bool = True, ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: if stream: - return self._create_turn_streaming( - messages, session_id, toolgroups, documents - ) + return self._create_turn_streaming(messages, session_id, toolgroups, documents) else: chunks = [] - for chunk in self._create_turn_streaming( - messages, session_id, toolgroups, documents - ): + for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents): if chunk.event.payload.event_type == "turn_complete": chunks.append(chunk) pass @@ -148,16 +145,12 @@ def create_turn( input_messages=chunks[0].event.payload.turn.input_messages, output_message=chunks[-1].event.payload.turn.output_message, session_id=chunks[0].event.payload.turn.session_id, - steps=[ - step for chunk in chunks for step in chunk.event.payload.turn.steps - ], + steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps], turn_id=chunks[0].event.payload.turn.turn_id, started_at=chunks[0].event.payload.turn.started_at, completed_at=chunks[-1].event.payload.turn.completed_at, output_attachments=[ - attachment - for chunk in chunks - for attachment in chunk.event.payload.turn.output_attachments + attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments ], ) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index e860d960..f672268d 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -56,10 +56,7 @@ def parameters_for_system_prompt(self) -> str: { "name": self.get_name(), "description": self.get_description(), - "parameters": { - name: definition - for name, definition in self.get_params_definition().items() - }, + "parameters": {name: definition for name, definition in self.get_params_definition().items()}, } ) @@ -158,28 +155,19 @@ def get_params_definition(self) -> Dict[str, Parameter]: break if param_doc == "": - raise ValueError( - f"No parameter description found for parameter {name}" - ) + raise ValueError(f"No parameter description found for parameter {name}") param = sig.parameters[name] - is_optional_type = get_origin(type_hint) is Union and type( - None - ) in get_args(type_hint) - is_required = ( - param.default == inspect.Parameter.empty and not is_optional_type - ) + is_optional_type = get_origin(type_hint) is Union and type(None) in get_args(type_hint) + is_required = param.default == inspect.Parameter.empty and not is_optional_type params[name] = Parameter( name=name, description=param_doc or f"Parameter {name}", parameter_type=type_hint.__name__, - default=( - param.default - if param.default != inspect.Parameter.empty - else None - ), + default=(param.default if param.default != inspect.Parameter.empty else None), required=is_required, ) + return params def run_impl(self, **kwargs): diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index 78c7696c..bad7e46e 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -13,6 +13,7 @@ from ..client_tool import ClientTool from ..tool_parser import ToolParser from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE + from .tool_parser import ReActToolParser @@ -73,18 +74,16 @@ def get_tool_defs(): tool_defs = get_tool_defs() tool_names = ", ".join([x["name"] for x in tool_defs]) tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) - instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace( - "<>", tool_names - ).replace("<>", tool_descriptions) + instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( + "<>", tool_descriptions + ) # user default toolgroups agent_config = AgentConfig( model=model, instructions=instruction, toolgroups=builtin_toolgroups, - client_tools=[ - client_tool.get_tool_definition() for client_tool in client_tools - ], + client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], tool_config={ "tool_choice": "auto", "tool_prompt_format": "json" if "3.1" in model else "python_list",