diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 0a8ab226..5a375275 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -13,18 +13,10 @@ from llama_stack_client.types.agents.turn_create_response import ( AgentTurnResponseStreamChunk, ) -from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent -from llama_stack_client.types.agents.turn_response_event_payload import ( - AgentTurnResponseStepCompletePayload, -) from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.agents.turn import CompletionMessage from .client_tool import ClientTool from .tool_parser import ToolParser -from datetime import datetime -import uuid -from llama_stack_client.types.tool_execution_step import ToolExecutionStep -from llama_stack_client.types.tool_response import ToolResponse DEFAULT_MAX_ITER = 10 @@ -65,7 +57,7 @@ def create_session(self, session_name: str) -> int: return self.session_id def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]: - if chunk.event.payload.event_type != "turn_complete": + if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}: return [] message = chunk.event.payload.turn.output_message @@ -77,6 +69,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall] return message.tool_calls + def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]: + if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]: + return None + + return chunk.event.payload.turn.turn_id + def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage: assert len(tool_calls) == 1, "Only one tool call is supported" tool_call = tool_calls[0] @@ -131,27 +129,10 @@ def create_turn( if stream: return self._create_turn_streaming(messages, session_id, toolgroups, documents) else: - chunks = [] - for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents): - if chunk.event.payload.event_type == "turn_complete": - chunks.append(chunk) - pass + chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)] if not chunks: raise Exception("Turn did not complete") - - # merge chunks - return Turn( - input_messages=chunks[0].event.payload.turn.input_messages, - output_message=chunks[-1].event.payload.turn.output_message, - session_id=chunks[0].event.payload.turn.session_id, - steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps], - turn_id=chunks[0].event.payload.turn.turn_id, - started_at=chunks[0].event.payload.turn.started_at, - completed_at=chunks[-1].event.payload.turn.completed_at, - output_attachments=[ - attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments - ], - ) + return chunks[-1].event.payload.turn def _create_turn_streaming( self, @@ -160,22 +141,26 @@ def _create_turn_streaming( toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, ) -> Iterator[AgentTurnResponseStreamChunk]: - stop = False n_iter = 0 max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER) - while not stop and n_iter < max_iter: - response = self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - ) - # by default, we stop after the first turn - stop = True - for chunk in response: + + # 1. create an agent turn + turn_response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + allow_turn_resume=True, + ) + + # 2. process turn and resume if there's a tool call + is_turn_complete = False + while not is_turn_complete: + is_turn_complete = True + for chunk in turn_response: tool_calls = self._get_tool_calls(chunk) if hasattr(chunk, "error"): yield chunk @@ -183,39 +168,23 @@ def _create_turn_streaming( elif not tool_calls: yield chunk else: - tool_execution_start_time = datetime.now() + is_turn_complete = False + turn_id = self._get_turn_id(chunk) + if n_iter == 0: + yield chunk + + # run the tools tool_response_message = self._run_tool(tool_calls) - tool_execution_step = ToolExecutionStep( - step_type="tool_execution", - step_id=str(uuid.uuid4()), - tool_calls=tool_calls, - tool_responses=[ - ToolResponse( - tool_name=tool_response_message.tool_name, - content=tool_response_message.content, - call_id=tool_response_message.call_id, - ) - ], - turn_id=chunk.event.payload.turn.turn_id, - completed_at=datetime.now(), - started_at=tool_execution_start_time, - ) - yield AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - event_type="step_complete", - step_id=tool_execution_step.step_id, - step_type="tool_execution", - step_details=tool_execution_step, - ) - ) + # pass it to next iteration + turn_response = self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=[tool_response_message], + stream=True, ) - - # HACK: append the tool execution step to the turn - chunk.event.payload.turn.steps.append(tool_execution_step) - yield chunk - - # continue the turn when there's a tool call - stop = False - messages = [tool_response_message] n_iter += 1 + break + + if n_iter >= max_iter: + raise Exception(f"Turn did not complete in {max_iter} iterations") diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index d7fa514a..40a1d359 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -75,7 +75,7 @@ def _yield_printable_events( event = chunk.event event_type = event.payload.event_type - if event_type in {"turn_start", "turn_complete"}: + if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}: # Currently not logging any turn realted info yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey") return @@ -149,7 +149,9 @@ def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional if hasattr(chunk, "event"): previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None previous_step_type = ( - chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None + chunk.event.payload.step_type + if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"} + else None ) return previous_event_type, previous_step_type return None, None