33#
44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
6- from typing import List , Optional , Tuple , Union
6+ from typing import Iterator , List , Optional , Tuple , Union
77
88from llama_stack_client import LlamaStackClient
99from llama_stack_client .types import ToolResponseMessage , UserMessage
1010from llama_stack_client .types .agent_create_params import AgentConfig
11+ from llama_stack_client .types .agents .turn import Turn
1112from llama_stack_client .types .agents .turn_create_params import (Document ,
1213 Toolgroup )
14+ from llama_stack_client .types .agents .turn_create_response import \
15+ AgentTurnResponseStreamChunk
1316
1417from .client_tool import ClientTool
1518
@@ -46,15 +49,15 @@ def create_session(self, session_name: str) -> int:
4649 self .sessions .append (self .session_id )
4750 return self .session_id
4851
49- def _has_tool_call (self , chunk ) :
52+ def _has_tool_call (self , chunk : AgentTurnResponseStreamChunk ) -> bool :
5053 if chunk .event .payload .event_type != "turn_complete" :
5154 return False
5255 message = chunk .event .payload .turn .output_message
5356 if message .stop_reason == "out_of_tokens" :
5457 return False
5558 return len (message .tool_calls ) > 0
5659
57- def _run_tool (self , chunk ) :
60+ def _run_tool (self , chunk : AgentTurnResponseStreamChunk ) -> ToolResponseMessage :
5861 message = chunk .event .payload .turn .output_message
5962 tool_call = message .tool_calls [0 ]
6063 if tool_call .tool_name not in self .client_tools :
@@ -75,7 +78,28 @@ def create_turn(
7578 session_id : Optional [str ] = None ,
7679 toolgroups : Optional [List [Toolgroup ]] = None ,
7780 documents : Optional [List [Document ]] = None ,
78- ):
81+ stream : bool = True ,
82+ ) -> Iterator [AgentTurnResponseStreamChunk ] | Turn :
83+ if stream :
84+ return self ._create_turn_streaming (messages , session_id , toolgroups , documents , stream )
85+ else :
86+ chunk = None
87+ for chunk in self ._create_turn_streaming (messages , session_id , toolgroups , documents , stream ):
88+ pass
89+ if not chunk :
90+ raise Exception ("No chunk returned" )
91+ if chunk .event .payload .event_type != 'turn_complete' :
92+ raise Exception ("Turn did not complete" )
93+ return chunk .event .payload .turn
94+
95+ def _create_turn_streaming (
96+ self ,
97+ messages : List [Union [UserMessage , ToolResponseMessage ]],
98+ session_id : Optional [str ] = None ,
99+ toolgroups : Optional [List [Toolgroup ]] = None ,
100+ documents : Optional [List [Document ]] = None ,
101+ stream : bool = True ,
102+ ) -> Iterator [AgentTurnResponseStreamChunk ]:
79103 stop = False
80104 n_iter = 0
81105 max_iter = self .agent_config .get ('max_infer_iters' , DEFAULT_MAX_ITER )
0 commit comments