diff --git a/src/a2a/server/agent_execution/__init__.py b/src/a2a/server/agent_execution/__init__.py index f6c853f6..e00df649 100644 --- a/src/a2a/server/agent_execution/__init__.py +++ b/src/a2a/server/agent_execution/__init__.py @@ -2,6 +2,17 @@ from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext +from a2a.server.agent_execution.request_context_builder import ( + RequestContextBuilder, +) +from a2a.server.agent_execution.simple_request_context_builder import ( + SimpleRequestContextBuilder, +) -__all__ = ['AgentExecutor', 'RequestContext'] +__all__ = [ + 'AgentExecutor', + 'RequestContext', + 'RequestContextBuilder', + 'SimpleRequestContextBuilder', +] diff --git a/src/a2a/server/agent_execution/request_context_builder.py b/src/a2a/server/agent_execution/request_context_builder.py new file mode 100644 index 00000000..5a59eb96 --- /dev/null +++ b/src/a2a/server/agent_execution/request_context_builder.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + +from a2a.server.agent_execution import RequestContext +from a2a.types import MessageSendParams, Task + + +class RequestContextBuilder(ABC): + """Builds request context to be supplied to agent executor""" + + @abstractmethod + async def build( + self, + params: MessageSendParams | None = None, + task_id: str | None = None, + context_id: str | None = None, + task: Task | None = None, + ) -> RequestContext: + pass diff --git a/src/a2a/server/agent_execution/simple_request_context_builder.py b/src/a2a/server/agent_execution/simple_request_context_builder.py new file mode 100644 index 00000000..4a9b9a88 --- /dev/null +++ b/src/a2a/server/agent_execution/simple_request_context_builder.py @@ -0,0 +1,48 @@ +import asyncio + +from a2a.server.agent_execution import RequestContext, RequestContextBuilder +from a2a.server.tasks import TaskStore +from a2a.types import MessageSendParams, Task + + +class SimpleRequestContextBuilder(RequestContextBuilder): + """Builds request context and populates referred tasks""" + + def __init__( + self, + should_populate_referred_tasks: bool = False, + task_store: TaskStore | None = None, + ) -> None: + self._task_store = task_store + self._should_populate_referred_tasks = should_populate_referred_tasks + + async def build( + self, + params: MessageSendParams | None = None, + task_id: str | None = None, + context_id: str | None = None, + task: Task | None = None, + ) -> RequestContext: + related_tasks: list[Task] | None = None + + if ( + self._task_store + and self._should_populate_referred_tasks + and params + and params.message.referenceTaskIds + ): + tasks = await asyncio.gather( + *[ + self._task_store.get(task_id) + for task_id in params.message.referenceTaskIds + ] + ) + related_tasks = [x for x in tasks if x is not None] + + return RequestContext( + request=params, + task_id=task_id, + context_id=context_id, + task=task, + related_tasks=related_tasks, + ) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index e3a47355..eb8de0a5 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -4,7 +4,12 @@ from collections.abc import AsyncGenerator from typing import cast -from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.agent_execution import ( + AgentExecutor, + RequestContext, + RequestContextBuilder, + SimpleRequestContextBuilder, +) from a2a.server.events import ( Event, EventConsumer, @@ -57,6 +62,7 @@ def __init__( task_store: TaskStore, queue_manager: QueueManager | None = None, push_notifier: PushNotifier | None = None, + request_context_builder: RequestContextBuilder | None = None, ) -> None: """Initializes the DefaultRequestHandler. @@ -70,6 +76,12 @@ def __init__( self.task_store = task_store self._queue_manager = queue_manager or InMemoryQueueManager() self._push_notifier = push_notifier + self._request_context_builder = ( + request_context_builder + or SimpleRequestContextBuilder( + should_populate_referred_tasks=False, task_store=self.task_store + ) + ) # TODO: Likely want an interface for managing this, like AgentExecutionManager. self._running_agents = {} self._running_agents_lock = asyncio.Lock() @@ -167,12 +179,13 @@ async def on_message_send( await self._push_notifier.set_info( task.id, params.configuration.pushNotificationConfig ) - request_context = RequestContext( - params, - task.id if task else None, - task.contextId if task else None, - task, + request_context = await self._request_context_builder.build( + params=params, + task_id=task.id if task else None, + context_id=params.message.contextId, + task=task, ) + task_id = cast(str, request_context.task_id) # Always assign a task ID. We may not actually upgrade to a task, but # dictating the task ID at this layer is useful for tracking running @@ -244,12 +257,13 @@ async def on_message_send_stream( else: queue = EventQueue() result_aggregator = ResultAggregator(task_manager) - request_context = RequestContext( - params, - task.id if task else None, - task.contextId if task else None, - task, + request_context = await self._request_context_builder.build( + params=params, + task_id=task.id if task else None, + context_id=params.message.contextId, + task=task, ) + task_id = cast(str, request_context.task_id) queue = await self._queue_manager.create_or_tap(task_id) producer_task = asyncio.create_task( diff --git a/src/a2a/server/tasks/task_updater.py b/src/a2a/server/tasks/task_updater.py index 58c2ca13..c079edd4 100644 --- a/src/a2a/server/tasks/task_updater.py +++ b/src/a2a/server/tasks/task_updater.py @@ -114,7 +114,6 @@ def start_work(self, message: Message | None = None): def new_agent_message( self, parts: list[Part], - final: bool | None = None, metadata: dict[str, Any] | None = None, ) -> Message: """Creates a new message object sent by the agent for this task/context. @@ -136,6 +135,5 @@ def new_agent_message( contextId=self.context_id, messageId=str(uuid.uuid4()), metadata=metadata, - final=final, parts=parts, ) diff --git a/src/a2a/types.py b/src/a2a/types.py index e04c5973..675ed832 100644 --- a/src/a2a/types.py +++ b/src/a2a/types.py @@ -613,6 +613,10 @@ class Message(BaseModel): """Extension metadata.""" parts: list[Part] """Message content.""" + referenceTaskIds: list[str] | None = None + """ + list of tasks referenced as context by this message. + """ role: Role """Message sender's role.""" taskId: str | None = None diff --git a/tests/server/tasks/test_task_updater.py b/tests/server/tasks/test_task_updater.py index 2df39ae9..fd278929 100644 --- a/tests/server/tasks/test_task_updater.py +++ b/tests/server/tasks/test_task_updater.py @@ -212,7 +212,7 @@ def test_new_agent_message(self, task_updater, sample_parts): assert message.parts == sample_parts assert message.metadata is None - def test_new_agent_message_with_metadata_and_final( + def test_new_agent_message_with_metadata( self, task_updater, sample_parts ): """Test creating a new agent message with metadata and final=True.""" @@ -223,7 +223,7 @@ def test_new_agent_message_with_metadata_and_final( return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), ): message = task_updater.new_agent_message( - parts=sample_parts, final=True, metadata=metadata + parts=sample_parts, metadata=metadata ) assert message.role == Role.agent