1313from llama_stack_client .types .agents .turn_create_response import (
1414 AgentTurnResponseStreamChunk ,
1515)
16- from llama_stack_client .types .agents .turn_response_event import TurnResponseEvent
17- from llama_stack_client .types .agents .turn_response_event_payload import (
18- AgentTurnResponseStepCompletePayload ,
19- )
2016from llama_stack_client .types .shared .tool_call import ToolCall
2117from llama_stack_client .types .agents .turn import CompletionMessage
2218from .client_tool import ClientTool
2319from .tool_parser import ToolParser
24- from datetime import datetime
25- import uuid
26- from llama_stack_client .types .tool_execution_step import ToolExecutionStep
27- from llama_stack_client .types .tool_response import ToolResponse
2820
2921DEFAULT_MAX_ITER = 10
3022
@@ -65,7 +57,7 @@ def create_session(self, session_name: str) -> int:
6557 return self .session_id
6658
6759 def _get_tool_calls (self , chunk : AgentTurnResponseStreamChunk ) -> List [ToolCall ]:
68- if chunk .event .payload .event_type != "turn_complete" :
60+ if chunk .event .payload .event_type not in { "turn_complete" , "turn_awaiting_input" } :
6961 return []
7062
7163 message = chunk .event .payload .turn .output_message
@@ -77,6 +69,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]
7769
7870 return message .tool_calls
7971
72+ def _get_turn_id (self , chunk : AgentTurnResponseStreamChunk ) -> Optional [str ]:
73+ if chunk .event .payload .event_type not in ["turn_complete" , "turn_awaiting_input" ]:
74+ return None
75+
76+ return chunk .event .payload .turn .turn_id
77+
8078 def _run_tool (self , tool_calls : List [ToolCall ]) -> ToolResponseMessage :
8179 assert len (tool_calls ) == 1 , "Only one tool call is supported"
8280 tool_call = tool_calls [0 ]
@@ -131,27 +129,10 @@ def create_turn(
131129 if stream :
132130 return self ._create_turn_streaming (messages , session_id , toolgroups , documents )
133131 else :
134- chunks = []
135- for chunk in self ._create_turn_streaming (messages , session_id , toolgroups , documents ):
136- if chunk .event .payload .event_type == "turn_complete" :
137- chunks .append (chunk )
138- pass
132+ chunks = [x for x in self ._create_turn_streaming (messages , session_id , toolgroups , documents )]
139133 if not chunks :
140134 raise Exception ("Turn did not complete" )
141-
142- # merge chunks
143- return Turn (
144- input_messages = chunks [0 ].event .payload .turn .input_messages ,
145- output_message = chunks [- 1 ].event .payload .turn .output_message ,
146- session_id = chunks [0 ].event .payload .turn .session_id ,
147- steps = [step for chunk in chunks for step in chunk .event .payload .turn .steps ],
148- turn_id = chunks [0 ].event .payload .turn .turn_id ,
149- started_at = chunks [0 ].event .payload .turn .started_at ,
150- completed_at = chunks [- 1 ].event .payload .turn .completed_at ,
151- output_attachments = [
152- attachment for chunk in chunks for attachment in chunk .event .payload .turn .output_attachments
153- ],
154- )
135+ return chunks [- 1 ].event .payload .turn
155136
156137 def _create_turn_streaming (
157138 self ,
@@ -160,62 +141,50 @@ def _create_turn_streaming(
160141 toolgroups : Optional [List [Toolgroup ]] = None ,
161142 documents : Optional [List [Document ]] = None ,
162143 ) -> Iterator [AgentTurnResponseStreamChunk ]:
163- stop = False
164144 n_iter = 0
165145 max_iter = self .agent_config .get ("max_infer_iters" , DEFAULT_MAX_ITER )
166- while not stop and n_iter < max_iter :
167- response = self .client .agents .turn .create (
168- agent_id = self .agent_id ,
169- # use specified session_id or last session created
170- session_id = session_id or self .session_id [- 1 ],
171- messages = messages ,
172- stream = True ,
173- documents = documents ,
174- toolgroups = toolgroups ,
175- )
176- # by default, we stop after the first turn
177- stop = True
178- for chunk in response :
146+
147+ # 1. create an agent turn
148+ turn_response = self .client .agents .turn .create (
149+ agent_id = self .agent_id ,
150+ # use specified session_id or last session created
151+ session_id = session_id or self .session_id [- 1 ],
152+ messages = messages ,
153+ stream = True ,
154+ documents = documents ,
155+ toolgroups = toolgroups ,
156+ allow_turn_resume = True ,
157+ )
158+
159+ # 2. process turn and resume if there's a tool call
160+ is_turn_complete = False
161+ while not is_turn_complete :
162+ is_turn_complete = True
163+ for chunk in turn_response :
179164 tool_calls = self ._get_tool_calls (chunk )
180165 if hasattr (chunk , "error" ):
181166 yield chunk
182167 return
183168 elif not tool_calls :
184169 yield chunk
185170 else :
186- tool_execution_start_time = datetime .now ()
171+ is_turn_complete = False
172+ turn_id = self ._get_turn_id (chunk )
173+ if n_iter == 0 :
174+ yield chunk
175+
176+ # run the tools
187177 tool_response_message = self ._run_tool (tool_calls )
188- tool_execution_step = ToolExecutionStep (
189- step_type = "tool_execution" ,
190- step_id = str (uuid .uuid4 ()),
191- tool_calls = tool_calls ,
192- tool_responses = [
193- ToolResponse (
194- tool_name = tool_response_message .tool_name ,
195- content = tool_response_message .content ,
196- call_id = tool_response_message .call_id ,
197- )
198- ],
199- turn_id = chunk .event .payload .turn .turn_id ,
200- completed_at = datetime .now (),
201- started_at = tool_execution_start_time ,
202- )
203- yield AgentTurnResponseStreamChunk (
204- event = TurnResponseEvent (
205- payload = AgentTurnResponseStepCompletePayload (
206- event_type = "step_complete" ,
207- step_id = tool_execution_step .step_id ,
208- step_type = "tool_execution" ,
209- step_details = tool_execution_step ,
210- )
211- )
178+ # pass it to next iteration
179+ turn_response = self .client .agents .turn .resume (
180+ agent_id = self .agent_id ,
181+ session_id = session_id or self .session_id [- 1 ],
182+ turn_id = turn_id ,
183+ tool_responses = [tool_response_message ],
184+ stream = True ,
212185 )
213-
214- # HACK: append the tool execution step to the turn
215- chunk .event .payload .turn .steps .append (tool_execution_step )
216- yield chunk
217-
218- # continue the turn when there's a tool call
219- stop = False
220- messages = [tool_response_message ]
221186 n_iter += 1
187+ break
188+
189+ if n_iter >= max_iter :
190+ raise Exception (f"Turn did not complete in { max_iter } iterations" )
0 commit comments