44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
66import logging
7- from typing import AsyncIterator , Iterator , List , Optional , Tuple , Union
7+ from typing import Any , AsyncIterator , Callable , Iterator , List , Optional , Tuple , Union
88
99from llama_stack_client import LlamaStackClient
10-
1110from llama_stack_client .types import ToolResponseMessage , ToolResponseParam , UserMessage
1211from llama_stack_client .types .agent_create_params import AgentConfig
1312from llama_stack_client .types .agents .turn import CompletionMessage , Turn
1817from llama_stack_client .types .shared_params .response_format import ResponseFormat
1918from llama_stack_client .types .shared_params .sampling_params import SamplingParams
2019
21- from .client_tool import ClientTool
20+ from .client_tool import ClientTool , client_tool
2221from .tool_parser import ToolParser
2322
2423DEFAULT_MAX_ITER = 10
2827
2928class AgentUtils :
3029 @staticmethod
31- def get_client_tools (tools : Optional [List [Union [Toolgroup , ClientTool ]]]) -> List [ClientTool ]:
30+ def get_client_tools (tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]]) -> List [ClientTool ]:
3231 if not tools :
3332 return []
3433
34+ # Wrap any function in client_tool decorator
35+ tools = [client_tool (tool ) if (callable (tool ) and not isinstance (tool , ClientTool )) else tool for tool in tools ]
3536 return [tool for tool in tools if isinstance (tool , ClientTool )]
3637
3738 @staticmethod
@@ -59,7 +60,7 @@ def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
5960 def get_agent_config (
6061 model : Optional [str ] = None ,
6162 instructions : Optional [str ] = None ,
62- tools : Optional [List [Union [Toolgroup , ClientTool ]]] = None ,
63+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]] = None ,
6364 tool_config : Optional [ToolConfig ] = None ,
6465 sampling_params : Optional [SamplingParams ] = None ,
6566 max_infer_iters : Optional [int ] = None ,
@@ -96,16 +97,12 @@ def get_agent_config(
9697 agent_config ["tool_config" ] = tool_config
9798 if tools is not None :
9899 toolgroups : List [Toolgroup ] = []
99- client_tools : List [ClientTool ] = []
100-
101100 for tool in tools :
102101 if isinstance (tool , str ) or isinstance (tool , dict ):
103102 toolgroups .append (tool )
104- else :
105- client_tools .append (tool )
106103
107104 agent_config ["toolgroups" ] = toolgroups
108- agent_config ["client_tools" ] = [tool .get_tool_definition () for tool in client_tools ]
105+ agent_config ["client_tools" ] = [tool .get_tool_definition () for tool in AgentUtils . get_client_tools ( tools ) ]
109106
110107 agent_config = AgentConfig (** agent_config )
111108 return agent_config
@@ -122,7 +119,7 @@ def __init__(
122119 tool_parser : Optional [ToolParser ] = None ,
123120 model : Optional [str ] = None ,
124121 instructions : Optional [str ] = None ,
125- tools : Optional [List [Union [Toolgroup , ClientTool ]]] = None ,
122+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]] = None ,
126123 tool_config : Optional [ToolConfig ] = None ,
127124 sampling_params : Optional [SamplingParams ] = None ,
128125 max_infer_iters : Optional [int ] = None ,
@@ -143,7 +140,7 @@ def __init__(
143140 :param instructions: The instructions for the agent.
144141 :param tools: A list of tools for the agent. Values can be one of the following:
145142 - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
146- - a python function decorated with @client_tool
143+ - a python function with a docstring. See @client_tool for more details.
147144 - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
148145 - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
149146 - an instance of ClientTool: A client tool object.
@@ -332,7 +329,7 @@ def __init__(
332329 tool_parser : Optional [ToolParser ] = None ,
333330 model : Optional [str ] = None ,
334331 instructions : Optional [str ] = None ,
335- tools : Optional [List [Union [Toolgroup , ClientTool ]]] = None ,
332+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]] = None ,
336333 tool_config : Optional [ToolConfig ] = None ,
337334 sampling_params : Optional [SamplingParams ] = None ,
338335 max_infer_iters : Optional [int ] = None ,
@@ -353,7 +350,7 @@ def __init__(
353350 :param instructions: The instructions for the agent.
354351 :param tools: A list of tools for the agent. Values can be one of the following:
355352 - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
356- - a python function decorated with @client_tool
353+ - a python function with a docstring. See @client_tool for more details.
357354 - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
358355 - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
359356 - an instance of ClientTool: A client tool object.
0 commit comments