44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
66import logging
7- from typing import AsyncIterator , Iterator , List , Optional , Tuple , Union
7+ from typing import AsyncIterator , Callable , Iterator , List , Optional , Tuple , Union , Any
88
99from llama_stack_client import LlamaStackClient
1010
1818from llama_stack_client .types .shared_params .response_format import ResponseFormat
1919from llama_stack_client .types .shared_params .sampling_params import SamplingParams
2020
21- from .client_tool import ClientTool
21+ from .client_tool import ClientTool , client_tool
2222from .tool_parser import ToolParser
2323
2424DEFAULT_MAX_ITER = 10
2828
2929class AgentUtils :
3030 @staticmethod
31- def get_client_tools (tools : Optional [List [Union [Toolgroup , ClientTool ]]]) -> List [ClientTool ]:
31+ def get_client_tools (tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]]) -> List [ClientTool ]:
3232 if not tools :
3333 return []
3434
35+ # Wrap any function in client_tool decorator
36+ tools = [client_tool (tool ) if (callable (tool ) and not isinstance (tool , ClientTool )) else tool for tool in tools ]
3537 return [tool for tool in tools if isinstance (tool , ClientTool )]
3638
3739 @staticmethod
@@ -59,7 +61,8 @@ def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
5961 def get_agent_config (
6062 model : Optional [str ] = None ,
6163 instructions : Optional [str ] = None ,
62- tools : Optional [List [Union [Toolgroup , ClientTool ]]] = None ,
64+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ]]]] = None ,
65+ client_tools : Optional [List [ClientTool ]] = None ,
6366 tool_config : Optional [ToolConfig ] = None ,
6467 sampling_params : Optional [SamplingParams ] = None ,
6568 max_infer_iters : Optional [int ] = None ,
@@ -96,15 +99,12 @@ def get_agent_config(
9699 agent_config ["tool_config" ] = tool_config
97100 if tools is not None :
98101 toolgroups : List [Toolgroup ] = []
99- client_tools : List [ClientTool ] = []
100-
101102 for tool in tools :
102103 if isinstance (tool , str ) or isinstance (tool , dict ):
103104 toolgroups .append (tool )
104- else :
105- client_tools .append (tool )
106105
107106 agent_config ["toolgroups" ] = toolgroups
107+ if client_tools :
108108 agent_config ["client_tools" ] = [tool .get_tool_definition () for tool in client_tools ]
109109
110110 agent_config = AgentConfig (** agent_config )
@@ -122,7 +122,7 @@ def __init__(
122122 tool_parser : Optional [ToolParser ] = None ,
123123 model : Optional [str ] = None ,
124124 instructions : Optional [str ] = None ,
125- tools : Optional [List [Union [Toolgroup , ClientTool ]]] = None ,
125+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]] = None ,
126126 tool_config : Optional [ToolConfig ] = None ,
127127 sampling_params : Optional [SamplingParams ] = None ,
128128 max_infer_iters : Optional [int ] = None ,
@@ -143,7 +143,7 @@ def __init__(
143143 :param instructions: The instructions for the agent.
144144 :param tools: A list of tools for the agent. Values can be one of the following:
145145 - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
146- - a python function decorated with @client_tool
146+ - a python function with a docstring. See @client_tool for more details.
147147 - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
148148 - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
149149 - an instance of ClientTool: A client tool object.
@@ -164,10 +164,12 @@ def __init__(
164164
165165 # Construct agent_config from parameters if not provided
166166 if agent_config is None :
167+ client_tools = AgentUtils .get_client_tools (tools )
167168 agent_config = AgentUtils .get_agent_config (
168169 model = model ,
169170 instructions = instructions ,
170171 tools = tools ,
172+ client_tools = client_tools ,
171173 tool_config = tool_config ,
172174 sampling_params = sampling_params ,
173175 max_infer_iters = max_infer_iters ,
@@ -176,7 +178,6 @@ def __init__(
176178 response_format = response_format ,
177179 enable_session_persistence = enable_session_persistence ,
178180 )
179- client_tools = AgentUtils .get_client_tools (tools )
180181
181182 self .agent_config = agent_config
182183 self .client_tools = {t .get_name (): t for t in client_tools }
@@ -332,7 +333,7 @@ def __init__(
332333 tool_parser : Optional [ToolParser ] = None ,
333334 model : Optional [str ] = None ,
334335 instructions : Optional [str ] = None ,
335- tools : Optional [List [Union [Toolgroup , ClientTool ]]] = None ,
336+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ] ]]] = None ,
336337 tool_config : Optional [ToolConfig ] = None ,
337338 sampling_params : Optional [SamplingParams ] = None ,
338339 max_infer_iters : Optional [int ] = None ,
@@ -353,7 +354,7 @@ def __init__(
353354 :param instructions: The instructions for the agent.
354355 :param tools: A list of tools for the agent. Values can be one of the following:
355356 - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
356- - a python function decorated with @client_tool
357+ - a python function with a docstring. See @client_tool for more details.
357358 - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
358359 - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
359360 - an instance of ClientTool: A client tool object.
@@ -374,10 +375,12 @@ def __init__(
374375
375376 # Construct agent_config from parameters if not provided
376377 if agent_config is None :
378+ client_toolss = AgentUtils .get_client_tools (tools )
377379 agent_config = AgentUtils .get_agent_config (
378380 model = model ,
379381 instructions = instructions ,
380382 tools = tools ,
383+ client_tools = client_toolss ,
381384 tool_config = tool_config ,
382385 sampling_params = sampling_params ,
383386 max_infer_iters = max_infer_iters ,
@@ -386,7 +389,6 @@ def __init__(
386389 response_format = response_format ,
387390 enable_session_persistence = enable_session_persistence ,
388391 )
389- client_tools = AgentUtils .get_client_tools (tools )
390392
391393 self .agent_config = agent_config
392394 self .client_tools = {t .get_name (): t for t in client_tools }
0 commit comments