Skip to content

Commit ba68957

Browse files
authored
[SDK] add tool decorator helper for custom client tools (#127)
# What does this PR do? - Add decorator for on callables for defining client side custom tools - Addresses: llamastack/llama-stack#948 ## Test Plan Usage: ```python @client_tool def add(x: int, y: int) -> int: '''Add 2 integer numbers :param x: integer 1 :param y: integer 2 :returns: sum of x + y ''' return x + y ``` `add` will be a ClientTool that can be passed - Working example in: llamastack/llama-stack-apps#166 ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
1 parent b183fb6 commit ba68957

File tree

2 files changed

+110
-9
lines changed

2 files changed

+110
-9
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,10 @@ def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage:
7878
# custom client tools
7979
if tool_call.tool_name in self.client_tools:
8080
tool = self.client_tools[tool_call.tool_name]
81-
result_messages = tool.run([message])
82-
next_message = result_messages[0]
83-
return next_message
81+
# NOTE: tool.run() expects a list of messages, we only pass in last message here
82+
# but we could pass in the entire message history
83+
result_message = tool.run([message])
84+
return result_message
8485

8586
# builtin tools executed by tool_runtime
8687
if tool_call.tool_name in self.builtin_tools:

src/llama_stack_client/lib/agents/client_tool.py

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
import json
88
from abc import abstractmethod
9-
from typing import Dict, List, Union
9+
from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List
10+
import inspect
1011

11-
from llama_stack_client.types import ToolResponseMessage, UserMessage
12+
from llama_stack_client.types import Message, ToolResponseMessage
1213
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam
1314

1415

@@ -46,7 +47,7 @@ def parameters_for_system_prompt(self) -> str:
4647
{
4748
"name": self.get_name(),
4849
"description": self.get_description(),
49-
"parameters": {name: definition.__dict__ for name, definition in self.get_params_definition().items()},
50+
"parameters": {name: definition for name, definition in self.get_params_definition().items()},
5051
}
5152
)
5253

@@ -59,8 +60,107 @@ def get_tool_definition(self) -> ToolDefParam:
5960
tool_prompt_format="python_list",
6061
)
6162

62-
@abstractmethod
6363
def run(
64-
self, messages: List[Union[UserMessage, ToolResponseMessage]]
65-
) -> List[Union[UserMessage, ToolResponseMessage]]:
64+
self,
65+
message_history: List[Message],
66+
) -> ToolResponseMessage:
67+
# NOTE: we could override this method to use the entire message history for advanced tools
68+
last_message = message_history[-1]
69+
70+
assert len(last_message.tool_calls) == 1, "Expected single tool call"
71+
tool_call = last_message.tool_calls[0]
72+
73+
try:
74+
response = self.run_impl(**tool_call.arguments)
75+
response_str = json.dumps(response, ensure_ascii=False)
76+
except Exception as e:
77+
response_str = f"Error when running tool: {e}"
78+
79+
return ToolResponseMessage(
80+
call_id=tool_call.call_id,
81+
tool_name=tool_call.tool_name,
82+
content=response_str,
83+
role="tool",
84+
)
85+
86+
@abstractmethod
87+
def run_impl(self, **kwargs):
6688
raise NotImplementedError
89+
90+
91+
T = TypeVar("T", bound=Callable)
92+
93+
94+
def client_tool(func: T) -> ClientTool:
95+
"""
96+
Decorator to convert a function into a ClientTool.
97+
Usage:
98+
@client_tool
99+
def add(x: int, y: int) -> int:
100+
'''Add 2 integer numbers
101+
102+
:param x: integer 1
103+
:param y: integer 2
104+
:returns: sum of x + y
105+
'''
106+
return x + y
107+
108+
Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly.
109+
:returns: tags in the docstring is optional as it would not be used for the tool's description.
110+
"""
111+
112+
class _WrappedTool(ClientTool):
113+
__name__ = func.__name__
114+
__doc__ = func.__doc__
115+
__module__ = func.__module__
116+
117+
def get_name(self) -> str:
118+
return func.__name__
119+
120+
def get_description(self) -> str:
121+
doc = inspect.getdoc(func)
122+
if doc:
123+
# Get everything before the first :param
124+
return doc.split(":param")[0].strip()
125+
else:
126+
raise ValueError(
127+
f"No description found for client tool {__name__}. Please provide a RST-style docstring with description and :param tags for each parameter."
128+
)
129+
130+
def get_params_definition(self) -> Dict[str, Parameter]:
131+
hints = get_type_hints(func)
132+
# Remove return annotation if present
133+
hints.pop("return", None)
134+
135+
# Get parameter descriptions from docstring
136+
params = {}
137+
sig = inspect.signature(func)
138+
doc = inspect.getdoc(func) or ""
139+
140+
for name, type_hint in hints.items():
141+
# Look for :param name: in docstring
142+
param_doc = ""
143+
for line in doc.split("\n"):
144+
if line.strip().startswith(f":param {name}:"):
145+
param_doc = line.split(":", 2)[2].strip()
146+
break
147+
148+
if param_doc == "":
149+
raise ValueError(f"No parameter description found for parameter {name}")
150+
151+
param = sig.parameters[name]
152+
is_optional_type = get_origin(type_hint) is Union and type(None) in get_args(type_hint)
153+
is_required = param.default == inspect.Parameter.empty and not is_optional_type
154+
params[name] = Parameter(
155+
name=name,
156+
description=param_doc or f"Parameter {name}",
157+
parameter_type=type_hint.__name__,
158+
default=param.default,
159+
required=is_required,
160+
)
161+
return params
162+
163+
def run_impl(self, **kwargs):
164+
return func(**kwargs)
165+
166+
return _WrappedTool()

0 commit comments

Comments
 (0)