diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 6ac46878..eae95f2c 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,13 +3,16 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient from llama_stack_client.types import ToolResponseMessage, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.agents.turn import Turn from llama_stack_client.types.agents.turn_create_params import (Document, Toolgroup) +from llama_stack_client.types.agents.turn_create_response import \ + AgentTurnResponseStreamChunk from .client_tool import ClientTool @@ -46,7 +49,7 @@ def create_session(self, session_name: str) -> int: self.sessions.append(self.session_id) return self.session_id - def _has_tool_call(self, chunk): + def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: if chunk.event.payload.event_type != "turn_complete": return False message = chunk.event.payload.turn.output_message @@ -54,7 +57,7 @@ def _has_tool_call(self, chunk): return False return len(message.tool_calls) > 0 - def _run_tool(self, chunk): + def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: message = chunk.event.payload.turn.output_message tool_call = message.tool_calls[0] if tool_call.tool_name not in self.client_tools: @@ -75,7 +78,28 @@ def create_turn( session_id: Optional[str] = None, toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, - ): + stream: bool = True, + ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: + if stream: + return self._create_turn_streaming(messages, session_id, toolgroups, documents, stream) + else: + chunk = None + for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents, stream): + pass + if not chunk: + raise Exception("No chunk returned") + if chunk.event.payload.event_type != 'turn_complete': + raise Exception("Turn did not complete") + return chunk.event.payload.turn + + def _create_turn_streaming( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + stream: bool = True, + ) -> Iterator[AgentTurnResponseStreamChunk]: stop = False n_iter = 0 max_iter = self.agent_config.get('max_infer_iters', DEFAULT_MAX_ITER)