From c9315efdb3fac6f09e2bf445b0dad3017b2c3d65 Mon Sep 17 00:00:00 2001 From: "Eric Huang (AI Platform)" Date: Thu, 30 Jan 2025 11:47:42 -0800 Subject: [PATCH] non-streaming support for agent create turn # What does this PR do? Adds non-streaming support for agent.create_turn, which returns the turn object and user can access the output_message like so: ``` response = rag_agent.create_turn( messages=[{"role": "user", "content": prompt}], session_id=session_id, stream=False, ) response.output_message ``` ## Test Plan Modified hello.py example to use stream=False and print output_message. ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- src/llama_stack_client/lib/agents/agent.py | 32 +++++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 6ac46878..eae95f2c 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,13 +3,16 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient from llama_stack_client.types import ToolResponseMessage, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.agents.turn import Turn from llama_stack_client.types.agents.turn_create_params import (Document, Toolgroup) +from llama_stack_client.types.agents.turn_create_response import \ + AgentTurnResponseStreamChunk from .client_tool import ClientTool @@ -46,7 +49,7 @@ def create_session(self, session_name: str) -> int: self.sessions.append(self.session_id) return self.session_id - def _has_tool_call(self, chunk): + def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: if chunk.event.payload.event_type != "turn_complete": return False message = chunk.event.payload.turn.output_message @@ -54,7 +57,7 @@ def _has_tool_call(self, chunk): return False return len(message.tool_calls) > 0 - def _run_tool(self, chunk): + def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage: message = chunk.event.payload.turn.output_message tool_call = message.tool_calls[0] if tool_call.tool_name not in self.client_tools: @@ -75,7 +78,28 @@ def create_turn( session_id: Optional[str] = None, toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, - ): + stream: bool = True, + ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: + if stream: + return self._create_turn_streaming(messages, session_id, toolgroups, documents, stream) + else: + chunk = None + for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents, stream): + pass + if not chunk: + raise Exception("No chunk returned") + if chunk.event.payload.event_type != 'turn_complete': + raise Exception("Turn did not complete") + return chunk.event.payload.turn + + def _create_turn_streaming( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + stream: bool = True, + ) -> Iterator[AgentTurnResponseStreamChunk]: stop = False n_iter = 0 max_iter = self.agent_config.get('max_infer_iters', DEFAULT_MAX_ITER)