diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 87badd46..504ef5e8 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -162,11 +162,11 @@ def _create_turn_streaming( while not is_turn_complete: is_turn_complete = True for chunk in turn_response: - tool_calls = self._get_tool_calls(chunk) if hasattr(chunk, "error"): yield chunk return - elif not tool_calls: + tool_calls = self._get_tool_calls(chunk) + if not tool_calls: yield chunk else: is_turn_complete = False diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index f672268d..6fc811e9 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -163,7 +163,8 @@ def get_params_definition(self) -> Dict[str, Parameter]: params[name] = Parameter( name=name, description=param_doc or f"Parameter {name}", - parameter_type=type_hint.__name__, + # Hack: litellm/openai expects "string" for str type + parameter_type=type_hint.__name__ if type_hint.__name__ != "str" else "string", default=(param.default if param.default != inspect.Parameter.empty else None), required=is_required, ) diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 40a1d359..731c7b2f 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -63,7 +63,8 @@ def yield_printable_events(self, chunk: Any) -> Iterator[TurnStreamPrintableEven for printable_event in self._yield_printable_events(chunk, self.previous_event_type, self.previous_step_type): yield printable_event - self.previous_event_type, self.previous_step_type = self._get_event_type_step_type(chunk) + if not hasattr(chunk, "error"): + self.previous_event_type, self.previous_step_type = self._get_event_type_step_type(chunk) def _yield_printable_events( self, chunk: Any, previous_event_type: Optional[str] = None, previous_step_type: Optional[str] = None diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index bad7e46e..63f37493 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -3,29 +3,18 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple from llama_stack_client import LlamaStackClient from llama_stack_client.types.agent_create_params import AgentConfig -from pydantic import BaseModel + from ..agent import Agent from ..client_tool import ClientTool from ..tool_parser import ToolParser from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE -from .tool_parser import ReActToolParser - - -class Action(BaseModel): - tool_name: str - tool_params: Dict[str, Any] - - -class ReActOutput(BaseModel): - thought: str - action: Optional[Action] = None - answer: Optional[str] = None +from .tool_parser import ReActToolParser, ReActOutput class ReActAgent(Agent): @@ -97,7 +86,7 @@ def get_tool_defs(): agent_config = custom_agent_config if json_response_format: - agent_config.response_format = { + agent_config["response_format"] = { "type": "json_schema", "json_schema": ReActOutput.model_json_schema(), } diff --git a/src/llama_stack_client/lib/agents/react/tool_parser.py b/src/llama_stack_client/lib/agents/react/tool_parser.py index e668d28d..9aaa30cc 100644 --- a/src/llama_stack_client/lib/agents/react/tool_parser.py +++ b/src/llama_stack_client/lib/agents/react/tool_parser.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from pydantic import BaseModel, ValidationError -from typing import Dict, Any, Optional, List +from typing import Optional, List, Union from ..tool_parser import ToolParser from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall @@ -13,15 +13,20 @@ import uuid +class Param(BaseModel): + name: str + value: Union[str, int, float, bool] + + class Action(BaseModel): tool_name: str - tool_params: Dict[str, Any] + tool_params: List[Param] class ReActOutput(BaseModel): thought: str - action: Optional[Action] = None - answer: Optional[str] = None + action: Optional[Action] + answer: Optional[str] class ReActToolParser(ToolParser): @@ -40,8 +45,9 @@ def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: if react_output.action: tool_name = react_output.action.tool_name tool_params = react_output.action.tool_params + params = {param.name: param.value for param in tool_params} if tool_name and tool_params: call_id = str(uuid.uuid4()) - tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] + tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=params)] return tool_calls