Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import uuid
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
from llama_stack_client.types.tool_response import ToolResponse
from llama_stack_client.types.tool_invocation_result import ToolInvocationResult

DEFAULT_MAX_ITER = 10

Expand Down Expand Up @@ -77,16 +78,14 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]

return message.tool_calls

def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
assert len(tool_calls) == 1, "Only one tool call is supported"
tool_call = tool_calls[0]

def _run_tool(self, tool_call: ToolCall) -> ToolInvocationResult:
tool_result = None
# custom client tools
if tool_call.tool_name in self.client_tools:
tool = self.client_tools[tool_call.tool_name]
# NOTE: tool.run() expects a list of messages, we only pass in last message here
# but we could pass in the entire message history
result_message = tool.run(
tool_result = tool.run(
[
CompletionMessage(
role="assistant",
Expand All @@ -96,29 +95,14 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
)
]
)
return result_message

# builtin tools executed by tool_runtime
if tool_call.tool_name in self.builtin_tools:
tool_result = self.client.tool_runtime.invoke_tool(
tool_name=tool_call.tool_name,
kwargs=tool_call.arguments,
)
tool_response_message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
role="tool",
)
return tool_response_message

# cannot find tools
return ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called.",
role="tool",
)
return tool_result

def create_turn(
self,
Expand Down Expand Up @@ -184,7 +168,30 @@ def _create_turn_streaming(
yield chunk
else:
tool_execution_start_time = datetime.now()
tool_response_message = self._run_tool(tool_calls)

assert len(tool_calls) == 1, "Only one tool call is supported"
tool_call = tool_calls[0]
try:
tool_result = self._run_tool(tool_call)
except Exception as e:
tool_result = ToolInvocationResult(
content=f"Error when running tool: {e}",
)
if tool_result:
tool_response_message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
role="tool",
)
else:
# cannot find tools
tool_response_message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called.",
role="tool",
)
tool_execution_step = ToolExecutionStep(
step_type="tool_execution",
step_id=str(uuid.uuid4()),
Expand All @@ -194,6 +201,7 @@ def _create_turn_streaming(
tool_name=tool_response_message.tool_name,
content=tool_response_message.content,
call_id=tool_response_message.call_id,
metadata=tool_result.metadata,
)
],
turn_id=chunk.event.payload.turn.turn_id,
Expand Down
40 changes: 19 additions & 21 deletions src/llama_stack_client/lib/agents/client_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import json
from abc import abstractmethod
from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List
from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List, Any
import inspect

from llama_stack_client.types import Message, ToolResponseMessage
from llama_stack_client.types import Message, CompletionMessage, ToolCall
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam
from llama_stack_client.types.tool_invocation_result import ToolInvocationResult


class ClientTool:
Expand Down Expand Up @@ -63,28 +64,17 @@ def get_tool_definition(self) -> ToolDefParam:
def run(
self,
message_history: List[Message],
) -> ToolResponseMessage:
) -> ToolInvocationResult:
# NOTE: we could override this method to use the entire message history for advanced tools
last_message = message_history[-1]
last_message: CompletionMessage = message_history[-1]

assert len(last_message.tool_calls) == 1, "Expected single tool call"
tool_call = last_message.tool_calls[0]

try:
response = self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"

return ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response_str,
role="tool",
)
tool_call: ToolCall = last_message.tool_calls[0]

return self.run_impl(**tool_call.arguments)

@abstractmethod
def run_impl(self, **kwargs):
def run_impl(self, **kwargs: Any) -> ToolInvocationResult:
raise NotImplementedError


Expand All @@ -107,6 +97,8 @@ def add(x: int, y: int) -> int:

Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly.
:returns: tags in the docstring is optional as it would not be used for the tool's description.

You can return any value, which will be serialized to json and fed into the model, or a ToolInvocationResult object/dict where you can include metadata about the output.
"""

class _WrappedTool(ClientTool):
Expand Down Expand Up @@ -160,7 +152,13 @@ def get_params_definition(self) -> Dict[str, Parameter]:
)
return params

def run_impl(self, **kwargs):
return func(**kwargs)
def run_impl(self, **kwargs: Any) -> ToolInvocationResult:
result = func(**kwargs)
try:
return ToolInvocationResult(**result)
except Exception:
return ToolInvocationResult(
content=json.dumps(result, ensure_ascii=False),
)

return _WrappedTool()
4 changes: 3 additions & 1 deletion src/llama_stack_client/types/tool_invocation_result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Optional
from typing import Optional, Dict, List, Union

from .._models import BaseModel
from .shared.interleaved_content import InterleavedContent
Expand All @@ -15,3 +15,5 @@ class ToolInvocationResult(BaseModel):
error_code: Optional[int] = None

error_message: Optional[str] = None

metadata: Optional[Dict[str, Union[bool, float, str, List[object], object, None]]] = None
4 changes: 3 additions & 1 deletion src/llama_stack_client/types/tool_response.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Union
from typing import Union, Optional, Dict, List
from typing_extensions import Literal

from .._models import BaseModel
Expand All @@ -16,3 +16,5 @@ class ToolResponse(BaseModel):
"""A image content item"""

tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]

metadata: Optional[Dict[str, Union[bool, float, str, List[object], object, None]]] = None