Skip to content

Commit 846ba85

Browse files
committed
feat(agent): support plain function as client_tool
Summary: Test Plan:
1 parent fc9907c commit 846ba85

File tree

1 file changed

+16
-14
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+16
-14
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66
import logging
7-
from typing import AsyncIterator, Iterator, List, Optional, Tuple, Union
7+
from typing import AsyncIterator, Callable, Iterator, List, Optional, Tuple, Union, Any
88

99
from llama_stack_client import LlamaStackClient
1010

@@ -18,7 +18,7 @@
1818
from llama_stack_client.types.shared_params.response_format import ResponseFormat
1919
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
2020

21-
from .client_tool import ClientTool
21+
from .client_tool import ClientTool, client_tool
2222
from .tool_parser import ToolParser
2323

2424
DEFAULT_MAX_ITER = 10
@@ -28,10 +28,12 @@
2828

2929
class AgentUtils:
3030
@staticmethod
31-
def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool]]]) -> List[ClientTool]:
31+
def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]]) -> List[ClientTool]:
3232
if not tools:
3333
return []
3434

35+
# Wrap any function in client_tool decorator
36+
tools = [client_tool(tool) if (callable(tool) and not isinstance(tool, ClientTool)) else tool for tool in tools]
3537
return [tool for tool in tools if isinstance(tool, ClientTool)]
3638

3739
@staticmethod
@@ -59,7 +61,8 @@ def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
5961
def get_agent_config(
6062
model: Optional[str] = None,
6163
instructions: Optional[str] = None,
62-
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
64+
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None,
65+
client_tools: Optional[List[ClientTool]] = None,
6366
tool_config: Optional[ToolConfig] = None,
6467
sampling_params: Optional[SamplingParams] = None,
6568
max_infer_iters: Optional[int] = None,
@@ -96,15 +99,12 @@ def get_agent_config(
9699
agent_config["tool_config"] = tool_config
97100
if tools is not None:
98101
toolgroups: List[Toolgroup] = []
99-
client_tools: List[ClientTool] = []
100-
101102
for tool in tools:
102103
if isinstance(tool, str) or isinstance(tool, dict):
103104
toolgroups.append(tool)
104-
else:
105-
client_tools.append(tool)
106105

107106
agent_config["toolgroups"] = toolgroups
107+
if client_tools:
108108
agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools]
109109

110110
agent_config = AgentConfig(**agent_config)
@@ -122,7 +122,7 @@ def __init__(
122122
tool_parser: Optional[ToolParser] = None,
123123
model: Optional[str] = None,
124124
instructions: Optional[str] = None,
125-
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
125+
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None,
126126
tool_config: Optional[ToolConfig] = None,
127127
sampling_params: Optional[SamplingParams] = None,
128128
max_infer_iters: Optional[int] = None,
@@ -143,7 +143,7 @@ def __init__(
143143
:param instructions: The instructions for the agent.
144144
:param tools: A list of tools for the agent. Values can be one of the following:
145145
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
146-
- a python function decorated with @client_tool
146+
- a python function with a docstring. See @client_tool for more details.
147147
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
148148
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
149149
- an instance of ClientTool: A client tool object.
@@ -164,10 +164,12 @@ def __init__(
164164

165165
# Construct agent_config from parameters if not provided
166166
if agent_config is None:
167+
client_tools = AgentUtils.get_client_tools(tools)
167168
agent_config = AgentUtils.get_agent_config(
168169
model=model,
169170
instructions=instructions,
170171
tools=tools,
172+
client_tools=client_tools,
171173
tool_config=tool_config,
172174
sampling_params=sampling_params,
173175
max_infer_iters=max_infer_iters,
@@ -176,7 +178,6 @@ def __init__(
176178
response_format=response_format,
177179
enable_session_persistence=enable_session_persistence,
178180
)
179-
client_tools = AgentUtils.get_client_tools(tools)
180181

181182
self.agent_config = agent_config
182183
self.client_tools = {t.get_name(): t for t in client_tools}
@@ -332,7 +333,7 @@ def __init__(
332333
tool_parser: Optional[ToolParser] = None,
333334
model: Optional[str] = None,
334335
instructions: Optional[str] = None,
335-
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
336+
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None,
336337
tool_config: Optional[ToolConfig] = None,
337338
sampling_params: Optional[SamplingParams] = None,
338339
max_infer_iters: Optional[int] = None,
@@ -353,7 +354,7 @@ def __init__(
353354
:param instructions: The instructions for the agent.
354355
:param tools: A list of tools for the agent. Values can be one of the following:
355356
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
356-
- a python function decorated with @client_tool
357+
- a python function with a docstring. See @client_tool for more details.
357358
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
358359
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
359360
- an instance of ClientTool: A client tool object.
@@ -374,10 +375,12 @@ def __init__(
374375

375376
# Construct agent_config from parameters if not provided
376377
if agent_config is None:
378+
client_toolss = AgentUtils.get_client_tools(tools)
377379
agent_config = AgentUtils.get_agent_config(
378380
model=model,
379381
instructions=instructions,
380382
tools=tools,
383+
client_tools=client_toolss,
381384
tool_config=tool_config,
382385
sampling_params=sampling_params,
383386
max_infer_iters=max_infer_iters,
@@ -386,7 +389,6 @@ def __init__(
386389
response_format=response_format,
387390
enable_session_persistence=enable_session_persistence,
388391
)
389-
client_tools = AgentUtils.get_client_tools(tools)
390392

391393
self.agent_config = agent_config
392394
self.client_tools = {t.get_name(): t for t in client_tools}

0 commit comments

Comments
 (0)