From 10c34cab0593a77878609dfcb628a295a42dc8f0 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Thu, 22 May 2025 10:37:55 +0000 Subject: [PATCH] feat: Add multi-agent example using a2a SDK This commit introduces a new example demonstrating a multi-agent system built with the a2a SDK. The system consists of: 1. HostAgent: The main orchestrator agent. 2. PlanAgent: An agent responsible for generating a plan. 3. SearchAgent: An agent responsible for performing searches. 4. ReportAgent: An agent responsible for compiling a report. The HostAgent coordinates these agents to process a task. It uses the A2AClient to communicate with the other agents over HTTP. The example includes: - Implementation of each agent. - A `main.py` script to launch all agents, each on a separate port, using multiprocessing. - A `test_client.py` script to send a task to the HostAgent and receive the final report, demonstrating an end-to-end workflow. - Unit tests for each agent, ensuring individual components function correctly. These tests cover success cases, edge cases, and error handling, with mocking employed for HostAgent's dependencies on other agents. This example provides a practical demonstration of how to build and orchestrate multiple agents using the a2a SDK. --- examples/multi_agent_system/__init__.py | 0 examples/multi_agent_system/host_agent.py | 239 ++++++++++++++++ examples/multi_agent_system/main.py | 242 +++++++++++++++++ examples/multi_agent_system/plan_agent.py | 168 ++++++++++++ examples/multi_agent_system/report_agent.py | 142 ++++++++++ examples/multi_agent_system/search_agent.py | 142 ++++++++++ examples/multi_agent_system/test_client.py | 127 +++++++++ .../tests/test_host_agent.py | 257 ++++++++++++++++++ .../tests/test_plan_agent.py | 151 ++++++++++ .../tests/test_report_agent.py | 100 +++++++ .../tests/test_search_agent.py | 98 +++++++ 11 files changed, 1666 insertions(+) create mode 100644 examples/multi_agent_system/__init__.py create mode 100644 examples/multi_agent_system/host_agent.py create mode 100644 examples/multi_agent_system/main.py create mode 100644 examples/multi_agent_system/plan_agent.py create mode 100644 examples/multi_agent_system/report_agent.py create mode 100644 examples/multi_agent_system/search_agent.py create mode 100644 examples/multi_agent_system/test_client.py create mode 100644 examples/multi_agent_system/tests/test_host_agent.py create mode 100644 examples/multi_agent_system/tests/test_plan_agent.py create mode 100644 examples/multi_agent_system/tests/test_report_agent.py create mode 100644 examples/multi_agent_system/tests/test_search_agent.py diff --git a/examples/multi_agent_system/__init__.py b/examples/multi_agent_system/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/multi_agent_system/host_agent.py b/examples/multi_agent_system/host_agent.py new file mode 100644 index 00000000..32fa30fc --- /dev/null +++ b/examples/multi_agent_system/host_agent.py @@ -0,0 +1,239 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import httpx +import uuid # For generating unique IDs in the test block + +# Core imports from the a2a framework +from a2a.client.client import A2AClient, A2AClientTaskInfo +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.types import Message, Part, Role, TextPart # Core types +from a2a.utils.message import new_agent_text_message, get_message_text, new_user_text_message # Message utilities + + +class HostAgent(AgentExecutor): + """ + An agent that orchestrates calls to PlanAgent, SearchAgent, and ReportAgent + to process a user's task. + """ + + def __init__( + self, + plan_agent_url: str, + search_agent_url: str, + report_agent_url: str, + name: str = "HostAgent", + ): + super().__init__(name=name) + self.plan_agent_url = plan_agent_url + self.search_agent_url = search_agent_url + self.report_agent_url = report_agent_url + # A2AClients will be initialized within execute, along with httpx.AsyncClient + + async def _call_sub_agent( + self, + client: A2AClient, + agent_name: str, # For logging/error messages + input_text: str, + original_message: Message, # To carry over contextId, taskId + ) -> str: + """Helper to call a sub-agent and extract its text response.""" + # Create a message to send to the sub-agent. + # It's a "user" message from the perspective of the sub-agent. + # However, the A2AClient might wrap this in a Task structure. + # The A2AClient's execute_agent_task expects a list of Message objects as input. + sub_agent_input_message = new_user_text_message( # HostAgent acts as a "user" to sub-agents + text=input_text, + context_id=original_message.contextId, # Propagate context + task_id=original_message.taskId, # Propagate task + ) + + try: + # The A2AClient.execute_agent_task expects a list of Messages + # and returns an A2AClientTaskInfo object. + task_info: A2AClientTaskInfo = await client.execute_agent_task( + messages=[sub_agent_input_message] + ) + + # The final message from the sub-agent is often in task_info.result.messages + if task_info.result and task_info.result.messages: + # Assuming the last message is the agent's response + agent_response_message = task_info.result.messages[-1] + if agent_response_message.role == Role.AGENT: + return get_message_text(agent_response_message) + else: + return f"Error: {agent_name} did not respond with an AGENT message." + else: + return f"Error: No response messages from {agent_name}." + + except Exception as e: + # Log the exception or handle it more gracefully + print(f"Error calling {agent_name} at {client._server_url}: {e}") + return f"Error: Could not get response from {agent_name} due to {type(e).__name__}." + + + async def execute(self, message: Message) -> Message: + """ + Orchestrates the sub-agents to process the task. + """ + task_description = get_message_text(message) + if not task_description: + return new_agent_text_message( + text="Error: HostAgent received a message with no task description.", + context_id=message.contextId, + task_id=message.taskId, + ) + + final_report = "Error: Orchestration failed." # Default error message + + async with httpx.AsyncClient() as http_client: + plan_agent_client = A2AClient(server_url=self.plan_agent_url, http_client=http_client) + search_agent_client = A2AClient(server_url=self.search_agent_url, http_client=http_client) + report_agent_client = A2AClient(server_url=self.report_agent_url, http_client=http_client) + + # 1. Call PlanAgent + plan = await self._call_sub_agent( + plan_agent_client, "PlanAgent", task_description, message + ) + if plan.startswith("Error:"): + return new_agent_text_message(text=plan, context_id=message.contextId, task_id=message.taskId) + + # 2. Call SearchAgent + # For simplicity, using the original task description as the search query. + # A more advanced version might parse the plan to create specific queries. + search_query = task_description + search_results = await self._call_sub_agent( + search_agent_client, "SearchAgent", search_query, message + ) + if search_results.startswith("Error:"): + # Proceed with reporting what we have, or return error + combined_input_for_report = f"Plan:\n{plan}\n\nSearch Results: Failed - {search_results}" + else: + combined_input_for_report = f"Plan:\n{plan}\n\nSearch Results:\n{search_results}" + + # 3. Call ReportAgent + final_report = await self._call_sub_agent( + report_agent_client, "ReportAgent", combined_input_for_report, message + ) + # If final_report itself is an error string from _call_sub_agent, it will be returned. + + # Return the final report from ReportAgent + return new_agent_text_message( + text=final_report, + context_id=message.contextId, + task_id=message.taskId, + ) + + async def cancel(self, interaction_id: str) -> None: + """ + Cancels an ongoing task. + For HostAgent, this would ideally involve propagating cancellations to sub-agents. + """ + print(f"Cancellation requested for interaction/context/task '{interaction_id}' in {self.name}.") + # TODO: Implement cancellation propagation to sub-agents if their A2AClient interface supports it. + # For now, this is a placeholder. + raise NotImplementedError( + "HostAgent cancellation requires propagation to sub-agents, which is not yet implemented." + ) + + +if __name__ == "__main__": + # This example is more complex to run directly as it involves HTTP calls + # to other agents. For a simple test, we would mock A2AClient. + + # --- Mocking section --- + class MockA2AClient: + def __init__(self, server_url: str, http_client=None): + self._server_url = server_url + self.http_client = http_client # Keep httpx.AsyncClient for realism if used by HostAgent + + async def execute_agent_task(self, messages: list[Message]) -> A2AClientTaskInfo: + input_text = get_message_text(messages[0]) + # Simulate responses based on the agent URL or input + response_text = "" + if "plan" in self._server_url: + response_text = f"Plan for '{input_text}': Step 1, Step 2." + elif "search" in self._server_url: + response_text = f"Search results for '{input_text}': Result A, Result B." + elif "report" in self._server_url: + response_text = f"Report based on: {input_text}" + + # Simulate A2AClientTaskInfo structure + response_message = new_agent_text_message( + text=response_text, + context_id=messages[0].contextId, + task_id=messages[0].taskId + ) + # Simplified TaskResult and A2AClientTaskInfo + class MockTaskResult: + def __init__(self, messages): + self.messages = messages + class MockA2AClientTaskInfo(A2AClientTaskInfo): + def __init__(self, messages): + super().__init__(task_id="", status="", messages=messages, result=MockTaskResult(messages=messages)) + + return MockA2AClientTaskInfo(messages=[response_message]) + + # Store original and apply mock + original_a2a_client = A2AClient + A2AClient = MockA2AClient # type: ignore + + # Mock AgentExecutor for HostAgent itself + class MockAgentExecutor: + def __init__(self, name: str): + self.name = name + original_agent_executor = AgentExecutor + AgentExecutor = MockAgentExecutor # type: ignore + # --- End Mocking section --- + + async def main_test(): + # Dummy URLs for the mocked clients + plan_url = "http://mockplanagent.test" + search_url = "http://mocksearchagent.test" + report_url = "http://mockreportagent.test" + + host_agent = HostAgent( + plan_agent_url=plan_url, + search_agent_url=search_url, + report_agent_url=report_url, + ) + + user_task = "Research benefits of async programming and report them." + test_message = new_user_text_message( + text=user_task, + context_id=str(uuid.uuid4()), + task_id=str(uuid.uuid4()) + ) + + print(f"HostAgent processing task: '{user_task}'") + final_response = await host_agent.execute(test_message) + + print("\nHostAgent Final Response:") + print(get_message_text(final_response)) + + # Test cancellation (will raise NotImplementedError as per implementation) + try: + print("\nTesting HostAgent cancellation...") + await host_agent.cancel(test_message.contextId) + except NotImplementedError as e: + print(f"Cancellation test: Caught expected error - {e}") + + try: + asyncio.run(main_test()) + finally: + # Restore original classes + A2AClient = original_a2a_client # type: ignore + AgentExecutor = original_agent_executor # type: ignore + print("\nRestored A2AClient and AgentExecutor. HostAgent example finished.") diff --git a/examples/multi_agent_system/main.py b/examples/multi_agent_system/main.py new file mode 100644 index 00000000..bb773553 --- /dev/null +++ b/examples/multi_agent_system/main.py @@ -0,0 +1,242 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import uvicorn +import signal # For graceful shutdown +import time # For process join timeout + +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCard, AgentSkill, AgentParameter, AgentCapabilities, ParameterType + +# Import the agent executor classes +from examples.multi_agent_system.host_agent import HostAgent +from examples.multi_agent_system.plan_agent import PlanAgent +from examples.multi_agent_system.search_agent import SearchAgent +from examples.multi_agent_system.report_agent import ReportAgent + +# Define base URL and ports +BASE_URL = "http://localhost" +HOST_AGENT_PORT = 8000 +PLAN_AGENT_PORT = 8001 +SEARCH_AGENT_PORT = 8002 +REPORT_AGENT_PORT = 8003 + +# Agent URLs +HOST_AGENT_URL = f"{BASE_URL}:{HOST_AGENT_PORT}" +PLAN_AGENT_URL = f"{BASE_URL}:{PLAN_AGENT_PORT}" +SEARCH_AGENT_URL = f"{BASE_URL}:{SEARCH_AGENT_PORT}" +REPORT_AGENT_URL = f"{BASE_URL}:{REPORT_AGENT_PORT}" + +# Agent Cards Definition +# Common parameter for tasks +task_param = AgentParameter(name="task_description", type=ParameterType.TEXT, description="The task to process.") +query_param = AgentParameter(name="search_query", type=ParameterType.TEXT, description="The query to search for.") +data_param = AgentParameter(name="combined_data", type=ParameterType.TEXT, description="Data to include in the report.") + + +host_agent_card = AgentCard( + id="host-agent-001", # Unique ID + name="Host Orchestrator Agent", + description="Orchestrates planning, searching, and reporting agents.", + icon_uri="https://storage.googleapis.com/agentsea-public-assets/agent-icons/orchestrator.png", + capabilities=AgentCapabilities(skills=[ + AgentSkill( + id="orchestrate_task_v1", # Unique skill ID + description="Processes a complex task by coordinating with other agents.", + parameters=[task_param], + target_url=f"{HOST_AGENT_URL}/execute", + ) + ]), + trust_level=1, + version="0.1.0" +) + +plan_agent_card = AgentCard( + id="plan-agent-001", + name="Planning Agent", + description="Generates a plan for a given task.", + icon_uri="https://storage.googleapis.com/agentsea-public-assets/agent-icons/planner.png", + capabilities=AgentCapabilities(skills=[ + AgentSkill( + id="generate_plan_v1", + description="Creates a step-by-step plan for a task description.", + parameters=[task_param], + target_url=f"{PLAN_AGENT_URL}/execute", + ) + ]), + trust_level=1, + version="0.1.0" +) + +search_agent_card = AgentCard( + id="search-agent-001", + name="Search Agent", + description="Performs searches based on a query.", + icon_uri="https://storage.googleapis.com/agentsea-public-assets/agent-icons/search.png", + capabilities=AgentCapabilities(skills=[ + AgentSkill( + id="perform_search_v1", + description="Searches for information based on a query string.", + parameters=[query_param], + target_url=f"{SEARCH_AGENT_URL}/execute", + ) + ]), + trust_level=1, + version="0.1.0" +) + +report_agent_card = AgentCard( + id="report-agent-001", + name="Reporting Agent", + description="Generates a report from combined data.", + icon_uri="https://storage.googleapis.com/agentsea-public-assets/agent-icons/reporter.png", + capabilities=AgentCapabilities(skills=[ + AgentSkill( + id="generate_report_v1", + description="Creates a formatted report from input data.", + parameters=[data_param], + target_url=f"{REPORT_AGENT_URL}/execute", + ) + ]), + trust_level=1, + version="0.1.0" +) + + +def run_agent_server(agent_executor_class, agent_card, port, agent_urls_for_host=None): + """ + Sets up and runs a single agent server. + agent_urls_for_host is a dict required only for HostAgent. + """ + print(f"Configuring {agent_card.name} on port {port}...") + + if agent_executor_class == HostAgent: + if agent_urls_for_host is None: + raise ValueError("HostAgent requires agent_urls_for_host (plan, search, report URLs)") + agent_executor = HostAgent( + plan_agent_url=agent_urls_for_host["plan"], + search_agent_url=agent_urls_for_host["search"], + report_agent_url=agent_urls_for_host["report"], + name=agent_card.name, + ) + else: + agent_executor = agent_executor_class(name=agent_card.name) + + task_store = InMemoryTaskStore() + # Ensure agent_id is passed to DefaultRequestHandler, as it's required + request_handler = DefaultRequestHandler( + agent_executor=agent_executor, + task_store=task_store, + agent_id=agent_card.id, + ) + + app = A2AStarletteApplication( + agent_card=agent_card, + request_handler=request_handler, + root_path="", + ) + + # uvicorn.run can be problematic with multiprocessing on some platforms/OS versions + # especially with signal handling. For simplicity, we'll proceed, but in a production + # setup, alternatives like gunicorn or running uvicorn programmatically with different + # loop policies might be needed. + print(f"Starting {agent_card.name} server on {BASE_URL}:{port}...") + uvicorn.run(app, host="localhost", port=port, log_level="info") + + +# Global list to keep track of processes for signal handling +processes = [] + +def signal_handler(sig, frame): + print(f"\nCaught signal {sig}, initiating graceful shutdown...") + for p_info in processes: + print(f"Terminating process {p_info['name']} (PID: {p_info['process'].pid})...") + if p_info['process'].is_alive(): + p_info['process'].terminate() # Send SIGTERM + + # Wait for processes to terminate + for p_info in processes: + try: + p_info['process'].join(timeout=10) # Wait for 10 seconds + if p_info['process'].is_alive(): + print(f"Process {p_info['name']} (PID: {p_info['process'].pid}) did not terminate gracefully, killing.") + p_info['process'].kill() # Send SIGKILL + else: + print(f"Process {p_info['name']} (PID: {p_info['process'].pid}) terminated.") + except Exception as e: + print(f"Error during termination of {p_info['name']}: {e}") + + print("All agent processes have been dealt with. Exiting.") + exit(0) + + +if __name__ == "__main__": + print("Starting multi-agent system...") + + # Register signal handlers for graceful shutdown + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # kill command + + host_agent_sub_urls = { + "plan": PLAN_AGENT_URL, + "search": SEARCH_AGENT_URL, + "report": REPORT_AGENT_URL, + } + + agents_config = [ + (HostAgent, host_agent_card, HOST_AGENT_PORT, host_agent_sub_urls, "HostAgent"), + (PlanAgent, plan_agent_card, PLAN_AGENT_PORT, None, "PlanAgent"), + (SearchAgent, search_agent_card, SEARCH_AGENT_PORT, None, "SearchAgent"), + (ReportAgent, report_agent_card, REPORT_AGENT_PORT, None, "ReportAgent"), + ] + + # Clear global processes list before starting new ones + processes.clear() + + for agent_class, card, port, sub_urls, name_for_logging in agents_config: + process = multiprocessing.Process( + target=run_agent_server, + args=(agent_class, card, port, sub_urls) + ) + processes.append({"process": process, "name": name_for_logging, "card": card}) + process.start() + print(f"Launched {card.name} process (PID: {process.pid}).") + + print("All agent servers launched. System is running. Press Ctrl+C to stop.") + + # Keep the main process alive until a signal is received + try: + while True: + time.sleep(1) # Keep main thread alive to handle signals + # Optionally, check if processes are alive and restart if needed (more complex) + all_stopped = True + for p_info in processes: + if p_info['process'].is_alive(): + all_stopped = False + break + if all_stopped and processes: # if processes list is not empty and all are stopped + print("All agent processes seem to have stopped unexpectedly. Exiting main.") + break + except KeyboardInterrupt: # Should be caught by signal handler, but as a fallback + print("KeyboardInterrupt in main loop, initiating shutdown via signal handler logic...") + signal_handler(signal.SIGINT, None) + finally: + # Ensure cleanup if loop exits for reasons other than signals handled by signal_handler + if any(p_info['process'].is_alive() for p_info in processes): + print("Main loop exited, ensuring processes are terminated...") + signal_handler(signal.SIGTERM, None) # Trigger cleanup + print("Multi-agent system main process finished.") diff --git a/examples/multi_agent_system/plan_agent.py b/examples/multi_agent_system/plan_agent.py new file mode 100644 index 00000000..a7f3e190 --- /dev/null +++ b/examples/multi_agent_system/plan_agent.py @@ -0,0 +1,168 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid # For generating unique IDs if needed for context/task + +# Corrected imports based on project structure +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.types import Message, Part, Role, TextPart # Core types +from a2a.utils.message import new_agent_text_message, get_message_text # Message utilities + + +class PlanAgent(AgentExecutor): + """ + An agent that receives a task description and creates a plan. + It uses the a2a.types.Message format. + """ + + def __init__(self, name: str = "PlanAgent"): + super().__init__(name=name) + + async def execute(self, message: Message) -> Message: + """ + Executes the planning task. + + Args: + message: A Message object, expected to contain text parts with the task description. + + Returns: + A new Message object (from the agent) containing the generated plan as text. + """ + # Check if the message is from the user and contains text + # (Simplified check: actual role check might be more complex depending on system design) + if message.role != Role.USER: # Assuming plan requests come from USER role + # Or handle messages from other agents if applicable + print(f"Warning: {self.name} received a message not from USER role, but from {message.role}.") + # Depending on strictness, could return an error message here. + + task_description = get_message_text(message) + + if not task_description: + # Return an error message if no text could be extracted + return new_agent_text_message( + text="Error: PlanAgent received a message with no text content.", + context_id=message.contextId, + task_id=message.taskId, + ) + + # Create a simple plan. + plan = [ + f"Step 1: Understand the task: '{task_description}'", + "Step 2: Identify key components for the task.", + "Step 3: Break down components into actionable steps.", + "Step 4: Sequence the steps logically.", + "Step 5: Output the plan.", + ] + plan_content = "\n".join(plan) + + # Create an agent message containing the plan + # The new_agent_text_message sets role=Role.agent automatically. + response_message = new_agent_text_message( + text=plan_content, + context_id=message.contextId, # Carry over context ID + task_id=message.taskId, # Carry over task ID + ) + # Recipient is implicitly the system/user that sent the original message. + return response_message + + async def cancel(self, interaction_id: str) -> None: + """ + Cancels an ongoing task. + For PlanAgent, execute is quick. This fulfills the AgentExecutor interface. + The interaction_id could correspond to a context_id or task_id. + """ + print(f"Cancellation requested for interaction/context/task '{interaction_id}' in {self.name}.") + # In a real scenario, you might try to find tasks associated with interaction_id + # and stop them if they are long-running. + # For now, we'll consider it a successful no-op as the execute() is not long-running. + # super().cancel() might not exist or have a different signature in the base AgentExecutor. + # If the base class `AgentExecutor` from `a2a` has a specific `cancel` to be called: + # await super().cancel(interaction_id) + # For now, raising NotImplementedError if specific cancel logic is expected but not implemented. + # However, a simple acknowledgment is also fine for simple agents. + # Let's assume the base AgentExecutor doesn't have a cancel to be called or it's handled by the framework. + pass # Or raise NotImplementedError("PlanAgent does not support explicit task cancellation.") + + +if __name__ == "__main__": + # Example of how to use the PlanAgent (for testing) + # This requires the a2a types and utils to be available in PYTHONPATH. + # Also, the AgentExecutor base class needs to be correctly defined. + + # Mock necessary parts of a2a for the main() to be runnable if a2a is not fully set up + # This is a simplified mock for local testing of this file only. + class MockAgentExecutor: + def __init__(self, name: str): + self.name = name + # async def cancel(self, interaction_id: str): # Base cancel, if any + # print(f"MockAgentExecutor: Cancel called for {interaction_id}") + # pass + + # Replace the actual AgentExecutor with the mock for this test script + # This is a common technique for unit testing or examples when the full environment isn't available. + # However, this means we are not testing the *actual* base class behavior here. + # For true integration, the script should run within the project's environment. + + # To make this example runnable without modifying the original AgentExecutor line: + # We would need to ensure a2a.server.agent_execution.agent_executor.AgentExecutor is mockable + # or the PYTHONPATH is set up. For now, let's assume the imports work. + # If not, one would typically run this via a test runner that handles paths. + + original_agent_executor = AgentExecutor # Save original + AgentExecutor = MockAgentExecutor # Replace with mock for this test block + + async def main_test(): + plan_agent = PlanAgent() + + # Simulate an incoming user message + user_task_description = "Develop a new feature for the user authentication module." + # Create a Message object similar to how the system might provide it + # (messageId, taskId, contextId would usually be generated by the system) + test_message = Message( + role=Role.USER, # Message from the user + parts=[Part(root=TextPart(text=user_task_description))], + messageId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()), + contextId=str(uuid.uuid4()) + ) + + print(f"Sending task to PlanAgent: '{user_task_description}'") + response_message = await plan_agent.execute(test_message) + + print(f"\nPlan Agent Response (Role: {response_message.role}):") + print(get_message_text(response_message)) + + # Test cancel + print(f"\nRequesting cancellation for contextId: {test_message.contextId}") + await plan_agent.cancel(test_message.contextId) + + # Test message with no text + empty_message = Message( + role=Role.USER, + parts=[], # No parts, or non-TextPart parts + messageId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()), + contextId=str(uuid.uuid4()) + ) + print(f"\nSending empty message to PlanAgent...") + error_response = await plan_agent.execute(empty_message) + print(f"Plan Agent Error Response:\n{get_message_text(error_response)}") + + + import asyncio + try: + asyncio.run(main_test()) + finally: + AgentExecutor = original_agent_executor # Restore original + print("\nRestored AgentExecutor. Example finished.") diff --git a/examples/multi_agent_system/report_agent.py b/examples/multi_agent_system/report_agent.py new file mode 100644 index 00000000..01be50bc --- /dev/null +++ b/examples/multi_agent_system/report_agent.py @@ -0,0 +1,142 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid # For generating unique IDs in the test block + +# Core imports from the a2a framework +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.types import Message, Part, Role, TextPart # Core types +from a2a.utils.message import new_agent_text_message, get_message_text # Message utilities + + +class ReportAgent(AgentExecutor): + """ + An agent that receives combined data (e.g., plan and search results) + and generates a simple report. + It uses the a2a.types.Message format. + """ + + def __init__(self, name: str = "ReportAgent"): + super().__init__(name=name) + + async def execute(self, message: Message) -> Message: + """ + Executes the report generation task. + + Args: + message: A Message object, expected to contain text parts with the + combined plan and search results. + + Returns: + A new Message object (from the agent) containing the generated report as text. + """ + + combined_input_text = get_message_text(message) + + if not combined_input_text: + # Return an error message if no text input could be extracted + return new_agent_text_message( + text="Error: ReportAgent received a message with no content to report.", + context_id=message.contextId, + task_id=message.taskId, + ) + + # Generate a simple report. + # In a real scenario, this might involve more sophisticated formatting, + # summarization, or data extraction. + report_content = f"--- Combined Report ---\n\n" + report_content += "Processed Input:\n" + report_content += "---------------------\n" + report_content += combined_input_text + report_content += "\n---------------------\n" + report_content += "End of Report.\n" + + + # Create an agent message containing the report + response_message = new_agent_text_message( + text=report_content, + context_id=message.contextId, # Carry over context ID + task_id=message.taskId, # Carry over task ID + ) + return response_message + + async def cancel(self, interaction_id: str) -> None: + """ + Cancels an ongoing task. + For ReportAgent, execute is quick. This fulfills the AgentExecutor interface. + The interaction_id could correspond to a context_id or task_id. + """ + print(f"Cancellation requested for interaction/context/task '{interaction_id}' in {self.name}.") + # As execute() is not long-running, no specific cancellation logic is implemented. + # raise NotImplementedError("ReportAgent does not support explicit task cancellation.") + pass # Acknowledging the request is sufficient for this simple agent. + + +if __name__ == "__main__": + # Example of how to use the ReportAgent (for testing) + + # Mock AgentExecutor for the main() to be runnable if a2a is not fully set up + class MockAgentExecutor: + def __init__(self, name: str): + self.name = name + # async def cancel(self, interaction_id: str): + # pass + + original_agent_executor = AgentExecutor + AgentExecutor = MockAgentExecutor + + async def main_test(): + report_agent = ReportAgent() + + # Simulate an incoming message with combined plan and search results + simulated_plan = "Step 1: Query for X\nStep 2: Analyze Y" + simulated_search_results = "Result 1 for X\nResult 2 for Y" + combined_data = f"Plan:\n{simulated_plan}\n\nSearch Results:\n{simulated_search_results}" + + test_message = Message( + role=Role.USER, # Or could be from another agent (e.g., HostAgent) + parts=[Part(root=TextPart(text=combined_data))], + messageId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()), + contextId=str(uuid.uuid4()) + ) + + print(f"Sending data to ReportAgent:\n'{combined_data}'") + response_message = await report_agent.execute(test_message) + + print(f"\nReport Agent Response (Role: {response_message.role}):") + print(get_message_text(response_message)) + + # Test cancel + print(f"\nRequesting cancellation for contextId: {test_message.contextId}") + await report_agent.cancel(test_message.contextId) + + # Test message with no content + empty_message = Message( + role=Role.USER, + parts=[], + messageId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()), + contextId=str(uuid.uuid4()) + ) + print(f"\nSending empty message to ReportAgent...") + error_response = await report_agent.execute(empty_message) + print(f"Report Agent Error Response:\n{get_message_text(error_response)}") + + import asyncio + try: + asyncio.run(main_test()) + finally: + AgentExecutor = original_agent_executor # Restore original + print("\nRestored AgentExecutor. ReportAgent example finished.") diff --git a/examples/multi_agent_system/search_agent.py b/examples/multi_agent_system/search_agent.py new file mode 100644 index 00000000..9e747b02 --- /dev/null +++ b/examples/multi_agent_system/search_agent.py @@ -0,0 +1,142 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid # For generating unique IDs in the test block + +# Core imports from the a2a framework +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.types import Message, Part, Role, TextPart # Core types +from a2a.utils.message import new_agent_text_message, get_message_text # Message utilities + + +class SearchAgent(AgentExecutor): + """ + An agent that receives a search query and returns dummy search results. + It uses the a2a.types.Message format. + """ + + def __init__(self, name: str = "SearchAgent"): + super().__init__(name=name) + + async def execute(self, message: Message) -> Message: + """ + Executes the search task. + + Args: + message: A Message object, expected to contain text parts with the search query. + + Returns: + A new Message object (from the agent) containing the dummy search results as text. + """ + # Assuming search requests come from a USER role or another agent that has processed a plan. + # For this example, we won't be too strict on message.role. + + search_query = get_message_text(message) + + if not search_query: + # Return an error message if no text query could be extracted + return new_agent_text_message( + text="Error: SearchAgent received a message with no search query.", + context_id=message.contextId, + task_id=message.taskId, + ) + + # Create dummy search results. + # In a real scenario, this would involve calling a search API or database. + dummy_results = [ + f"https://example.com/search?q={search_query.replace(' ', '+')}&result=1", + f"https://example.com/search?q={search_query.replace(' ', '+')}&result=2", + f"https://en.wikipedia.org/wiki/{search_query.replace(' ', '_')}", + f"Relevant internal document: DOC-{search_query.upper().replace(' ', '-')}-001", + ] + + results_content = f"Search results for query: '{search_query}'\n" + "\n".join(dummy_results) + + # Create an agent message containing the search results + response_message = new_agent_text_message( + text=results_content, + context_id=message.contextId, # Carry over context ID + task_id=message.taskId, # Carry over task ID + ) + return response_message + + async def cancel(self, interaction_id: str) -> None: + """ + Cancels an ongoing task. + For SearchAgent, execute is quick. This fulfills the AgentExecutor interface. + The interaction_id could correspond to a context_id or task_id. + """ + print(f"Cancellation requested for interaction/context/task '{interaction_id}' in {self.name}.") + # As execute() is not long-running, no specific cancellation logic is implemented. + # In a real search agent that makes external API calls, this would involve + # attempting to cancel the HTTP request or database query. + # raise NotImplementedError("SearchAgent does not support explicit task cancellation yet.") + pass # Acknowledging the request is sufficient for this simple agent. + + +if __name__ == "__main__": + # Example of how to use the SearchAgent (for testing) + + # Mock AgentExecutor for the main() to be runnable if a2a is not fully set up + class MockAgentExecutor: + def __init__(self, name: str): + self.name = name + # async def cancel(self, interaction_id: str): + # pass + + original_agent_executor = AgentExecutor + AgentExecutor = MockAgentExecutor + + async def main_test(): + search_agent = SearchAgent() + + # Simulate an incoming user message with a search query + user_search_query = "large language models applications" + + test_message = Message( + role=Role.USER, # Or could be from another agent + parts=[Part(root=TextPart(text=user_search_query))], + messageId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()), # Task ID from the overall process + contextId=str(uuid.uuid4()) # Context ID for the interaction + ) + + print(f"Sending query to SearchAgent: '{user_search_query}'") + response_message = await search_agent.execute(test_message) + + print(f"\nSearch Agent Response (Role: {response_message.role}):") + print(get_message_text(response_message)) + + # Test cancel + print(f"\nRequesting cancellation for contextId: {test_message.contextId}") + await search_agent.cancel(test_message.contextId) + + # Test message with no query + empty_message = Message( + role=Role.USER, + parts=[], + messageId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()), + contextId=str(uuid.uuid4()) + ) + print(f"\nSending empty message to SearchAgent...") + error_response = await search_agent.execute(empty_message) + print(f"Search Agent Error Response:\n{get_message_text(error_response)}") + + import asyncio + try: + asyncio.run(main_test()) + finally: + AgentExecutor = original_agent_executor # Restore original + print("\nRestored AgentExecutor. SearchAgent example finished.") diff --git a/examples/multi_agent_system/test_client.py b/examples/multi_agent_system/test_client.py new file mode 100644 index 00000000..dc1ff99b --- /dev/null +++ b/examples/multi_agent_system/test_client.py @@ -0,0 +1,127 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import httpx +import uuid + +from a2a.client.client import A2AClient +from a2a.types import ( + SendMessageRequest, + MessageSendParams, + Message, + Part, + TextPart, + Role, + SendMessageResponse, # Correct response type for client.send_message +) +from a2a.utils.message import get_message_text # Helper to extract text + +# Configuration +HOST_AGENT_BASE_URL = "http://localhost:8000" # Assuming HostAgent runs on port 8000 +HOST_AGENT_CARD_URL = f"{HOST_AGENT_BASE_URL}/agent_card" # Default endpoint for agent card + +async def main(): + """ + Client script to interact with the HostAgent. + """ + print(f"Attempting to connect to HostAgent via card URL: {HOST_AGENT_CARD_URL}") + + try: + async with httpx.AsyncClient() as http_client: + # Get an A2AClient instance for the HostAgent using its agent card URL + # This client will be configured to interact with the HostAgent's skills + try: + client: A2AClient = await A2AClient.get_client_from_agent_card_url( + agent_card_url=HOST_AGENT_CARD_URL, http_client=http_client + ) + print(f"Successfully created A2AClient for HostAgent: {client.agent_id}") + except Exception as e: + print(f"Error creating A2AClient from agent card: {e}") + print("Please ensure the HostAgent server is running and accessible.") + return + + # Define the task for the HostAgent + task_description = "Plan a two-day trip to Paris, including museum visits and local dining." + print(f"\nSending task to HostAgent: '{task_description}'") + + # Create the message to send + # For send_message, a context_id (or task_id) is usually required if the interaction + # is not the very first one. If get_client_from_agent_card_url or the first send_message + # doesn't establish a context, this might need adjustment (e.g., using start_task). + # For this example, we'll assume send_message can initiate if no context_id is provided, + # or that the client handles it for the primary skill. + + # A task_id is required by DefaultRequestHandler to create/lookup a task. + # Let's generate one for this new interaction. + # The `context_id` in `MessageSendParams` is often the same as `task_id` for the first message. + current_task_id = str(uuid.uuid4()) + + user_message = Message( + messageId=str(uuid.uuid4()), + role=Role.USER, + parts=[Part(root=TextPart(text=task_description))], + # taskId and contextId within the Message object itself are optional here, + # as they are primarily for server-side tracking within a Task. + # The crucial part is setting them in MessageSendParams for the request. + ) + + # Create MessageSendParams + # The context_id here refers to the ongoing conversation/task context. + # If this is the first message of a new task, this context_id will be used to create that task. + message_params = MessageSendParams( + message=user_message, + # context_id=current_task_id, # DefaultRequestHandler uses task_id + task_id=current_task_id # This is what DefaultRequestHandler will use to create a new task + ) + + # Create SendMessageRequest + send_request = SendMessageRequest(params=message_params) + + try: + # Call client.send_message() + # This will typically send the message to the agent's default skill + # or a skill determined by the client/server if context is ambiguous. + # The DefaultRequestHandler on the server should pick this up, + # create a new task using current_task_id, and pass it to HostAgent.execute(). + print(f"Sending message with task_id: {current_task_id}...") + response: SendMessageResponse = await client.send_message(request=send_request) + + # The response from send_message is SendMessageResponse, which contains a Message + host_agent_reply: Message = response.message + + # Extract text from the response message + # Using get_message_text for robustness, though direct access is also possible. + final_report_text = get_message_text(host_agent_reply) + + print("\n--- HostAgent Final Report ---") + print(final_report_text) + print("--- End of Report ---") + + except httpx.HTTPStatusError as e: + print(f"HTTP error occurred while sending message: {e.response.status_code} - {e.response.text}") + except Exception as e: + print(f"An error occurred while sending message or processing response: {e}") + + + except httpx.ConnectError: + print(f"Connection error: Could not connect to HostAgent at {HOST_AGENT_BASE_URL}.") + print("Please ensure the multi-agent system (main.py) is running.") + except Exception as e: + print(f"An unexpected error occurred: {e}") + +if __name__ == "__main__": + print("Starting HostAgent test client...") + asyncio.run(main()) + print("\nTest client finished.") diff --git a/examples/multi_agent_system/tests/test_host_agent.py b/examples/multi_agent_system/tests/test_host_agent.py new file mode 100644 index 00000000..fb8064a7 --- /dev/null +++ b/examples/multi_agent_system/tests/test_host_agent.py @@ -0,0 +1,257 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio # For async fixtures if needed, though not strictly here +import uuid +from unittest.mock import AsyncMock, patch, MagicMock, call + +from examples.multi_agent_system.host_agent import HostAgent +from a2a.client.client import A2AClient, A2AClientTaskInfo +from a2a.types import Message, Role, TextPart, Part, TaskResult, TaskStatus +from a2a.utils.message import get_message_text, new_user_text_message, new_agent_text_message + +DUMMY_PLAN_URL = "http://plan.test" +DUMMY_SEARCH_URL = "http://search.test" +DUMMY_REPORT_URL = "http://report.test" + +@pytest.fixture +def host_agent_instance(): + """Provides an instance of HostAgent with dummy URLs.""" + return HostAgent( + plan_agent_url=DUMMY_PLAN_URL, + search_agent_url=DUMMY_SEARCH_URL, + report_agent_url=DUMMY_REPORT_URL, + name="TestHostAgent" + ) + +@pytest.fixture +def sample_task_description(): + return "Test task: orchestrate sub-agents" + +@pytest.fixture +def sample_user_message(sample_task_description): + """Provides a sample user Message object for HostAgent input.""" + return new_user_text_message( + text=sample_task_description, + context_id=str(uuid.uuid4()), + task_id=str(uuid.uuid4()) + ) + +@pytest.fixture +def empty_user_message(): + """Provides a user Message object with no text content.""" + return new_user_text_message(text="", context_id=str(uuid.uuid4()), task_id=str(uuid.uuid4())) + + +def create_mock_a2a_client_task_info(text_response: str, original_message: Message) -> A2AClientTaskInfo: + """Helper to create a mock A2AClientTaskInfo object.""" + agent_reply_message = new_agent_text_message( + text=text_response, + context_id=original_message.contextId, + task_id=original_message.taskId + ) + # Create a TaskResult object that A2AClientTaskInfo expects + task_result = TaskResult( + task_id=original_message.taskId or str(uuid.uuid4()), # Ensure task_id is not None + status=TaskStatus.COMPLETED, + messages=[agent_reply_message] # List of messages, last one is typically the result + ) + return A2AClientTaskInfo( + task_id=original_message.taskId or str(uuid.uuid4()), + status=TaskStatus.COMPLETED, # Simplified status + messages=[agent_reply_message], # Full message history for the task + result=task_result # The final result object + ) + + +@pytest.mark.asyncio +@patch('examples.multi_agent_system.host_agent.A2AClient') # Patch A2AClient where it's used +async def test_host_agent_execute_success( + MockA2AClientConstructor: MagicMock, + host_agent_instance: HostAgent, + sample_user_message: Message, + sample_task_description: str +): + # Create AsyncMock instances for each sub-agent client's execute_agent_task method + mock_plan_client = AsyncMock(spec=A2AClient) + mock_search_client = AsyncMock(spec=A2AClient) + mock_report_client = AsyncMock(spec=A2AClient) + + # Configure the constructor mock to return different client mocks based on URL + def side_effect_constructor(server_url, http_client): + if server_url == DUMMY_PLAN_URL: + return mock_plan_client + elif server_url == DUMMY_SEARCH_URL: + return mock_search_client + elif server_url == DUMMY_REPORT_URL: + return mock_report_client + raise ValueError(f"Unexpected server_url: {server_url}") + + MockA2AClientConstructor.side_effect = side_effect_constructor + + # Define responses from sub-agents + plan_response_text = f"Plan for '{sample_task_description}'" + search_response_text = f"Search results for '{sample_task_description}'" + report_response_text = f"Final report based on Plan and Search" + + mock_plan_client.execute_agent_task = AsyncMock( + return_value=create_mock_a2a_client_task_info(plan_response_text, sample_user_message) + ) + mock_search_client.execute_agent_task = AsyncMock( + return_value=create_mock_a2a_client_task_info(search_response_text, sample_user_message) + ) + mock_report_client.execute_agent_task = AsyncMock( + return_value=create_mock_a2a_client_task_info(report_response_text, sample_user_message) + ) + + # Execute HostAgent + final_message = await host_agent_instance.execute(message=sample_user_message) + + # Assertions + MockA2AClientConstructor.assert_any_call(server_url=DUMMY_PLAN_URL, http_client=unittest.mock.ANY) + MockA2AClientConstructor.assert_any_call(server_url=DUMMY_SEARCH_URL, http_client=unittest.mock.ANY) + MockA2AClientConstructor.assert_any_call(server_url=DUMMY_REPORT_URL, http_client=unittest.mock.ANY) + + mock_plan_client.execute_agent_task.assert_called_once() + # Check the content of the message sent to plan_agent + plan_call_args = mock_plan_client.execute_agent_task.call_args[0][0] # messages list + assert get_message_text(plan_call_args[0]) == sample_task_description + + mock_search_client.execute_agent_task.assert_called_once() + search_call_args = mock_search_client.execute_agent_task.call_args[0][0] + assert get_message_text(search_call_args[0]) == sample_task_description # Simple pass-through for query + + mock_report_client.execute_agent_task.assert_called_once() + report_call_args = mock_report_client.execute_agent_task.call_args[0][0] + expected_report_input = f"Plan:\n{plan_response_text}\n\nSearch Results:\n{search_response_text}" + assert get_message_text(report_call_args[0]) == expected_report_input + + assert final_message.role == Role.AGENT + assert get_message_text(final_message) == report_response_text + assert final_message.contextId == sample_user_message.contextId + assert final_message.taskId == sample_user_message.taskId + + +@pytest.mark.asyncio +@patch('examples.multi_agent_system.host_agent.A2AClient') +async def test_host_agent_execute_plan_agent_error( + MockA2AClientConstructor: MagicMock, + host_agent_instance: HostAgent, + sample_user_message: Message +): + mock_plan_client = AsyncMock(spec=A2AClient) + # Other clients won't be called if plan fails early + mock_search_client = AsyncMock(spec=A2AClient) + mock_report_client = AsyncMock(spec=A2AClient) + + def side_effect_constructor(server_url, http_client): + if server_url == DUMMY_PLAN_URL: return mock_plan_client + if server_url == DUMMY_SEARCH_URL: return mock_search_client # Should not be constructed + if server_url == DUMMY_REPORT_URL: return mock_report_client # Should not be constructed + raise ValueError(f"Unexpected server_url: {server_url}") + MockA2AClientConstructor.side_effect = side_effect_constructor + + error_response_text = "Error: PlanAgent failed spectacularly." + mock_plan_client.execute_agent_task = AsyncMock( + # Simulate an error message being returned by _call_sub_agent + return_value=create_mock_a2a_client_task_info(error_response_text, sample_user_message) + ) + # Hack: Modify the response text from create_mock_a2a_client_task_info to ensure it starts with "Error:" + # for the HostAgent's internal check `if plan.startswith("Error:")` + # This is a bit brittle. A better mock for _call_sub_agent would be ideal but is more complex. + # For now, we make sure the text inside the message starts with "Error:" + error_task_info = create_mock_a2a_client_task_info(error_response_text, sample_user_message) + error_task_info.result.messages[0].parts[0].root.text = error_response_text + mock_plan_client.execute_agent_task.return_value = error_task_info + + + final_message = await host_agent_instance.execute(message=sample_user_message) + + mock_plan_client.execute_agent_task.assert_called_once() + mock_search_client.execute_agent_task.assert_not_called() # Should not be called + mock_report_client.execute_agent_task.assert_not_called() # Should not be called + + assert final_message.role == Role.AGENT + assert get_message_text(final_message) == error_response_text + + +@pytest.mark.asyncio +@patch('examples.multi_agent_system.host_agent.A2AClient') +async def test_host_agent_execute_search_agent_error( + MockA2AClientConstructor: MagicMock, + host_agent_instance: HostAgent, + sample_user_message: Message, + sample_task_description: str +): + mock_plan_client = AsyncMock(spec=A2AClient) + mock_search_client = AsyncMock(spec=A2AClient) + mock_report_client = AsyncMock(spec=A2AClient) + + MockA2AClientConstructor.side_effect = lambda server_url, http_client: { + DUMMY_PLAN_URL: mock_plan_client, + DUMMY_SEARCH_URL: mock_search_client, + DUMMY_REPORT_URL: mock_report_client + }.get(server_url) + + plan_response_text = f"Plan for '{sample_task_description}'" + search_error_text = "Error: SearchAgent could not find anything." + report_response_text = "Final report reflecting search failure" + + mock_plan_client.execute_agent_task = AsyncMock(return_value=create_mock_a2a_client_task_info(plan_response_text, sample_user_message)) + + search_error_task_info = create_mock_a2a_client_task_info(search_error_text, sample_user_message) + search_error_task_info.result.messages[0].parts[0].root.text = search_error_text # Ensure it starts with "Error:" + mock_search_client.execute_agent_task = AsyncMock(return_value=search_error_task_info) + + mock_report_client.execute_agent_task = AsyncMock(return_value=create_mock_a2a_client_task_info(report_response_text, sample_user_message)) + + final_message = await host_agent_instance.execute(message=sample_user_message) + + mock_plan_client.execute_agent_task.assert_called_once() + mock_search_client.execute_agent_task.assert_called_once() + mock_report_client.execute_agent_task.assert_called_once() + + # Check that ReportAgent received input indicating search failure + report_call_args = mock_report_client.execute_agent_task.call_args[0][0] + expected_report_input = f"Plan:\n{plan_response_text}\n\nSearch Results: Failed - {search_error_text}" + assert get_message_text(report_call_args[0]) == expected_report_input + + assert get_message_text(final_message) == report_response_text + + +@pytest.mark.asyncio +async def test_host_agent_execute_empty_task_description(host_agent_instance: HostAgent, empty_user_message: Message): + """ + Tests HostAgent's response to an empty task description. + """ + # No mocking needed for A2AClient as it should return error before client calls + response_message = await host_agent_instance.execute(message=empty_user_message) + + assert response_message is not None + assert response_message.role == Role.AGENT + response_text = get_message_text(response_message) + assert "Error: HostAgent received a message with no task description." in response_text + + +@pytest.mark.asyncio +async def test_host_agent_cancel_method(host_agent_instance: HostAgent): + """ + Tests the cancel method of the HostAgent. + Current implementation raises NotImplementedError. + """ + test_interaction_id = "test_cancel_host_interaction_123" + with pytest.raises(NotImplementedError) as excinfo: + await host_agent_instance.cancel(interaction_id=test_interaction_id) + assert "HostAgent cancellation requires propagation to sub-agents" in str(excinfo.value) diff --git a/examples/multi_agent_system/tests/test_plan_agent.py b/examples/multi_agent_system/tests/test_plan_agent.py new file mode 100644 index 00000000..65c78d52 --- /dev/null +++ b/examples/multi_agent_system/tests/test_plan_agent.py @@ -0,0 +1,151 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import uuid +from unittest.mock import Mock # Though not used for event_queue/context in execute directly + +from examples.multi_agent_system.plan_agent import PlanAgent +from a2a.types import Message, Role, TextPart, Part +from a2a.utils.message import get_message_text, new_user_text_message + +@pytest.fixture +def plan_agent_instance(): + """Provides an instance of PlanAgent.""" + return PlanAgent(name="TestPlanAgent") + +@pytest.fixture +def sample_task_description(): + """Provides a sample task description string.""" + return "Develop a marketing strategy for a new product." + +@pytest.fixture +def user_message_with_task(sample_task_description): + """Provides a sample user Message object with a task description.""" + # Using new_user_text_message for convenience, though manual creation is also fine + return new_user_text_message( + text=sample_task_description, + context_id=str(uuid.uuid4()), # Or task_id + task_id=str(uuid.uuid4()) + ) + +@pytest.fixture +def empty_user_message(): + """Provides a user Message object with no text content.""" + return Message( + messageId=str(uuid.uuid4()), + role=Role.USER, + parts=[], # No text parts + contextId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()) + ) + +@pytest.mark.asyncio +async def test_plan_agent_execute_success(plan_agent_instance: PlanAgent, user_message_with_task: Message, sample_task_description: str): + """ + Tests successful plan generation by the PlanAgent. + """ + response_message = await plan_agent_instance.execute(message=user_message_with_task) + + assert response_message is not None + assert response_message.role == Role.AGENT + assert response_message.contextId == user_message_with_task.contextId + assert response_message.taskId == user_message_with_task.taskId + + response_text = get_message_text(response_message) + assert sample_task_description in response_text + assert "Step 1: Understand the task" in response_text + assert "Step 5: Output the plan." in response_text + assert not response_text.startswith("Error:") + +@pytest.mark.asyncio +async def test_plan_agent_execute_empty_task_description(plan_agent_instance: PlanAgent, empty_user_message: Message): + """ + Tests PlanAgent's response to an empty task description. + """ + response_message = await plan_agent_instance.execute(message=empty_user_message) + + assert response_message is not None + assert response_message.role == Role.AGENT # It still replies as an agent + + response_text = get_message_text(response_message) + assert "Error: PlanAgent received a message with no text content." in response_text + +@pytest.mark.asyncio +async def test_plan_agent_execute_non_user_role_message(plan_agent_instance: PlanAgent, sample_task_description: str): + """ + Tests PlanAgent's behavior when receiving a message from a non-USER role. + The current implementation logs a warning but still processes the task. + """ + agent_message_with_task = Message( + messageId=str(uuid.uuid4()), + role=Role.AGENT, # Message from another agent + parts=[Part(root=TextPart(text=sample_task_description))], + contextId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()) + ) + + response_message = await plan_agent_instance.execute(message=agent_message_with_task) + + assert response_message is not None + assert response_message.role == Role.AGENT + response_text = get_message_text(response_message) + assert sample_task_description in response_text # Should still process + assert "Step 1: Understand the task" in response_text + assert not response_text.startswith("Error:") + + +@pytest.mark.asyncio +async def test_plan_agent_cancel_method(plan_agent_instance: PlanAgent): + """ + Tests the cancel method of the PlanAgent. + The current implementation just prints and passes. + """ + test_interaction_id = "test_cancel_interaction_123" + try: + await plan_agent_instance.cancel(interaction_id=test_interaction_id) + # If it had super().cancel that could fail: + # await super(PlanAgent, plan_agent_instance).cancel(interaction_id=test_interaction_id) + except Exception as e: + pytest.fail(f"PlanAgent.cancel() raised an unexpected exception: {e}") + # No specific assertion needed if it's just a print/pass, other than it doesn't error. + # If it were to raise NotImplementedError, the test would be: + # with pytest.raises(NotImplementedError): + # await plan_agent_instance.cancel(interaction_id=test_interaction_id) + +# Example of how one might mock if PlanAgent.execute *did* use event_queue: +# @pytest.mark.asyncio +# async def test_plan_agent_execute_with_mocked_event_queue(plan_agent_instance: PlanAgent, user_message_with_task: Message): +# """ +# Example test if PlanAgent.execute directly used an event_queue. +# THIS IS NOT HOW THE CURRENT PlanAgent WORKS. +# """ +# mock_event_queue = Mock() +# mock_context = Mock() # If context was also a direct argument + +# # Let's imagine PlanAgent was modified to take these: +# # response_message = await plan_agent_instance.execute( +# # message=user_message_with_task, +# # event_queue=mock_event_queue, # Hypothetical argument +# # context=mock_context # Hypothetical argument +# # ) + +# # If execute was supposed to enqueue its own result: +# # mock_event_queue.enqueue_event.assert_called_once() +# # called_event_message = mock_event_queue.enqueue_event.call_args[0][0] # Assuming event is the first arg +# # assert called_event_message.role == Role.AGENT +# # assert "Step 1: Understand the task" in get_message_text(called_event_message) + +# # This test is for illustration; the actual PlanAgent.execute returns a Message. +# pass diff --git a/examples/multi_agent_system/tests/test_report_agent.py b/examples/multi_agent_system/tests/test_report_agent.py new file mode 100644 index 00000000..d608fad4 --- /dev/null +++ b/examples/multi_agent_system/tests/test_report_agent.py @@ -0,0 +1,100 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import uuid + +from examples.multi_agent_system.report_agent import ReportAgent +from a2a.types import Message, Role, TextPart, Part +from a2a.utils.message import get_message_text, new_user_text_message + +@pytest.fixture +def report_agent_instance(): + """Provides an instance of ReportAgent.""" + return ReportAgent(name="TestReportAgent") + +@pytest.fixture +def sample_combined_data(): + """Provides a sample combined data string (plan + search results).""" + plan = "Plan:\nStep 1: Do this.\nStep 2: Do that." + search_results = "Search Results:\nResult A found.\nResult B found." + return f"{plan}\n\n{search_results}" + +@pytest.fixture +def user_message_with_data(sample_combined_data): + """Provides a sample user Message object with combined data.""" + return new_user_text_message( + text=sample_combined_data, + context_id=str(uuid.uuid4()), + task_id=str(uuid.uuid4()) + ) + +@pytest.fixture +def empty_user_message(): + """Provides a user Message object with no text content.""" + return Message( + messageId=str(uuid.uuid4()), + role=Role.USER, + parts=[], # No text parts + contextId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()) + ) + +@pytest.mark.asyncio +async def test_report_agent_execute_success(report_agent_instance: ReportAgent, user_message_with_data: Message, sample_combined_data: str): + """ + Tests successful report generation by the ReportAgent. + """ + response_message = await report_agent_instance.execute(message=user_message_with_data) + + assert response_message is not None + assert response_message.role == Role.AGENT + assert response_message.contextId == user_message_with_data.contextId + assert response_message.taskId == user_message_with_data.taskId + + response_text = get_message_text(response_message) + assert "--- Combined Report ---" in response_text + assert "Processed Input:" in response_text + assert sample_combined_data in response_text # Original data should be part of the report + assert "End of Report." in response_text + assert not response_text.startswith("Error:") + +@pytest.mark.asyncio +async def test_report_agent_execute_empty_data(report_agent_instance: ReportAgent, empty_user_message: Message): + """ + Tests ReportAgent's response to empty input data. + """ + response_message = await report_agent_instance.execute(message=empty_user_message) + + assert response_message is not None + assert response_message.role == Role.AGENT # Still replies as an agent + + response_text = get_message_text(response_message) + assert "Error: ReportAgent received a message with no content to report." in response_text + +@pytest.mark.asyncio +async def test_report_agent_cancel_method(report_agent_instance: ReportAgent): + """ + Tests the cancel method of the ReportAgent. + The current implementation just prints and passes. + """ + test_interaction_id = "test_cancel_interaction_report_789" + try: + await report_agent_instance.cancel(interaction_id=test_interaction_id) + except Exception as e: + pytest.fail(f"ReportAgent.cancel() raised an unexpected exception: {e}") + # No specific assertion needed if it's just a print/pass. + # If it were to raise NotImplementedError, the test would be: + # with pytest.raises(NotImplementedError): + # await report_agent_instance.cancel(interaction_id=test_interaction_id) diff --git a/examples/multi_agent_system/tests/test_search_agent.py b/examples/multi_agent_system/tests/test_search_agent.py new file mode 100644 index 00000000..24b31fba --- /dev/null +++ b/examples/multi_agent_system/tests/test_search_agent.py @@ -0,0 +1,98 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import uuid + +from examples.multi_agent_system.search_agent import SearchAgent +from a2a.types import Message, Role, TextPart, Part +from a2a.utils.message import get_message_text, new_user_text_message + +@pytest.fixture +def search_agent_instance(): + """Provides an instance of SearchAgent.""" + return SearchAgent(name="TestSearchAgent") + +@pytest.fixture +def sample_search_query(): + """Provides a sample search query string.""" + return "latest advancements in AI" + +@pytest.fixture +def user_message_with_query(sample_search_query): + """Provides a sample user Message object with a search query.""" + return new_user_text_message( + text=sample_search_query, + context_id=str(uuid.uuid4()), + task_id=str(uuid.uuid4()) + ) + +@pytest.fixture +def empty_user_message(): + """Provides a user Message object with no text content.""" + return Message( + messageId=str(uuid.uuid4()), + role=Role.USER, + parts=[], # No text parts + contextId=str(uuid.uuid4()), + taskId=str(uuid.uuid4()) + ) + +@pytest.mark.asyncio +async def test_search_agent_execute_success(search_agent_instance: SearchAgent, user_message_with_query: Message, sample_search_query: str): + """ + Tests successful search operation by the SearchAgent. + """ + response_message = await search_agent_instance.execute(message=user_message_with_query) + + assert response_message is not None + assert response_message.role == Role.AGENT + assert response_message.contextId == user_message_with_query.contextId + assert response_message.taskId == user_message_with_query.taskId + + response_text = get_message_text(response_message) + assert f"Search results for query: '{sample_search_query}'" in response_text + # Check for dummy result structure (e.g., contains "https://example.com") + assert "https://example.com/search?q=" in response_text + assert "https://en.wikipedia.org/wiki/" in response_text + assert not response_text.startswith("Error:") + +@pytest.mark.asyncio +async def test_search_agent_execute_empty_query(search_agent_instance: SearchAgent, empty_user_message: Message): + """ + Tests SearchAgent's response to an empty search query. + """ + response_message = await search_agent_instance.execute(message=empty_user_message) + + assert response_message is not None + assert response_message.role == Role.AGENT # Still replies as an agent + + response_text = get_message_text(response_message) + assert "Error: SearchAgent received a message with no search query." in response_text + +@pytest.mark.asyncio +async def test_search_agent_cancel_method(search_agent_instance: SearchAgent): + """ + Tests the cancel method of the SearchAgent. + The current implementation just prints and passes. + """ + test_interaction_id = "test_cancel_interaction_search_456" + try: + await search_agent_instance.cancel(interaction_id=test_interaction_id) + except Exception as e: + pytest.fail(f"SearchAgent.cancel() raised an unexpected exception: {e}") + # No specific assertion needed if it's just a print/pass. + # If it were to raise NotImplementedError, the test would be: + # with pytest.raises(NotImplementedError): + # await search_agent_instance.cancel(interaction_id=test_interaction_id)