Skip to content

Commit 3a45f32

Browse files
Add tests
1 parent 98d4414 commit 3a45f32

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

tests/server/agent_execution/test_simple_request_context_builder.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SimpleRequestContextBuilder,
1111
)
1212
from a2a.server.context import ServerCallContext
13+
from a2a.server.id_generator import IDGenerator
1314
from a2a.server.tasks.task_store import TaskStore
1415
from a2a.types import (
1516
Message,
@@ -276,5 +277,58 @@ async def test_build_populate_false_with_reference_task_ids(self) -> None:
276277
self.mock_task_store.get.assert_not_called()
277278

278279

280+
async def test_build_with_custom_id_generators(self) -> None:
281+
mock_task_id_generator = AsyncMock(spec=IDGenerator)
282+
mock_context_id_generator = AsyncMock(spec=IDGenerator)
283+
284+
builder = SimpleRequestContextBuilder(
285+
should_populate_referred_tasks=False,
286+
task_store=self.mock_task_store,
287+
task_id_generator=mock_task_id_generator,
288+
context_id_generator=mock_context_id_generator,
289+
)
290+
params = MessageSendParams(message=create_sample_message())
291+
server_call_context = ServerCallContext(user=UnauthenticatedUser())
292+
293+
request_context = await builder.build(
294+
params=params,
295+
task_id=None,
296+
context_id=None,
297+
task=None,
298+
context=server_call_context,
299+
)
300+
301+
mock_task_id_generator.generate.assert_called_once()
302+
mock_context_id_generator.generate.assert_called_once()
303+
304+
async def test_build_with_provided_ids_and_custom_id_generators(self) -> None:
305+
mock_task_id_generator = AsyncMock(spec=IDGenerator)
306+
mock_context_id_generator = AsyncMock(spec=IDGenerator)
307+
308+
builder = SimpleRequestContextBuilder(
309+
should_populate_referred_tasks=False,
310+
task_store=self.mock_task_store,
311+
task_id_generator=mock_task_id_generator,
312+
context_id_generator=mock_context_id_generator,
313+
)
314+
params = MessageSendParams(message=create_sample_message())
315+
server_call_context = ServerCallContext(user=UnauthenticatedUser())
316+
317+
provided_task_id = 'provided_task_id'
318+
provided_context_id = 'provided_context_id'
319+
320+
request_context = await builder.build(
321+
params=params,
322+
task_id=provided_task_id,
323+
context_id=provided_context_id,
324+
task=None,
325+
context=server_call_context,
326+
)
327+
328+
mock_task_id_generator.generate.assert_not_called()
329+
mock_context_id_generator.generate.assert_not_called()
330+
self.assertEqual(request_context.task_id, provided_task_id)
331+
self.assertEqual(request_context.context_id, provided_context_id)
332+
279333
if __name__ == '__main__':
280334
unittest.main()

0 commit comments

Comments
 (0)