Skip to content

Commit 04bcafc

Browse files
feat: Add custom ID generators to SimpleRequestContextBuilder (#594)
# Description This change allows passing custom `task_id_generator` and `context_id_generator` functions to the `SimpleRequestContextBuilder`. This provides flexibility in how task and context IDs are generated, defaulting to the previous behavior if no generators are provided. Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 03fa4c2 commit 04bcafc

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

src/a2a/server/agent_execution/simple_request_context_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from a2a.server.agent_execution import RequestContext, RequestContextBuilder
44
from a2a.server.context import ServerCallContext
5+
from a2a.server.id_generator import IDGenerator
56
from a2a.server.tasks import TaskStore
67
from a2a.types import MessageSendParams, Task
78

@@ -13,6 +14,8 @@ def __init__(
1314
self,
1415
should_populate_referred_tasks: bool = False,
1516
task_store: TaskStore | None = None,
17+
task_id_generator: IDGenerator | None = None,
18+
context_id_generator: IDGenerator | None = None,
1619
) -> None:
1720
"""Initializes the SimpleRequestContextBuilder.
1821
@@ -22,9 +25,13 @@ def __init__(
2225
`related_tasks` field in the RequestContext. Defaults to False.
2326
task_store: The TaskStore instance to use for fetching referred tasks.
2427
Required if `should_populate_referred_tasks` is True.
28+
task_id_generator: ID generator for new task IDs. Defaults to None.
29+
context_id_generator: ID generator for new context IDs. Defaults to None.
2530
"""
2631
self._task_store = task_store
2732
self._should_populate_referred_tasks = should_populate_referred_tasks
33+
self._task_id_generator = task_id_generator
34+
self._context_id_generator = context_id_generator
2835

2936
async def build(
3037
self,
@@ -74,4 +81,6 @@ async def build(
7481
task=task,
7582
related_tasks=related_tasks,
7683
call_context=context,
84+
task_id_generator=self._task_id_generator,
85+
context_id_generator=self._context_id_generator,
7786
)

tests/server/agent_execution/test_simple_request_context_builder.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SimpleRequestContextBuilder,
1111
)
1212
from a2a.server.context import ServerCallContext
13+
from a2a.server.id_generator import IDGenerator
1314
from a2a.server.tasks.task_store import TaskStore
1415
from a2a.types import (
1516
Message,
@@ -275,6 +276,65 @@ async def test_build_populate_false_with_reference_task_ids(self) -> None:
275276
self.assertEqual(request_context.related_tasks, [])
276277
self.mock_task_store.get.assert_not_called()
277278

279+
async def test_build_with_custom_id_generators(self) -> None:
280+
mock_task_id_generator = AsyncMock(spec=IDGenerator)
281+
mock_context_id_generator = AsyncMock(spec=IDGenerator)
282+
mock_task_id_generator.generate.return_value = 'custom_task_id'
283+
mock_context_id_generator.generate.return_value = 'custom_context_id'
284+
285+
builder = SimpleRequestContextBuilder(
286+
should_populate_referred_tasks=False,
287+
task_store=self.mock_task_store,
288+
task_id_generator=mock_task_id_generator,
289+
context_id_generator=mock_context_id_generator,
290+
)
291+
params = MessageSendParams(message=create_sample_message())
292+
server_call_context = ServerCallContext(user=UnauthenticatedUser())
293+
294+
request_context = await builder.build(
295+
params=params,
296+
task_id=None,
297+
context_id=None,
298+
task=None,
299+
context=server_call_context,
300+
)
301+
302+
mock_task_id_generator.generate.assert_called_once()
303+
mock_context_id_generator.generate.assert_called_once()
304+
self.assertEqual(request_context.task_id, 'custom_task_id')
305+
self.assertEqual(request_context.context_id, 'custom_context_id')
306+
307+
async def test_build_with_provided_ids_and_custom_id_generators(
308+
self,
309+
) -> None:
310+
mock_task_id_generator = AsyncMock(spec=IDGenerator)
311+
mock_context_id_generator = AsyncMock(spec=IDGenerator)
312+
313+
builder = SimpleRequestContextBuilder(
314+
should_populate_referred_tasks=False,
315+
task_store=self.mock_task_store,
316+
task_id_generator=mock_task_id_generator,
317+
context_id_generator=mock_context_id_generator,
318+
)
319+
params = MessageSendParams(message=create_sample_message())
320+
server_call_context = ServerCallContext(user=UnauthenticatedUser())
321+
322+
provided_task_id = 'provided_task_id'
323+
provided_context_id = 'provided_context_id'
324+
325+
request_context = await builder.build(
326+
params=params,
327+
task_id=provided_task_id,
328+
context_id=provided_context_id,
329+
task=None,
330+
context=server_call_context,
331+
)
332+
333+
mock_task_id_generator.generate.assert_not_called()
334+
mock_context_id_generator.generate.assert_not_called()
335+
self.assertEqual(request_context.task_id, provided_task_id)
336+
self.assertEqual(request_context.context_id, provided_context_id)
337+
278338

279339
if __name__ == '__main__':
280340
unittest.main()

0 commit comments

Comments
 (0)