Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 44 additions & 75 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,10 @@
from llama_stack_client.types.agents.turn_create_response import (
AgentTurnResponseStreamChunk,
)
from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent
from llama_stack_client.types.agents.turn_response_event_payload import (
AgentTurnResponseStepCompletePayload,
)
from llama_stack_client.types.shared.tool_call import ToolCall
from llama_stack_client.types.agents.turn import CompletionMessage
from .client_tool import ClientTool
from .tool_parser import ToolParser
from datetime import datetime
import uuid
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
from llama_stack_client.types.tool_response import ToolResponse

DEFAULT_MAX_ITER = 10

Expand Down Expand Up @@ -65,7 +57,7 @@ def create_session(self, session_name: str) -> int:
return self.session_id

def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]:
if chunk.event.payload.event_type != "turn_complete":
if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}:
return []

message = chunk.event.payload.turn.output_message
Expand All @@ -77,6 +69,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]

return message.tool_calls

def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]:
return None

return chunk.event.payload.turn.turn_id

def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
assert len(tool_calls) == 1, "Only one tool call is supported"
tool_call = tool_calls[0]
Expand Down Expand Up @@ -131,27 +129,10 @@ def create_turn(
if stream:
return self._create_turn_streaming(messages, session_id, toolgroups, documents)
else:
chunks = []
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents):
if chunk.event.payload.event_type == "turn_complete":
chunks.append(chunk)
pass
chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)]
if not chunks:
raise Exception("Turn did not complete")

# merge chunks
return Turn(
input_messages=chunks[0].event.payload.turn.input_messages,
output_message=chunks[-1].event.payload.turn.output_message,
session_id=chunks[0].event.payload.turn.session_id,
steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps],
turn_id=chunks[0].event.payload.turn.turn_id,
started_at=chunks[0].event.payload.turn.started_at,
completed_at=chunks[-1].event.payload.turn.completed_at,
output_attachments=[
attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments
],
)
return chunks[-1].event.payload.turn

def _create_turn_streaming(
self,
Expand All @@ -160,62 +141,50 @@ def _create_turn_streaming(
toolgroups: Optional[List[Toolgroup]] = None,
documents: Optional[List[Document]] = None,
) -> Iterator[AgentTurnResponseStreamChunk]:
stop = False
n_iter = 0
max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER)
while not stop and n_iter < max_iter:
response = self.client.agents.turn.create(
agent_id=self.agent_id,
# use specified session_id or last session created
session_id=session_id or self.session_id[-1],
messages=messages,
stream=True,
documents=documents,
toolgroups=toolgroups,
)
# by default, we stop after the first turn
stop = True
for chunk in response:

# 1. create an agent turn
turn_response = self.client.agents.turn.create(
agent_id=self.agent_id,
# use specified session_id or last session created
session_id=session_id or self.session_id[-1],
messages=messages,
stream=True,
documents=documents,
toolgroups=toolgroups,
allow_turn_resume=True,
)

# 2. process turn and resume if there's a tool call
is_turn_complete = False
while not is_turn_complete:
is_turn_complete = True
for chunk in turn_response:
tool_calls = self._get_tool_calls(chunk)
if hasattr(chunk, "error"):
yield chunk
return
elif not tool_calls:
yield chunk
else:
tool_execution_start_time = datetime.now()
is_turn_complete = False
turn_id = self._get_turn_id(chunk)
if n_iter == 0:
yield chunk
Comment on lines +173 to +174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why yield when n_iter == 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b/c when n_iter == 0, we will get back the actual Turn with turn_awaiting_input. e.g. https://gist.github.com/yanxi0830/17754d56d08ccbeaec419b693137500c

AgentTurnResponseStreamChunk(
│   event=TurnResponseEvent(
│   │   payload=AgentTurnResponseTurnAwaitingInputPayload(
│   │   │   event_type='turn_awaiting_input',
│   │   │   turn=Turn(
│   │   │   │   input_messages=[
│   │   │   │   │   UserMessage(
│   │   │   │   │   │   content='What is the boiling point of polyjuice in Celcius?',
│   │   │   │   │   │   role='user',
│   │   │   │   │   │   context=None
│   │   │   │   │   )
│   │   │   │   ],
│   │   │   │   output_message=CompletionMessage(
│   │   │   │   │   content='',
│   │   │   │   │   role='assistant',
│   │   │   │   │   stop_reason='end_of_turn',
│   │   │   │   │   tool_calls=[
│   │   │   │   │   │   ToolCall(
│   │   │   │   │   │   │   arguments={'liquid_name': 'polyjuice', 'celcius': 'True'},
│   │   │   │   │   │   │   call_id='782f2ba1-976d-45a5-874d-791c79ececf6',
│   │   │   │   │   │   │   tool_name='get_boiling_point'
│   │   │   │   │   │   )
│   │   │   │   │   ]
│   │   │   │   ),
│   │   │   │   session_id='8221e2f3-c4d3-44e7-ab7e-e74fdecb10e3',
│   │   │   │   started_at=datetime.datetime(2025, 2, 20, 21, 38, 8, 13477),
│   │   │   │   steps=[
│   │   │   │   │   InferenceStep(
│   │   │   │   │   │   api_model_response=CompletionMessage(
│   │   │   │   │   │   │   content='',
│   │   │   │   │   │   │   role='assistant',
│   │   │   │   │   │   │   stop_reason='end_of_turn',
│   │   │   │   │   │   │   tool_calls=[
│   │   │   │   │   │   │   │   ToolCall(
│   │   │   │   │   │   │   │   │   arguments={'liquid_name': 'polyjuice', 'celcius': 'True'},
│   │   │   │   │   │   │   │   │   call_id='782f2ba1-976d-45a5-874d-791c79ececf6',
│   │   │   │   │   │   │   │   │   tool_name='get_boiling_point'
│   │   │   │   │   │   │   │   )
│   │   │   │   │   │   │   ]
│   │   │   │   │   │   ),
│   │   │   │   │   │   step_id='c8a34ebd-2902-4da8-a78e-a549009880bb',
│   │   │   │   │   │   step_type='inference',
│   │   │   │   │   │   turn_id='38c0cc58-5c3b-4d8e-adf6-cb1c1a7bf0ee',
│   │   │   │   │   │   completed_at=datetime.datetime(2025, 2, 20, 21, 38, 8, 427067),
│   │   │   │   │   │   started_at=datetime.datetime(2025, 2, 20, 21, 38, 8, 24697)
│   │   │   │   │   )
│   │   │   │   ],
│   │   │   │   turn_id='38c0cc58-5c3b-4d8e-adf6-cb1c1a7bf0ee',
│   │   │   │   completed_at=datetime.datetime(2025, 2, 20, 21, 38, 8, 461543),
│   │   │   │   output_attachments=[]
│   │   │   )
│   │   )
│   )
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding: sorry am still confused. What about subsequent iteration? Why dont we need to yield? Might be good to add a comment here.


# run the tools
tool_response_message = self._run_tool(tool_calls)
tool_execution_step = ToolExecutionStep(
step_type="tool_execution",
step_id=str(uuid.uuid4()),
tool_calls=tool_calls,
tool_responses=[
ToolResponse(
tool_name=tool_response_message.tool_name,
content=tool_response_message.content,
call_id=tool_response_message.call_id,
)
],
turn_id=chunk.event.payload.turn.turn_id,
completed_at=datetime.now(),
started_at=tool_execution_start_time,
)
yield AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
event_type="step_complete",
step_id=tool_execution_step.step_id,
step_type="tool_execution",
step_details=tool_execution_step,
)
)
# pass it to next iteration
turn_response = self.client.agents.turn.resume(
agent_id=self.agent_id,
session_id=session_id or self.session_id[-1],
turn_id=turn_id,
tool_responses=[tool_response_message],
stream=True,
)

# HACK: append the tool execution step to the turn
chunk.event.payload.turn.steps.append(tool_execution_step)
yield chunk

# continue the turn when there's a tool call
stop = False
messages = [tool_response_message]
n_iter += 1
break

if n_iter >= max_iter:
raise Exception(f"Turn did not complete in {max_iter} iterations")
6 changes: 4 additions & 2 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _yield_printable_events(
event = chunk.event
event_type = event.payload.event_type

if event_type in {"turn_start", "turn_complete"}:
if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}:
# Currently not logging any turn realted info
yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey")
return
Expand Down Expand Up @@ -149,7 +149,9 @@ def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional
if hasattr(chunk, "event"):
previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
previous_step_type = (
chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None
chunk.event.payload.step_type
if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"}
else None
)
return previous_event_type, previous_step_type
return None, None
Expand Down