diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 0a8ab226..3b7bcc7f 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,12 +3,15 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import uuid +from datetime import datetime from typing import Iterator, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient + from llama_stack_client.types import ToolResponseMessage, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig -from llama_stack_client.types.agents.turn import Turn +from llama_stack_client.types.agents.turn import CompletionMessage, Turn from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup from llama_stack_client.types.agents.turn_create_response import ( AgentTurnResponseStreamChunk, @@ -18,14 +21,12 @@ AgentTurnResponseStepCompletePayload, ) from llama_stack_client.types.shared.tool_call import ToolCall -from llama_stack_client.types.agents.turn import CompletionMessage -from .client_tool import ClientTool -from .tool_parser import ToolParser -from datetime import datetime -import uuid from llama_stack_client.types.tool_execution_step import ToolExecutionStep from llama_stack_client.types.tool_response import ToolResponse +from .client_tool import ClientTool +from .tool_parser import ToolParser + DEFAULT_MAX_ITER = 10 @@ -55,7 +56,7 @@ def _create_agent(self, agent_config: AgentConfig) -> int: self.agent_id = agentic_system_create_response.agent_id return self.agent_id - def create_session(self, session_name: str) -> int: + def create_session(self, session_name: str) -> str: agentic_system_create_session_response = self.client.agents.session.create( agent_id=self.agent_id, session_name=session_name, diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 9ffec4a0..f672268d 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -4,10 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import inspect import json from abc import abstractmethod -from typing import Callable, Dict, TypeVar, get_type_hints, Union, get_origin, get_args, List -import inspect +from typing import ( + Callable, + Dict, + get_args, + get_origin, + get_type_hints, + List, + TypeVar, + Union, +) from llama_stack_client.types import Message, ToolResponseMessage from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -155,9 +164,10 @@ def get_params_definition(self) -> Dict[str, Parameter]: name=name, description=param_doc or f"Parameter {name}", parameter_type=type_hint.__name__, - default=param.default if param.default != inspect.Parameter.empty else None, + default=(param.default if param.default != inspect.Parameter.empty else None), required=is_required, ) + return params def run_impl(self, **kwargs): diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index fafca9dc..bad7e46e 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -3,17 +3,18 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, Dict, Optional, Tuple + +from llama_stack_client import LlamaStackClient +from llama_stack_client.types.agent_create_params import AgentConfig from pydantic import BaseModel -from typing import Dict, Any + from ..agent import Agent -from .tool_parser import ReActToolParser +from ..client_tool import ClientTool from ..tool_parser import ToolParser from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE -from typing import Tuple, Optional -from llama_stack_client import LlamaStackClient -from ..client_tool import ClientTool -from llama_stack_client.types.agent_create_params import AgentConfig +from .tool_parser import ReActToolParser class Action(BaseModel): @@ -92,6 +93,8 @@ def get_tool_defs(): output_shields=[], enable_session_persistence=False, ) + else: + agent_config = custom_agent_config if json_response_format: agent_config.response_format = {