diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 19e16d3b..ea87c1df 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import logging -from typing import Iterator, List, Optional, Tuple, Union +from typing import AsyncIterator, Iterator, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient @@ -26,6 +26,91 @@ logger = logging.getLogger(__name__) +class AgentUtils: + @staticmethod + def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool]]]) -> List[ClientTool]: + if not tools: + return [] + + return [tool for tool in tools if isinstance(tool, ClientTool)] + + @staticmethod + def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[ToolParser] = None) -> List[ToolCall]: + if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: + return [] + + message = chunk.event.payload.turn.output_message + if message.stop_reason == "out_of_tokens": + return [] + + if tool_parser: + return tool_parser.get_tool_calls(message) + + return message.tool_calls + + @staticmethod + def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]: + if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: + return None + + return chunk.event.payload.turn.turn_id + + @staticmethod + def get_agent_config( + model: Optional[str] = None, + instructions: Optional[str] = None, + tools: Optional[List[Union[Toolgroup, ClientTool]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, + ) -> AgentConfig: + # Create a minimal valid AgentConfig with required fields + if model is None or instructions is None: + raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided") + + agent_config = { + "model": model, + "instructions": instructions, + "toolgroups": [], + "client_tools": [], + } + + # Add optional parameters if provided + if enable_session_persistence is not None: + agent_config["enable_session_persistence"] = enable_session_persistence + if max_infer_iters is not None: + agent_config["max_infer_iters"] = max_infer_iters + if input_shields is not None: + agent_config["input_shields"] = input_shields + if output_shields is not None: + agent_config["output_shields"] = output_shields + if response_format is not None: + agent_config["response_format"] = response_format + if sampling_params is not None: + agent_config["sampling_params"] = sampling_params + if tool_config is not None: + agent_config["tool_config"] = tool_config + if tools is not None: + toolgroups: List[Toolgroup] = [] + client_tools: List[ClientTool] = [] + + for tool in tools: + if isinstance(tool, str) or isinstance(tool, dict): + toolgroups.append(tool) + else: + client_tools.append(tool) + + agent_config["toolgroups"] = toolgroups + agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools] + + agent_config = AgentConfig(**agent_config) + return agent_config + + class Agent: def __init__( self, @@ -79,63 +164,35 @@ def __init__( # Construct agent_config from parameters if not provided if agent_config is None: - # Create a minimal valid AgentConfig with required fields - if model is None or instructions is None: - raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided") - - agent_config = { - "model": model, - "instructions": instructions, - "toolgroups": [], - "client_tools": [], - } - - # Add optional parameters if provided - if enable_session_persistence is not None: - agent_config["enable_session_persistence"] = enable_session_persistence - if max_infer_iters is not None: - agent_config["max_infer_iters"] = max_infer_iters - if input_shields is not None: - agent_config["input_shields"] = input_shields - if output_shields is not None: - agent_config["output_shields"] = output_shields - if response_format is not None: - agent_config["response_format"] = response_format - if sampling_params is not None: - agent_config["sampling_params"] = sampling_params - if tool_config is not None: - agent_config["tool_config"] = tool_config - if tools is not None: - toolgroups: List[Toolgroup] = [] - client_tools: List[ClientTool] = [] - - for tool in tools: - if isinstance(tool, str) or isinstance(tool, dict): - toolgroups.append(tool) - else: - client_tools.append(tool) - - agent_config["toolgroups"] = toolgroups - agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools] - - agent_config = AgentConfig(**agent_config) + agent_config = AgentUtils.get_agent_config( + model=model, + instructions=instructions, + tools=tools, + tool_config=tool_config, + sampling_params=sampling_params, + max_infer_iters=max_infer_iters, + input_shields=input_shields, + output_shields=output_shields, + response_format=response_format, + enable_session_persistence=enable_session_persistence, + ) + client_tools = AgentUtils.get_client_tools(tools) self.agent_config = agent_config - self.agent_id = self._create_agent(agent_config) self.client_tools = {t.get_name(): t for t in client_tools} self.sessions = [] self.tool_parser = tool_parser self.builtin_tools = {} - for tg in agent_config["toolgroups"]: - for tool in self.client.tools.list(toolgroup_id=tg): - self.builtin_tools[tool.identifier] = tool + self.initialize() - def _create_agent(self, agent_config: AgentConfig) -> int: + def initialize(self) -> None: agentic_system_create_response = self.client.agents.create( - agent_config=agent_config, + agent_config=self.agent_config, ) self.agent_id = agentic_system_create_response.agent_id - return self.agent_id + for tg in self.agent_config["toolgroups"]: + for tool in self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool def create_session(self, session_name: str) -> str: agentic_system_create_session_response = self.client.agents.session.create( @@ -146,25 +203,6 @@ def create_session(self, session_name: str) -> str: self.sessions.append(self.session_id) return self.session_id - def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: - if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: - return [] - - message = chunk.event.payload.turn.output_message - if message.stop_reason == "out_of_tokens": - return [] - - if self.tool_parser: - return self.tool_parser.get_tool_calls(message) - - return message.tool_calls - - def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: - if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: - return None - - return chunk.event.payload.turn.turn_id - def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam: assert len(tool_calls) == 1, "Only one tool call is supported" tool_call = tool_calls[0] @@ -250,7 +288,7 @@ def _create_turn_streaming( if hasattr(chunk, "error"): yield chunk return - tool_calls = self._get_tool_calls(chunk) + tool_calls = AgentUtils.get_tool_calls(chunk, self.tool_parser) if not tool_calls: yield chunk else: @@ -262,7 +300,7 @@ def _create_turn_streaming( yield chunk break - turn_id = self._get_turn_id(chunk) + turn_id = AgentUtils.get_turn_id(chunk) if n_iter == 0: yield chunk @@ -281,3 +319,225 @@ def _create_turn_streaming( if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER): raise Exception("Max inference iterations reached") + + +class AsyncAgent: + def __init__( + self, + client: LlamaStackClient, + # begin deprecated + agent_config: Optional[AgentConfig] = None, + client_tools: Tuple[ClientTool, ...] = (), + # end deprecated + tool_parser: Optional[ToolParser] = None, + model: Optional[str] = None, + instructions: Optional[str] = None, + tools: Optional[List[Union[Toolgroup, ClientTool]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, + ): + """Construct an Agent with the given parameters. + + :param client: The LlamaStackClient instance. + :param agent_config: The AgentConfig instance. + ::deprecated: use other parameters instead + :param client_tools: A tuple of ClientTool instances. + ::deprecated: use tools instead + :param tool_parser: Custom logic that parses tool calls from a message. + :param model: The model to use for the agent. + :param instructions: The instructions for the agent. + :param tools: A list of tools for the agent. Values can be one of the following: + - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} + - a python function decorated with @client_tool + - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" + - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent + - an instance of ClientTool: A client tool object. + :param tool_config: The tool configuration for the agent. + :param sampling_params: The sampling parameters for the agent. + :param max_infer_iters: The maximum number of inference iterations. + :param input_shields: The input shields for the agent. + :param output_shields: The output shields for the agent. + :param response_format: The response format for the agent. + :param enable_session_persistence: Whether to enable session persistence. + """ + self.client = client + + if agent_config is not None: + logger.warning("`agent_config` is deprecated. Use inlined parameters instead.") + if client_tools != (): + logger.warning("`client_tools` is deprecated. Use `tools` instead.") + + # Construct agent_config from parameters if not provided + if agent_config is None: + agent_config = AgentUtils.get_agent_config( + model=model, + instructions=instructions, + tools=tools, + tool_config=tool_config, + sampling_params=sampling_params, + max_infer_iters=max_infer_iters, + input_shields=input_shields, + output_shields=output_shields, + response_format=response_format, + enable_session_persistence=enable_session_persistence, + ) + client_tools = AgentUtils.get_client_tools(tools) + + self.agent_config = agent_config + self.client_tools = {t.get_name(): t for t in client_tools} + self.sessions = [] + self.tool_parser = tool_parser + self.builtin_tools = {} + self._agent_id = None + + if isinstance(client, LlamaStackClient): + raise ValueError("AsyncAgent must be initialized with an AsyncLlamaStackClient") + + @property + def agent_id(self) -> str: + if not self._agent_id: + raise RuntimeError("Agent ID not initialized. Call initialize() first.") + return self._agent_id + + async def initialize(self) -> None: + if self._agent_id: + return + + agentic_system_create_response = await self.client.agents.create( + agent_config=self.agent_config, + ) + self._agent_id = agentic_system_create_response.agent_id + for tg in self.agent_config["toolgroups"]: + for tool in await self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool + + async def create_session(self, session_name: str) -> str: + await self.initialize() + agentic_system_create_session_response = await self.client.agents.session.create( + agent_id=self.agent_id, + session_name=session_name, + ) + self.session_id = agentic_system_create_session_response.session_id + self.sessions.append(self.session_id) + return self.session_id + + async def create_turn( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + stream: bool = True, + ) -> AsyncIterator[AgentTurnResponseStreamChunk] | Turn: + if stream: + return self._create_turn_streaming(messages, session_id, toolgroups, documents) + else: + chunks = [x async for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)] + if not chunks: + raise Exception("Turn did not complete") + return chunks[-1].event.payload.turn + + async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: + assert len(tool_calls) == 1, "Only one tool call is supported" + tool_call = tool_calls[0] + + # custom client tools + if tool_call.tool_name in self.client_tools: + tool = self.client_tools[tool_call.tool_name] + result_message = await tool.async_run( + [ + CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + ] + ) + return result_message + + # builtin tools executed by tool_runtime + if tool_call.tool_name in self.builtin_tools: + tool_result = await self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs=tool_call.arguments, + ) + tool_response_message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + role="tool", + ) + return tool_response_message + + # cannot find tools + return ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called.", + role="tool", + ) + + async def _create_turn_streaming( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + ) -> AsyncIterator[AgentTurnResponseStreamChunk]: + n_iter = 0 + + # 1. create an agent turn + turn_response = await self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + ) + + # 2. process turn and resume if there's a tool call + is_turn_complete = False + while not is_turn_complete: + is_turn_complete = True + async for chunk in turn_response: + if hasattr(chunk, "error"): + yield chunk + return + + tool_calls = AgentUtils.get_tool_calls(chunk, self.tool_parser) + if not tool_calls: + yield chunk + else: + is_turn_complete = False + # End of turn is reached, do not resume even if there's a tool call + if not self.tool_parser and chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: + yield chunk + break + + turn_id = AgentUtils.get_turn_id(chunk) + if n_iter == 0: + yield chunk + + # run the tools + tool_response_message = await self._run_tool(tool_calls) + + # pass it to next iteration + turn_response = await self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=[tool_response_message], + stream=True, + ) + n_iter += 1 + + if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER): + raise Exception("Max inference iterations reached") diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 2b9a15b1..71da2b6b 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -9,7 +9,7 @@ from abc import abstractmethod from typing import Any, Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union -from llama_stack_client.types import Message, CompletionMessage, ToolResponse +from llama_stack_client.types import CompletionMessage, Message, ToolResponse from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -87,6 +87,32 @@ def run( metadata=metadata, ) + async def async_run( + self, + message_history: List[Message], + ) -> ToolResponse: + last_message = message_history[-1] + + assert len(last_message.tool_calls) == 1, "Expected single tool call" + tool_call = last_message.tool_calls[0] + metadata = {} + try: + response = await self.async_run_impl(**tool_call.arguments) + if isinstance(response, dict) and "content" in response: + content = json.dumps(response["content"], ensure_ascii=False) + metadata = response.get("metadata", {}) + else: + content = json.dumps(response, ensure_ascii=False) + except Exception as e: + content = f"Error when running tool: {e}" + + return ToolResponse( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=content, + metadata=metadata, + ) + @abstractmethod def run_impl(self, **kwargs) -> Any: """ @@ -96,6 +122,10 @@ def run_impl(self, **kwargs) -> Any: """ raise NotImplementedError + @abstractmethod + def async_run_impl(self, **kwargs): + raise NotImplementedError + T = TypeVar("T", bound=Callable) @@ -176,6 +206,14 @@ def get_params_definition(self) -> Dict[str, Parameter]: return params def run_impl(self, **kwargs) -> Any: + if inspect.iscoroutinefunction(func): + raise NotImplementedError("Tool is async but run_impl is not async") return func(**kwargs) + async def async_run_impl(self, **kwargs): + if inspect.iscoroutinefunction(func): + return await func(**kwargs) + else: + return func(**kwargs) + return _WrappedTool()