Skip to content

Commit 08ab5df

Browse files
hardikjshahHardik Shah
andauthored
fix: React agent should be able to work with provided config (#146)
passing a custom_config was failing ``` python -m examples.agents.react_agent ``` --------- Co-authored-by: Hardik Shah <hjshah@fb.com>
1 parent c645726 commit 08ab5df

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6+
import uuid
7+
from datetime import datetime
68
from typing import Iterator, List, Optional, Tuple, Union
79

810
from llama_stack_client import LlamaStackClient
11+
912
from llama_stack_client.types import ToolResponseMessage, UserMessage
1013
from llama_stack_client.types.agent_create_params import AgentConfig
11-
from llama_stack_client.types.agents.turn import Turn
14+
from llama_stack_client.types.agents.turn import CompletionMessage, Turn
1215
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
1316
from llama_stack_client.types.agents.turn_create_response import (
1417
AgentTurnResponseStreamChunk,
@@ -18,14 +21,12 @@
1821
AgentTurnResponseStepCompletePayload,
1922
)
2023
from llama_stack_client.types.shared.tool_call import ToolCall
21-
from llama_stack_client.types.agents.turn import CompletionMessage
22-
from .client_tool import ClientTool
23-
from .tool_parser import ToolParser
24-
from datetime import datetime
25-
import uuid
2624
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
2725
from llama_stack_client.types.tool_response import ToolResponse
2826

27+
from .client_tool import ClientTool
28+
from .tool_parser import ToolParser
29+
2930
DEFAULT_MAX_ITER = 10
3031

3132

@@ -55,7 +56,7 @@ def _create_agent(self, agent_config: AgentConfig) -> int:
5556
self.agent_id = agentic_system_create_response.agent_id
5657
return self.agent_id
5758

58-
def create_session(self, session_name: str) -> int:
59+
def create_session(self, session_name: str) -> str:
5960
agentic_system_create_session_response = self.client.agents.session.create(
6061
agent_id=self.agent_id,
6162
session_name=session_name,

src/llama_stack_client/lib/agents/client_tool.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,19 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
import inspect
78
import json
89
from abc import abstractmethod
9-
from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List
10-
import inspect
10+
from typing import (
11+
Callable,
12+
Dict,
13+
get_args,
14+
get_origin,
15+
get_type_hints,
16+
List,
17+
TypeVar,
18+
Union,
19+
)
1120

1221
from llama_stack_client.types import Message, ToolResponseMessage
1322
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam
@@ -155,9 +164,10 @@ def get_params_definition(self) -> Dict[str, Parameter]:
155164
name=name,
156165
description=param_doc or f"Parameter {name}",
157166
parameter_type=type_hint.__name__,
158-
default=param.default if param.default != inspect.Parameter.empty else None,
167+
default=(param.default if param.default != inspect.Parameter.empty else None),
159168
required=is_required,
160169
)
170+
161171
return params
162172

163173
def run_impl(self, **kwargs):

src/llama_stack_client/lib/agents/react/agent.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6+
from typing import Any, Dict, Optional, Tuple
7+
8+
from llama_stack_client import LlamaStackClient
9+
from llama_stack_client.types.agent_create_params import AgentConfig
610
from pydantic import BaseModel
7-
from typing import Dict, Any
11+
812
from ..agent import Agent
9-
from .tool_parser import ReActToolParser
13+
from ..client_tool import ClientTool
1014
from ..tool_parser import ToolParser
1115
from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE
1216

13-
from typing import Tuple, Optional
14-
from llama_stack_client import LlamaStackClient
15-
from ..client_tool import ClientTool
16-
from llama_stack_client.types.agent_create_params import AgentConfig
17+
from .tool_parser import ReActToolParser
1718

1819

1920
class Action(BaseModel):
@@ -92,6 +93,8 @@ def get_tool_defs():
9293
output_shields=[],
9394
enable_session_persistence=False,
9495
)
96+
else:
97+
agent_config = custom_agent_config
9598

9699
if json_response_format:
97100
agent_config.response_format = {

0 commit comments

Comments
 (0)