diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 56d8907b..92f9ac82 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -79,9 +79,10 @@ def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] - result_messages = tool.run([message]) - next_message = result_messages[0] - return next_message + # NOTE: tool.run() expects a list of messages, we only pass in last message here + # but we could pass in the entire message history + result_message = tool.run([message]) + return result_message # builtin tools executed by tool_runtime if tool_call.tool_name in self.builtin_tools: diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 3559793a..4d5ca5fd 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -6,9 +6,10 @@ import json from abc import abstractmethod -from typing import Dict, List, Union +from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List +import inspect -from llama_stack_client.types import ToolResponseMessage, UserMessage +from llama_stack_client.types import Message, ToolResponseMessage from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -46,7 +47,7 @@ def parameters_for_system_prompt(self) -> str: { "name": self.get_name(), "description": self.get_description(), - "parameters": {name: definition.__dict__ for name, definition in self.get_params_definition().items()}, + "parameters": {name: definition for name, definition in self.get_params_definition().items()}, } ) @@ -59,8 +60,107 @@ def get_tool_definition(self) -> ToolDefParam: tool_prompt_format="python_list", ) - @abstractmethod def run( - self, messages: List[Union[UserMessage, ToolResponseMessage]] - ) -> List[Union[UserMessage, ToolResponseMessage]]: + self, + message_history: List[Message], + ) -> ToolResponseMessage: + # NOTE: we could override this method to use the entire message history for advanced tools + last_message = message_history[-1] + + assert len(last_message.tool_calls) == 1, "Expected single tool call" + tool_call = last_message.tool_calls[0] + + try: + response = self.run_impl(**tool_call.arguments) + response_str = json.dumps(response, ensure_ascii=False) + except Exception as e: + response_str = f"Error when running tool: {e}" + + return ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=response_str, + role="tool", + ) + + @abstractmethod + def run_impl(self, **kwargs): raise NotImplementedError + + +T = TypeVar("T", bound=Callable) + + +def client_tool(func: T) -> ClientTool: + """ + Decorator to convert a function into a ClientTool. + Usage: + @client_tool + def add(x: int, y: int) -> int: + '''Add 2 integer numbers + + :param x: integer 1 + :param y: integer 2 + :returns: sum of x + y + ''' + return x + y + + Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly. + :returns: tags in the docstring is optional as it would not be used for the tool's description. + """ + + class _WrappedTool(ClientTool): + __name__ = func.__name__ + __doc__ = func.__doc__ + __module__ = func.__module__ + + def get_name(self) -> str: + return func.__name__ + + def get_description(self) -> str: + doc = inspect.getdoc(func) + if doc: + # Get everything before the first :param + return doc.split(":param")[0].strip() + else: + raise ValueError( + f"No description found for client tool {__name__}. Please provide a RST-style docstring with description and :param tags for each parameter." + ) + + def get_params_definition(self) -> Dict[str, Parameter]: + hints = get_type_hints(func) + # Remove return annotation if present + hints.pop("return", None) + + # Get parameter descriptions from docstring + params = {} + sig = inspect.signature(func) + doc = inspect.getdoc(func) or "" + + for name, type_hint in hints.items(): + # Look for :param name: in docstring + param_doc = "" + for line in doc.split("\n"): + if line.strip().startswith(f":param {name}:"): + param_doc = line.split(":", 2)[2].strip() + break + + if param_doc == "": + raise ValueError(f"No parameter description found for parameter {name}") + + param = sig.parameters[name] + is_optional_type = get_origin(type_hint) is Union and type(None) in get_args(type_hint) + is_required = param.default == inspect.Parameter.empty and not is_optional_type + params[name] = Parameter( + name=name, + description=param_doc or f"Parameter {name}", + parameter_type=type_hint.__name__, + default=param.default, + required=is_required, + ) + return params + + def run_impl(self, **kwargs): + return func(**kwargs) + + return _WrappedTool()