diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 504ef5e8..e323cde2 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -11,9 +11,7 @@ from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agents.turn import CompletionMessage, Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup -from llama_stack_client.types.agents.turn_create_response import ( - AgentTurnResponseStreamChunk, -) +from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk from llama_stack_client.types.shared.tool_call import ToolCall from .client_tool import ClientTool @@ -143,7 +141,6 @@ def _create_turn_streaming( documents: Optional[List[Document]] = None, ) -> Iterator[AgentTurnResponseStreamChunk]: n_iter = 0 - max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) # 1. create an agent turn turn_response = self.client.agents.turn.create( @@ -170,12 +167,18 @@ def _create_turn_streaming( yield chunk else: is_turn_complete = False + # End of turn is reached, do not resume even if there's a tool call + if chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: + yield chunk + break + turn_id = self._get_turn_id(chunk) if n_iter == 0: yield chunk # run the tools tool_response_message = self._run_tool(tool_calls) + # pass it to next iteration turn_response = self.client.agents.turn.resume( agent_id=self.agent_id, @@ -185,7 +188,3 @@ def _create_turn_streaming( stream=True, ) n_iter += 1 - break - - if n_iter >= max_iter: - raise Exception(f"Turn did not complete in {max_iter} iterations") diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 6fc811e9..a1066616 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -7,16 +7,7 @@ import inspect import json from abc import abstractmethod -from typing import ( - Callable, - Dict, - get_args, - get_origin, - get_type_hints, - List, - TypeVar, - Union, -) +from typing import Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union from llama_stack_client.types import Message, ToolResponseMessage from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam diff --git a/uv.lock b/uv.lock index bdcd805c..abb3e6da 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.7" resolution-markers = [ "python_full_version >= '3.12'", @@ -288,7 +289,7 @@ wheels = [ [[package]] name = "llama-stack-client" -version = "0.1.3" +version = "0.1.4" source = { editable = "." } dependencies = [ { name = "anyio", version = "3.7.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.8'" },