diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 0a8ab226..5e537e01 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -25,6 +25,7 @@ import uuid from llama_stack_client.types.tool_execution_step import ToolExecutionStep from llama_stack_client.types.tool_response import ToolResponse +from llama_stack_client.types.tool_invocation_result import ToolInvocationResult DEFAULT_MAX_ITER = 10 @@ -77,16 +78,14 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall] return message.tool_calls - def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: - assert len(tool_calls) == 1, "Only one tool call is supported" - tool_call = tool_calls[0] - + def _run_tool(self, tool_call: ToolCall) -> ToolInvocationResult: + tool_result = None # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] # NOTE: tool.run() expects a list of messages, we only pass in last message here # but we could pass in the entire message history - result_message = tool.run( + tool_result = tool.run( [ CompletionMessage( role="assistant", @@ -96,7 +95,6 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: ) ] ) - return result_message # builtin tools executed by tool_runtime if tool_call.tool_name in self.builtin_tools: @@ -104,21 +102,7 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: tool_name=tool_call.tool_name, kwargs=tool_call.arguments, ) - tool_response_message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=tool_result.content, - role="tool", - ) - return tool_response_message - - # cannot find tools - return ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called.", - role="tool", - ) + return tool_result def create_turn( self, @@ -184,7 +168,30 @@ def _create_turn_streaming( yield chunk else: tool_execution_start_time = datetime.now() - tool_response_message = self._run_tool(tool_calls) + + assert len(tool_calls) == 1, "Only one tool call is supported" + tool_call = tool_calls[0] + try: + tool_result = self._run_tool(tool_call) + except Exception as e: + tool_result = ToolInvocationResult( + content=f"Error when running tool: {e}", + ) + if tool_result: + tool_response_message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + role="tool", + ) + else: + # cannot find tools + tool_response_message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called.", + role="tool", + ) tool_execution_step = ToolExecutionStep( step_type="tool_execution", step_id=str(uuid.uuid4()), @@ -194,6 +201,7 @@ def _create_turn_streaming( tool_name=tool_response_message.tool_name, content=tool_response_message.content, call_id=tool_response_message.call_id, + metadata=tool_result.metadata, ) ], turn_id=chunk.event.payload.turn.turn_id, diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 9ffec4a0..e4362e3d 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -6,11 +6,12 @@ import json from abc import abstractmethod -from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List +from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List, Any import inspect -from llama_stack_client.types import Message, ToolResponseMessage +from llama_stack_client.types import Message, CompletionMessage, ToolCall from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam +from llama_stack_client.types.tool_invocation_result import ToolInvocationResult class ClientTool: @@ -63,28 +64,17 @@ def get_tool_definition(self) -> ToolDefParam: def run( self, message_history: List[Message], - ) -> ToolResponseMessage: + ) -> ToolInvocationResult: # NOTE: we could override this method to use the entire message history for advanced tools - last_message = message_history[-1] + last_message: CompletionMessage = message_history[-1] assert len(last_message.tool_calls) == 1, "Expected single tool call" - tool_call = last_message.tool_calls[0] - - try: - response = self.run_impl(**tool_call.arguments) - response_str = json.dumps(response, ensure_ascii=False) - except Exception as e: - response_str = f"Error when running tool: {e}" - - return ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=response_str, - role="tool", - ) + tool_call: ToolCall = last_message.tool_calls[0] + + return self.run_impl(**tool_call.arguments) @abstractmethod - def run_impl(self, **kwargs): + def run_impl(self, **kwargs: Any) -> ToolInvocationResult: raise NotImplementedError @@ -107,6 +97,8 @@ def add(x: int, y: int) -> int: Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly. :returns: tags in the docstring is optional as it would not be used for the tool's description. + + You can return any value, which will be serialized to json and fed into the model, or a ToolInvocationResult object/dict where you can include metadata about the output. """ class _WrappedTool(ClientTool): @@ -160,7 +152,13 @@ def get_params_definition(self) -> Dict[str, Parameter]: ) return params - def run_impl(self, **kwargs): - return func(**kwargs) + def run_impl(self, **kwargs: Any) -> ToolInvocationResult: + result = func(**kwargs) + try: + return ToolInvocationResult(**result) + except Exception: + return ToolInvocationResult( + content=json.dumps(result, ensure_ascii=False), + ) return _WrappedTool() diff --git a/src/llama_stack_client/types/tool_invocation_result.py b/src/llama_stack_client/types/tool_invocation_result.py index 4ecc3d03..8cbaf1f2 100644 --- a/src/llama_stack_client/types/tool_invocation_result.py +++ b/src/llama_stack_client/types/tool_invocation_result.py @@ -1,6 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Optional +from typing import Optional, Dict, List, Union from .._models import BaseModel from .shared.interleaved_content import InterleavedContent @@ -15,3 +15,5 @@ class ToolInvocationResult(BaseModel): error_code: Optional[int] = None error_message: Optional[str] = None + + metadata: Optional[Dict[str, Union[bool, float, str, List[object], object, None]]] = None diff --git a/src/llama_stack_client/types/tool_response.py b/src/llama_stack_client/types/tool_response.py index 2617f6e3..b67f47dd 100644 --- a/src/llama_stack_client/types/tool_response.py +++ b/src/llama_stack_client/types/tool_response.py @@ -1,6 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Union +from typing import Union, Optional, Dict, List from typing_extensions import Literal from .._models import BaseModel @@ -16,3 +16,5 @@ class ToolResponse(BaseModel): """A image content item""" tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str] + + metadata: Optional[Dict[str, Union[bool, float, str, List[object], object, None]]] = None