diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 90356b0e..19c81d85 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,10 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging from typing import Iterator, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient -import logging from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig @@ -14,9 +14,9 @@ from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk from llama_stack_client.types.shared.tool_call import ToolCall +from llama_stack_client.types.shared_params.agent_config import ToolConfig from llama_stack_client.types.shared_params.response_format import ResponseFormat from llama_stack_client.types.shared_params.sampling_params import SamplingParams -from llama_stack_client.types.shared_params.agent_config import ToolConfig from .client_tool import ClientTool from .tool_parser import ToolParser @@ -91,10 +91,10 @@ def __init__( # Add optional parameters if provided if enable_session_persistence is not None: agent_config["enable_session_persistence"] = enable_session_persistence - if input_shields is not None: - agent_config["input_shields"] = input_shields if max_infer_iters is not None: agent_config["max_infer_iters"] = max_infer_iters + if input_shields is not None: + agent_config["input_shields"] = input_shields if output_shields is not None: agent_config["output_shields"] = output_shields if response_format is not None: @@ -254,7 +254,9 @@ def _create_turn_streaming( else: is_turn_complete = False # End of turn is reached, do not resume even if there's a tool call - if chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: + # We only check for this if tool_parser is not set, because otherwise + # tool call will be parsed on client side, and server will always return "end_of_turn" + if not self.tool_parser and chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: yield chunk break @@ -274,3 +276,6 @@ def _create_turn_streaming( stream=True, ) n_iter += 1 + + if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER): + raise Exception("Max inference iterations reached") diff --git a/src/llama_stack_client/lib/agents/react/tool_parser.py b/src/llama_stack_client/lib/agents/react/tool_parser.py index 9aaa30cc..0dfcfe48 100644 --- a/src/llama_stack_client/lib/agents/react/tool_parser.py +++ b/src/llama_stack_client/lib/agents/react/tool_parser.py @@ -4,13 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import uuid +from typing import List, Optional, Union + from pydantic import BaseModel, ValidationError -from typing import Optional, List, Union -from ..tool_parser import ToolParser + from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall - -import uuid +from ..tool_parser import ToolParser class Param(BaseModel):