From 226df10bc19a4695d2d16bfbed04ccf3c2c88ac1 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 12:06:43 -0800 Subject: [PATCH 01/14] tool decorator working --- .../lib/agents/client_tool.py | 107 +++++++++++++++++- .../lib/agents/react/agent.py | 5 +- 2 files changed, 108 insertions(+), 4 deletions(-) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 3559793a..60f66452 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -6,7 +6,9 @@ import json from abc import abstractmethod -from typing import Dict, List, Union +from functools import wraps +from typing import Callable, Dict, Type, TypeVar, get_type_hints, List, Union, get_origin, get_args +import inspect from llama_stack_client.types import ToolResponseMessage, UserMessage from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -46,7 +48,7 @@ def parameters_for_system_prompt(self) -> str: { "name": self.get_name(), "description": self.get_description(), - "parameters": {name: definition.__dict__ for name, definition in self.get_params_definition().items()}, + "parameters": {name: definition for name, definition in self.get_params_definition().items()}, } ) @@ -64,3 +66,104 @@ def run( self, messages: List[Union[UserMessage, ToolResponseMessage]] ) -> List[Union[UserMessage, ToolResponseMessage]]: raise NotImplementedError + + +class SingleMessageClientTool(ClientTool): + """ + Helper class to handle custom tools that take a single message + Extending this class and implementing the `run_impl` method will + allow for the tool be called by the model and the necessary plumbing. + """ + + def run(self, messages: List[Union[UserMessage, ToolResponseMessage]]) -> List[ToolResponseMessage]: + assert len(messages) == 1, "Expected single message" + + message = messages[0] + + tool_call = message.tool_calls[0] + + try: + response = self.run_impl(**tool_call.arguments) + response_str = json.dumps(response, ensure_ascii=False) + except Exception as e: + response_str = f"Error when running tool: {e}" + + message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=response_str, + role="tool", + ) + return [message] + + @abstractmethod + def run_impl(self, *args, **kwargs): + raise NotImplementedError() + + +T = TypeVar('T', bound=Callable) + +def tool(func: T) -> ClientTool: + """ + Decorator to convert a function into a ClientTool. + Usage: + @tool + def add(x: int, y: int) -> int: + '''Add 2 integer numbers + + :param x: integer 1 + :param y: integer 2 + :returns: sum of x + y + ''' + return x + y + """ + class WrappedTool(SingleMessageClientTool): + __name__ = func.__name__ + __doc__ = func.__doc__ + __module__ = func.__module__ + + def get_name(self) -> str: + return func.__name__ + + def get_description(self) -> str: + doc = inspect.getdoc(func) + if doc: + # Get everything before the first :param + return doc.split(':param')[0].strip() + return "" + + def get_params_definition(self) -> Dict[str, Parameter]: + hints = get_type_hints(func) + # Remove return annotation if present + hints.pop('return', None) + + # Get parameter descriptions from docstring + params = {} + sig = inspect.signature(func) + doc = inspect.getdoc(func) or "" + + for name, type_hint in hints.items(): + # Look for :param name: in docstring + param_doc = "" + for line in doc.split('\n'): + if line.strip().startswith(f':param {name}:'): + param_doc = line.split(':', 2)[2].strip() + break + + param = sig.parameters[name] + is_optional_type = get_origin(type_hint) is Union and type(None) in get_args(type_hint) + is_required = param.default == inspect.Parameter.empty and not is_optional_type + params[name] = Parameter( + name=name, + description=param_doc or f"Parameter {name}", + parameter_type=type_hint.__name__, + default=param.default, + required=is_required + ) + return params + + + def run_impl(self, **kwargs): + return func(**kwargs) + + return WrappedTool() \ No newline at end of file diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index 3d40a08b..b3a40648 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -67,7 +67,7 @@ def get_tool_defs(): ] ) return tool_defs - + if custom_agent_config is None: tool_names, tool_descriptions = "", "" tool_defs = get_tool_defs() @@ -82,7 +82,8 @@ def get_tool_defs(): model=model, instructions=instruction, toolgroups=builtin_toolgroups, - client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], + client_tools=[], + # client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], tool_choice="auto", # TODO: refactor this to use SystemMessageBehaviour.replace tool_prompt_format="json", From a2d6d7004740ac16302e2e808a34ea27ec27d9ee Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 12:11:44 -0800 Subject: [PATCH 02/14] revert debugging --- src/llama_stack_client/lib/agents/react/agent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index b3a40648..046164b6 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -82,8 +82,7 @@ def get_tool_defs(): model=model, instructions=instruction, toolgroups=builtin_toolgroups, - client_tools=[], - # client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], + client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], tool_choice="auto", # TODO: refactor this to use SystemMessageBehaviour.replace tool_prompt_format="json", From 9a538d179eecd6fa78a892ec3aa631707121768e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 12:13:51 -0800 Subject: [PATCH 03/14] precommit --- .../lib/agents/client_tool.py | 28 +++++++++---------- .../lib/agents/react/agent.py | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 60f66452..089bc4bc 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -6,8 +6,7 @@ import json from abc import abstractmethod -from functools import wraps -from typing import Callable, Dict, Type, TypeVar, get_type_hints, List, Union, get_origin, get_args +from typing import Callable, Dict, TypeVar, get_type_hints, List, Union, get_origin, get_args import inspect from llama_stack_client.types import ToolResponseMessage, UserMessage @@ -101,7 +100,8 @@ def run_impl(self, *args, **kwargs): raise NotImplementedError() -T = TypeVar('T', bound=Callable) +T = TypeVar("T", bound=Callable) + def tool(func: T) -> ClientTool: """ @@ -110,13 +110,14 @@ def tool(func: T) -> ClientTool: @tool def add(x: int, y: int) -> int: '''Add 2 integer numbers - + :param x: integer 1 :param y: integer 2 :returns: sum of x + y ''' return x + y """ + class WrappedTool(SingleMessageClientTool): __name__ = func.__name__ __doc__ = func.__doc__ @@ -129,14 +130,14 @@ def get_description(self) -> str: doc = inspect.getdoc(func) if doc: # Get everything before the first :param - return doc.split(':param')[0].strip() + return doc.split(":param")[0].strip() return "" def get_params_definition(self) -> Dict[str, Parameter]: hints = get_type_hints(func) # Remove return annotation if present - hints.pop('return', None) - + hints.pop("return", None) + # Get parameter descriptions from docstring params = {} sig = inspect.signature(func) @@ -145,11 +146,11 @@ def get_params_definition(self) -> Dict[str, Parameter]: for name, type_hint in hints.items(): # Look for :param name: in docstring param_doc = "" - for line in doc.split('\n'): - if line.strip().startswith(f':param {name}:'): - param_doc = line.split(':', 2)[2].strip() + for line in doc.split("\n"): + if line.strip().startswith(f":param {name}:"): + param_doc = line.split(":", 2)[2].strip() break - + param = sig.parameters[name] is_optional_type = get_origin(type_hint) is Union and type(None) in get_args(type_hint) is_required = param.default == inspect.Parameter.empty and not is_optional_type @@ -158,12 +159,11 @@ def get_params_definition(self) -> Dict[str, Parameter]: description=param_doc or f"Parameter {name}", parameter_type=type_hint.__name__, default=param.default, - required=is_required + required=is_required, ) return params - def run_impl(self, **kwargs): return func(**kwargs) - return WrappedTool() \ No newline at end of file + return WrappedTool() diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index 046164b6..3d40a08b 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -67,7 +67,7 @@ def get_tool_defs(): ] ) return tool_defs - + if custom_agent_config is None: tool_names, tool_descriptions = "", "" tool_defs = get_tool_defs() From a32996cb50d015037a34a9c6da7eb59c01b59fcf Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 12:48:45 -0800 Subject: [PATCH 04/14] add validation --- src/llama_stack_client/lib/agents/client_tool.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 089bc4bc..277985aa 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -116,8 +116,12 @@ def add(x: int, y: int) -> int: :returns: sum of x + y ''' return x + y + + Note that you must use RST-style docstrings with :param, tags. + :returns: tags in the docstring is optional as it would not be used for the tool's description. """ + class WrappedTool(SingleMessageClientTool): __name__ = func.__name__ __doc__ = func.__doc__ @@ -150,6 +154,9 @@ def get_params_definition(self) -> Dict[str, Parameter]: if line.strip().startswith(f":param {name}:"): param_doc = line.split(":", 2)[2].strip() break + + if param_doc == "": + raise ValueError(f"No parameter description found for parameter {name}") param = sig.parameters[name] is_optional_type = get_origin(type_hint) is Union and type(None) in get_args(type_hint) From e9eb8867ce72a50b7fb9ae661437e7d0b5646e2c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 12:49:49 -0800 Subject: [PATCH 05/14] doc --- src/llama_stack_client/lib/agents/client_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 277985aa..3d771ed1 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -117,7 +117,7 @@ def add(x: int, y: int) -> int: ''' return x + y - Note that you must use RST-style docstrings with :param, tags. + Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly. :returns: tags in the docstring is optional as it would not be used for the tool's description. """ From b5ff0588c1dc82ecbb0a24f92139db57e631fd40 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 12:53:48 -0800 Subject: [PATCH 06/14] nit --- src/llama_stack_client/lib/agents/client_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 3d771ed1..edf26cbe 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -122,7 +122,7 @@ def add(x: int, y: int) -> int: """ - class WrappedTool(SingleMessageClientTool): + class _WrappedTool(SingleMessageClientTool): __name__ = func.__name__ __doc__ = func.__doc__ __module__ = func.__module__ @@ -173,4 +173,4 @@ def get_params_definition(self) -> Dict[str, Parameter]: def run_impl(self, **kwargs): return func(**kwargs) - return WrappedTool() + return _WrappedTool() From 0cfae80b5dcf21e2c45b2678ab3ef2b75a01e428 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 12:58:05 -0800 Subject: [PATCH 07/14] remove single message client tool --- src/llama_stack_client/lib/agents/agent.py | 5 ++- .../lib/agents/client_tool.py | 36 ++++++------------- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 56d8907b..520b53d3 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -79,9 +79,8 @@ def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] - result_messages = tool.run([message]) - next_message = result_messages[0] - return next_message + result_messages = tool.run(message) + return result_messages # builtin tools executed by tool_runtime if tool_call.tool_name in self.builtin_tools: diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index edf26cbe..46a6a2a9 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -9,7 +9,7 @@ from typing import Callable, Dict, TypeVar, get_type_hints, List, Union, get_origin, get_args import inspect -from llama_stack_client.types import ToolResponseMessage, UserMessage +from llama_stack_client.types import CompletionMessage, ToolResponseMessage from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -60,25 +60,10 @@ def get_tool_definition(self) -> ToolDefParam: tool_prompt_format="python_list", ) - @abstractmethod def run( - self, messages: List[Union[UserMessage, ToolResponseMessage]] - ) -> List[Union[UserMessage, ToolResponseMessage]]: - raise NotImplementedError - - -class SingleMessageClientTool(ClientTool): - """ - Helper class to handle custom tools that take a single message - Extending this class and implementing the `run_impl` method will - allow for the tool be called by the model and the necessary plumbing. - """ - - def run(self, messages: List[Union[UserMessage, ToolResponseMessage]]) -> List[ToolResponseMessage]: - assert len(messages) == 1, "Expected single message" - - message = messages[0] - + self, message: CompletionMessage, + ) -> ToolResponseMessage: + assert len(message.tool_calls) == 1, "Expected single tool call" tool_call = message.tool_calls[0] try: @@ -87,27 +72,26 @@ def run(self, messages: List[Union[UserMessage, ToolResponseMessage]]) -> List[T except Exception as e: response_str = f"Error when running tool: {e}" - message = ToolResponseMessage( + return ToolResponseMessage( call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=response_str, role="tool", ) - return [message] @abstractmethod - def run_impl(self, *args, **kwargs): - raise NotImplementedError() + def run_impl(self, **kwargs): + raise NotImplementedError T = TypeVar("T", bound=Callable) -def tool(func: T) -> ClientTool: +def client_tool(func: T) -> ClientTool: """ Decorator to convert a function into a ClientTool. Usage: - @tool + @client_tool def add(x: int, y: int) -> int: '''Add 2 integer numbers @@ -122,7 +106,7 @@ def add(x: int, y: int) -> int: """ - class _WrappedTool(SingleMessageClientTool): + class _WrappedTool(ClientTool): __name__ = func.__name__ __doc__ = func.__doc__ __module__ = func.__module__ From aae72bca7c230ca0581355d7a7e10ce48daec5e6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 13:00:26 -0800 Subject: [PATCH 08/14] pre --- src/llama_stack_client/lib/agents/client_tool.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 46a6a2a9..7e5fe4a7 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -6,7 +6,7 @@ import json from abc import abstractmethod -from typing import Callable, Dict, TypeVar, get_type_hints, List, Union, get_origin, get_args +from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args import inspect from llama_stack_client.types import CompletionMessage, ToolResponseMessage @@ -61,7 +61,8 @@ def get_tool_definition(self) -> ToolDefParam: ) def run( - self, message: CompletionMessage, + self, + message: CompletionMessage, ) -> ToolResponseMessage: assert len(message.tool_calls) == 1, "Expected single tool call" tool_call = message.tool_calls[0] @@ -100,12 +101,11 @@ def add(x: int, y: int) -> int: :returns: sum of x + y ''' return x + y - - Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly. + + Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly. :returns: tags in the docstring is optional as it would not be used for the tool's description. """ - class _WrappedTool(ClientTool): __name__ = func.__name__ __doc__ = func.__doc__ @@ -138,7 +138,7 @@ def get_params_definition(self) -> Dict[str, Parameter]: if line.strip().startswith(f":param {name}:"): param_doc = line.split(":", 2)[2].strip() break - + if param_doc == "": raise ValueError(f"No parameter description found for parameter {name}") From abd527a0abfa668fa62ac0ea5432007762a3c704 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 13:33:04 -0800 Subject: [PATCH 09/14] update tools to use message list --- src/llama_stack_client/lib/agents/agent.py | 4 +++- src/llama_stack_client/lib/agents/client_tool.py | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 520b53d3..c5e368d8 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -79,7 +79,9 @@ def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] - result_messages = tool.run(message) + # NOTE: tool.run() expects a list of messages, we only pass in last message here + # but we could pass in the entire message history + result_messages = tool.run([message]) return result_messages # builtin tools executed by tool_runtime diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 7e5fe4a7..50517085 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -6,7 +6,7 @@ import json from abc import abstractmethod -from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args +from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List import inspect from llama_stack_client.types import CompletionMessage, ToolResponseMessage @@ -62,10 +62,12 @@ def get_tool_definition(self) -> ToolDefParam: def run( self, - message: CompletionMessage, + message_history: List[CompletionMessage], ) -> ToolResponseMessage: - assert len(message.tool_calls) == 1, "Expected single tool call" - tool_call = message.tool_calls[0] + last_message = message_history[-1] + + assert len(last_message.tool_calls) == 1, "Expected single tool call" + tool_call = last_message.tool_calls[0] try: response = self.run_impl(**tool_call.arguments) From e8ac4c7236908f57337e9e24e06b70b9f0ffcd1e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 13:34:47 -0800 Subject: [PATCH 10/14] add comments --- src/llama_stack_client/lib/agents/client_tool.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 50517085..a4046e7f 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -64,6 +64,7 @@ def run( self, message_history: List[CompletionMessage], ) -> ToolResponseMessage: + # NOTE: we could override this method to use the entire message history for advanced tools last_message = message_history[-1] assert len(last_message.tool_calls) == 1, "Expected single tool call" From 50c6ff5da2437db585685fe7c1bcd26d27f49087 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 13:36:01 -0800 Subject: [PATCH 11/14] add comments --- src/llama_stack_client/lib/agents/client_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index a4046e7f..eeea15db 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -66,7 +66,7 @@ def run( ) -> ToolResponseMessage: # NOTE: we could override this method to use the entire message history for advanced tools last_message = message_history[-1] - + assert len(last_message.tool_calls) == 1, "Expected single tool call" tool_call = last_message.tool_calls[0] From 2adef1f600d4e965b06e92f82c1dc817ea019ecb Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 13:39:29 -0800 Subject: [PATCH 12/14] raise if no docs --- src/llama_stack_client/lib/agents/agent.py | 4 ++-- src/llama_stack_client/lib/agents/client_tool.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index c5e368d8..92f9ac82 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -81,8 +81,8 @@ def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: tool = self.client_tools[tool_call.tool_name] # NOTE: tool.run() expects a list of messages, we only pass in last message here # but we could pass in the entire message history - result_messages = tool.run([message]) - return result_messages + result_message = tool.run([message]) + return result_message # builtin tools executed by tool_runtime if tool_call.tool_name in self.builtin_tools: diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index eeea15db..a96731ec 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -122,7 +122,8 @@ def get_description(self) -> str: if doc: # Get everything before the first :param return doc.split(":param")[0].strip() - return "" + else: + raise ValueError(f"No description found for client tool {__name__}. Please provide a RST-style docstring with description and :param tags for each parameter.") def get_params_definition(self) -> Dict[str, Parameter]: hints = get_type_hints(func) From 99244f93e455ac7f469bb60f6a9a40c030406fc2 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 13:41:07 -0800 Subject: [PATCH 13/14] precommit --- src/llama_stack_client/lib/agents/client_tool.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index a96731ec..b0599e6f 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -123,7 +123,9 @@ def get_description(self) -> str: # Get everything before the first :param return doc.split(":param")[0].strip() else: - raise ValueError(f"No description found for client tool {__name__}. Please provide a RST-style docstring with description and :param tags for each parameter.") + raise ValueError( + f"No description found for client tool {__name__}. Please provide a RST-style docstring with description and :param tags for each parameter." + ) def get_params_definition(self) -> Dict[str, Parameter]: hints = get_type_hints(func) From 0f36d454c7ed03691aaf74827867611e274622fb Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Feb 2025 13:42:25 -0800 Subject: [PATCH 14/14] update types --- src/llama_stack_client/lib/agents/client_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index b0599e6f..4d5ca5fd 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -9,7 +9,7 @@ from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List import inspect -from llama_stack_client.types import CompletionMessage, ToolResponseMessage +from llama_stack_client.types import Message, ToolResponseMessage from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -62,7 +62,7 @@ def get_tool_definition(self) -> ToolDefParam: def run( self, - message_history: List[CompletionMessage], + message_history: List[Message], ) -> ToolResponseMessage: # NOTE: we could override this method to use the entire message history for advanced tools last_message = message_history[-1]