From b0de6d9341e0dc77bb427d649b6eecb9fd0c83b6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 27 Feb 2025 19:34:02 -0800 Subject: [PATCH 1/6] max infer iters --- src/llama_stack_client/lib/agents/agent.py | 25 +++++++++++++------ .../lib/agents/client_tool.py | 13 ++-------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 87badd46..4cac97f0 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -11,9 +11,7 @@ from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agents.turn import CompletionMessage, Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup -from llama_stack_client.types.agents.turn_create_response import ( - AgentTurnResponseStreamChunk, -) +from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk from llama_stack_client.types.shared.tool_call import ToolCall from .client_tool import ClientTool @@ -76,6 +74,12 @@ def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: return chunk.event.payload.turn.turn_id + def _is_turn_complete(self, chunk: AgentTurnResponseStreamChunk) -> bool: + if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: + return False + + return chunk.event.payload.turn.output_message.stop_reason == "end_of_turn" + def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: assert len(tool_calls) == 1, "Only one tool call is supported" tool_call = tool_calls[0] @@ -143,7 +147,6 @@ def _create_turn_streaming( documents: Optional[List[Document]] = None, ) -> Iterator[AgentTurnResponseStreamChunk]: n_iter = 0 - max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) # 1. create an agent turn turn_response = self.client.agents.turn.create( @@ -159,6 +162,7 @@ def _create_turn_streaming( # 2. process turn and resume if there's a tool call is_turn_complete = False + is_max_iter_reached = False while not is_turn_complete: is_turn_complete = True for chunk in turn_response: @@ -168,14 +172,21 @@ def _create_turn_streaming( return elif not tool_calls: yield chunk + elif is_max_iter_reached: + yield chunk else: is_turn_complete = False + if chunk.event.payload.turn.output_message.stop_reason != "end_of_message": + is_max_iter_reached = True + turn_id = self._get_turn_id(chunk) if n_iter == 0: yield chunk # run the tools tool_response_message = self._run_tool(tool_calls) + print("tool_response_message", tool_response_message) + # pass it to next iteration turn_response = self.client.agents.turn.resume( agent_id=self.agent_id, @@ -185,7 +196,7 @@ def _create_turn_streaming( stream=True, ) n_iter += 1 - break - if n_iter >= max_iter: - raise Exception(f"Turn did not complete in {max_iter} iterations") + # # if max iter is reached, raise an error + # if is_max_iter_reached: + # raise Exception("Max iteration reached") diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index f672268d..2ab6d13a 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -7,16 +7,7 @@ import inspect import json from abc import abstractmethod -from typing import ( - Callable, - Dict, - get_args, - get_origin, - get_type_hints, - List, - TypeVar, - Union, -) +from typing import Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union from llama_stack_client.types import Message, ToolResponseMessage from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -163,7 +154,7 @@ def get_params_definition(self) -> Dict[str, Parameter]: params[name] = Parameter( name=name, description=param_doc or f"Parameter {name}", - parameter_type=type_hint.__name__, + parameter_type=type_hint.__name__ if type_hint.__name__ != "str" else "string", default=(param.default if param.default != inspect.Parameter.empty else None), required=is_required, ) From ee49838df6651b42876bf9fa8b6f44a7b5c47060 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 28 Feb 2025 12:48:47 -0800 Subject: [PATCH 2/6] update client handling --- src/llama_stack_client/lib/agents/agent.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 49bb4b0a..85cdde8c 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -162,7 +162,6 @@ def _create_turn_streaming( # 2. process turn and resume if there's a tool call is_turn_complete = False - is_max_iter_reached = False while not is_turn_complete: is_turn_complete = True for chunk in turn_response: @@ -172,12 +171,12 @@ def _create_turn_streaming( tool_calls = self._get_tool_calls(chunk) if not tool_calls: yield chunk - elif is_max_iter_reached: - yield chunk else: is_turn_complete = False + # End of turn is reached, do not resume even if there's a tool call if chunk.event.payload.turn.output_message.stop_reason != "end_of_message": - is_max_iter_reached = True + yield chunk + continue turn_id = self._get_turn_id(chunk) if n_iter == 0: @@ -196,7 +195,3 @@ def _create_turn_streaming( stream=True, ) n_iter += 1 - - # # if max iter is reached, raise an error - # if is_max_iter_reached: - # raise Exception("Max iteration reached") From cd87d9fa8a06843a6dc55bf5c01417a2ff553fe0 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 28 Feb 2025 12:51:59 -0800 Subject: [PATCH 3/6] remove unused stuff --- src/llama_stack_client/lib/agents/agent.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 85cdde8c..ff4a4374 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -74,12 +74,6 @@ def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: return chunk.event.payload.turn.turn_id - def _is_turn_complete(self, chunk: AgentTurnResponseStreamChunk) -> bool: - if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: - return False - - return chunk.event.payload.turn.output_message.stop_reason == "end_of_turn" - def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: assert len(tool_calls) == 1, "Only one tool call is supported" tool_call = tool_calls[0] @@ -184,7 +178,6 @@ def _create_turn_streaming( # run the tools tool_response_message = self._run_tool(tool_calls) - print("tool_response_message", tool_response_message) # pass it to next iteration turn_response = self.client.agents.turn.resume( From 81b2bbe08a884d546a504c03b9ce96d1b08148cb Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 28 Feb 2025 12:52:51 -0800 Subject: [PATCH 4/6] remove unused stuff --- src/llama_stack_client/lib/agents/client_tool.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 2ab6d13a..a1066616 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -154,6 +154,7 @@ def get_params_definition(self) -> Dict[str, Parameter]: params[name] = Parameter( name=name, description=param_doc or f"Parameter {name}", + # Hack: litellm/openai expects "string" for str type parameter_type=type_hint.__name__ if type_hint.__name__ != "str" else "string", default=(param.default if param.default != inspect.Parameter.empty else None), required=is_required, From b3744305214d61091553ffd06517cfa0244357e4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 2 Mar 2025 16:09:30 -0800 Subject: [PATCH 5/6] comments --- src/llama_stack_client/lib/agents/agent.py | 2 +- uv.lock | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index ff4a4374..bb2b98e3 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -168,7 +168,7 @@ def _create_turn_streaming( else: is_turn_complete = False # End of turn is reached, do not resume even if there's a tool call - if chunk.event.payload.turn.output_message.stop_reason != "end_of_message": + if chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: yield chunk continue diff --git a/uv.lock b/uv.lock index bdcd805c..abb3e6da 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.7" resolution-markers = [ "python_full_version >= '3.12'", @@ -288,7 +289,7 @@ wheels = [ [[package]] name = "llama-stack-client" -version = "0.1.3" +version = "0.1.4" source = { editable = "." } dependencies = [ { name = "anyio", version = "3.7.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.8'" }, From 2a87391c9e288a798fdeb8df4dc465b866b2fdfc Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 3 Mar 2025 10:07:35 -0800 Subject: [PATCH 6/6] continue to break --- src/llama_stack_client/lib/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index bb2b98e3..e323cde2 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -170,7 +170,7 @@ def _create_turn_streaming( # End of turn is reached, do not resume even if there's a tool call if chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: yield chunk - continue + break turn_id = self._get_turn_id(chunk) if n_iter == 0: