Skip to content

Commit 3cf8ac8

Browse files
authored
feat (1/n): agents resume turn (Sync updates from stainless branch: yanxi0830/dev) (#157)
# What does this PR do? - Adapt to llamastack/llama-stack#1178 ## Test Plan - test in following PR [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant)
1 parent 08ab5df commit 3cf8ac8

File tree

14 files changed

+859
-96
lines changed

14 files changed

+859
-96
lines changed

src/llama_stack_client/_base_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def _build_request(
518518
# so that passing a `TypedDict` doesn't cause an error.
519519
# https://github.com/microsoft/pyright/issues/3526#event-6715453066
520520
params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None,
521-
json=json_data,
521+
json=json_data if is_given(json_data) else None,
522522
files=files,
523523
**kwargs,
524524
)

src/llama_stack_client/_client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,16 @@ def __init__(
126126
) -> None:
127127
"""Construct a new synchronous llama-stack-client client instance.
128128
129-
This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided.
129+
This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided.
130130
"""
131131
if api_key is None:
132-
api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY")
132+
api_key = os.environ.get("LLAMA_STACK_API_KEY")
133133
self.api_key = api_key
134134

135135
if base_url is None:
136-
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
136+
base_url = os.environ.get("LLAMA_STACK_BASE_URL")
137137
if base_url is None:
138-
base_url = f"http://any-hosted-llama-stack.com"
138+
base_url = "http://any-hosted-llama-stack.com"
139139

140140
custom_headers = default_headers or {}
141141
custom_headers["X-LlamaStack-Client-Version"] = __version__
@@ -342,16 +342,16 @@ def __init__(
342342
) -> None:
343343
"""Construct a new async llama-stack-client client instance.
344344
345-
This automatically infers the `api_key` argument from the `LLAMA_STACK_CLIENT_API_KEY` environment variable if it is not provided.
345+
This automatically infers the `api_key` argument from the `LLAMA_STACK_API_KEY` environment variable if it is not provided.
346346
"""
347347
if api_key is None:
348-
api_key = os.environ.get("LLAMA_STACK_CLIENT_API_KEY")
348+
api_key = os.environ.get("LLAMA_STACK_API_KEY")
349349
self.api_key = api_key
350350

351351
if base_url is None:
352-
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
352+
base_url = os.environ.get("LLAMA_STACK_BASE_URL")
353353
if base_url is None:
354-
base_url = f"http://any-hosted-llama-stack.com"
354+
base_url = "http://any-hosted-llama-stack.com"
355355

356356
custom_headers = default_headers or {}
357357
custom_headers["X-LlamaStack-Client-Version"] = __version__

src/llama_stack_client/_files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
7171
if is_tuple_t(file):
7272
return (file[0], _read_file_content(file[1]), *file[2:])
7373

74-
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
74+
raise TypeError("Expected file types input to be a FileContent type or to be a tuple")
7575

7676

7777
def _read_file_content(file: FileContent) -> HttpxFileContent:
@@ -113,7 +113,7 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
113113
if is_tuple_t(file):
114114
return (file[0], await _async_read_file_content(file[1]), *file[2:])
115115

116-
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
116+
raise TypeError("Expected file types input to be a FileContent type or to be a tuple")
117117

118118

119119
async def _async_read_file_content(file: FileContent) -> HttpxFileContent:

src/llama_stack_client/_response.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
229229
# the response class ourselves but that is something that should be supported directly in httpx
230230
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
231231
if cast_to != httpx.Response:
232-
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
232+
raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_to`")
233233
return cast(R, response)
234234

235235
if (
@@ -245,9 +245,9 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
245245

246246
if (
247247
cast_to is not object
248-
and not origin is list
249-
and not origin is dict
250-
and not origin is Union
248+
and origin is not list
249+
and origin is not dict
250+
and origin is not Union
251251
and not issubclass(origin, BaseModel)
252252
):
253253
raise RuntimeError(

src/llama_stack_client/_utils/_logs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def _basic_config() -> None:
1414

1515

1616
def setup_logging() -> None:
17-
env = os.environ.get("LLAMA_STACK_CLIENT_LOG")
17+
env = os.environ.get("LLAMA_STACK_LOG")
1818
if env == "debug":
1919
_basic_config()
2020
logger.setLevel(logging.DEBUG)

src/llama_stack_client/lib/agents/agent.py

Lines changed: 44 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
import uuid
7-
from datetime import datetime
86
from typing import Iterator, List, Optional, Tuple, Union
97

108
from llama_stack_client import LlamaStackClient
@@ -16,13 +14,7 @@
1614
from llama_stack_client.types.agents.turn_create_response import (
1715
AgentTurnResponseStreamChunk,
1816
)
19-
from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent
20-
from llama_stack_client.types.agents.turn_response_event_payload import (
21-
AgentTurnResponseStepCompletePayload,
22-
)
2317
from llama_stack_client.types.shared.tool_call import ToolCall
24-
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
25-
from llama_stack_client.types.tool_response import ToolResponse
2618

2719
from .client_tool import ClientTool
2820
from .tool_parser import ToolParser
@@ -66,7 +58,7 @@ def create_session(self, session_name: str) -> str:
6658
return self.session_id
6759

6860
def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]:
69-
if chunk.event.payload.event_type != "turn_complete":
61+
if chunk.event.payload.event_type not in {"turn_complete", "turn_awaiting_input"}:
7062
return []
7163

7264
message = chunk.event.payload.turn.output_message
@@ -78,6 +70,12 @@ def _get_tool_calls(self, chunk: AgentTurnResponseStreamChunk) -> List[ToolCall]
7870

7971
return message.tool_calls
8072

73+
def _get_turn_id(self, chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
74+
if chunk.event.payload.event_type not in ["turn_complete", "turn_awaiting_input"]:
75+
return None
76+
77+
return chunk.event.payload.turn.turn_id
78+
8179
def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseMessage:
8280
assert len(tool_calls) == 1, "Only one tool call is supported"
8381
tool_call = tool_calls[0]
@@ -132,27 +130,10 @@ def create_turn(
132130
if stream:
133131
return self._create_turn_streaming(messages, session_id, toolgroups, documents)
134132
else:
135-
chunks = []
136-
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents):
137-
if chunk.event.payload.event_type == "turn_complete":
138-
chunks.append(chunk)
139-
pass
133+
chunks = [x for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)]
140134
if not chunks:
141135
raise Exception("Turn did not complete")
142-
143-
# merge chunks
144-
return Turn(
145-
input_messages=chunks[0].event.payload.turn.input_messages,
146-
output_message=chunks[-1].event.payload.turn.output_message,
147-
session_id=chunks[0].event.payload.turn.session_id,
148-
steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps],
149-
turn_id=chunks[0].event.payload.turn.turn_id,
150-
started_at=chunks[0].event.payload.turn.started_at,
151-
completed_at=chunks[-1].event.payload.turn.completed_at,
152-
output_attachments=[
153-
attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments
154-
],
155-
)
136+
return chunks[-1].event.payload.turn
156137

157138
def _create_turn_streaming(
158139
self,
@@ -161,62 +142,50 @@ def _create_turn_streaming(
161142
toolgroups: Optional[List[Toolgroup]] = None,
162143
documents: Optional[List[Document]] = None,
163144
) -> Iterator[AgentTurnResponseStreamChunk]:
164-
stop = False
165145
n_iter = 0
166146
max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER)
167-
while not stop and n_iter < max_iter:
168-
response = self.client.agents.turn.create(
169-
agent_id=self.agent_id,
170-
# use specified session_id or last session created
171-
session_id=session_id or self.session_id[-1],
172-
messages=messages,
173-
stream=True,
174-
documents=documents,
175-
toolgroups=toolgroups,
176-
)
177-
# by default, we stop after the first turn
178-
stop = True
179-
for chunk in response:
147+
148+
# 1. create an agent turn
149+
turn_response = self.client.agents.turn.create(
150+
agent_id=self.agent_id,
151+
# use specified session_id or last session created
152+
session_id=session_id or self.session_id[-1],
153+
messages=messages,
154+
stream=True,
155+
documents=documents,
156+
toolgroups=toolgroups,
157+
allow_turn_resume=True,
158+
)
159+
160+
# 2. process turn and resume if there's a tool call
161+
is_turn_complete = False
162+
while not is_turn_complete:
163+
is_turn_complete = True
164+
for chunk in turn_response:
180165
tool_calls = self._get_tool_calls(chunk)
181166
if hasattr(chunk, "error"):
182167
yield chunk
183168
return
184169
elif not tool_calls:
185170
yield chunk
186171
else:
187-
tool_execution_start_time = datetime.now()
172+
is_turn_complete = False
173+
turn_id = self._get_turn_id(chunk)
174+
if n_iter == 0:
175+
yield chunk
176+
177+
# run the tools
188178
tool_response_message = self._run_tool(tool_calls)
189-
tool_execution_step = ToolExecutionStep(
190-
step_type="tool_execution",
191-
step_id=str(uuid.uuid4()),
192-
tool_calls=tool_calls,
193-
tool_responses=[
194-
ToolResponse(
195-
tool_name=tool_response_message.tool_name,
196-
content=tool_response_message.content,
197-
call_id=tool_response_message.call_id,
198-
)
199-
],
200-
turn_id=chunk.event.payload.turn.turn_id,
201-
completed_at=datetime.now(),
202-
started_at=tool_execution_start_time,
203-
)
204-
yield AgentTurnResponseStreamChunk(
205-
event=TurnResponseEvent(
206-
payload=AgentTurnResponseStepCompletePayload(
207-
event_type="step_complete",
208-
step_id=tool_execution_step.step_id,
209-
step_type="tool_execution",
210-
step_details=tool_execution_step,
211-
)
212-
)
179+
# pass it to next iteration
180+
turn_response = self.client.agents.turn.resume(
181+
agent_id=self.agent_id,
182+
session_id=session_id or self.session_id[-1],
183+
turn_id=turn_id,
184+
tool_responses=[tool_response_message],
185+
stream=True,
213186
)
214-
215-
# HACK: append the tool execution step to the turn
216-
chunk.event.payload.turn.steps.append(tool_execution_step)
217-
yield chunk
218-
219-
# continue the turn when there's a tool call
220-
stop = False
221-
messages = [tool_response_message]
222187
n_iter += 1
188+
break
189+
190+
if n_iter >= max_iter:
191+
raise Exception(f"Turn did not complete in {max_iter} iterations")

src/llama_stack_client/lib/agents/event_logger.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _yield_printable_events(
7575
event = chunk.event
7676
event_type = event.payload.event_type
7777

78-
if event_type in {"turn_start", "turn_complete"}:
78+
if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}:
7979
# Currently not logging any turn realted info
8080
yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey")
8181
return
@@ -149,7 +149,9 @@ def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional
149149
if hasattr(chunk, "event"):
150150
previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
151151
previous_step_type = (
152-
chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None
152+
chunk.event.payload.step_type
153+
if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"}
154+
else None
153155
)
154156
return previous_event_type, previous_step_type
155157
return None, None

0 commit comments

Comments
 (0)