Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from a2a.server.agent_execution import RequestContext, RequestContextBuilder
from a2a.server.context import ServerCallContext
from a2a.server.id_generator import IDGenerator
from a2a.server.tasks import TaskStore
from a2a.types import MessageSendParams, Task

Expand All @@ -13,6 +14,8 @@
self,
should_populate_referred_tasks: bool = False,
task_store: TaskStore | None = None,
task_id_generator: IDGenerator | None = None,
context_id_generator: IDGenerator | None = None,
) -> None:
"""Initializes the SimpleRequestContextBuilder.

Expand All @@ -22,35 +25,39 @@
`related_tasks` field in the RequestContext. Defaults to False.
task_store: The TaskStore instance to use for fetching referred tasks.
Required if `should_populate_referred_tasks` is True.
task_id_generator: ID generator for new task IDs. Defaults to None.
context_id_generator: ID generator for new context IDs. Defaults to None.
"""
self._task_store = task_store
self._should_populate_referred_tasks = should_populate_referred_tasks
self._task_id_generator = task_id_generator
self._context_id_generator = context_id_generator

async def build(
self,
params: MessageSendParams | None = None,
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
context: ServerCallContext | None = None,
) -> RequestContext:
"""Builds the request context for an agent execution.

This method assembles the RequestContext object. If the builder was
initialized with `should_populate_referred_tasks=True`, it fetches all tasks
referenced in `params.message.reference_task_ids` from the `task_store`.

Args:
params: The parameters of the incoming message send request.
task_id: The ID of the task being executed.
context_id: The ID of the current execution context.
task: The primary task object associated with the request.
context: The server call context, containing metadata about the call.

Returns:
An instance of RequestContext populated with the provided information
and potentially a list of related tasks.
"""

Check notice on line 60 in src/a2a/server/agent_execution/simple_request_context_builder.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/agent_execution/request_context_builder.py (12-20)
related_tasks: list[Task] | None = None

if (
Expand All @@ -74,4 +81,6 @@
task=task,
related_tasks=related_tasks,
call_context=context,
task_id_generator=self._task_id_generator,
context_id_generator=self._context_id_generator,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SimpleRequestContextBuilder,
)
from a2a.server.context import ServerCallContext
from a2a.server.id_generator import IDGenerator
from a2a.server.tasks.task_store import TaskStore
from a2a.types import (
Message,
Expand Down Expand Up @@ -275,6 +276,65 @@ async def test_build_populate_false_with_reference_task_ids(self) -> None:
self.assertEqual(request_context.related_tasks, [])
self.mock_task_store.get.assert_not_called()

async def test_build_with_custom_id_generators(self) -> None:
mock_task_id_generator = AsyncMock(spec=IDGenerator)
mock_context_id_generator = AsyncMock(spec=IDGenerator)
mock_task_id_generator.generate.return_value = 'custom_task_id'
mock_context_id_generator.generate.return_value = 'custom_context_id'

builder = SimpleRequestContextBuilder(
should_populate_referred_tasks=False,
task_store=self.mock_task_store,
task_id_generator=mock_task_id_generator,
context_id_generator=mock_context_id_generator,
)
params = MessageSendParams(message=create_sample_message())
server_call_context = ServerCallContext(user=UnauthenticatedUser())

request_context = await builder.build(
params=params,
task_id=None,
context_id=None,
task=None,
context=server_call_context,
)

mock_task_id_generator.generate.assert_called_once()
mock_context_id_generator.generate.assert_called_once()
self.assertEqual(request_context.task_id, 'custom_task_id')
self.assertEqual(request_context.context_id, 'custom_context_id')

async def test_build_with_provided_ids_and_custom_id_generators(
self,
) -> None:
mock_task_id_generator = AsyncMock(spec=IDGenerator)
mock_context_id_generator = AsyncMock(spec=IDGenerator)

builder = SimpleRequestContextBuilder(
should_populate_referred_tasks=False,
task_store=self.mock_task_store,
task_id_generator=mock_task_id_generator,
context_id_generator=mock_context_id_generator,
)
params = MessageSendParams(message=create_sample_message())
server_call_context = ServerCallContext(user=UnauthenticatedUser())

provided_task_id = 'provided_task_id'
provided_context_id = 'provided_context_id'

request_context = await builder.build(
params=params,
task_id=provided_task_id,
context_id=provided_context_id,
task=None,
context=server_call_context,
)

mock_task_id_generator.generate.assert_not_called()
mock_context_id_generator.generate.assert_not_called()
self.assertEqual(request_context.task_id, provided_task_id)
self.assertEqual(request_context.context_id, provided_context_id)


if __name__ == '__main__':
unittest.main()
Loading