Skip to content

Commit cccd240

Browse files
authored
feat (2/n): Agent lib using resume_turn (#158)
# What does this PR do? - #157 - Server change: llamastack/llama-stack#1194 ## Test Plan - See llamastack/llama-stack#1194 <img width="1080" alt="image" src="https://github.com/user-attachments/assets/fb4cf70d-1c3d-423d-ac75-432c2a3505d7" /> [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant)
1 parent 250676f commit cccd240

File tree

2 files changed

+48
-77
lines changed

2 files changed

+48
-77
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 44 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,10 @@
1313
from llama_stack_client.types.agents.turn_create_response import (
1414
AgentTurnResponseStreamChunk,
1515
)
16-
from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent
17-
from llama_stack_client.types.agents.turn_response_event_payload import (
18-
AgentTurnResponseStepCompletePayload,
19-
)
2016
from llama_stack_client.types.shared.tool_call import ToolCall
2117
from llama_stack_client.types.agents.turn import CompletionMessage
2218
from .client_tool import ClientTool
2319
from .tool_parser import ToolParser
24-
from datetime import datetime
25-
import uuid
26-
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
27-
from llama_stack_client.types.tool_response import ToolResponse
2820

2921
DEFAULT_MAX_ITER = 10
3022

@@ -65,7 +57,7 @@ def create_session(self, session_name: str) -> int:
6557
return self.session_id
6658

6759
def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]:
68-
if chunk.event.payload.event_type != "turn_complete":
60+
if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}:
6961
return []
7062

7163
message = chunk.event.payload.turn.output_message
@@ -77,6 +69,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]
7769

7870
return message.tool_calls
7971

72+
def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
73+
if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]:
74+
return None
75+
76+
return chunk.event.payload.turn.turn_id
77+
8078
def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
8179
assert len(tool_calls) == 1, "Only one tool call is supported"
8280
tool_call = tool_calls[0]
@@ -131,27 +129,10 @@ def create_turn(
131129
if stream:
132130
return self._create_turn_streaming(messages, session_id, toolgroups, documents)
133131
else:
134-
chunks = []
135-
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents):
136-
if chunk.event.payload.event_type == "turn_complete":
137-
chunks.append(chunk)
138-
pass
132+
chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)]
139133
if not chunks:
140134
raise Exception("Turn did not complete")
141-
142-
# merge chunks
143-
return Turn(
144-
input_messages=chunks[0].event.payload.turn.input_messages,
145-
output_message=chunks[-1].event.payload.turn.output_message,
146-
session_id=chunks[0].event.payload.turn.session_id,
147-
steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps],
148-
turn_id=chunks[0].event.payload.turn.turn_id,
149-
started_at=chunks[0].event.payload.turn.started_at,
150-
completed_at=chunks[-1].event.payload.turn.completed_at,
151-
output_attachments=[
152-
attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments
153-
],
154-
)
135+
return chunks[-1].event.payload.turn
155136

156137
def _create_turn_streaming(
157138
self,
@@ -160,62 +141,50 @@ def _create_turn_streaming(
160141
toolgroups: Optional[List[Toolgroup]] = None,
161142
documents: Optional[List[Document]] = None,
162143
) -> Iterator[AgentTurnResponseStreamChunk]:
163-
stop = False
164144
n_iter = 0
165145
max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER)
166-
while not stop and n_iter < max_iter:
167-
response = self.client.agents.turn.create(
168-
agent_id=self.agent_id,
169-
# use specified session_id or last session created
170-
session_id=session_id or self.session_id[-1],
171-
messages=messages,
172-
stream=True,
173-
documents=documents,
174-
toolgroups=toolgroups,
175-
)
176-
# by default, we stop after the first turn
177-
stop = True
178-
for chunk in response:
146+
147+
# 1. create an agent turn
148+
turn_response = self.client.agents.turn.create(
149+
agent_id=self.agent_id,
150+
# use specified session_id or last session created
151+
session_id=session_id or self.session_id[-1],
152+
messages=messages,
153+
stream=True,
154+
documents=documents,
155+
toolgroups=toolgroups,
156+
allow_turn_resume=True,
157+
)
158+
159+
# 2. process turn and resume if there's a tool call
160+
is_turn_complete = False
161+
while not is_turn_complete:
162+
is_turn_complete = True
163+
for chunk in turn_response:
179164
tool_calls = self._get_tool_calls(chunk)
180165
if hasattr(chunk, "error"):
181166
yield chunk
182167
return
183168
elif not tool_calls:
184169
yield chunk
185170
else:
186-
tool_execution_start_time = datetime.now()
171+
is_turn_complete = False
172+
turn_id = self._get_turn_id(chunk)
173+
if n_iter == 0:
174+
yield chunk
175+
176+
# run the tools
187177
tool_response_message = self._run_tool(tool_calls)
188-
tool_execution_step = ToolExecutionStep(
189-
step_type="tool_execution",
190-
step_id=str(uuid.uuid4()),
191-
tool_calls=tool_calls,
192-
tool_responses=[
193-
ToolResponse(
194-
tool_name=tool_response_message.tool_name,
195-
content=tool_response_message.content,
196-
call_id=tool_response_message.call_id,
197-
)
198-
],
199-
turn_id=chunk.event.payload.turn.turn_id,
200-
completed_at=datetime.now(),
201-
started_at=tool_execution_start_time,
202-
)
203-
yield AgentTurnResponseStreamChunk(
204-
event=TurnResponseEvent(
205-
payload=AgentTurnResponseStepCompletePayload(
206-
event_type="step_complete",
207-
step_id=tool_execution_step.step_id,
208-
step_type="tool_execution",
209-
step_details=tool_execution_step,
210-
)
211-
)
178+
# pass it to next iteration
179+
turn_response = self.client.agents.turn.resume(
180+
agent_id=self.agent_id,
181+
session_id=session_id or self.session_id[-1],
182+
turn_id=turn_id,
183+
tool_responses=[tool_response_message],
184+
stream=True,
212185
)
213-
214-
# HACK: append the tool execution step to the turn
215-
chunk.event.payload.turn.steps.append(tool_execution_step)
216-
yield chunk
217-
218-
# continue the turn when there's a tool call
219-
stop = False
220-
messages = [tool_response_message]
221186
n_iter += 1
187+
break
188+
189+
if n_iter >= max_iter:
190+
raise Exception(f"Turn did not complete in {max_iter} iterations")

src/llama_stack_client/lib/agents/event_logger.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _yield_printable_events(
7575
event = chunk.event
7676
event_type = event.payload.event_type
7777

78-
if event_type in {"turn_start", "turn_complete"}:
78+
if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}:
7979
# Currently not logging any turn realted info
8080
yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey")
8181
return
@@ -149,7 +149,9 @@ def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional
149149
if hasattr(chunk, "event"):
150150
previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
151151
previous_step_type = (
152-
chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None
152+
chunk.event.payload.step_type
153+
if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"}
154+
else None
153155
)
154156
return previous_event_type, previous_step_type
155157
return None, None

0 commit comments

Comments
 (0)