77import inspect
88import json
99from abc import abstractmethod
10- from typing import Callable , Dict , get_args , get_origin , get_type_hints , List , TypeVar , Union
10+ from typing import Any , Callable , Dict , get_args , get_origin , get_type_hints , List , TypeVar , Union
1111
12- from llama_stack_client .types import Message , ToolResponseMessage
12+ from llama_stack_client .types import Message , CompletionMessage , ToolResponse
1313from llama_stack_client .types .tool_def_param import Parameter , ToolDefParam
1414
1515
@@ -63,28 +63,37 @@ def get_tool_definition(self) -> ToolDefParam:
6363 def run (
6464 self ,
6565 message_history : List [Message ],
66- ) -> ToolResponseMessage :
66+ ) -> ToolResponse :
6767 # NOTE: we could override this method to use the entire message history for advanced tools
6868 last_message = message_history [- 1 ]
69-
69+ assert isinstance ( last_message , CompletionMessage ), "Expected CompletionMessage"
7070 assert len (last_message .tool_calls ) == 1 , "Expected single tool call"
7171 tool_call = last_message .tool_calls [0 ]
7272
73+ metadata = {}
7374 try :
7475 response = self .run_impl (** tool_call .arguments )
75- response_str = json .dumps (response , ensure_ascii = False )
76+ if isinstance (response , dict ) and "content" in response :
77+ content = json .dumps (response ["content" ], ensure_ascii = False )
78+ metadata = response .get ("metadata" , {})
79+ else :
80+ content = json .dumps (response , ensure_ascii = False )
7681 except Exception as e :
77- response_str = f"Error when running tool: { e } "
78-
79- return ToolResponseMessage (
82+ content = f"Error when running tool: { e } "
83+ return ToolResponse (
8084 call_id = tool_call .call_id ,
8185 tool_name = tool_call .tool_name ,
82- content = response_str ,
83- role = "tool" ,
86+ content = content ,
87+ metadata = metadata ,
8488 )
8589
8690 @abstractmethod
87- def run_impl (self , ** kwargs ):
91+ def run_impl (self , ** kwargs ) -> Any :
92+ """
93+ Can return any json serializable object.
94+ To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll
95+ be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace.
96+ """
8897 raise NotImplementedError
8998
9099
@@ -107,6 +116,10 @@ def add(x: int, y: int) -> int:
107116
108117 Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly.
109118 :returns: tags in the docstring is optional as it would not be used for the tool's description.
119+
120+ Your function can return any json serializable object.
121+ To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll
122+ be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace.
110123 """
111124
112125 class _WrappedTool (ClientTool ):
@@ -162,7 +175,7 @@ def get_params_definition(self) -> Dict[str, Parameter]:
162175
163176 return params
164177
165- def run_impl (self , ** kwargs ):
178+ def run_impl (self , ** kwargs ) -> Any :
166179 return func (** kwargs )
167180
168181 return _WrappedTool ()
0 commit comments