From 08bb1a10722ab19a5cf783404d424425e46dc6c2 Mon Sep 17 00:00:00 2001 From: swapnilag Date: Thu, 15 May 2025 02:29:29 +0000 Subject: [PATCH 1/4] Add push notification support --- .../google_adk/birthday_planner/__main__.py | 2 + .../birthday_planner/adk_agent_executor.py | 28 +- examples/langgraph/__main__.py | 5 +- src/a2a/server/agent_execution/context.py | 2 +- .../server/events/in_memory_queue_manager.py | 2 +- .../default_request_handler.py | 79 +++++- src/a2a/server/tasks/__init__.py | 4 + .../server/tasks/inmemory_push_notifier.py | 49 ++++ src/a2a/server/tasks/push_notifier.py | 25 ++ src/a2a/server/tasks/result_aggregator.py | 5 +- .../request_handlers/test_jsonrpc_handler.py | 241 ++++++++++++++++-- 11 files changed, 389 insertions(+), 53 deletions(-) create mode 100644 src/a2a/server/tasks/inmemory_push_notifier.py create mode 100644 src/a2a/server/tasks/push_notifier.py diff --git a/examples/google_adk/birthday_planner/__main__.py b/examples/google_adk/birthday_planner/__main__.py index d4a0343d..c6ef4225 100644 --- a/examples/google_adk/birthday_planner/__main__.py +++ b/examples/google_adk/birthday_planner/__main__.py @@ -5,6 +5,7 @@ import click import uvicorn + from adk_agent_executor import ADKAgentExecutor from dotenv import load_dotenv @@ -18,6 +19,7 @@ AgentSkill, ) + load_dotenv() logging.basicConfig() diff --git a/examples/google_adk/birthday_planner/adk_agent_executor.py b/examples/google_adk/birthday_planner/adk_agent_executor.py index 49f13e96..30e6826f 100644 --- a/examples/google_adk/birthday_planner/adk_agent_executor.py +++ b/examples/google_adk/birthday_planner/adk_agent_executor.py @@ -1,10 +1,12 @@ import asyncio import logging -from collections.abc import AsyncGenerator -from typing import Any, AsyncIterable + +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any from uuid import uuid4 import httpx + from google.adk import Runner from google.adk.agents import LlmAgent, RunConfig from google.adk.artifacts import InMemoryArtifactService @@ -42,6 +44,7 @@ from a2a.utils import get_text_parts from a2a.utils.errors import ServerError + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -66,7 +69,7 @@ def __init__(self, calendar_agent_url): name='birthday_planner_agent', description='An agent that helps manage birthday parties.', after_tool_callback=self._handle_auth_required_task, - instruction=f""" + instruction=""" You are an agent that helps plan birthday parties. Your job as a party planner is to act as a sounding board and idea generator for @@ -165,7 +168,7 @@ async def _process_request( task_updater.add_artifact(response) task_updater.complete() break - elif calls := event.get_function_calls(): + if calls := event.get_function_calls(): for call in calls: # Provide an update on what we're doing. if call.name == 'message_calendar_agent': @@ -314,23 +317,21 @@ def convert_a2a_part_to_genai(part: Part) -> types.Part: part = part.root if isinstance(part, TextPart): return types.Part(text=part.text) - elif isinstance(part, FilePart): + if isinstance(part, FilePart): if isinstance(part.file, FileWithUri): return types.Part( file_data=types.FileData( file_uri=part.file.uri, mime_type=part.file.mime_type ) ) - elif isinstance(part.file, FileWithBytes): + if isinstance(part.file, FileWithBytes): return types.Part( inline_data=types.Blob( data=part.file.bytes, mime_type=part.file.mime_type ) ) - else: - raise ValueError(f'Unsupported file type: {type(part.file)}') - else: - raise ValueError(f'Unsupported part type: {type(part)}') + raise ValueError(f'Unsupported file type: {type(part.file)}') + raise ValueError(f'Unsupported part type: {type(part)}') def convert_genai_parts_to_a2a(parts: list[types.Part]) -> list[Part]: @@ -346,14 +347,14 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part: """Convert a single Google GenAI Part type into an A2A Part type.""" if part.text: return TextPart(text=part.text) - elif part.file_data: + if part.file_data: return FilePart( file=FileWithUri( uri=part.file_data.file_uri, mime_type=part.file_data.mime_type, ) ) - elif part.inline_data: + if part.inline_data: return Part( root=FilePart( file=FileWithBytes( @@ -362,5 +363,4 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part: ) ) ) - else: - raise ValueError(f'Unsupported part type: {part}') + raise ValueError(f'Unsupported part type: {part}') diff --git a/examples/langgraph/__main__.py b/examples/langgraph/__main__.py index af06aeae..6f151ed3 100644 --- a/examples/langgraph/__main__.py +++ b/examples/langgraph/__main__.py @@ -2,6 +2,7 @@ import sys import click +import httpx from agent import CurrencyAgent from agent_executor import CurrencyAgentExecutor @@ -9,7 +10,7 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler -from a2a.server.tasks import InMemoryTaskStore +from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore from a2a.types import ( AgentAuthentication, AgentCapabilities, @@ -29,9 +30,11 @@ def main(host: str, port: int): print('GOOGLE_API_KEY environment variable not set.') sys.exit(1) + client = httpx.AsyncClient() request_handler = DefaultRequestHandler( agent_executor=CurrencyAgentExecutor(), task_store=InMemoryTaskStore(), + push_notifier=InMemoryPushNotifier(client), ) server = A2AStarletteApplication( diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 3c61dc6b..870c5f8e 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -19,7 +19,7 @@ def __init__( task_id: str | None = None, context_id: str | None = None, task: Task | None = None, - related_tasks: list[Task] = None, + related_tasks: list[Task] | None = None, ): if related_tasks is None: related_tasks = [] diff --git a/src/a2a/server/events/in_memory_queue_manager.py b/src/a2a/server/events/in_memory_queue_manager.py index a0d95f8e..9d4a135b 100644 --- a/src/a2a/server/events/in_memory_queue_manager.py +++ b/src/a2a/server/events/in_memory_queue_manager.py @@ -18,7 +18,7 @@ class InMemoryQueueManager(QueueManager): true scalable deployment. """ - def __init__(self): + def __init__(self) -> None: self._task_queue: dict[str, EventQueue] = {} self._lock = asyncio.Lock() diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index f656e414..a1604919 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,6 +1,6 @@ import asyncio -import contextlib import logging + from collections.abc import AsyncGenerator from typing import cast @@ -10,12 +10,16 @@ EventConsumer, EventQueue, InMemoryQueueManager, - NoTaskQueue, QueueManager, TaskQueueExists, ) from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.tasks import ResultAggregator, TaskManager, TaskStore +from a2a.server.tasks import ( + PushNotifier, + ResultAggregator, + TaskManager, + TaskStore, +) from a2a.types import ( InternalError, Message, @@ -29,6 +33,7 @@ ) from a2a.utils.errors import ServerError + logger = logging.getLogger(__name__) @@ -42,10 +47,12 @@ def __init__( agent_executor: AgentExecutor, task_store: TaskStore, queue_manager: QueueManager | None = None, + push_notifier: PushNotifier | None = None, ) -> None: self.agent_executor = agent_executor self.task_store = task_store self._queue_manager = queue_manager or InMemoryQueueManager() + self._push_notifier = push_notifier # TODO: Likely want an interface for managing this, like AgentExecutionManager. self._running_agents = {} self._running_agents_lock = asyncio.Lock() @@ -116,6 +123,15 @@ async def on_message_send( task: Task | None = await task_manager.get_task() if task: task = task_manager.update_with_message(params.message, task) + if ( + self._push_notifier + and params.configuration + and params.configuration.pushNotificationConfig + and not params.configuration.blocking + ): + await self._push_notifier.set_info( + task.id, params.configuration.pushNotificationConfig + ) request_context = RequestContext( params, task.id if task else None, @@ -173,6 +189,16 @@ async def on_message_send_stream( if task: task = task_manager.update_with_message(params.message, task) + if ( + self._push_notifier + and params.configuration + and params.configuration.pushNotificationConfig + ): + await self._push_notifier.set_info( + task.id, params.configuration.pushNotificationConfig + ) + else: + queue = EventQueue() result_aggregator = ResultAggregator(task_manager) request_context = RequestContext( params, @@ -196,12 +222,26 @@ async def on_message_send_stream( # Now we know we have a Task, register the queue if isinstance(event, Task): try: - await self._queue_manager.add(event.id, queue) - task_id = event.id + created_task: Task = event + await self._queue_manager.add(created_task.id, queue) + task_id = created_task.id except TaskQueueExists: logging.info( 'Multiple Task objects created in event stream.' ) + if ( + self._push_notifier + and params.configuration + and params.configuration.pushNotificationConfig + ): + await self._push_notifier.set_info( + created_task.id, + params.configuration.pushNotificationConfig, + ) + if self._push_notifier and task_id: + latest_task = await result_aggregator.current_result + if isinstance(latest_task, Task): + await self._push_notifier.send_notification(latest_task) yield event finally: await self._cleanup_producer(producer_task, task_id) @@ -220,13 +260,38 @@ async def on_set_task_push_notification_config( self, params: TaskPushNotificationConfig ) -> TaskPushNotificationConfig: """Default handler for 'tasks/pushNotificationConfig/set'.""" - raise ServerError(error=UnsupportedOperationError()) + if not self._push_notifier: + raise ServerError(error=UnsupportedOperationError()) + + task: Task | None = await self.task_store.get(params.taskId) + if not task: + raise ServerError(error=TaskNotFoundError()) + + await self._push_notifier.set_info( + params.taskId, + params.pushNotificationConfig, + ) + + return params async def on_get_task_push_notification_config( self, params: TaskIdParams ) -> TaskPushNotificationConfig: """Default handler for 'tasks/pushNotificationConfig/get'.""" - raise ServerError(error=UnsupportedOperationError()) + if not self._push_notifier: + raise ServerError(error=UnsupportedOperationError()) + + task: Task | None = await self.task_store.get(params.id) + if not task: + raise ServerError(error=TaskNotFoundError()) + + push_notification_config = await self._push_notifier.get_info(params.id) + if not push_notification_config: + raise ServerError(error=TaskNotFoundError()) + + return TaskPushNotificationConfig( + taskId=params.id, pushNotificationConfig=push_notification_config + ) async def on_resubscribe_to_task( self, params: TaskIdParams diff --git a/src/a2a/server/tasks/__init__.py b/src/a2a/server/tasks/__init__.py index d61df11f..4dc94947 100644 --- a/src/a2a/server/tasks/__init__.py +++ b/src/a2a/server/tasks/__init__.py @@ -1,4 +1,6 @@ +from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.server.tasks.push_notifier import PushNotifier from a2a.server.tasks.result_aggregator import ResultAggregator from a2a.server.tasks.task_manager import TaskManager from a2a.server.tasks.task_store import TaskStore @@ -6,7 +8,9 @@ __all__ = [ + 'InMemoryPushNotifier', 'InMemoryTaskStore', + 'PushNotifier', 'ResultAggregator', 'TaskManager', 'TaskStore', diff --git a/src/a2a/server/tasks/inmemory_push_notifier.py b/src/a2a/server/tasks/inmemory_push_notifier.py new file mode 100644 index 00000000..6af15516 --- /dev/null +++ b/src/a2a/server/tasks/inmemory_push_notifier.py @@ -0,0 +1,49 @@ +import asyncio +import logging + +import httpx + +from a2a.server.tasks.push_notifier import PushNotifier +from a2a.types import PushNotificationConfig, Task + + +logger = logging.getLogger(__name__) + + +class InMemoryPushNotifier(PushNotifier): + """In-memory implementation of PushNotifier interface.""" + + def __init__(self, httpx_client: httpx.AsyncClient) -> None: + self._client = httpx_client + self.lock = asyncio.Lock() + self._push_notification_infos: dict[str, PushNotificationConfig] = {} + + async def set_info( + self, task_id: str, notification_config: PushNotificationConfig + ): + async with self.lock: + self._push_notification_infos[task_id] = notification_config + + async def get_info(self, task_id: str) -> PushNotificationConfig | None: + async with self.lock: + return self._push_notification_infos.get(task_id) + + async def delete_info(self, task_id: str): + async with self.lock: + if task_id in self._push_notification_infos: + del self._push_notification_infos[task_id] + + async def send_notification(self, task: Task): + push_info = await self.get_info(task.id) + if not push_info: + return + url = push_info.url + + try: + response = await self._client.post( + url, json=task.model_dump(exclude_none=True) + ) + response.raise_for_status() + logger.info(f'Push-notification sent for URL: {url}') + except Exception as e: + logger.error(f'Error sending push-notification: {e}') diff --git a/src/a2a/server/tasks/push_notifier.py b/src/a2a/server/tasks/push_notifier.py new file mode 100644 index 00000000..10f01f3a --- /dev/null +++ b/src/a2a/server/tasks/push_notifier.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +from a2a.types import PushNotificationConfig, Task + + +class PushNotifier(ABC): + """PushNotifier interface to store, retrieve push notification for tasks and send push notifications.""" + + @abstractmethod + async def set_info( + self, task_id: str, notification_config: PushNotificationConfig + ): + pass + + @abstractmethod + async def get_info(self, task_id: str) -> PushNotificationConfig | None: + pass + + @abstractmethod + async def delete_info(self, task_id: str): + pass + + @abstractmethod + async def send_notification(self, task: Task): + pass diff --git a/src/a2a/server/tasks/result_aggregator.py b/src/a2a/server/tasks/result_aggregator.py index 9220fc86..2b7bc5f6 100644 --- a/src/a2a/server/tasks/result_aggregator.py +++ b/src/a2a/server/tasks/result_aggregator.py @@ -1,12 +1,13 @@ import asyncio import logging + from collections.abc import AsyncGenerator, AsyncIterator -from typing import Tuple from a2a.server.events import Event, EventConsumer from a2a.server.tasks.task_manager import TaskManager from a2a.types import Message, Task, TaskState, TaskStatusUpdateEvent + logger = logging.getLogger(__name__) @@ -54,7 +55,7 @@ async def consume_all( async def consume_and_break_on_interrupt( self, consumer: EventConsumer - ) -> Tuple[Task | Message | None, bool]: + ) -> tuple[Task | Message | None, bool]: """Process the event stream until completion or an interruptable state is encountered.""" event_stream = consumer.consume_all() interrupted = False diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index c482fcf7..8dd02092 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -1,49 +1,63 @@ import unittest import unittest.async_case -from unittest.mock import AsyncMock, patch, MagicMock + +from collections.abc import AsyncGenerator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, call, patch + +import httpx import pytest -from a2a.server.events.event_queue import EventQueue + from a2a.server.agent_execution import AgentExecutor -from a2a.utils.errors import ServerError +from a2a.server.events import ( + QueueManager, +) +from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import ( DefaultRequestHandler, JSONRPCHandler, ) -from a2a.server.events import ( - QueueManager, -) -from a2a.server.tasks import TaskStore +from a2a.server.tasks import InMemoryPushNotifier, PushNotifier, TaskStore from a2a.types import ( - AgentCard, AgentCapabilities, + AgentCard, + Artifact, + CancelTaskRequest, + CancelTaskSuccessResponse, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigSuccessResponse, GetTaskRequest, GetTaskResponse, GetTaskSuccessResponse, - Task, - TaskQueryParams, JSONRPCErrorResponse, - TaskNotFoundError, - TaskIdParams, - CancelTaskRequest, - CancelTaskSuccessResponse, - UnsupportedOperationError, - SendMessageRequest, Message, + MessageSendConfiguration, MessageSendParams, + Part, + PushNotificationConfig, + SendMessageRequest, SendMessageSuccessResponse, SendStreamingMessageRequest, SendStreamingMessageSuccessResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + SetTaskPushNotificationConfigSuccessResponse, + Task, TaskArtifactUpdateEvent, + TaskIdParams, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskQueryParams, + TaskResubscriptionRequest, + TaskState, + TaskStatus, TaskStatusUpdateEvent, - Artifact, - Part, TextPart, - TaskStatus, - TaskState, - TaskResubscriptionRequest, + UnsupportedOperationError, ) -from collections.abc import AsyncGenerator -from typing import Any +from a2a.utils.errors import ServerError + MINIMAL_TASK: dict[str, Any] = { 'id': 'task_123', @@ -316,7 +330,7 @@ async def streaming_coro(): assert isinstance( event.root, SendStreamingMessageSuccessResponse ) - assert collected_events[i].root.result == events[i] + assert event.root.result == events[i] mock_agent_executor.execute.assert_called_once() async def test_on_message_stream_new_message_existing_task_success( @@ -371,13 +385,186 @@ async def streaming_coro(): ) response = handler.on_message_send_stream(request) assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) + collected_events = [item async for item in response] assert len(collected_events) == len(events) mock_agent_executor.execute.assert_called_once() assert mock_task.history is not None and len(mock_task.history) == 1 + async def test_set_push_notif_success(self) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + mock_push_notifier = AsyncMock(spec=PushNotifier) + request_handler = DefaultRequestHandler( + mock_agent_executor, + mock_task_store, + push_notifier=mock_push_notifier, + ) + self.mock_agent_card.capabilities = AgentCapabilities( + streaming=True, pushNotifications=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + task_push_config = TaskPushNotificationConfig( + taskId=mock_task.id, + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + request = SetTaskPushNotificationConfigRequest( + id='1', params=task_push_config + ) + response: SetTaskPushNotificationConfigResponse = ( + await handler.set_push_notification(request) + ) + self.assertIsInstance( + response.root, SetTaskPushNotificationConfigSuccessResponse + ) + assert response.root.result == task_push_config # type: ignore + mock_push_notifier.set_info.assert_called_once_with( + mock_task.id, task_push_config.pushNotificationConfig + ) + + async def test_get_push_notif_success(self) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + push_notifier = InMemoryPushNotifier(httpx_client=mock_httpx_client) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store, push_notifier=push_notifier + ) + self.mock_agent_card.capabilities = AgentCapabilities( + streaming=True, pushNotifications=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + task_push_config = TaskPushNotificationConfig( + taskId=mock_task.id, + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + request = SetTaskPushNotificationConfigRequest( + id='1', params=task_push_config + ) + await handler.set_push_notification(request) + + get_request: GetTaskPushNotificationConfigRequest = ( + GetTaskPushNotificationConfigRequest( + id='1', params=TaskIdParams(id=mock_task.id) + ) + ) + get_response: GetTaskPushNotificationConfigResponse = ( + await handler.get_push_notification(get_request) + ) + self.assertIsInstance( + get_response.root, GetTaskPushNotificationConfigSuccessResponse + ) + assert get_response.root.result == task_push_config # type: ignore + + async def test_on_message_stream_new_message_send_push_notif_success( + self, + ) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + push_notifier = InMemoryPushNotifier(httpx_client=mock_httpx_client) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store, push_notifier=push_notifier + ) + self.mock_agent_card.capabilities = AgentCapabilities( + streaming=True, pushNotifications=True + ) + + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + events: list[Any] = [ + Task(**MINIMAL_TASK), + TaskArtifactUpdateEvent( + taskId='task_123', + contextId='session-xyz', + artifact=Artifact( + artifactId='11', parts=[Part(TextPart(text='text'))] + ), + ), + TaskStatusUpdateEvent( + taskId='task_123', + contextId='session-xyz', + status=TaskStatus(state=TaskState.completed), + final=True, + ), + ] + + async def streaming_coro(): + for event in events: + yield event + + with patch( + 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', + return_value=streaming_coro(), + ): + mock_task_store.get.return_value = None + mock_agent_executor.execute.return_value = None + mock_httpx_client.post.return_value = httpx.Response(200) + request = SendStreamingMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + request.params.configuration = MessageSendConfiguration( + acceptedOutputModes=['text'], + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + response = handler.on_message_send_stream(request) + assert isinstance(response, AsyncGenerator) + + collected_events = [item async for item in response] + assert len(collected_events) == len(events) + + calls = [ + call( + 'http://example.com', + json={ + 'contextId': 'session-xyz', + 'id': 'task_123', + 'status': {'state': TaskState.submitted}, + 'type': 'task', + }, + ), + call( + 'http://example.com', + json={ + 'artifacts': [ + { + 'artifactId': '11', + 'parts': [{'text': 'text', 'type': 'text'}], + } + ], + 'contextId': 'session-xyz', + 'id': 'task_123', + 'status': {'state': TaskState.submitted}, + 'type': 'task', + }, + ), + call( + 'http://example.com', + json={ + 'artifacts': [ + { + 'artifactId': '11', + 'parts': [{'text': 'text', 'type': 'text'}], + } + ], + 'contextId': 'session-xyz', + 'id': 'task_123', + 'status': {'state': TaskState.completed}, + 'type': 'task', + }, + ), + ] + mock_httpx_client.post.assert_has_calls(calls) + async def test_on_resubscribe_existing_task_success( self, ) -> None: From e5adaa16dc170df7285c236466ef02467f4116b2 Mon Sep 17 00:00:00 2001 From: swapnilag Date: Thu, 15 May 2025 18:13:31 +0000 Subject: [PATCH 2/4] adress review comments: rename notif & reverted elifs --- .../birthday_planner/adk_agent_executor.py | 28 +++++++++---------- .../request_handlers/test_jsonrpc_handler.py | 6 ++-- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/google_adk/birthday_planner/adk_agent_executor.py b/examples/google_adk/birthday_planner/adk_agent_executor.py index 30e6826f..49f13e96 100644 --- a/examples/google_adk/birthday_planner/adk_agent_executor.py +++ b/examples/google_adk/birthday_planner/adk_agent_executor.py @@ -1,12 +1,10 @@ import asyncio import logging - -from collections.abc import AsyncGenerator, AsyncIterable -from typing import Any +from collections.abc import AsyncGenerator +from typing import Any, AsyncIterable from uuid import uuid4 import httpx - from google.adk import Runner from google.adk.agents import LlmAgent, RunConfig from google.adk.artifacts import InMemoryArtifactService @@ -44,7 +42,6 @@ from a2a.utils import get_text_parts from a2a.utils.errors import ServerError - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -69,7 +66,7 @@ def __init__(self, calendar_agent_url): name='birthday_planner_agent', description='An agent that helps manage birthday parties.', after_tool_callback=self._handle_auth_required_task, - instruction=""" + instruction=f""" You are an agent that helps plan birthday parties. Your job as a party planner is to act as a sounding board and idea generator for @@ -168,7 +165,7 @@ async def _process_request( task_updater.add_artifact(response) task_updater.complete() break - if calls := event.get_function_calls(): + elif calls := event.get_function_calls(): for call in calls: # Provide an update on what we're doing. if call.name == 'message_calendar_agent': @@ -317,21 +314,23 @@ def convert_a2a_part_to_genai(part: Part) -> types.Part: part = part.root if isinstance(part, TextPart): return types.Part(text=part.text) - if isinstance(part, FilePart): + elif isinstance(part, FilePart): if isinstance(part.file, FileWithUri): return types.Part( file_data=types.FileData( file_uri=part.file.uri, mime_type=part.file.mime_type ) ) - if isinstance(part.file, FileWithBytes): + elif isinstance(part.file, FileWithBytes): return types.Part( inline_data=types.Blob( data=part.file.bytes, mime_type=part.file.mime_type ) ) - raise ValueError(f'Unsupported file type: {type(part.file)}') - raise ValueError(f'Unsupported part type: {type(part)}') + else: + raise ValueError(f'Unsupported file type: {type(part.file)}') + else: + raise ValueError(f'Unsupported part type: {type(part)}') def convert_genai_parts_to_a2a(parts: list[types.Part]) -> list[Part]: @@ -347,14 +346,14 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part: """Convert a single Google GenAI Part type into an A2A Part type.""" if part.text: return TextPart(text=part.text) - if part.file_data: + elif part.file_data: return FilePart( file=FileWithUri( uri=part.file_data.file_uri, mime_type=part.file_data.mime_type, ) ) - if part.inline_data: + elif part.inline_data: return Part( root=FilePart( file=FileWithBytes( @@ -363,4 +362,5 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part: ) ) ) - raise ValueError(f'Unsupported part type: {part}') + else: + raise ValueError(f'Unsupported part type: {part}') diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 8dd02092..fddf7e0e 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -390,7 +390,7 @@ async def streaming_coro(): mock_agent_executor.execute.assert_called_once() assert mock_task.history is not None and len(mock_task.history) == 1 - async def test_set_push_notif_success(self) -> None: + async def test_set_push_notification_success(self) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_task_store = AsyncMock(spec=TaskStore) mock_push_notifier = AsyncMock(spec=PushNotifier) @@ -425,7 +425,7 @@ async def test_set_push_notif_success(self) -> None: mock_task.id, task_push_config.pushNotificationConfig ) - async def test_get_push_notif_success(self) -> None: + async def test_get_push_notification_success(self) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_task_store = AsyncMock(spec=TaskStore) mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) @@ -463,7 +463,7 @@ async def test_get_push_notif_success(self) -> None: ) assert get_response.root.result == task_push_config # type: ignore - async def test_on_message_stream_new_message_send_push_notif_success( + async def test_on_message_stream_new_message_send_push_notification_success( self, ) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) From 4706c680e9005afa81c9c108a38d96ef7a763d61 Mon Sep 17 00:00:00 2001 From: swapnilag Date: Thu, 15 May 2025 18:53:41 +0000 Subject: [PATCH 3/4] use json mode & internal error --- src/a2a/server/request_handlers/default_request_handler.py | 2 +- src/a2a/server/tasks/inmemory_push_notifier.py | 2 +- tests/server/request_handlers/test_jsonrpc_handler.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index c97aaa9c..4deba267 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -289,7 +289,7 @@ async def on_get_task_push_notification_config( push_notification_config = await self._push_notifier.get_info(params.id) if not push_notification_config: - raise ServerError(error=TaskNotFoundError()) + raise ServerError(error=InternalError()) return TaskPushNotificationConfig( taskId=params.id, pushNotificationConfig=push_notification_config diff --git a/src/a2a/server/tasks/inmemory_push_notifier.py b/src/a2a/server/tasks/inmemory_push_notifier.py index 6af15516..222dc90f 100644 --- a/src/a2a/server/tasks/inmemory_push_notifier.py +++ b/src/a2a/server/tasks/inmemory_push_notifier.py @@ -41,7 +41,7 @@ async def send_notification(self, task: Task): try: response = await self._client.post( - url, json=task.model_dump(exclude_none=True) + url, json=task.model_dump(mode='json', exclude_none=True) ) response.raise_for_status() logger.info(f'Push-notification sent for URL: {url}') diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 896ee557..d5b94605 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -527,7 +527,7 @@ async def streaming_coro(): json={ 'contextId': 'session-xyz', 'id': 'task_123', - 'status': {'state': TaskState.submitted}, + 'status': {'state': 'submitted'}, 'type': 'task', }, ), @@ -542,7 +542,7 @@ async def streaming_coro(): ], 'contextId': 'session-xyz', 'id': 'task_123', - 'status': {'state': TaskState.submitted}, + 'status': {'state': 'submitted'}, 'type': 'task', }, ), @@ -557,7 +557,7 @@ async def streaming_coro(): ], 'contextId': 'session-xyz', 'id': 'task_123', - 'status': {'state': TaskState.completed}, + 'status': {'state': 'completed'}, 'type': 'task', }, ), From 9e2b3b2a5f4ea5ccd8de7beebd3ce10cac3c1c19 Mon Sep 17 00:00:00 2001 From: swapnilag Date: Thu, 15 May 2025 21:23:59 +0000 Subject: [PATCH 4/4] add push for blocking --- .../default_request_handler.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 4deba267..f8501526 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -23,7 +23,9 @@ from a2a.types import ( InternalError, Message, + MessageSendConfiguration, MessageSendParams, + PushNotificationConfig, Task, TaskIdParams, TaskNotFoundError, @@ -123,12 +125,10 @@ async def on_message_send( task: Task | None = await task_manager.get_task() if task: task = task_manager.update_with_message(params.message, task) - if ( - self._push_notifier - and params.configuration - and params.configuration.pushNotificationConfig - and not params.configuration.blocking - ): + if self.should_add_push_info(params): + assert isinstance(self._push_notifier, PushNotifier) # For typechecker + assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker + assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker await self._push_notifier.set_info( task.id, params.configuration.pushNotificationConfig ) @@ -190,11 +190,10 @@ async def on_message_send_stream( if task: task = task_manager.update_with_message(params.message, task) - if ( - self._push_notifier - and params.configuration - and params.configuration.pushNotificationConfig - ): + if self.should_add_push_info(params): + assert isinstance(self._push_notifier, PushNotifier) # For typechecker + assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker + assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker await self._push_notifier.set_info( task.id, params.configuration.pushNotificationConfig ) @@ -319,3 +318,13 @@ async def on_resubscribe_to_task( consumer = EventConsumer(queue) async for event in result_aggregator.consume_and_emit(consumer): yield event + + def should_add_push_info(self, params: MessageSendParams) -> bool: + if ( + self._push_notifier + and params.configuration + and params.configuration.pushNotificationConfig + ): + return True + else: + return False