Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,19 @@ async def on_message_send(

interrupted_or_non_blocking = False
try:
# Create async callback for push notifications
async def push_notification_callback() -> None:
await self._send_push_notification_if_needed(
task_id, result_aggregator
)

(
result,
interrupted_or_non_blocking,
) = await result_aggregator.consume_and_break_on_interrupt(
consumer, blocking=blocking
consumer,
blocking=blocking,
event_callback=push_notification_callback,
)
if not result:
raise ServerError(error=InternalError()) # noqa: TRY301
Expand Down
26 changes: 21 additions & 5 deletions src/a2a/server/tasks/result_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging

from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable

from a2a.server.events import Event, EventConsumer
from a2a.server.tasks.task_manager import TaskManager
Expand All @@ -24,7 +24,10 @@ class ResultAggregator:
Task object and emit that Task object.
"""

def __init__(self, task_manager: TaskManager):
def __init__(
self,
task_manager: TaskManager,
) -> None:
"""Initializes the ResultAggregator.

Args:
Expand Down Expand Up @@ -92,7 +95,10 @@ async def consume_all(
return await self.task_manager.get_task()

async def consume_and_break_on_interrupt(
self, consumer: EventConsumer, blocking: bool = True
self,
consumer: EventConsumer,
blocking: bool = True,
event_callback: Callable[[], Awaitable[None]] | None = None,
) -> tuple[Task | Message | None, bool]:
"""Processes the event stream until completion or an interruptable state is encountered.

Expand All @@ -105,6 +111,9 @@ async def consume_and_break_on_interrupt(
consumer: The `EventConsumer` to read events from.
blocking: If `False`, the method returns as soon as a task/message
is available. If `True`, it waits for a terminal state.
event_callback: Optional async callback function to be called after each event
is processed in the background continuation.
Mainly used for push notifications currently.

Returns:
A tuple containing:
Expand Down Expand Up @@ -150,13 +159,17 @@ async def consume_and_break_on_interrupt(
if should_interrupt:
# Continue consuming the rest of the events in the background.
# TODO: We should track all outstanding tasks to ensure they eventually complete.
asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006
asyncio.create_task( # noqa: RUF006
self._continue_consuming(event_stream, event_callback)
)
interrupted = True
break
return await self.task_manager.get_task(), interrupted

async def _continue_consuming(
self, event_stream: AsyncIterator[Event]
self,
event_stream: AsyncIterator[Event],
event_callback: Callable[[], Awaitable[None]] | None = None,
) -> None:
"""Continues processing an event stream in a background task.

Expand All @@ -165,6 +178,9 @@ async def _continue_consuming(

Args:
event_stream: The remaining `AsyncIterator` of events from the consumer.
event_callback: Optional async callback function to be called after each event is processed.
"""
async for event in event_stream:
await self.task_manager.process(event)
if event_callback:
await event_callback()
128 changes: 128 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,134 @@ async def get_current_result():
mock_agent_executor.execute.assert_awaited_once()


@pytest.mark.asyncio
async def test_on_message_send_with_push_notification_in_non_blocking_request():
"""Test that push notification callback is called during background event processing for non-blocking requests."""
mock_task_store = AsyncMock(spec=TaskStore)
mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore)
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
mock_push_sender = AsyncMock()

task_id = 'non_blocking_task_1'
context_id = 'non_blocking_ctx_1'

# Create a task that will be returned after the first event
initial_task = create_sample_task(
task_id=task_id, context_id=context_id, status_state=TaskState.working
)

# Create a final task that will be available during background processing
final_task = create_sample_task(
task_id=task_id, context_id=context_id, status_state=TaskState.completed
)

mock_task_store.get.return_value = None

# Mock request context
mock_request_context = MagicMock(spec=RequestContext)
mock_request_context.task_id = task_id
mock_request_context.context_id = context_id
mock_request_context_builder.build.return_value = mock_request_context

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
push_config_store=mock_push_notification_store,
request_context_builder=mock_request_context_builder,
push_sender=mock_push_sender,
)

# Configure push notification
push_config = PushNotificationConfig(url='http://callback.com/push')
message_config = MessageSendConfiguration(
push_notification_config=push_config,
accepted_output_modes=['text/plain'],
blocking=False, # Non-blocking request
)
params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_non_blocking',
parts=[],
task_id=task_id,
context_id=context_id,
),
configuration=message_config,
)

# Mock ResultAggregator with custom behavior
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)

# First call returns the initial task and indicates interruption (non-blocking)
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
initial_task,
True, # interrupted = True for non-blocking
)

# Mock the current_result property to return the final task
async def get_current_result():
return final_task

type(mock_result_aggregator_instance).current_result = PropertyMock(
return_value=get_current_result()
)

# Track if the event_callback was passed to consume_and_break_on_interrupt
event_callback_passed = False
event_callback_received = None

async def mock_consume_and_break_on_interrupt(
consumer, blocking=True, event_callback=None
):
nonlocal event_callback_passed, event_callback_received
event_callback_passed = event_callback is not None
event_callback_received = event_callback
return initial_task, True # interrupted = True for non-blocking

mock_result_aggregator_instance.consume_and_break_on_interrupt = (
mock_consume_and_break_on_interrupt
)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=initial_task,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.update_with_message',
return_value=initial_task,
),
):
# Execute the non-blocking request
result = await request_handler.on_message_send(
params, create_server_call_context()
)

# Verify the result is the initial task (non-blocking behavior)
assert result == initial_task

# Verify that the event_callback was passed to consume_and_break_on_interrupt
assert event_callback_passed, (
'event_callback should have been passed to consume_and_break_on_interrupt'
)
assert event_callback_received is not None, (
'event_callback should not be None'
)

# Verify that the push notification was sent with the final task
mock_push_sender.send_notification.assert_called_with(final_task)

# Verify that the push notification config was stored
mock_push_notification_store.set_info.assert_awaited_once_with(
task_id, push_config
)


@pytest.mark.asyncio
async def test_on_message_send_with_push_notification_no_existing_Task():
"""Test on_message_send for new task sets push notification info if provided."""
Expand Down
Loading