Skip to content

Commit 995ff88

Browse files
committed
feat: new Agent API
Summary: Test Plan:
1 parent cd3b2b8 commit 995ff88

File tree

1 file changed

+92
-3
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+92
-3
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,121 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
from typing import Iterator, List, Optional, Tuple, Union
6+
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union
7+
from typing_extensions import Literal
78

89
from llama_stack_client import LlamaStackClient
10+
import logging
911

1012
from llama_stack_client.types import ToolResponseMessage, UserMessage
1113
from llama_stack_client.types.agent_create_params import AgentConfig
1214
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
1315
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
1416
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
1517
from llama_stack_client.types.shared.tool_call import ToolCall
18+
from llama_stack_client.types.tool_def_param import ToolDefParam
19+
from llama_stack_client.types.shared_params.response_format import ResponseFormat
20+
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
21+
from llama_stack_client.types.shared_params.agent_config import ToolConfig
1622

1723
from .client_tool import ClientTool
1824
from .tool_parser import ToolParser
1925

2026
DEFAULT_MAX_ITER = 10
2127

28+
logger = logging.getLogger(__name__)
29+
2230

2331
class Agent:
2432
def __init__(
2533
self,
2634
client: LlamaStackClient,
27-
agent_config: AgentConfig,
28-
client_tools: Tuple[ClientTool] = (),
35+
# begin deprecated
36+
agent_config: Optional[AgentConfig] = None,
37+
client_tools: Tuple[ClientTool, ...] = (),
38+
# end deprecated
2939
tool_parser: Optional[ToolParser] = None,
40+
model: Optional[str] = None,
41+
instructions: Optional[str] = None,
42+
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
43+
enable_session_persistence: Optional[bool] = None,
44+
max_infer_iters: Optional[int] = None,
45+
input_shields: Optional[List[str]] = None,
46+
output_shields: Optional[List[str]] = None,
47+
response_format: Optional[ResponseFormat] = None,
48+
sampling_params: Optional[SamplingParams] = None,
49+
tool_config: Optional[ToolConfig] = None,
3050
):
51+
"""Construct an Agent with the given parameters.
52+
53+
:param client: The LlamaStackClient instance.
54+
:param agent_config: The AgentConfig instance.
55+
::deprecated: use other parameters instead
56+
:param client_tools: A tuple of ClientTool instances.
57+
::deprecated: use tools instead
58+
:param tool_parser: Custom logic that parses tool calls from a message.
59+
:param model: The model to use for the agent.
60+
:param instructions: The instructions for the agent.
61+
:param tools: A list of tools for the agent. Values can be one of the following:
62+
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter"
63+
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
64+
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
65+
- a python function decorated with @client_tool
66+
- an instance of ClientTool: A client tool object.
67+
:param enable_session_persistence: Whether to enable session persistence.
68+
:param max_infer_iters: The maximum number of inference iterations.
69+
:param input_shields: The input shields for the agent.
70+
:param output_shields: The output shields for the agent.
71+
:param response_format: The response format for the agent.
72+
:param sampling_params: The sampling parameters for the agent.
73+
:param tool_config: The tool configuration for the agent.
74+
"""
3175
self.client = client
76+
77+
if agent_config is not None or client_tools is not ():
78+
logger.warning("agent_config and client_tools are deprecated. Use tools instead.")
79+
80+
# Construct agent_config from parameters if not provided
81+
if agent_config is None:
82+
# Create a minimal valid AgentConfig with required fields
83+
if model is None or instructions is None:
84+
raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided")
85+
86+
agent_config = {
87+
"model": model,
88+
"instructions": instructions,
89+
}
90+
91+
# Add optional parameters if provided
92+
if enable_session_persistence is not None:
93+
agent_config["enable_session_persistence"] = enable_session_persistence
94+
if input_shields is not None:
95+
agent_config["input_shields"] = input_shields
96+
if max_infer_iters is not None:
97+
agent_config["max_infer_iters"] = max_infer_iters
98+
if output_shields is not None:
99+
agent_config["output_shields"] = output_shields
100+
if response_format is not None:
101+
agent_config["response_format"] = response_format
102+
if sampling_params is not None:
103+
agent_config["sampling_params"] = sampling_params
104+
if tool_config is not None:
105+
agent_config["tool_config"] = tool_config
106+
if tools is not None:
107+
toolgroups: List[Toolgroup] = []
108+
client_tools: List[ClientTool] = []
109+
110+
for tool in tools:
111+
if isinstance(tool, str) or isinstance(tool, dict):
112+
toolgroups.append(tool)
113+
else:
114+
client_tools.append(tool)
115+
116+
agent_config["toolgroups"] = toolgroups
117+
agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools]
118+
119+
agent_config = AgentConfig(**agent_config)
120+
32121
self.agent_config = agent_config
33122
self.agent_id = self._create_agent(agent_config)
34123
self.client_tools = {t.get_name(): t for t in client_tools}

0 commit comments

Comments
 (0)