Skip to content

Commit cc7824a

Browse files
author
Eric Huang (AI Platform)
committed
non-streaming support for agent create turn
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 2cc1782 commit cc7824a

File tree

1 file changed

+28
-4
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+28
-4
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
from typing import List, Optional, Tuple, Union
6+
from typing import Iterator, List, Optional, Tuple, Union
77

88
from llama_stack_client import LlamaStackClient
99
from llama_stack_client.types import ToolResponseMessage, UserMessage
1010
from llama_stack_client.types.agent_create_params import AgentConfig
11+
from llama_stack_client.types.agents.turn import Turn
1112
from llama_stack_client.types.agents.turn_create_params import (Document,
1213
Toolgroup)
14+
from llama_stack_client.types.agents.turn_create_response import \
15+
AgentTurnResponseStreamChunk
1316

1417
from .client_tool import ClientTool
1518

@@ -46,15 +49,15 @@ def create_session(self, session_name: str) -> int:
4649
self.sessions.append(self.session_id)
4750
return self.session_id
4851

49-
def _has_tool_call(self, chunk):
52+
def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool:
5053
if chunk.event.payload.event_type != "turn_complete":
5154
return False
5255
message = chunk.event.payload.turn.output_message
5356
if message.stop_reason == "out_of_tokens":
5457
return False
5558
return len(message.tool_calls) > 0
5659

57-
def _run_tool(self, chunk):
60+
def _run_tool(self, chunk: AgentTurnResponseStreamChunk) -> ToolResponseMessage:
5861
message = chunk.event.payload.turn.output_message
5962
tool_call = message.tool_calls[0]
6063
if tool_call.tool_name not in self.client_tools:
@@ -75,7 +78,28 @@ def create_turn(
7578
session_id: Optional[str] = None,
7679
toolgroups: Optional[List[Toolgroup]] = None,
7780
documents: Optional[List[Document]] = None,
78-
):
81+
stream: bool = True,
82+
) -> Iterator[AgentTurnResponseStreamChunk] | Turn:
83+
if stream:
84+
return self._create_turn_streaming(messages, session_id, toolgroups, documents, stream)
85+
else:
86+
chunk = None
87+
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents, stream):
88+
pass
89+
if not chunk:
90+
raise Exception("No chunk returned")
91+
if chunk.event.payload.event_type != 'turn_complete':
92+
raise Exception("Turn did not complete")
93+
return chunk.event.payload.turn
94+
95+
def _create_turn_streaming(
96+
self,
97+
messages: List[Union[UserMessage, ToolResponseMessage]],
98+
session_id: Optional[str] = None,
99+
toolgroups: Optional[List[Toolgroup]] = None,
100+
documents: Optional[List[Document]] = None,
101+
stream: bool = True,
102+
) -> Iterator[AgentTurnResponseStreamChunk]:
79103
stop = False
80104
n_iter = 0
81105
max_iter = self.agent_config.get('max_infer_iters', DEFAULT_MAX_ITER)

0 commit comments

Comments
 (0)