diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index ea87c1df..34def541 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -4,10 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import logging -from typing import AsyncIterator, Iterator, List, Optional, Tuple, Union +from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient - from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agents.turn import CompletionMessage, Turn @@ -18,7 +17,7 @@ from llama_stack_client.types.shared_params.response_format import ResponseFormat from llama_stack_client.types.shared_params.sampling_params import SamplingParams -from .client_tool import ClientTool +from .client_tool import ClientTool, client_tool from .tool_parser import ToolParser DEFAULT_MAX_ITER = 10 @@ -28,10 +27,12 @@ class AgentUtils: @staticmethod - def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool]]]) -> List[ClientTool]: + def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]]) -> List[ClientTool]: if not tools: return [] + # Wrap any function in client_tool decorator + tools = [client_tool(tool) if (callable(tool) and not isinstance(tool, ClientTool)) else tool for tool in tools] return [tool for tool in tools if isinstance(tool, ClientTool)] @staticmethod @@ -59,7 +60,7 @@ def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]: def get_agent_config( model: Optional[str] = None, instructions: Optional[str] = None, - tools: Optional[List[Union[Toolgroup, ClientTool]]] = None, + tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, tool_config: Optional[ToolConfig] = None, sampling_params: Optional[SamplingParams] = None, max_infer_iters: Optional[int] = None, @@ -96,16 +97,12 @@ def get_agent_config( agent_config["tool_config"] = tool_config if tools is not None: toolgroups: List[Toolgroup] = [] - client_tools: List[ClientTool] = [] - for tool in tools: if isinstance(tool, str) or isinstance(tool, dict): toolgroups.append(tool) - else: - client_tools.append(tool) agent_config["toolgroups"] = toolgroups - agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools] + agent_config["client_tools"] = [tool.get_tool_definition() for tool in AgentUtils.get_client_tools(tools)] agent_config = AgentConfig(**agent_config) return agent_config @@ -122,7 +119,7 @@ def __init__( tool_parser: Optional[ToolParser] = None, model: Optional[str] = None, instructions: Optional[str] = None, - tools: Optional[List[Union[Toolgroup, ClientTool]]] = None, + tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, tool_config: Optional[ToolConfig] = None, sampling_params: Optional[SamplingParams] = None, max_infer_iters: Optional[int] = None, @@ -143,7 +140,7 @@ def __init__( :param instructions: The instructions for the agent. :param tools: A list of tools for the agent. Values can be one of the following: - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} - - a python function decorated with @client_tool + - a python function with a docstring. See @client_tool for more details. - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent - an instance of ClientTool: A client tool object. @@ -332,7 +329,7 @@ def __init__( tool_parser: Optional[ToolParser] = None, model: Optional[str] = None, instructions: Optional[str] = None, - tools: Optional[List[Union[Toolgroup, ClientTool]]] = None, + tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, tool_config: Optional[ToolConfig] = None, sampling_params: Optional[SamplingParams] = None, max_infer_iters: Optional[int] = None, @@ -353,7 +350,7 @@ def __init__( :param instructions: The instructions for the agent. :param tools: A list of tools for the agent. Values can be one of the following: - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} - - a python function decorated with @client_tool + - a python function with a docstring. See @client_tool for more details. - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent - an instance of ClientTool: A client tool object.