From b9f7b5803a0813f6f9ac886008280d2332b7f1df Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 26 Feb 2025 17:38:30 -0800 Subject: [PATCH 01/19] async agent wip --- src/llama_stack_client/lib/agents/agent.py | 118 ++++++++++++++++----- 1 file changed, 93 insertions(+), 25 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 87badd46..432db1d5 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,17 +3,15 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Iterator, List, Optional, Tuple, Union +from typing import AsyncIterator, Iterator, List, Optional, Tuple, Union -from llama_stack_client import LlamaStackClient +from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient from llama_stack_client.types import ToolResponseMessage, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agents.turn import CompletionMessage, Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup -from llama_stack_client.types.agents.turn_create_response import ( - AgentTurnResponseStreamChunk, -) +from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk from llama_stack_client.types.shared.tool_call import ToolCall from .client_tool import ClientTool @@ -22,7 +20,28 @@ DEFAULT_MAX_ITER = 10 -class Agent: +class AgentMixin: + def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: + if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: + return [] + + message = chunk.event.payload.turn.output_message + if message.stop_reason == "out_of_tokens": + return [] + + if self.tool_parser: + return self.tool_parser.get_tool_calls(message) + + return message.tool_calls + + def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: + if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: + return None + + return chunk.event.payload.turn.turn_id + + +class Agent(AgentMixin): def __init__( self, client: LlamaStackClient, @@ -57,25 +76,6 @@ def create_session(self, session_name: str) -> str: self.sessions.append(self.session_id) return self.session_id - def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: - if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: - return [] - - message = chunk.event.payload.turn.output_message - if message.stop_reason == "out_of_tokens": - return [] - - if self.tool_parser: - return self.tool_parser.get_tool_calls(message) - - return message.tool_calls - - def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: - if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: - return None - - return chunk.event.payload.turn.turn_id - def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: assert len(tool_calls) == 1, "Only one tool call is supported" tool_call = tool_calls[0] @@ -189,3 +189,71 @@ def _create_turn_streaming( if n_iter >= max_iter: raise Exception(f"Turn did not complete in {max_iter} iterations") + + +class AsyncAgent: + def __init__( + self, + client: AsyncLlamaStackClient, + agent_config: AgentConfig, + client_tools: Tuple[ClientTool] = (), + tool_parser: Optional[ToolParser] = None, + ): + self.client = client + self.agent_config = agent_config + self.client_tools = {t.get_name(): t for t in client_tools} + self.sessions = [] + self.tool_parser = tool_parser + self.builtin_tools = {} + + async def initialize(self) -> None: + agentic_system_create_response = await self.client.agents.create( + agent_config=self.agent_config, + ) + self.agent_id = agentic_system_create_response.agent_id + for tg in self.agent_config["toolgroups"]: + for tool in await self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool + + async def create_session(self, session_name: str) -> str: + agentic_system_create_session_response = await self.client.agents.session.create( + agent_id=self.agent_id, + session_name=session_name, + ) + self.session_id = agentic_system_create_session_response.session_id + self.sessions.append(self.session_id) + return self.session_id + + async def create_turn( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + stream: bool = True, + ) -> AsyncIterator[AgentTurnResponseStreamChunk] | Turn: + if stream: + return self._create_turn_streaming(messages, session_id, toolgroups, documents) + else: + chunks = [x async for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)] + if not chunks: + raise Exception("Turn did not complete") + return chunks[-1].event.payload.turn + + async def _create_turn_streaming( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + ) -> AsyncIterator[AgentTurnResponseStreamChunk]: + turn_response = await self.client.agents.turn.create( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + ) + + async for chunk in turn_response: + yield chunk From f10b826a768301961a5c5d77a34c9022d114a2fa Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 26 Feb 2025 17:40:19 -0800 Subject: [PATCH 02/19] mixin --- src/llama_stack_client/lib/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 432db1d5..8f78bb59 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -191,7 +191,7 @@ def _create_turn_streaming( raise Exception(f"Turn did not complete in {max_iter} iterations") -class AsyncAgent: +class AsyncAgent(AgentMixin): def __init__( self, client: AsyncLlamaStackClient, From 42b664eeb11c6eaaaac7f8c53ded7da70c852ad9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 26 Feb 2025 17:52:19 -0800 Subject: [PATCH 03/19] tools --- src/llama_stack_client/lib/agents/agent.py | 79 +++++++++++++++++++++- 1 file changed, 77 insertions(+), 2 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 8f78bb59..ac6d61f7 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -240,6 +240,48 @@ async def create_turn( raise Exception("Turn did not complete") return chunks[-1].event.payload.turn + async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: + assert len(tool_calls) == 1, "Only one tool call is supported" + tool_call = tool_calls[0] + + # custom client tools + if tool_call.tool_name in self.client_tools: + tool = self.client_tools[tool_call.tool_name] + # TODO: make the client tool async + result_message = tool.run( + [ + CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + ] + ) + return result_message + + # builtin tools executed by tool_runtime + if tool_call.tool_name in self.builtin_tools: + tool_result = await self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs=tool_call.arguments, + ) + tool_response_message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + role="tool", + ) + return tool_response_message + + # cannot find tools + return ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called.", + role="tool", + ) + async def _create_turn_streaming( self, messages: List[Union[UserMessage, ToolResponseMessage]], @@ -247,6 +289,7 @@ async def _create_turn_streaming( toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, ) -> AsyncIterator[AgentTurnResponseStreamChunk]: + # 1. create an agent turn turn_response = await self.client.agents.turn.create( agent_id=self.agent_id, session_id=session_id or self.session_id[-1], @@ -255,5 +298,37 @@ async def _create_turn_streaming( documents=documents, ) - async for chunk in turn_response: - yield chunk + # 2. process turn and resume if there's a tool call + n_iter = 0 + max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) + is_turn_complete = False + while not is_turn_complete: + async for chunk in turn_response: + tool_calls = self._get_tool_calls(chunk) + if hasattr(chunk, "error"): + yield chunk + return + elif not tool_calls: + yield chunk + else: + is_turn_complete = False + turn_id = self._get_turn_id(chunk) + if n_iter == 0: + yield chunk + + # run the tools + tool_response_message = await self._run_tool(tool_calls) + # pass it to next iteration + turn_response = await self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=[tool_response_message], + stream=True, + ) + + n_iter += 1 + break + + if n_iter >= max_iter: + raise Exception(f"Turn did not complete in {max_iter} iterations") From 3eba2c5b45efa359729d92a584f7a6d5f760f57c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 26 Feb 2025 21:04:22 -0800 Subject: [PATCH 04/19] async tool --- src/llama_stack_client/lib/agents/agent.py | 4 +- .../lib/agents/client_tool.py | 44 ++++++++++++++----- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index ac6d61f7..595adfe4 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -247,8 +247,7 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] - # TODO: make the client tool async - result_message = tool.run( + result_message = await tool.async_run( [ CompletionMessage( role="assistant", @@ -303,6 +302,7 @@ async def _create_turn_streaming( max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) is_turn_complete = False while not is_turn_complete: + is_turn_complete = True async for chunk in turn_response: tool_calls = self._get_tool_calls(chunk) if hasattr(chunk, "error"): diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index f672268d..c1c48f54 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -7,16 +7,7 @@ import inspect import json from abc import abstractmethod -from typing import ( - Callable, - Dict, - get_args, - get_origin, - get_type_hints, - List, - TypeVar, - Union, -) +from typing import Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union from llama_stack_client.types import Message, ToolResponseMessage from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -92,10 +83,35 @@ def run( role="tool", ) + async def async_run( + self, + message_history: List[Message], + ) -> ToolResponseMessage: + last_message = message_history[-1] + + assert len(last_message.tool_calls) == 1, "Expected single tool call" + tool_call = last_message.tool_calls[0] + try: + response = await self.async_run_impl(**tool_call.arguments) + response_str = json.dumps(response, ensure_ascii=False) + except Exception as e: + response_str = f"Error when running tool: {e}" + + return ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=response_str, + role="tool", + ) + @abstractmethod def run_impl(self, **kwargs): raise NotImplementedError + @abstractmethod + def async_run_impl(self, **kwargs): + raise NotImplementedError + T = TypeVar("T", bound=Callable) @@ -171,6 +187,14 @@ def get_params_definition(self) -> Dict[str, Parameter]: return params def run_impl(self, **kwargs): + if inspect.iscoroutinefunction(func): + raise NotImplementedError("Tool is async but run_impl is not async") return func(**kwargs) + async def async_run_impl(self, **kwargs): + if inspect.iscoroutinefunction(func): + return await func(**kwargs) + else: + return func(**kwargs) + return _WrappedTool() From c5cf89e97f2dd09e10938bda3caa770d64ba54e0 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 27 Feb 2025 10:21:02 -0800 Subject: [PATCH 05/19] refactor async agent --- src/llama_stack_client/lib/agents/agent.py | 348 ++++++++++----------- 1 file changed, 163 insertions(+), 185 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 595adfe4..47c08273 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio from typing import AsyncIterator, Iterator, List, Optional, Tuple, Union from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient @@ -41,83 +42,24 @@ def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: return chunk.event.payload.turn.turn_id -class Agent(AgentMixin): +class Agent: def __init__( self, - client: LlamaStackClient, + client: Union[AsyncLlamaStackClient, LlamaStackClient], agent_config: AgentConfig, client_tools: Tuple[ClientTool] = (), tool_parser: Optional[ToolParser] = None, ): - self.client = client - self.agent_config = agent_config - self.agent_id = self._create_agent(agent_config) - self.client_tools = {t.get_name(): t for t in client_tools} - self.sessions = [] - self.tool_parser = tool_parser - self.builtin_tools = {} - for tg in agent_config["toolgroups"]: - for tool in self.client.tools.list(toolgroup_id=tg): - self.builtin_tools[tool.identifier] = tool - - def _create_agent(self, agent_config: AgentConfig) -> int: - agentic_system_create_response = self.client.agents.create( - agent_config=agent_config, - ) - self.agent_id = agentic_system_create_response.agent_id - return self.agent_id + self.async_agent = AsyncAgent(client, agent_config, client_tools, tool_parser) + asyncio.run(self.async_agent.initialize()) + self.sessions = self.async_agent.sessions + self.client_tools = self.async_agent.client_tools + self.tool_parser = self.async_agent.tool_parser + self.builtin_tools = self.async_agent.builtin_tools + self.agent_id = self.async_agent.agent_id def create_session(self, session_name: str) -> str: - agentic_system_create_session_response = self.client.agents.session.create( - agent_id=self.agent_id, - session_name=session_name, - ) - self.session_id = agentic_system_create_session_response.session_id - self.sessions.append(self.session_id) - return self.session_id - - def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: - assert len(tool_calls) == 1, "Only one tool call is supported" - tool_call = tool_calls[0] - - # custom client tools - if tool_call.tool_name in self.client_tools: - tool = self.client_tools[tool_call.tool_name] - # NOTE: tool.run() expects a list of messages, we only pass in last message here - # but we could pass in the entire message history - result_message = tool.run( - [ - CompletionMessage( - role="assistant", - content=tool_call.tool_name, - tool_calls=[tool_call], - stop_reason="end_of_turn", - ) - ] - ) - return result_message - - # builtin tools executed by tool_runtime - if tool_call.tool_name in self.builtin_tools: - tool_result = self.client.tool_runtime.invoke_tool( - tool_name=tool_call.tool_name, - kwargs=tool_call.arguments, - ) - tool_response_message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=tool_result.content, - role="tool", - ) - return tool_response_message - - # cannot find tools - return ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called.", - role="tool", - ) + return asyncio.run(self.async_agent.create_session(session_name)) def create_turn( self, @@ -127,74 +69,33 @@ def create_turn( documents: Optional[List[Document]] = None, stream: bool = True, ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: - if stream: - return self._create_turn_streaming(messages, session_id, toolgroups, documents) - else: - chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)] - if not chunks: - raise Exception("Turn did not complete") - return chunks[-1].event.payload.turn - - def _create_turn_streaming( - self, - messages: List[Union[UserMessage, ToolResponseMessage]], - session_id: Optional[str] = None, - toolgroups: Optional[List[Toolgroup]] = None, - documents: Optional[List[Document]] = None, - ) -> Iterator[AgentTurnResponseStreamChunk]: - n_iter = 0 - max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) - - # 1. create an agent turn - turn_response = self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - allow_turn_resume=True, - ) - # 2. process turn and resume if there's a tool call - is_turn_complete = False - while not is_turn_complete: - is_turn_complete = True - for chunk in turn_response: - tool_calls = self._get_tool_calls(chunk) - if hasattr(chunk, "error"): - yield chunk - return - elif not tool_calls: - yield chunk - else: - is_turn_complete = False - turn_id = self._get_turn_id(chunk) - if n_iter == 0: - yield chunk + if stream: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - # run the tools - tool_response_message = self._run_tool(tool_calls) - # pass it to next iteration - turn_response = self.client.agents.turn.resume( - agent_id=self.agent_id, - session_id=session_id or self.session_id[-1], - turn_id=turn_id, - tool_responses=[tool_response_message], - stream=True, + def sync_generator(): + try: + async_stream = loop.run_until_complete( + self.async_agent.create_turn(messages, session_id, toolgroups, documents, stream) ) - n_iter += 1 - break + while True: + chunk = loop.run_until_complete(async_stream.__anext__()) + yield chunk + except StopAsyncIteration: + pass + finally: + loop.close() - if n_iter >= max_iter: - raise Exception(f"Turn did not complete in {max_iter} iterations") + return sync_generator() + else: + return asyncio.run(self.async_agent.create_turn(messages, session_id, toolgroups, documents, stream)) class AsyncAgent(AgentMixin): def __init__( self, - client: AsyncLlamaStackClient, + client: Union[AsyncLlamaStackClient, LlamaStackClient], agent_config: AgentConfig, client_tools: Tuple[ClientTool] = (), tool_parser: Optional[ToolParser] = None, @@ -206,20 +107,41 @@ def __init__( self.tool_parser = tool_parser self.builtin_tools = {} + self.is_async = True + if isinstance(client, LlamaStackClient): + self.is_async = False + async def initialize(self) -> None: - agentic_system_create_response = await self.client.agents.create( - agent_config=self.agent_config, - ) + if self.is_async: + agentic_system_create_response = await self.client.agents.create( + agent_config=self.agent_config, + ) + else: + agentic_system_create_response = self.client.agents.create( + agent_config=self.agent_config, + ) + self.agent_id = agentic_system_create_response.agent_id for tg in self.agent_config["toolgroups"]: - for tool in await self.client.tools.list(toolgroup_id=tg): - self.builtin_tools[tool.identifier] = tool + if self.is_async: + for tool in await self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool + else: + for tool in self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool async def create_session(self, session_name: str) -> str: - agentic_system_create_session_response = await self.client.agents.session.create( - agent_id=self.agent_id, - session_name=session_name, - ) + if self.is_async: + agentic_system_create_session_response = await self.client.agents.session.create( + agent_id=self.agent_id, + session_name=session_name, + ) + else: + agentic_system_create_session_response = self.client.agents.session.create( + agent_id=self.agent_id, + session_name=session_name, + ) + self.session_id = agentic_system_create_session_response.session_id self.sessions.append(self.session_id) return self.session_id @@ -247,24 +169,35 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] - result_message = await tool.async_run( - [ - CompletionMessage( - role="assistant", - content=tool_call.tool_name, - tool_calls=[tool_call], - stop_reason="end_of_turn", - ) - ] - ) + # NOTE: tool.run() expects a list of messages, we only pass in last message here + # but we could pass in the entire message history + messages = [ + CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + ] + if self.is_async: + result_message = await tool.async_run(messages) + else: + result_message = tool.run(messages) + return result_message # builtin tools executed by tool_runtime if tool_call.tool_name in self.builtin_tools: - tool_result = await self.client.tool_runtime.invoke_tool( - tool_name=tool_call.tool_name, - kwargs=tool_call.arguments, - ) + if self.is_async: + tool_result = await self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs=tool_call.arguments, + ) + else: + tool_result = self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs=tool_call.arguments, + ) tool_response_message = ToolResponseMessage( call_id=tool_call.call_id, tool_name=tool_call.tool_name, @@ -288,47 +221,92 @@ async def _create_turn_streaming( toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, ) -> AsyncIterator[AgentTurnResponseStreamChunk]: + n_iter = 0 + max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) + # 1. create an agent turn - turn_response = await self.client.agents.turn.create( - agent_id=self.agent_id, - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - ) + if self.is_async: + turn_response = await self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + allow_turn_resume=True, + ) + else: + turn_response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + allow_turn_resume=True, + ) # 2. process turn and resume if there's a tool call - n_iter = 0 - max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) is_turn_complete = False while not is_turn_complete: is_turn_complete = True - async for chunk in turn_response: - tool_calls = self._get_tool_calls(chunk) - if hasattr(chunk, "error"): - yield chunk - return - elif not tool_calls: - yield chunk - else: - is_turn_complete = False - turn_id = self._get_turn_id(chunk) - if n_iter == 0: - yield chunk - # run the tools - tool_response_message = await self._run_tool(tool_calls) - # pass it to next iteration - turn_response = await self.client.agents.turn.resume( - agent_id=self.agent_id, - session_id=session_id or self.session_id[-1], - turn_id=turn_id, - tool_responses=[tool_response_message], - stream=True, - ) - - n_iter += 1 - break + if self.is_async: + async for chunk in turn_response: + tool_calls = self._get_tool_calls(chunk) + if hasattr(chunk, "error"): + yield chunk + return + elif not tool_calls: + yield chunk + else: + is_turn_complete = False + turn_id = self._get_turn_id(chunk) + if n_iter == 0: + yield chunk + + # run the tools + tool_response_message = await self._run_tool(tool_calls) + + # pass it to next iteration + turn_response = await self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=[tool_response_message], + stream=True, + ) + + n_iter += 1 + else: + for chunk in turn_response: + tool_calls = self._get_tool_calls(chunk) + if hasattr(chunk, "error"): + yield chunk + return + elif not tool_calls: + yield chunk + else: + is_turn_complete = False + turn_id = self._get_turn_id(chunk) + if n_iter == 0: + yield chunk + + # run the tools + tool_response_message = await self._run_tool(tool_calls) + + # pass it to next iteration + turn_response = self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=[tool_response_message], + stream=True, + ) + + n_iter += 1 if n_iter >= max_iter: raise Exception(f"Turn did not complete in {max_iter} iterations") From b3ca53b67eed0cedec7769d745811c8913b94505 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 27 Feb 2025 10:27:29 -0800 Subject: [PATCH 06/19] refactor async agent --- src/llama_stack_client/lib/agents/agent.py | 43 +++++++++++----------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 47c08273..dbaea2e3 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -21,27 +21,6 @@ DEFAULT_MAX_ITER = 10 -class AgentMixin: - def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: - if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: - return [] - - message = chunk.event.payload.turn.output_message - if message.stop_reason == "out_of_tokens": - return [] - - if self.tool_parser: - return self.tool_parser.get_tool_calls(message) - - return message.tool_calls - - def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: - if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: - return None - - return chunk.event.payload.turn.turn_id - - class Agent: def __init__( self, @@ -92,7 +71,7 @@ def sync_generator(): return asyncio.run(self.async_agent.create_turn(messages, session_id, toolgroups, documents, stream)) -class AsyncAgent(AgentMixin): +class AsyncAgent: def __init__( self, client: Union[AsyncLlamaStackClient, LlamaStackClient], @@ -108,6 +87,7 @@ def __init__( self.builtin_tools = {} self.is_async = True + if isinstance(client, LlamaStackClient): self.is_async = False @@ -310,3 +290,22 @@ async def _create_turn_streaming( if n_iter >= max_iter: raise Exception(f"Turn did not complete in {max_iter} iterations") + + def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: + if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: + return [] + + message = chunk.event.payload.turn.output_message + if message.stop_reason == "out_of_tokens": + return [] + + if self.tool_parser: + return self.tool_parser.get_tool_calls(message) + + return message.tool_calls + + def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: + if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: + return None + + return chunk.event.payload.turn.turn_id From 152c1430474ba7ca4a68a744fbd1f407bb434de3 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 27 Feb 2025 10:29:49 -0800 Subject: [PATCH 07/19] pre --- src/llama_stack_client/lib/agents/agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index dbaea2e3..6fcac1f5 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -48,7 +48,6 @@ def create_turn( documents: Optional[List[Document]] = None, stream: bool = True, ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: - if stream: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) From c3fe4b70dd373ecb2eb2b1846c33a81845d8a926 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 27 Feb 2025 11:42:54 -0800 Subject: [PATCH 08/19] revert async_agent wrapper --- src/llama_stack_client/lib/agents/agent.py | 384 +++++++++++---------- 1 file changed, 206 insertions(+), 178 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 6fcac1f5..aa7e6f34 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio from typing import AsyncIterator, Iterator, List, Optional, Tuple, Union from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient @@ -21,24 +20,104 @@ DEFAULT_MAX_ITER = 10 -class Agent: +class AgentMixin: + def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: + if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: + return [] + + message = chunk.event.payload.turn.output_message + if message.stop_reason == "out_of_tokens": + return [] + + if self.tool_parser: + return self.tool_parser.get_tool_calls(message) + + return message.tool_calls + + def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: + if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: + return None + + return chunk.event.payload.turn.turn_id + + +class Agent(AgentMixin): def __init__( self, - client: Union[AsyncLlamaStackClient, LlamaStackClient], + client: LlamaStackClient, agent_config: AgentConfig, client_tools: Tuple[ClientTool] = (), tool_parser: Optional[ToolParser] = None, ): - self.async_agent = AsyncAgent(client, agent_config, client_tools, tool_parser) - asyncio.run(self.async_agent.initialize()) - self.sessions = self.async_agent.sessions - self.client_tools = self.async_agent.client_tools - self.tool_parser = self.async_agent.tool_parser - self.builtin_tools = self.async_agent.builtin_tools - self.agent_id = self.async_agent.agent_id + self.client = client + self.agent_config = agent_config + self.agent_id = self._create_agent(agent_config) + self.client_tools = {t.get_name(): t for t in client_tools} + self.sessions = [] + self.tool_parser = tool_parser + self.builtin_tools = {} + for tg in agent_config["toolgroups"]: + for tool in self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool + + def _create_agent(self, agent_config: AgentConfig) -> int: + agentic_system_create_response = self.client.agents.create( + agent_config=agent_config, + ) + self.agent_id = agentic_system_create_response.agent_id + return self.agent_id def create_session(self, session_name: str) -> str: - return asyncio.run(self.async_agent.create_session(session_name)) + agentic_system_create_session_response = self.client.agents.session.create( + agent_id=self.agent_id, + session_name=session_name, + ) + self.session_id = agentic_system_create_session_response.session_id + self.sessions.append(self.session_id) + return self.session_id + + def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: + assert len(tool_calls) == 1, "Only one tool call is supported" + tool_call = tool_calls[0] + + # custom client tools + if tool_call.tool_name in self.client_tools: + tool = self.client_tools[tool_call.tool_name] + # NOTE: tool.run() expects a list of messages, we only pass in last message here + # but we could pass in the entire message history + result_message = tool.run( + [ + CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + ] + ) + return result_message + + # builtin tools executed by tool_runtime + if tool_call.tool_name in self.builtin_tools: + tool_result = self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs=tool_call.arguments, + ) + tool_response_message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + role="tool", + ) + return tool_response_message + + # cannot find tools + return ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called.", + role="tool", + ) def create_turn( self, @@ -49,31 +128,73 @@ def create_turn( stream: bool = True, ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: if stream: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + return self._create_turn_streaming(messages, session_id, toolgroups, documents) + else: + chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)] + if not chunks: + raise Exception("Turn did not complete") + return chunks[-1].event.payload.turn - def sync_generator(): - try: - async_stream = loop.run_until_complete( - self.async_agent.create_turn(messages, session_id, toolgroups, documents, stream) - ) - while True: - chunk = loop.run_until_complete(async_stream.__anext__()) + def _create_turn_streaming( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + ) -> Iterator[AgentTurnResponseStreamChunk]: + n_iter = 0 + max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) + + # 1. create an agent turn + turn_response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + allow_turn_resume=True, + ) + + # 2. process turn and resume if there's a tool call + is_turn_complete = False + while not is_turn_complete: + is_turn_complete = True + for chunk in turn_response: + tool_calls = self._get_tool_calls(chunk) + if hasattr(chunk, "error"): + yield chunk + return + elif not tool_calls: + yield chunk + else: + is_turn_complete = False + turn_id = self._get_turn_id(chunk) + if n_iter == 0: yield chunk - except StopAsyncIteration: - pass - finally: - loop.close() - return sync_generator() - else: - return asyncio.run(self.async_agent.create_turn(messages, session_id, toolgroups, documents, stream)) + # run the tools + tool_response_message = self._run_tool(tool_calls) + # pass it to next iteration + turn_response = self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=[tool_response_message], + stream=True, + ) + n_iter += 1 + break + + if n_iter >= max_iter: + raise Exception(f"Turn did not complete in {max_iter} iterations") -class AsyncAgent: +class AsyncAgent(AgentMixin): def __init__( self, - client: Union[AsyncLlamaStackClient, LlamaStackClient], + client: AsyncLlamaStackClient, agent_config: AgentConfig, client_tools: Tuple[ClientTool] = (), tool_parser: Optional[ToolParser] = None, @@ -85,42 +206,23 @@ def __init__( self.tool_parser = tool_parser self.builtin_tools = {} - self.is_async = True - if isinstance(client, LlamaStackClient): - self.is_async = False + raise ValueError("AsyncAgent must be initialized with an AsyncLlamaStackClient") async def initialize(self) -> None: - if self.is_async: - agentic_system_create_response = await self.client.agents.create( - agent_config=self.agent_config, - ) - else: - agentic_system_create_response = self.client.agents.create( - agent_config=self.agent_config, - ) - + agentic_system_create_response = await self.client.agents.create( + agent_config=self.agent_config, + ) self.agent_id = agentic_system_create_response.agent_id for tg in self.agent_config["toolgroups"]: - if self.is_async: - for tool in await self.client.tools.list(toolgroup_id=tg): - self.builtin_tools[tool.identifier] = tool - else: - for tool in self.client.tools.list(toolgroup_id=tg): - self.builtin_tools[tool.identifier] = tool + for tool in await self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool async def create_session(self, session_name: str) -> str: - if self.is_async: - agentic_system_create_session_response = await self.client.agents.session.create( - agent_id=self.agent_id, - session_name=session_name, - ) - else: - agentic_system_create_session_response = self.client.agents.session.create( - agent_id=self.agent_id, - session_name=session_name, - ) - + agentic_system_create_session_response = await self.client.agents.session.create( + agent_id=self.agent_id, + session_name=session_name, + ) self.session_id = agentic_system_create_session_response.session_id self.sessions.append(self.session_id) return self.session_id @@ -148,35 +250,24 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: # custom client tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] - # NOTE: tool.run() expects a list of messages, we only pass in last message here - # but we could pass in the entire message history - messages = [ - CompletionMessage( - role="assistant", - content=tool_call.tool_name, - tool_calls=[tool_call], - stop_reason="end_of_turn", - ) - ] - if self.is_async: - result_message = await tool.async_run(messages) - else: - result_message = tool.run(messages) - + result_message = await tool.async_run( + [ + CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + ] + ) return result_message # builtin tools executed by tool_runtime if tool_call.tool_name in self.builtin_tools: - if self.is_async: - tool_result = await self.client.tool_runtime.invoke_tool( - tool_name=tool_call.tool_name, - kwargs=tool_call.arguments, - ) - else: - tool_result = self.client.tool_runtime.invoke_tool( - tool_name=tool_call.tool_name, - kwargs=tool_call.arguments, - ) + tool_result = await self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs=tool_call.arguments, + ) tool_response_message = ToolResponseMessage( call_id=tool_call.call_id, tool_name=tool_call.tool_name, @@ -204,107 +295,44 @@ async def _create_turn_streaming( max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) # 1. create an agent turn - if self.is_async: - turn_response = await self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - allow_turn_resume=True, - ) - else: - turn_response = self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - allow_turn_resume=True, - ) + turn_response = await self.client.agents.turn.create( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + ) # 2. process turn and resume if there's a tool call is_turn_complete = False while not is_turn_complete: is_turn_complete = True - - if self.is_async: - async for chunk in turn_response: - tool_calls = self._get_tool_calls(chunk) - if hasattr(chunk, "error"): + async for chunk in turn_response: + tool_calls = self._get_tool_calls(chunk) + if hasattr(chunk, "error"): + yield chunk + return + elif not tool_calls: + yield chunk + else: + is_turn_complete = False + turn_id = self._get_turn_id(chunk) + if n_iter == 0: yield chunk - return - elif not tool_calls: - yield chunk - else: - is_turn_complete = False - turn_id = self._get_turn_id(chunk) - if n_iter == 0: - yield chunk - - # run the tools - tool_response_message = await self._run_tool(tool_calls) - - # pass it to next iteration - turn_response = await self.client.agents.turn.resume( - agent_id=self.agent_id, - session_id=session_id or self.session_id[-1], - turn_id=turn_id, - tool_responses=[tool_response_message], - stream=True, - ) - - n_iter += 1 - else: - for chunk in turn_response: - tool_calls = self._get_tool_calls(chunk) - if hasattr(chunk, "error"): - yield chunk - return - elif not tool_calls: - yield chunk - else: - is_turn_complete = False - turn_id = self._get_turn_id(chunk) - if n_iter == 0: - yield chunk - - # run the tools - tool_response_message = await self._run_tool(tool_calls) - - # pass it to next iteration - turn_response = self.client.agents.turn.resume( - agent_id=self.agent_id, - session_id=session_id or self.session_id[-1], - turn_id=turn_id, - tool_responses=[tool_response_message], - stream=True, - ) - - n_iter += 1 - - if n_iter >= max_iter: - raise Exception(f"Turn did not complete in {max_iter} iterations") - - def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: - if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: - return [] - - message = chunk.event.payload.turn.output_message - if message.stop_reason == "out_of_tokens": - return [] - - if self.tool_parser: - return self.tool_parser.get_tool_calls(message) - return message.tool_calls + # run the tools + tool_response_message = await self._run_tool(tool_calls) + # pass it to next iteration + turn_response = await self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=[tool_response_message], + stream=True, + ) - def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: - if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: - return None + n_iter += 1 + break - return chunk.event.payload.turn.turn_id + if n_iter >= max_iter: + raise Exception(f"Turn did not complete in {max_iter} iterations") From 9c762ec1d37acae9544a0204700d805f1e160fa7 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 3 Mar 2025 10:18:55 -0800 Subject: [PATCH 09/19] rebase --- src/llama_stack_client/lib/agents/agent.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 8e405eb1..21b86ed3 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -293,15 +293,17 @@ async def _create_turn_streaming( documents: Optional[List[Document]] = None, ) -> AsyncIterator[AgentTurnResponseStreamChunk]: n_iter = 0 - max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) # 1. create an agent turn turn_response = await self.client.agents.turn.create( agent_id=self.agent_id, + # use specified session_id or last session created session_id=session_id or self.session_id[-1], messages=messages, stream=True, documents=documents, + toolgroups=toolgroups, + allow_turn_resume=True, ) # 2. process turn and resume if there's a tool call @@ -309,20 +311,27 @@ async def _create_turn_streaming( while not is_turn_complete: is_turn_complete = True async for chunk in turn_response: - tool_calls = self._get_tool_calls(chunk) if hasattr(chunk, "error"): yield chunk return - elif not tool_calls: + + tool_calls = self._get_tool_calls(chunk) + if not tool_calls: yield chunk else: is_turn_complete = False + # End of turn is reached, do not resume even if there's a tool call + if chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: + yield chunk + break + turn_id = self._get_turn_id(chunk) if n_iter == 0: yield chunk # run the tools tool_response_message = await self._run_tool(tool_calls) + # pass it to next iteration turn_response = await self.client.agents.turn.resume( agent_id=self.agent_id, @@ -331,9 +340,4 @@ async def _create_turn_streaming( tool_responses=[tool_response_message], stream=True, ) - n_iter += 1 - break - - if n_iter >= max_iter: - raise Exception(f"Turn did not complete in {max_iter} iterations") From 3641027acae313b1777d0295752795a3f1ff0ad8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Mar 2025 14:23:43 -0800 Subject: [PATCH 10/19] precommit --- src/llama_stack_client/lib/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 571525fb..7626c445 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -278,7 +278,7 @@ def _create_turn_streaming( stream=True, ) n_iter += 1 - + if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER): raise Exception("Max inference iterations reached") From 276d7fe9d84042db8d85094b39082bc45d4359c3 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Mar 2025 16:35:12 -0800 Subject: [PATCH 11/19] refactor to utils --- src/llama_stack_client/lib/agents/agent.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index b427b413..e7a0d130 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -26,8 +26,9 @@ logger = logging.getLogger(__name__) -class AgentMixin: - def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: +class AgentUtils: + @staticmethod + def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[ToolParser] = None) -> List[ToolCall]: if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: return [] @@ -35,19 +36,20 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall] if message.stop_reason == "out_of_tokens": return [] - if self.tool_parser: - return self.tool_parser.get_tool_calls(message) + if tool_parser: + return tool_parser.get_tool_calls(message) return message.tool_calls - def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: + @staticmethod + def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]: if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: return None return chunk.event.payload.turn.turn_id -class Agent(AgentMixin): +class Agent: def __init__( self, client: LlamaStackClient, @@ -252,7 +254,7 @@ def _create_turn_streaming( if hasattr(chunk, "error"): yield chunk return - tool_calls = self._get_tool_calls(chunk) + tool_calls = AgentUtils.get_tool_calls(chunk, self.tool_parser) if not tool_calls: yield chunk else: @@ -264,7 +266,7 @@ def _create_turn_streaming( yield chunk break - turn_id = self._get_turn_id(chunk) + turn_id = AgentUtils.get_turn_id(chunk) if n_iter == 0: yield chunk @@ -408,7 +410,7 @@ async def _create_turn_streaming( yield chunk return - tool_calls = self._get_tool_calls(chunk) + tool_calls = AgentUtils.get_tool_calls(chunk, self.tool_parser) if not tool_calls: yield chunk else: @@ -418,7 +420,7 @@ async def _create_turn_streaming( yield chunk break - turn_id = self._get_turn_id(chunk) + turn_id = AgentUtils.get_turn_id(chunk) if n_iter == 0: yield chunk From d2657c4cbaf59efce9abb4c635d79d44acbf0ec3 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Mar 2025 17:10:55 -0800 Subject: [PATCH 12/19] refactor --- src/llama_stack_client/lib/agents/agent.py | 109 +++++++++++++-------- 1 file changed, 68 insertions(+), 41 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index e7a0d130..30af6746 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -48,6 +48,61 @@ def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]: return chunk.event.payload.turn.turn_id + @staticmethod + def get_agent_config( + model: Optional[str] = None, + instructions: Optional[str] = None, + tools: Optional[List[Union[Toolgroup, ClientTool]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, + ) -> AgentConfig: + # Create a minimal valid AgentConfig with required fields + if model is None or instructions is None: + raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided") + + agent_config = { + "model": model, + "instructions": instructions, + "toolgroups": [], + "client_tools": [], + } + + # Add optional parameters if provided + if enable_session_persistence is not None: + agent_config["enable_session_persistence"] = enable_session_persistence + if max_infer_iters is not None: + agent_config["max_infer_iters"] = max_infer_iters + if input_shields is not None: + agent_config["input_shields"] = input_shields + if output_shields is not None: + agent_config["output_shields"] = output_shields + if response_format is not None: + agent_config["response_format"] = response_format + if sampling_params is not None: + agent_config["sampling_params"] = sampling_params + if tool_config is not None: + agent_config["tool_config"] = tool_config + if tools is not None: + toolgroups: List[Toolgroup] = [] + client_tools: List[ClientTool] = [] + + for tool in tools: + if isinstance(tool, str) or isinstance(tool, dict): + toolgroups.append(tool) + else: + client_tools.append(tool) + + agent_config["toolgroups"] = toolgroups + agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools] + + agent_config = AgentConfig(**agent_config) + return agent_config + class Agent: def __init__( @@ -102,46 +157,18 @@ def __init__( # Construct agent_config from parameters if not provided if agent_config is None: - # Create a minimal valid AgentConfig with required fields - if model is None or instructions is None: - raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided") - - agent_config = { - "model": model, - "instructions": instructions, - "toolgroups": [], - "client_tools": [], - } - - # Add optional parameters if provided - if enable_session_persistence is not None: - agent_config["enable_session_persistence"] = enable_session_persistence - if max_infer_iters is not None: - agent_config["max_infer_iters"] = max_infer_iters - if input_shields is not None: - agent_config["input_shields"] = input_shields - if output_shields is not None: - agent_config["output_shields"] = output_shields - if response_format is not None: - agent_config["response_format"] = response_format - if sampling_params is not None: - agent_config["sampling_params"] = sampling_params - if tool_config is not None: - agent_config["tool_config"] = tool_config - if tools is not None: - toolgroups: List[Toolgroup] = [] - client_tools: List[ClientTool] = [] - - for tool in tools: - if isinstance(tool, str) or isinstance(tool, dict): - toolgroups.append(tool) - else: - client_tools.append(tool) - - agent_config["toolgroups"] = toolgroups - agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools] - - agent_config = AgentConfig(**agent_config) + agent_config = AgentUtils.get_agent_config( + model=model, + instructions=instructions, + tools=tools, + tool_config=tool_config, + sampling_params=sampling_params, + max_infer_iters=max_infer_iters, + input_shields=input_shields, + output_shields=output_shields, + response_format=response_format, + enable_session_persistence=enable_session_persistence, + ) self.agent_config = agent_config self.agent_id = self._create_agent(agent_config) @@ -420,7 +447,7 @@ async def _create_turn_streaming( yield chunk break - turn_id = AgentUtils.get_turn_id(chunk) + turn_id = self._get_turn_id(chunk) if n_iter == 0: yield chunk From 22bc71d7ee6c3c09f9f4f12d6cf64a04c5b47486 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Mar 2025 17:17:18 -0800 Subject: [PATCH 13/19] async --- src/llama_stack_client/lib/agents/agent.py | 2 +- .../lib/agents/client_tool.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 30af6746..7e2dc782 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -314,7 +314,7 @@ def _create_turn_streaming( raise Exception("Max inference iterations reached") -class AsyncAgent(AgentMixin): +class AsyncAgent: def __init__( self, client: AsyncLlamaStackClient, diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index fb7083c9..71da2b6b 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -9,7 +9,7 @@ from abc import abstractmethod from typing import Any, Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union -from llama_stack_client.types import Message, CompletionMessage, ToolResponse +from llama_stack_client.types import CompletionMessage, Message, ToolResponse from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -90,22 +90,27 @@ def run( async def async_run( self, message_history: List[Message], - ) -> ToolResponseMessage: + ) -> ToolResponse: last_message = message_history[-1] assert len(last_message.tool_calls) == 1, "Expected single tool call" tool_call = last_message.tool_calls[0] + metadata = {} try: response = await self.async_run_impl(**tool_call.arguments) - response_str = json.dumps(response, ensure_ascii=False) + if isinstance(response, dict) and "content" in response: + content = json.dumps(response["content"], ensure_ascii=False) + metadata = response.get("metadata", {}) + else: + content = json.dumps(response, ensure_ascii=False) except Exception as e: - response_str = f"Error when running tool: {e}" + content = f"Error when running tool: {e}" - return ToolResponseMessage( + return ToolResponse( call_id=tool_call.call_id, tool_name=tool_call.tool_name, - content=response_str, - role="tool", + content=content, + metadata=metadata, ) @abstractmethod From 32062de2b51eb9a4a2fe263e6caa5637150c8af1 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Mar 2025 17:20:07 -0800 Subject: [PATCH 14/19] wip --- src/llama_stack_client/lib/agents/agent.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 7e2dc782..c9efac64 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -171,21 +171,20 @@ def __init__( ) self.agent_config = agent_config - self.agent_id = self._create_agent(agent_config) self.client_tools = {t.get_name(): t for t in client_tools} self.sessions = [] self.tool_parser = tool_parser self.builtin_tools = {} - for tg in agent_config["toolgroups"]: - for tool in self.client.tools.list(toolgroup_id=tg): - self.builtin_tools[tool.identifier] = tool + self.initialize() - def _create_agent(self, agent_config: AgentConfig) -> int: + def initialize(self) -> None: agentic_system_create_response = self.client.agents.create( - agent_config=agent_config, + agent_config=self.agent_config, ) self.agent_id = agentic_system_create_response.agent_id - return self.agent_id + for tg in self.agent_config["toolgroups"]: + for tool in self.client.tools.list(toolgroup_id=tg): + self.builtin_tools[tool.identifier] = tool def create_session(self, session_name: str) -> str: agentic_system_create_session_response = self.client.agents.session.create( From d1e35d9de7fbce8436c4462989ad3d3f971326f6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Mar 2025 17:41:29 -0800 Subject: [PATCH 15/19] comments --- src/llama_stack_client/lib/agents/agent.py | 64 +++++++++++++++++++++- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index c9efac64..73d0e639 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -316,12 +316,69 @@ def _create_turn_streaming( class AsyncAgent: def __init__( self, - client: AsyncLlamaStackClient, - agent_config: AgentConfig, - client_tools: Tuple[ClientTool] = (), + client: LlamaStackClient, + # begin deprecated + agent_config: Optional[AgentConfig] = None, + client_tools: Tuple[ClientTool, ...] = (), + # end deprecated tool_parser: Optional[ToolParser] = None, + model: Optional[str] = None, + instructions: Optional[str] = None, + tools: Optional[List[Union[Toolgroup, ClientTool]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, ): + """Construct an Agent with the given parameters. + + :param client: The LlamaStackClient instance. + :param agent_config: The AgentConfig instance. + ::deprecated: use other parameters instead + :param client_tools: A tuple of ClientTool instances. + ::deprecated: use tools instead + :param tool_parser: Custom logic that parses tool calls from a message. + :param model: The model to use for the agent. + :param instructions: The instructions for the agent. + :param tools: A list of tools for the agent. Values can be one of the following: + - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} + - a python function decorated with @client_tool + - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" + - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent + - an instance of ClientTool: A client tool object. + :param tool_config: The tool configuration for the agent. + :param sampling_params: The sampling parameters for the agent. + :param max_infer_iters: The maximum number of inference iterations. + :param input_shields: The input shields for the agent. + :param output_shields: The output shields for the agent. + :param response_format: The response format for the agent. + :param enable_session_persistence: Whether to enable session persistence. + """ self.client = client + + if agent_config is not None: + logger.warning("`agent_config` is deprecated. Use inlined parameters instead.") + if client_tools != (): + logger.warning("`client_tools` is deprecated. Use `tools` instead.") + + # Construct agent_config from parameters if not provided + if agent_config is None: + agent_config = AgentUtils.get_agent_config( + model=model, + instructions=instructions, + tools=tools, + tool_config=tool_config, + sampling_params=sampling_params, + max_infer_iters=max_infer_iters, + input_shields=input_shields, + output_shields=output_shields, + response_format=response_format, + enable_session_persistence=enable_session_persistence, + ) + self.agent_config = agent_config self.client_tools = {t.get_name(): t for t in client_tools} self.sessions = [] @@ -341,6 +398,7 @@ async def initialize(self) -> None: self.builtin_tools[tool.identifier] = tool async def create_session(self, session_name: str) -> str: + await self.initialize() agentic_system_create_session_response = await self.client.agents.session.create( agent_id=self.agent_id, session_name=session_name, From 7858cafb7bd31e7d166938dc9debc44fa3d5ee39 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Mar 2025 17:43:09 -0800 Subject: [PATCH 16/19] comments --- src/llama_stack_client/lib/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 73d0e639..e732e575 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -6,7 +6,7 @@ import logging from typing import AsyncIterator, Iterator, List, Optional, Tuple, Union -from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient +from llama_stack_client import LlamaStackClient from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig From 5914e3f76448d19c1cd2c93a17b142e2f4c9d8e5 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Mar 2025 18:00:26 -0800 Subject: [PATCH 17/19] client tools --- src/llama_stack_client/lib/agents/agent.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index e732e575..637ef923 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -27,6 +27,13 @@ class AgentUtils: + @staticmethod + def get_client_tools(tools: Optional[List[Union[Toolgroup, ClientTool]]]) -> List[ClientTool]: + if not tools: + return [] + + return [tool for tool in tools if isinstance(tool, ClientTool)] + @staticmethod def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[ToolParser] = None) -> List[ToolCall]: if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: @@ -169,6 +176,7 @@ def __init__( response_format=response_format, enable_session_persistence=enable_session_persistence, ) + client_tools = AgentUtils.get_client_tools(tools) self.agent_config = agent_config self.client_tools = {t.get_name(): t for t in client_tools} From c00e43a32a87c8b9de0acddce9318c7693bc0fb6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 6 Mar 2025 18:00:48 -0800 Subject: [PATCH 18/19] async --- src/llama_stack_client/lib/agents/agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 637ef923..3a5da669 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -386,6 +386,7 @@ def __init__( response_format=response_format, enable_session_persistence=enable_session_persistence, ) + client_tools = AgentUtils.get_client_tools(tools) self.agent_config = agent_config self.client_tools = {t.get_name(): t for t in client_tools} From 72a4306348f1672a5daa9c46edf75e79e5f3a4c6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 7 Mar 2025 09:52:16 -0800 Subject: [PATCH 19/19] fix initialization --- src/llama_stack_client/lib/agents/agent.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 3a5da669..ea87c1df 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -393,15 +393,25 @@ def __init__( self.sessions = [] self.tool_parser = tool_parser self.builtin_tools = {} + self._agent_id = None if isinstance(client, LlamaStackClient): raise ValueError("AsyncAgent must be initialized with an AsyncLlamaStackClient") + @property + def agent_id(self) -> str: + if not self._agent_id: + raise RuntimeError("Agent ID not initialized. Call initialize() first.") + return self._agent_id + async def initialize(self) -> None: + if self._agent_id: + return + agentic_system_create_response = await self.client.agents.create( agent_config=self.agent_config, ) - self.agent_id = agentic_system_create_response.agent_id + self._agent_id = agentic_system_create_response.agent_id for tg in self.agent_config["toolgroups"]: for tool in await self.client.tools.list(toolgroup_id=tg): self.builtin_tools[tool.identifier] = tool @@ -491,7 +501,6 @@ async def _create_turn_streaming( stream=True, documents=documents, toolgroups=toolgroups, - allow_turn_resume=True, ) # 2. process turn and resume if there's a tool call @@ -513,7 +522,7 @@ async def _create_turn_streaming( yield chunk break - turn_id = self._get_turn_id(chunk) + turn_id = AgentUtils.get_turn_id(chunk) if n_iter == 0: yield chunk