diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index c40ef4c8..0a8ab226 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -10,11 +10,21 @@ from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agents.turn import Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup -from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk +from llama_stack_client.types.agents.turn_create_response import ( + AgentTurnResponseStreamChunk, +) +from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent +from llama_stack_client.types.agents.turn_response_event_payload import ( + AgentTurnResponseStepCompletePayload, +) from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.agents.turn import CompletionMessage from .client_tool import ClientTool from .tool_parser import ToolParser +from datetime import datetime +import uuid +from llama_stack_client.types.tool_execution_step import ToolExecutionStep +from llama_stack_client.types.tool_response import ToolResponse DEFAULT_MAX_ITER = 10 @@ -119,16 +129,29 @@ def create_turn( stream: bool = True, ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: if stream: - return self._create_turn_streaming(messages, session_id, toolgroups, documents, stream) + return self._create_turn_streaming(messages, session_id, toolgroups, documents) else: - chunk = None - for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents, stream): + chunks = [] + for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents): + if chunk.event.payload.event_type == "turn_complete": + chunks.append(chunk) pass - if not chunk: - raise Exception("No chunk returned") - if chunk.event.payload.event_type != "turn_complete": + if not chunks: raise Exception("Turn did not complete") - return chunk.event.payload.turn + + # merge chunks + return Turn( + input_messages=chunks[0].event.payload.turn.input_messages, + output_message=chunks[-1].event.payload.turn.output_message, + session_id=chunks[0].event.payload.turn.session_id, + steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps], + turn_id=chunks[0].event.payload.turn.turn_id, + started_at=chunks[0].event.payload.turn.started_at, + completed_at=chunks[-1].event.payload.turn.completed_at, + output_attachments=[ + attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments + ], + ) def _create_turn_streaming( self, @@ -136,7 +159,6 @@ def _create_turn_streaming( session_id: Optional[str] = None, toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, - stream: bool = True, ) -> Iterator[AgentTurnResponseStreamChunk]: stop = False n_iter = 0 @@ -161,10 +183,39 @@ def _create_turn_streaming( elif not tool_calls: yield chunk else: - next_message = self._run_tool(tool_calls) - yield next_message + tool_execution_start_time = datetime.now() + tool_response_message = self._run_tool(tool_calls) + tool_execution_step = ToolExecutionStep( + step_type="tool_execution", + step_id=str(uuid.uuid4()), + tool_calls=tool_calls, + tool_responses=[ + ToolResponse( + tool_name=tool_response_message.tool_name, + content=tool_response_message.content, + call_id=tool_response_message.call_id, + ) + ], + turn_id=chunk.event.payload.turn.turn_id, + completed_at=datetime.now(), + started_at=tool_execution_start_time, + ) + yield AgentTurnResponseStreamChunk( + event=TurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + event_type="step_complete", + step_id=tool_execution_step.step_id, + step_type="tool_execution", + step_details=tool_execution_step, + ) + ) + ) + + # HACK: append the tool execution step to the turn + chunk.event.payload.turn.steps.append(tool_execution_step) + yield chunk # continue the turn when there's a tool call stop = False - messages = [next_message] + messages = [tool_response_message] n_iter += 1 diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index dff81994..fbf627f2 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -8,7 +8,7 @@ from termcolor import cprint -from llama_stack_client.types import InterleavedContent, ToolResponseMessage +from llama_stack_client.types import InterleavedContent def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: @@ -70,14 +70,6 @@ def _yield_printable_events(self, chunk, previous_event_type=None, previous_step yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red") return - if not hasattr(chunk, "event"): - # Need to check for custom tool first - # since it does not produce event but instead - # a Message - if isinstance(chunk, ToolResponseMessage): - yield TurnStreamPrintableEvent(role="CustomTool", content=chunk.content, color="green") - return - event = chunk.event event_type = event.payload.event_type