Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
from llama_stack_client.types.agents.turn_create_response import (
AgentTurnResponseStreamChunk,
)
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
from llama_stack_client.types.shared.tool_call import ToolCall

from .client_tool import ClientTool
Expand Down Expand Up @@ -143,7 +141,6 @@ def _create_turn_streaming(
documents: Optional[List[Document]] = None,
) -> Iterator[AgentTurnResponseStreamChunk]:
n_iter = 0
max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER)

# 1. create an agent turn
turn_response = self.client.agents.turn.create(
Expand All @@ -170,12 +167,18 @@ def _create_turn_streaming(
yield chunk
else:
is_turn_complete = False
# End of turn is reached, do not resume even if there's a tool call
if chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}:
yield chunk
break

turn_id = self._get_turn_id(chunk)
if n_iter == 0:
yield chunk

# run the tools
tool_response_message = self._run_tool(tool_calls)

# pass it to next iteration
turn_response = self.client.agents.turn.resume(
agent_id=self.agent_id,
Expand All @@ -185,7 +188,3 @@ def _create_turn_streaming(
stream=True,
)
n_iter += 1
break

if n_iter >= max_iter:
raise Exception(f"Turn did not complete in {max_iter} iterations")
11 changes: 1 addition & 10 deletions src/llama_stack_client/lib/agents/client_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@
import inspect
import json
from abc import abstractmethod
from typing import (
Callable,
Dict,
get_args,
get_origin,
get_type_hints,
List,
TypeVar,
Union,
)
from typing import Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union

from llama_stack_client.types import Message, ToolResponseMessage
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam
Expand Down
3 changes: 2 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.