Skip to content

Commit 2cc1782

Browse files
ehhuangEric Huang (AI Platform)
andauthored
Resume turn after calling client tool (#102)
# What does this PR do? When client tools are used, `create_turn` currently terminates after executing the client tool requested by the model. This PR proposes to mirror the behavior when only server side tools are used where the tool response is passed back to the model for a final response. ## Test Plan Before: <img width="1352" alt="image" src="https://github.com/user-attachments/assets/90dc6bf4-9fa0-48e0-9e08-c31d3cc8ca44" /> After: <img width="1372" alt="image" src="https://github.com/user-attachments/assets/08e6bf64-1202-4ea8-a293-25b331488949" /> ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --------- Co-authored-by: Eric Huang (AI Platform) <erichuang@fb.com>
1 parent e65fa28 commit 2cc1782

File tree

1 file changed

+30
-18
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+30
-18
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from .client_tool import ClientTool
1515

16+
DEFAULT_MAX_ITER = 10
1617

1718
class Agent:
1819
def __init__(
@@ -75,21 +76,32 @@ def create_turn(
7576
toolgroups: Optional[List[Toolgroup]] = None,
7677
documents: Optional[List[Document]] = None,
7778
):
78-
response = self.client.agents.turn.create(
79-
agent_id=self.agent_id,
80-
# use specified session_id or last session created
81-
session_id=session_id or self.session_id[-1],
82-
messages=messages,
83-
stream=True,
84-
documents=documents,
85-
toolgroups=toolgroups,
86-
)
87-
for chunk in response:
88-
if hasattr(chunk, "error"):
89-
yield chunk
90-
return
91-
elif not self._has_tool_call(chunk):
92-
yield chunk
93-
else:
94-
next_message = self._run_tool(chunk)
95-
yield next_message
79+
stop = False
80+
n_iter = 0
81+
max_iter = self.agent_config.get('max_infer_iters', DEFAULT_MAX_ITER)
82+
while not stop and n_iter < max_iter:
83+
response = self.client.agents.turn.create(
84+
agent_id=self.agent_id,
85+
# use specified session_id or last session created
86+
session_id=session_id or self.session_id[-1],
87+
messages=messages,
88+
stream=True,
89+
documents=documents,
90+
toolgroups=toolgroups,
91+
)
92+
# by default, we stop after the first turn
93+
stop = True
94+
for chunk in response:
95+
if hasattr(chunk, "error"):
96+
yield chunk
97+
return
98+
elif not self._has_tool_call(chunk):
99+
yield chunk
100+
else:
101+
next_message = self._run_tool(chunk)
102+
yield next_message
103+
104+
# continue the turn when there's a tool call
105+
stop = False
106+
messages = [next_message]
107+
n_iter += 1

0 commit comments

Comments
 (0)