33#
44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
6- import uuid
7- from datetime import datetime
86from typing import Iterator , List , Optional , Tuple , Union
97
108from llama_stack_client import LlamaStackClient
1614from llama_stack_client .types .agents .turn_create_response import (
1715 AgentTurnResponseStreamChunk ,
1816)
19- from llama_stack_client .types .agents .turn_response_event import TurnResponseEvent
20- from llama_stack_client .types .agents .turn_response_event_payload import (
21- AgentTurnResponseStepCompletePayload ,
22- )
2317from llama_stack_client .types .shared .tool_call import ToolCall
24- from llama_stack_client .types .tool_execution_step import ToolExecutionStep
25- from llama_stack_client .types .tool_response import ToolResponse
2618
2719from .client_tool import ClientTool
2820from .tool_parser import ToolParser
@@ -66,7 +58,7 @@ def create_session(self, session_name: str) -> str:
6658 return self .session_id
6759
6860 def _get_tool_calls (self , chunk : AgentTurnResponseStreamChunk ) -> List [ToolCall ]:
69- if chunk .event .payload .event_type != "turn_complete" :
61+ if chunk .event .payload .event_type not in { "turn_complete" , "turn_awaiting_input" } :
7062 return []
7163
7264 message = chunk .event .payload .turn .output_message
@@ -78,6 +70,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]
7870
7971 return message .tool_calls
8072
73+ def _get_turn_id (self , chunk : AgentTurnResponseStreamChunk ) -> Optional [str ]:
74+ if chunk .event .payload .event_type not in ["turn_complete" , "turn_awaiting_input" ]:
75+ return None
76+
77+ return chunk .event .payload .turn .turn_id
78+
8179 def _run_tool (self , tool_calls : List [ToolCall ]) -> ToolResponseMessage :
8280 assert len (tool_calls ) == 1 , "Only one tool call is supported"
8381 tool_call = tool_calls [0 ]
@@ -132,27 +130,10 @@ def create_turn(
132130 if stream :
133131 return self ._create_turn_streaming (messages , session_id , toolgroups , documents )
134132 else :
135- chunks = []
136- for chunk in self ._create_turn_streaming (messages , session_id , toolgroups , documents ):
137- if chunk .event .payload .event_type == "turn_complete" :
138- chunks .append (chunk )
139- pass
133+ chunks = [x for x in self ._create_turn_streaming (messages , session_id , toolgroups , documents )]
140134 if not chunks :
141135 raise Exception ("Turn did not complete" )
142-
143- # merge chunks
144- return Turn (
145- input_messages = chunks [0 ].event .payload .turn .input_messages ,
146- output_message = chunks [- 1 ].event .payload .turn .output_message ,
147- session_id = chunks [0 ].event .payload .turn .session_id ,
148- steps = [step for chunk in chunks for step in chunk .event .payload .turn .steps ],
149- turn_id = chunks [0 ].event .payload .turn .turn_id ,
150- started_at = chunks [0 ].event .payload .turn .started_at ,
151- completed_at = chunks [- 1 ].event .payload .turn .completed_at ,
152- output_attachments = [
153- attachment for chunk in chunks for attachment in chunk .event .payload .turn .output_attachments
154- ],
155- )
136+ return chunks [- 1 ].event .payload .turn
156137
157138 def _create_turn_streaming (
158139 self ,
@@ -161,62 +142,50 @@ def _create_turn_streaming(
161142 toolgroups : Optional [List [Toolgroup ]] = None ,
162143 documents : Optional [List [Document ]] = None ,
163144 ) -> Iterator [AgentTurnResponseStreamChunk ]:
164- stop = False
165145 n_iter = 0
166146 max_iter = self .agent_config .get ("max_infer_iters" , DEFAULT_MAX_ITER )
167- while not stop and n_iter < max_iter :
168- response = self .client .agents .turn .create (
169- agent_id = self .agent_id ,
170- # use specified session_id or last session created
171- session_id = session_id or self .session_id [- 1 ],
172- messages = messages ,
173- stream = True ,
174- documents = documents ,
175- toolgroups = toolgroups ,
176- )
177- # by default, we stop after the first turn
178- stop = True
179- for chunk in response :
147+
148+ # 1. create an agent turn
149+ turn_response = self .client .agents .turn .create (
150+ agent_id = self .agent_id ,
151+ # use specified session_id or last session created
152+ session_id = session_id or self .session_id [- 1 ],
153+ messages = messages ,
154+ stream = True ,
155+ documents = documents ,
156+ toolgroups = toolgroups ,
157+ allow_turn_resume = True ,
158+ )
159+
160+ # 2. process turn and resume if there's a tool call
161+ is_turn_complete = False
162+ while not is_turn_complete :
163+ is_turn_complete = True
164+ for chunk in turn_response :
180165 tool_calls = self ._get_tool_calls (chunk )
181166 if hasattr (chunk , "error" ):
182167 yield chunk
183168 return
184169 elif not tool_calls :
185170 yield chunk
186171 else :
187- tool_execution_start_time = datetime .now ()
172+ is_turn_complete = False
173+ turn_id = self ._get_turn_id (chunk )
174+ if n_iter == 0 :
175+ yield chunk
176+
177+ # run the tools
188178 tool_response_message = self ._run_tool (tool_calls )
189- tool_execution_step = ToolExecutionStep (
190- step_type = "tool_execution" ,
191- step_id = str (uuid .uuid4 ()),
192- tool_calls = tool_calls ,
193- tool_responses = [
194- ToolResponse (
195- tool_name = tool_response_message .tool_name ,
196- content = tool_response_message .content ,
197- call_id = tool_response_message .call_id ,
198- )
199- ],
200- turn_id = chunk .event .payload .turn .turn_id ,
201- completed_at = datetime .now (),
202- started_at = tool_execution_start_time ,
203- )
204- yield AgentTurnResponseStreamChunk (
205- event = TurnResponseEvent (
206- payload = AgentTurnResponseStepCompletePayload (
207- event_type = "step_complete" ,
208- step_id = tool_execution_step .step_id ,
209- step_type = "tool_execution" ,
210- step_details = tool_execution_step ,
211- )
212- )
179+ # pass it to next iteration
180+ turn_response = self .client .agents .turn .resume (
181+ agent_id = self .agent_id ,
182+ session_id = session_id or self .session_id [- 1 ],
183+ turn_id = turn_id ,
184+ tool_responses = [tool_response_message ],
185+ stream = True ,
213186 )
214-
215- # HACK: append the tool execution step to the turn
216- chunk .event .payload .turn .steps .append (tool_execution_step )
217- yield chunk
218-
219- # continue the turn when there's a tool call
220- stop = False
221- messages = [tool_response_message ]
222187 n_iter += 1
188+ break
189+
190+ if n_iter >= max_iter :
191+ raise Exception (f"Turn did not complete in { max_iter } iterations" )
0 commit comments