Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/a2a/server/agent_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@

from a2a.server.agent_execution.agent_executor import AgentExecutor
from a2a.server.agent_execution.context import RequestContext
from a2a.server.agent_execution.request_context_builder import (
RequestContextBuilder,
)
from a2a.server.agent_execution.simple_request_context_builder import (
SimpleRequestContextBuilder,
)


__all__ = ['AgentExecutor', 'RequestContext']
__all__ = [
'AgentExecutor',
'RequestContext',
'RequestContextBuilder',
'SimpleRequestContextBuilder',
]
18 changes: 18 additions & 0 deletions src/a2a/server/agent_execution/request_context_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod

from a2a.server.agent_execution import RequestContext
from a2a.types import MessageSendParams, Task


class RequestContextBuilder(ABC):
"""Builds request context to be supplied to agent executor"""

@abstractmethod
async def build(
self,
params: MessageSendParams | None = None,
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
) -> RequestContext:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import asyncio

from a2a.server.agent_execution import RequestContext, RequestContextBuilder
from a2a.server.tasks import TaskStore
from a2a.types import MessageSendParams, Task


class SimpleRequestContextBuilder(RequestContextBuilder):
"""Builds request context and populates referred tasks"""

def __init__(
self,
should_populate_referred_tasks: bool = False,
task_store: TaskStore | None = None,
) -> None:
self._task_store = task_store
self._should_populate_referred_tasks = should_populate_referred_tasks

async def build(
self,
params: MessageSendParams | None = None,
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
) -> RequestContext:
related_tasks: list[Task] | None = None

if (
self._task_store
and self._should_populate_referred_tasks
and params
and params.message.referenceTaskIds
):
tasks = await asyncio.gather(
*[
self._task_store.get(task_id)
for task_id in params.message.referenceTaskIds
]
)
related_tasks = [x for x in tasks if x is not None]

return RequestContext(
request=params,
task_id=task_id,
context_id=context_id,
task=task,
related_tasks=related_tasks,
)
36 changes: 25 additions & 11 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from collections.abc import AsyncGenerator
from typing import cast

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.agent_execution import (
AgentExecutor,
RequestContext,
RequestContextBuilder,
SimpleRequestContextBuilder,
)
from a2a.server.events import (
Event,
EventConsumer,
Expand Down Expand Up @@ -57,6 +62,7 @@ def __init__(
task_store: TaskStore,
queue_manager: QueueManager | None = None,
push_notifier: PushNotifier | None = None,
request_context_builder: RequestContextBuilder | None = None,
) -> None:
"""Initializes the DefaultRequestHandler.

Expand All @@ -70,6 +76,12 @@ def __init__(
self.task_store = task_store
self._queue_manager = queue_manager or InMemoryQueueManager()
self._push_notifier = push_notifier
self._request_context_builder = (
request_context_builder
or SimpleRequestContextBuilder(
should_populate_referred_tasks=False, task_store=self.task_store
)
)
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
self._running_agents = {}
self._running_agents_lock = asyncio.Lock()
Expand Down Expand Up @@ -167,12 +179,13 @@ async def on_message_send(
await self._push_notifier.set_info(
task.id, params.configuration.pushNotificationConfig
)
request_context = RequestContext(
params,
task.id if task else None,
task.contextId if task else None,
task,
request_context = await self._request_context_builder.build(
params=params,
task_id=task.id if task else None,
context_id=params.message.contextId,
task=task,
)

task_id = cast(str, request_context.task_id)
# Always assign a task ID. We may not actually upgrade to a task, but
# dictating the task ID at this layer is useful for tracking running
Expand Down Expand Up @@ -244,12 +257,13 @@ async def on_message_send_stream(
else:
queue = EventQueue()
result_aggregator = ResultAggregator(task_manager)
request_context = RequestContext(
params,
task.id if task else None,
task.contextId if task else None,
task,
request_context = await self._request_context_builder.build(
params=params,
task_id=task.id if task else None,
context_id=params.message.contextId,
task=task,
)

task_id = cast(str, request_context.task_id)
queue = await self._queue_manager.create_or_tap(task_id)
producer_task = asyncio.create_task(
Expand Down
2 changes: 0 additions & 2 deletions src/a2a/server/tasks/task_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def start_work(self, message: Message | None = None):
def new_agent_message(
self,
parts: list[Part],
final: bool | None = None,
metadata: dict[str, Any] | None = None,
) -> Message:
"""Creates a new message object sent by the agent for this task/context.
Expand All @@ -136,6 +135,5 @@ def new_agent_message(
contextId=self.context_id,
messageId=str(uuid.uuid4()),
metadata=metadata,
final=final,
parts=parts,
)
4 changes: 4 additions & 0 deletions src/a2a/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,10 @@ class Message(BaseModel):
"""Extension metadata."""
parts: list[Part]
"""Message content."""
referenceTaskIds: list[str] | None = None
"""
list of tasks referenced as context by this message.
"""
role: Role
"""Message sender's role."""
taskId: str | None = None
Expand Down
4 changes: 2 additions & 2 deletions tests/server/tasks/test_task_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def test_new_agent_message(self, task_updater, sample_parts):
assert message.parts == sample_parts
assert message.metadata is None

def test_new_agent_message_with_metadata_and_final(
def test_new_agent_message_with_metadata(
self, task_updater, sample_parts
):
"""Test creating a new agent message with metadata and final=True."""
Expand All @@ -223,7 +223,7 @@ def test_new_agent_message_with_metadata_and_final(
return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'),
):
message = task_updater.new_agent_message(
parts=sample_parts, final=True, metadata=metadata
parts=sample_parts, metadata=metadata
)

assert message.role == Role.agent
Expand Down