Skip to content

Commit 2caed90

Browse files
committed
feat: new Agent API
Summary: Test Plan:
1 parent cd3b2b8 commit 2caed90

File tree

1 file changed

+94
-3
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+94
-3
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,123 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
from typing import Iterator, List, Optional, Tuple, Union
6+
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union
7+
from typing_extensions import Literal
78

89
from llama_stack_client import LlamaStackClient
10+
import logging
911

1012
from llama_stack_client.types import ToolResponseMessage, UserMessage
1113
from llama_stack_client.types.agent_create_params import AgentConfig
1214
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
1315
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
1416
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
1517
from llama_stack_client.types.shared.tool_call import ToolCall
18+
from llama_stack_client.types.tool_def_param import ToolDefParam
19+
from llama_stack_client.types.shared_params.response_format import ResponseFormat
20+
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
21+
from llama_stack_client.types.shared_params.agent_config import ToolConfig
1622

1723
from .client_tool import ClientTool
1824
from .tool_parser import ToolParser
1925

2026
DEFAULT_MAX_ITER = 10
2127

28+
logger = logging.getLogger(__name__)
29+
2230

2331
class Agent:
2432
def __init__(
2533
self,
2634
client: LlamaStackClient,
27-
agent_config: AgentConfig,
28-
client_tools: Tuple[ClientTool] = (),
35+
# begin deprecated
36+
agent_config: Optional[AgentConfig] = None,
37+
client_tools: Tuple[ClientTool, ...] = (),
38+
# end deprecated
2939
tool_parser: Optional[ToolParser] = None,
40+
model: Optional[str] = None,
41+
instructions: Optional[str] = None,
42+
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
43+
enable_session_persistence: Optional[bool] = None,
44+
max_infer_iters: Optional[int] = None,
45+
input_shields: Optional[List[str]] = None,
46+
output_shields: Optional[List[str]] = None,
47+
response_format: Optional[ResponseFormat] = None,
48+
sampling_params: Optional[SamplingParams] = None,
49+
tool_config: Optional[ToolConfig] = None,
3050
):
51+
"""Construct an Agent with the given parameters.
52+
53+
:param client: The LlamaStackClient instance.
54+
:param agent_config: The AgentConfig instance.
55+
::deprecated: use other parameters instead
56+
:param client_tools: A tuple of ClientTool instances.
57+
::deprecated: use tools instead
58+
:param tool_parser: Custom logic that parses tool calls from a message.
59+
:param model: The model to use for the agent.
60+
:param instructions: The instructions for the agent.
61+
:param tools: A list of tools for the agent. Values can be one of the following:
62+
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter"
63+
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
64+
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
65+
- a python function decorated with @client_tool
66+
- an instance of ClientTool: A client tool object.
67+
:param enable_session_persistence: Whether to enable session persistence.
68+
:param max_infer_iters: The maximum number of inference iterations.
69+
:param input_shields: The input shields for the agent.
70+
:param output_shields: The output shields for the agent.
71+
:param response_format: The response format for the agent.
72+
:param sampling_params: The sampling parameters for the agent.
73+
:param tool_config: The tool configuration for the agent.
74+
"""
3175
self.client = client
76+
77+
if agent_config is not None:
78+
logger.warning("`agent_config` is deprecated. Use inlined parameters instead.")
79+
if client_tools != ():
80+
logger.warning("`client_tools` is deprecated. Use `tools` instead.")
81+
82+
# Construct agent_config from parameters if not provided
83+
if agent_config is None:
84+
# Create a minimal valid AgentConfig with required fields
85+
if model is None or instructions is None:
86+
raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided")
87+
88+
agent_config = {
89+
"model": model,
90+
"instructions": instructions,
91+
}
92+
93+
# Add optional parameters if provided
94+
if enable_session_persistence is not None:
95+
agent_config["enable_session_persistence"] = enable_session_persistence
96+
if input_shields is not None:
97+
agent_config["input_shields"] = input_shields
98+
if max_infer_iters is not None:
99+
agent_config["max_infer_iters"] = max_infer_iters
100+
if output_shields is not None:
101+
agent_config["output_shields"] = output_shields
102+
if response_format is not None:
103+
agent_config["response_format"] = response_format
104+
if sampling_params is not None:
105+
agent_config["sampling_params"] = sampling_params
106+
if tool_config is not None:
107+
agent_config["tool_config"] = tool_config
108+
if tools is not None:
109+
toolgroups: List[Toolgroup] = []
110+
client_tools: List[ClientTool] = []
111+
112+
for tool in tools:
113+
if isinstance(tool, str) or isinstance(tool, dict):
114+
toolgroups.append(tool)
115+
else:
116+
client_tools.append(tool)
117+
118+
agent_config["toolgroups"] = toolgroups
119+
agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools]
120+
121+
agent_config = AgentConfig(**agent_config)
122+
32123
self.agent_config = agent_config
33124
self.agent_id = self._create_agent(agent_config)
34125
self.client_tools = {t.get_name(): t for t in client_tools}

0 commit comments

Comments
 (0)