diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 56d8907b..315a0641 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -60,8 +60,7 @@ def _process_chunk(self, chunk: AgentTurnResponseStreamChunk) -> None: message = chunk.event.payload.turn.output_message if self.output_parser: - parsed_message = self.output_parser.parse(message) - message = parsed_message + self.output_parser.parse(message) def _has_tool_call(self, chunk: AgentTurnResponseStreamChunk) -> bool: if chunk.event.payload.event_type != "turn_complete": diff --git a/src/llama_stack_client/lib/agents/output_parser.py b/src/llama_stack_client/lib/agents/output_parser.py index 1097d6d5..20c8468e 100644 --- a/src/llama_stack_client/lib/agents/output_parser.py +++ b/src/llama_stack_client/lib/agents/output_parser.py @@ -39,10 +39,10 @@ def parse(self, output_message: CompletionMessage) -> CompletionMessage: Args: output_message (CompletionMessage): The response message from agent turn - Returns: - CompletionMessage: The processed/transformed response message + Returns: None + Modifies the output_message in place """ @abstractmethod - def parse(self, output_message: CompletionMessage) -> CompletionMessage: + def parse(self, output_message: CompletionMessage) -> None: raise NotImplementedError diff --git a/src/llama_stack_client/lib/agents/react/output_parser.py b/src/llama_stack_client/lib/agents/react/output_parser.py index 6e4861a9..71177a6f 100644 --- a/src/llama_stack_client/lib/agents/react/output_parser.py +++ b/src/llama_stack_client/lib/agents/react/output_parser.py @@ -25,16 +25,16 @@ class ReActOutput(BaseModel): class ReActOutputParser(OutputParser): - def parse(self, output_message: CompletionMessage) -> CompletionMessage: + def parse(self, output_message: CompletionMessage) -> None: response_text = str(output_message.content) try: react_output = ReActOutput.model_validate_json(response_text) except ValidationError as e: print(f"Error parsing action: {e}") - return output_message + return if react_output.answer: - return output_message + return if react_output.action: tool_name = react_output.action.tool_name @@ -43,4 +43,4 @@ def parse(self, output_message: CompletionMessage) -> CompletionMessage: call_id = str(uuid.uuid4()) output_message.tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] - return output_message + return