diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index bb6bb26c..0b44eca5 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -7,7 +7,7 @@ from llama_stack_client import LlamaStackClient -from llama_stack_client.types import ToolResponseMessage, UserMessage +from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agents.turn import CompletionMessage, Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup @@ -74,7 +74,7 @@ def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: return chunk.event.payload.turn.turn_id - def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: + def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam: assert len(tool_calls) == 1, "Only one tool call is supported" tool_call = tool_calls[0] @@ -101,20 +101,18 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: tool_name=tool_call.tool_name, kwargs=tool_call.arguments, ) - tool_response_message = ToolResponseMessage( + tool_response = ToolResponseParam( call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=tool_result.content, - role="tool", ) - return tool_response_message + return tool_response # cannot find tools - return ToolResponseMessage( + return ToolResponseParam( call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=f"Unknown tool `{tool_call.tool_name}` was called.", - role="tool", ) def create_turn( @@ -176,14 +174,14 @@ def _create_turn_streaming( yield chunk # run the tools - tool_response_message = self._run_tool(tool_calls) + tool_response = self._run_tool(tool_calls) # pass it to next iteration turn_response = self.client.agents.turn.resume( agent_id=self.agent_id, session_id=session_id or self.session_id[-1], turn_id=turn_id, - tool_responses=[tool_response_message], + tool_responses=[tool_response], stream=True, ) n_iter += 1 diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index a1066616..2b9a15b1 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -7,9 +7,9 @@ import inspect import json from abc import abstractmethod -from typing import Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union +from typing import Any, Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union -from llama_stack_client.types import Message, ToolResponseMessage +from llama_stack_client.types import Message, CompletionMessage, ToolResponse from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -63,28 +63,37 @@ def get_tool_definition(self) -> ToolDefParam: def run( self, message_history: List[Message], - ) -> ToolResponseMessage: + ) -> ToolResponse: # NOTE: we could override this method to use the entire message history for advanced tools last_message = message_history[-1] - + assert isinstance(last_message, CompletionMessage), "Expected CompletionMessage" assert len(last_message.tool_calls) == 1, "Expected single tool call" tool_call = last_message.tool_calls[0] + metadata = {} try: response = self.run_impl(**tool_call.arguments) - response_str = json.dumps(response, ensure_ascii=False) + if isinstance(response, dict) and "content" in response: + content = json.dumps(response["content"], ensure_ascii=False) + metadata = response.get("metadata", {}) + else: + content = json.dumps(response, ensure_ascii=False) except Exception as e: - response_str = f"Error when running tool: {e}" - - return ToolResponseMessage( + content = f"Error when running tool: {e}" + return ToolResponse( call_id=tool_call.call_id, tool_name=tool_call.tool_name, - content=response_str, - role="tool", + content=content, + metadata=metadata, ) @abstractmethod - def run_impl(self, **kwargs): + def run_impl(self, **kwargs) -> Any: + """ + Can return any json serializable object. + To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll + be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace. + """ raise NotImplementedError @@ -107,6 +116,10 @@ def add(x: int, y: int) -> int: Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly. :returns: tags in the docstring is optional as it would not be used for the tool's description. + + Your function can return any json serializable object. + To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll + be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace. """ class _WrappedTool(ClientTool): @@ -162,7 +175,7 @@ def get_params_definition(self) -> Dict[str, Parameter]: return params - def run_impl(self, **kwargs): + def run_impl(self, **kwargs) -> Any: return func(**kwargs) return _WrappedTool()