Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from llama_stack_client import LlamaStackClient

from llama_stack_client.types import ToolResponseMessage, UserMessage
from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
Expand Down Expand Up @@ -74,7 +74,7 @@ def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]:

return chunk.event.payload.turn.turn_id

def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
assert len(tool_calls) == 1, "Only one tool call is supported"
tool_call = tool_calls[0]

Expand All @@ -101,20 +101,18 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
tool_name=tool_call.tool_name,
kwargs=tool_call.arguments,
)
tool_response_message = ToolResponseMessage(
tool_response = ToolResponseParam(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
role="tool",
)
return tool_response_message
return tool_response

# cannot find tools
return ToolResponseMessage(
return ToolResponseParam(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called.",
role="tool",
)

def create_turn(
Expand Down Expand Up @@ -176,14 +174,14 @@ def _create_turn_streaming(
yield chunk

# run the tools
tool_response_message = self._run_tool(tool_calls)
tool_response = self._run_tool(tool_calls)

# pass it to next iteration
turn_response = self.client.agents.turn.resume(
agent_id=self.agent_id,
session_id=session_id or self.session_id[-1],
turn_id=turn_id,
tool_responses=[tool_response_message],
tool_responses=[tool_response],
stream=True,
)
n_iter += 1
37 changes: 25 additions & 12 deletions src/llama_stack_client/lib/agents/client_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import inspect
import json
from abc import abstractmethod
from typing import Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union
from typing import Any, Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union

from llama_stack_client.types import Message, ToolResponseMessage
from llama_stack_client.types import Message, CompletionMessage, ToolResponse
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam


Expand Down Expand Up @@ -63,28 +63,37 @@ def get_tool_definition(self) -> ToolDefParam:
def run(
self,
message_history: List[Message],
) -> ToolResponseMessage:
) -> ToolResponse:
# NOTE: we could override this method to use the entire message history for advanced tools
last_message = message_history[-1]

assert isinstance(last_message, CompletionMessage), "Expected CompletionMessage"
assert len(last_message.tool_calls) == 1, "Expected single tool call"
tool_call = last_message.tool_calls[0]

metadata = {}
try:
response = self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
if isinstance(response, dict) and "content" in response:
content = json.dumps(response["content"], ensure_ascii=False)
metadata = response.get("metadata", {})
else:
content = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"

return ToolResponseMessage(
content = f"Error when running tool: {e}"
return ToolResponse(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response_str,
role="tool",
content=content,
metadata=metadata,
)

@abstractmethod
def run_impl(self, **kwargs):
def run_impl(self, **kwargs) -> Any:
"""
Can return any json serializable object.
To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll
be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace.
"""
raise NotImplementedError


Expand All @@ -107,6 +116,10 @@ def add(x: int, y: int) -> int:

Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly.
:returns: tags in the docstring is optional as it would not be used for the tool's description.

Your function can return any json serializable object.
To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll
be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace.
"""

class _WrappedTool(ClientTool):
Expand Down Expand Up @@ -162,7 +175,7 @@ def get_params_definition(self) -> Dict[str, Parameter]:

return params

def run_impl(self, **kwargs):
def run_impl(self, **kwargs) -> Any:
return func(**kwargs)

return _WrappedTool()