From 518ea50f312411f20f940bee1b0103bb8c0576c3 Mon Sep 17 00:00:00 2001 From: swapnilag Date: Tue, 20 May 2025 03:08:49 +0000 Subject: [PATCH 1/3] feat: Add request context builder with referenceTasks --- src/a2a/server/agent_execution/__init__.py | 13 ++++- .../request_context_builder.py | 18 +++++++ .../simple_request_context_builder.py | 48 +++++++++++++++++++ .../default_request_handler.py | 36 +++++++++----- src/a2a/server/tasks/task_updater.py | 1 - src/a2a/types.py | 4 ++ 6 files changed, 107 insertions(+), 13 deletions(-) create mode 100644 src/a2a/server/agent_execution/request_context_builder.py create mode 100644 src/a2a/server/agent_execution/simple_request_context_builder.py diff --git a/src/a2a/server/agent_execution/__init__.py b/src/a2a/server/agent_execution/__init__.py index 88660d62..b93e19bf 100644 --- a/src/a2a/server/agent_execution/__init__.py +++ b/src/a2a/server/agent_execution/__init__.py @@ -1,5 +1,16 @@ from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext +from a2a.server.agent_execution.request_context_builder import ( + RequestContextBuilder, +) +from a2a.server.agent_execution.simple_request_context_builder import ( + SimpleRequestContextBuilder, +) -__all__ = ['AgentExecutor', 'RequestContext'] +__all__ = [ + 'AgentExecutor', + 'RequestContext', + 'RequestContextBuilder', + 'SimpleRequestContextBuilder', +] diff --git a/src/a2a/server/agent_execution/request_context_builder.py b/src/a2a/server/agent_execution/request_context_builder.py new file mode 100644 index 00000000..5a59eb96 --- /dev/null +++ b/src/a2a/server/agent_execution/request_context_builder.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + +from a2a.server.agent_execution import RequestContext +from a2a.types import MessageSendParams, Task + + +class RequestContextBuilder(ABC): + """Builds request context to be supplied to agent executor""" + + @abstractmethod + async def build( + self, + params: MessageSendParams | None = None, + task_id: str | None = None, + context_id: str | None = None, + task: Task | None = None, + ) -> RequestContext: + pass diff --git a/src/a2a/server/agent_execution/simple_request_context_builder.py b/src/a2a/server/agent_execution/simple_request_context_builder.py new file mode 100644 index 00000000..4a9b9a88 --- /dev/null +++ b/src/a2a/server/agent_execution/simple_request_context_builder.py @@ -0,0 +1,48 @@ +import asyncio + +from a2a.server.agent_execution import RequestContext, RequestContextBuilder +from a2a.server.tasks import TaskStore +from a2a.types import MessageSendParams, Task + + +class SimpleRequestContextBuilder(RequestContextBuilder): + """Builds request context and populates referred tasks""" + + def __init__( + self, + should_populate_referred_tasks: bool = False, + task_store: TaskStore | None = None, + ) -> None: + self._task_store = task_store + self._should_populate_referred_tasks = should_populate_referred_tasks + + async def build( + self, + params: MessageSendParams | None = None, + task_id: str | None = None, + context_id: str | None = None, + task: Task | None = None, + ) -> RequestContext: + related_tasks: list[Task] | None = None + + if ( + self._task_store + and self._should_populate_referred_tasks + and params + and params.message.referenceTaskIds + ): + tasks = await asyncio.gather( + *[ + self._task_store.get(task_id) + for task_id in params.message.referenceTaskIds + ] + ) + related_tasks = [x for x in tasks if x is not None] + + return RequestContext( + request=params, + task_id=task_id, + context_id=context_id, + task=task, + related_tasks=related_tasks, + ) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 17cbd49f..5b92a658 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -4,7 +4,12 @@ from collections.abc import AsyncGenerator from typing import cast -from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.agent_execution import ( + AgentExecutor, + RequestContext, + RequestContextBuilder, + SimpleRequestContextBuilder, +) from a2a.server.events import ( Event, EventConsumer, @@ -52,11 +57,18 @@ def __init__( task_store: TaskStore, queue_manager: QueueManager | None = None, push_notifier: PushNotifier | None = None, + request_context_builder: RequestContextBuilder | None = None, ) -> None: self.agent_executor = agent_executor self.task_store = task_store self._queue_manager = queue_manager or InMemoryQueueManager() self._push_notifier = push_notifier + self._request_context_builder = ( + request_context_builder + or SimpleRequestContextBuilder( + should_populate_referred_tasks=False, task_store=self.task_store + ) + ) # TODO: Likely want an interface for managing this, like AgentExecutionManager. self._running_agents = {} self._running_agents_lock = asyncio.Lock() @@ -139,12 +151,13 @@ async def on_message_send( await self._push_notifier.set_info( task.id, params.configuration.pushNotificationConfig ) - request_context = RequestContext( - params, - task.id if task else None, - task.contextId if task else None, - task, + request_context = await self._request_context_builder.build( + params=params, + task_id=task.id if task else None, + context_id=params.message.contextId, + task=task, ) + task_id = cast(str, request_context.task_id) # Always assign a task ID. We may not actually upgrade to a task, but # dictating the task ID at this layer is useful for tracking running @@ -212,12 +225,13 @@ async def on_message_send_stream( else: queue = EventQueue() result_aggregator = ResultAggregator(task_manager) - request_context = RequestContext( - params, - task.id if task else None, - task.contextId if task else None, - task, + request_context = await self._request_context_builder.build( + params=params, + task_id=task.id if task else None, + context_id=params.message.contextId, + task=task, ) + task_id = cast(str, request_context.task_id) queue = await self._queue_manager.create_or_tap(task_id) producer_task = asyncio.create_task( diff --git a/src/a2a/server/tasks/task_updater.py b/src/a2a/server/tasks/task_updater.py index fdecde5a..8c0c9669 100644 --- a/src/a2a/server/tasks/task_updater.py +++ b/src/a2a/server/tasks/task_updater.py @@ -96,6 +96,5 @@ def new_agent_message( contextId=self.context_id, messageId=str(uuid.uuid4()), metadata=metadata, - final=final, parts=parts, ) diff --git a/src/a2a/types.py b/src/a2a/types.py index a7d1155a..aa4a5adf 100644 --- a/src/a2a/types.py +++ b/src/a2a/types.py @@ -1117,6 +1117,10 @@ class Message(BaseModel): """ message content """ + referenceTaskIds: list[str] | None = None + """ + list of tasks referenced as context by this message. + """ role: Role """ message sender's role From ac4bc51079b14055383b8b10219c36fe55b48ed7 Mon Sep 17 00:00:00 2001 From: swapnilag Date: Tue, 20 May 2025 19:09:35 +0000 Subject: [PATCH 2/3] Remove final from new_agent_message --- src/a2a/server/tasks/task_updater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/a2a/server/tasks/task_updater.py b/src/a2a/server/tasks/task_updater.py index 8c0c9669..ad597cf8 100644 --- a/src/a2a/server/tasks/task_updater.py +++ b/src/a2a/server/tasks/task_updater.py @@ -87,7 +87,7 @@ def start_work(self, message: Message | None = None): ) def new_agent_message( - self, parts: list[Part], final=False, metadata=None + self, parts: list[Part], metadata=None ) -> Message: """Create a new message for the task.""" return Message( From 6540c70319343bf0eecf709fad4d43c4e493d773 Mon Sep 17 00:00:00 2001 From: swapnilag Date: Tue, 20 May 2025 20:08:07 +0000 Subject: [PATCH 3/3] Fix unit test --- tests/server/tasks/test_task_updater.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/server/tasks/test_task_updater.py b/tests/server/tasks/test_task_updater.py index 2df39ae9..fd278929 100644 --- a/tests/server/tasks/test_task_updater.py +++ b/tests/server/tasks/test_task_updater.py @@ -212,7 +212,7 @@ def test_new_agent_message(self, task_updater, sample_parts): assert message.parts == sample_parts assert message.metadata is None - def test_new_agent_message_with_metadata_and_final( + def test_new_agent_message_with_metadata( self, task_updater, sample_parts ): """Test creating a new agent message with metadata and final=True.""" @@ -223,7 +223,7 @@ def test_new_agent_message_with_metadata_and_final( return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), ): message = task_updater.new_agent_message( - parts=sample_parts, final=True, metadata=metadata + parts=sample_parts, metadata=metadata ) assert message.role == Role.agent