Skip to content

Commit d2657c4

Browse files
committed
refactor
1 parent 276d7fe commit d2657c4

File tree

1 file changed

+68
-41
lines changed
  • src/llama_stack_client/lib/agents

1 file changed

+68
-41
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,61 @@ def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]:
4848

4949
return chunk.event.payload.turn.turn_id
5050

51+
@staticmethod
52+
def get_agent_config(
53+
model: Optional[str] = None,
54+
instructions: Optional[str] = None,
55+
tools: Optional[List[Union[Toolgroup, ClientTool]]] = None,
56+
tool_config: Optional[ToolConfig] = None,
57+
sampling_params: Optional[SamplingParams] = None,
58+
max_infer_iters: Optional[int] = None,
59+
input_shields: Optional[List[str]] = None,
60+
output_shields: Optional[List[str]] = None,
61+
response_format: Optional[ResponseFormat] = None,
62+
enable_session_persistence: Optional[bool] = None,
63+
) -> AgentConfig:
64+
# Create a minimal valid AgentConfig with required fields
65+
if model is None or instructions is None:
66+
raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided")
67+
68+
agent_config = {
69+
"model": model,
70+
"instructions": instructions,
71+
"toolgroups": [],
72+
"client_tools": [],
73+
}
74+
75+
# Add optional parameters if provided
76+
if enable_session_persistence is not None:
77+
agent_config["enable_session_persistence"] = enable_session_persistence
78+
if max_infer_iters is not None:
79+
agent_config["max_infer_iters"] = max_infer_iters
80+
if input_shields is not None:
81+
agent_config["input_shields"] = input_shields
82+
if output_shields is not None:
83+
agent_config["output_shields"] = output_shields
84+
if response_format is not None:
85+
agent_config["response_format"] = response_format
86+
if sampling_params is not None:
87+
agent_config["sampling_params"] = sampling_params
88+
if tool_config is not None:
89+
agent_config["tool_config"] = tool_config
90+
if tools is not None:
91+
toolgroups: List[Toolgroup] = []
92+
client_tools: List[ClientTool] = []
93+
94+
for tool in tools:
95+
if isinstance(tool, str) or isinstance(tool, dict):
96+
toolgroups.append(tool)
97+
else:
98+
client_tools.append(tool)
99+
100+
agent_config["toolgroups"] = toolgroups
101+
agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools]
102+
103+
agent_config = AgentConfig(**agent_config)
104+
return agent_config
105+
51106

52107
class Agent:
53108
def __init__(
@@ -102,46 +157,18 @@ def __init__(
102157

103158
# Construct agent_config from parameters if not provided
104159
if agent_config is None:
105-
# Create a minimal valid AgentConfig with required fields
106-
if model is None or instructions is None:
107-
raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided")
108-
109-
agent_config = {
110-
"model": model,
111-
"instructions": instructions,
112-
"toolgroups": [],
113-
"client_tools": [],
114-
}
115-
116-
# Add optional parameters if provided
117-
if enable_session_persistence is not None:
118-
agent_config["enable_session_persistence"] = enable_session_persistence
119-
if max_infer_iters is not None:
120-
agent_config["max_infer_iters"] = max_infer_iters
121-
if input_shields is not None:
122-
agent_config["input_shields"] = input_shields
123-
if output_shields is not None:
124-
agent_config["output_shields"] = output_shields
125-
if response_format is not None:
126-
agent_config["response_format"] = response_format
127-
if sampling_params is not None:
128-
agent_config["sampling_params"] = sampling_params
129-
if tool_config is not None:
130-
agent_config["tool_config"] = tool_config
131-
if tools is not None:
132-
toolgroups: List[Toolgroup] = []
133-
client_tools: List[ClientTool] = []
134-
135-
for tool in tools:
136-
if isinstance(tool, str) or isinstance(tool, dict):
137-
toolgroups.append(tool)
138-
else:
139-
client_tools.append(tool)
140-
141-
agent_config["toolgroups"] = toolgroups
142-
agent_config["client_tools"] = [tool.get_tool_definition() for tool in client_tools]
143-
144-
agent_config = AgentConfig(**agent_config)
160+
agent_config = AgentUtils.get_agent_config(
161+
model=model,
162+
instructions=instructions,
163+
tools=tools,
164+
tool_config=tool_config,
165+
sampling_params=sampling_params,
166+
max_infer_iters=max_infer_iters,
167+
input_shields=input_shields,
168+
output_shields=output_shields,
169+
response_format=response_format,
170+
enable_session_persistence=enable_session_persistence,
171+
)
145172

146173
self.agent_config = agent_config
147174
self.agent_id = self._create_agent(agent_config)
@@ -420,7 +447,7 @@ async def _create_turn_streaming(
420447
yield chunk
421448
break
422449

423-
turn_id = AgentUtils.get_turn_id(chunk)
450+
turn_id = self._get_turn_id(chunk)
424451
if n_iter == 0:
425452
yield chunk
426453

0 commit comments

Comments
 (0)