Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Iterator, List, Optional, Tuple, Union

from llama_stack_client import LlamaStackClient
import logging

from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
from llama_stack_client.types.shared.tool_call import ToolCall
from llama_stack_client.types.shared_params.agent_config import ToolConfig
from llama_stack_client.types.shared_params.response_format import ResponseFormat
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
from llama_stack_client.types.shared_params.agent_config import ToolConfig

from .client_tool import ClientTool
from .tool_parser import ToolParser
Expand Down Expand Up @@ -91,10 +91,10 @@ def __init__(
# Add optional parameters if provided
if enable_session_persistence is not None:
agent_config["enable_session_persistence"] = enable_session_persistence
if input_shields is not None:
agent_config["input_shields"] = input_shields
if max_infer_iters is not None:
agent_config["max_infer_iters"] = max_infer_iters
if input_shields is not None:
agent_config["input_shields"] = input_shields
if output_shields is not None:
agent_config["output_shields"] = output_shields
if response_format is not None:
Expand Down Expand Up @@ -254,7 +254,9 @@ def _create_turn_streaming(
else:
is_turn_complete = False
# End of turn is reached, do not resume even if there's a tool call
if chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}:
# We only check for this if tool_parser is not set, because otherwise
# tool call will be parsed on client side, and server will always return "end_of_turn"
if not self.tool_parser and chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}:
yield chunk
break

Expand All @@ -274,3 +276,6 @@ def _create_turn_streaming(
stream=True,
)
n_iter += 1

if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER):
raise Exception("Max inference iterations reached")
9 changes: 5 additions & 4 deletions src/llama_stack_client/lib/agents/react/tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import uuid
from typing import List, Optional, Union

from pydantic import BaseModel, ValidationError
from typing import Optional, List, Union
from ..tool_parser import ToolParser

from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared.tool_call import ToolCall

import uuid
from ..tool_parser import ToolParser


class Param(BaseModel):
Expand Down