Skip to content

Commit d24aaa1

Browse files
committed
feat: new Agent API
Summary: Test Plan:
1 parent 303054b commit d24aaa1

File tree

1 file changed

+91
-2
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+91
-2
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,118 @@
66
from typing import Iterator, List, Optional, Tuple, Union
77

88
from llama_stack_client import LlamaStackClient
9+
import logging
910

1011
from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage
1112
from llama_stack_client.types.agent_create_params import AgentConfig
1213
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
1314
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
1415
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
1516
from llama_stack_client.types.shared.tool_call import ToolCall
17+
from llama_stack_client.types.shared_params.response_format import ResponseFormat
18+
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
19+
from llama_stack_client.types.shared_params.agent_config import ToolConfig
1620

1721
from .client_tool import ClientTool
1822
from .tool_parser import ToolParser
1923

2024
DEFAULT_MAX_ITER = 10
2125

26+
logger = logging.getLogger(__name__)
27+
2228

2329
class Agent:
2430
def __init__(
2531
self,
2632
client: LlamaStackClient,
27-
agent_config: AgentConfig,
28-
client_tools: Tuple[ClientTool] = (),
33+
# begin deprecated
34+
agent_config: Optional[AgentConfig] = None,
35+
client_tools: Tuple[ClientTool, ...] = (),
36+
# end deprecated
2937
tool_parser: Optional[ToolParser] = None,
38+
model: Optional[str] = None,
39+
instructions: Optional[str] = None,
40+
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
41+
tool_config: Optional[ToolConfig] = None,
42+
sampling_params: Optional[SamplingParams] = None,
43+
max_infer_iters: Optional[int] = None,
44+
input_shields: Optional[List[str]] = None,
45+
output_shields: Optional[List[str]] = None,
46+
response_format: Optional[ResponseFormat] = None,
47+
enable_session_persistence: Optional[bool] = None,
3048
):
49+
"""Construct an Agent with the given parameters.
50+
51+
:param client: The LlamaStackClient instance.
52+
:param agent_config: The AgentConfig instance.
53+
::deprecated: use other parameters instead
54+
:param client_tools: A tuple of ClientTool instances.
55+
::deprecated: use tools instead
56+
:param tool_parser: Custom logic that parses tool calls from a message.
57+
:param model: The model to use for the agent.
58+
:param instructions: The instructions for the agent.
59+
:param tools: A list of tools for the agent. Values can be one of the following:
60+
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
61+
- a python function decorated with @client_tool
62+
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
63+
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
64+
- an instance of ClientTool: A client tool object.
65+
:param tool_config: The tool configuration for the agent.
66+
:param sampling_params: The sampling parameters for the agent.
67+
:param max_infer_iters: The maximum number of inference iterations.
68+
:param input_shields: The input shields for the agent.
69+
:param output_shields: The output shields for the agent.
70+
:param response_format: The response format for the agent.
71+
:param enable_session_persistence: Whether to enable session persistence.
72+
"""
3173
self.client = client
74+
75+
if agent_config is not None:
76+
logger.warning("`agent_config` is deprecated. Use inlined parameters instead.")
77+
if client_tools != ():
78+
logger.warning("`client_tools` is deprecated. Use `tools` instead.")
79+
80+
# Construct agent_config from parameters if not provided
81+
if agent_config is None:
82+
# Create a minimal valid AgentConfig with required fields
83+
if model is None or instructions is None:
84+
raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided")
85+
86+
agent_config = {
87+
"model": model,
88+
"instructions": instructions,
89+
}
90+
91+
# Add optional parameters if provided
92+
if enable_session_persistence is not None:
93+
agent_config["enable_session_persistence"] = enable_session_persistence
94+
if input_shields is not None:
95+
agent_config["input_shields"] = input_shields
96+
if max_infer_iters is not None:
97+
agent_config["max_infer_iters"] = max_infer_iters
98+
if output_shields is not None:
99+
agent_config["output_shields"] = output_shields
100+
if response_format is not None:
101+
agent_config["response_format"] = response_format
102+
if sampling_params is not None:
103+
agent_config["sampling_params"] = sampling_params
104+
if tool_config is not None:
105+
agent_config["tool_config"] = tool_config
106+
if tools is not None:
107+
toolgroups: List[Toolgroup] = []
108+
client_tools: List[ClientTool] = []
109+
110+
for tool in tools:
111+
if isinstance(tool, str) or isinstance(tool, dict):
112+
toolgroups.append(tool)
113+
else:
114+
client_tools.append(tool)
115+
116+
agent_config["toolgroups"] = toolgroups
117+
agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools]
118+
119+
agent_config = AgentConfig(**agent_config)
120+
32121
self.agent_config = agent_config
33122
self.agent_id = self._create_agent(agent_config)
34123
self.client_tools = {t.get_name(): t for t in client_tools}

0 commit comments

Comments
 (0)