|
10 | 10 | SimpleRequestContextBuilder, |
11 | 11 | ) |
12 | 12 | from a2a.server.context import ServerCallContext |
| 13 | +from a2a.server.id_generator import IDGenerator |
13 | 14 | from a2a.server.tasks.task_store import TaskStore |
14 | 15 | from a2a.types import ( |
15 | 16 | Message, |
@@ -276,5 +277,58 @@ async def test_build_populate_false_with_reference_task_ids(self) -> None: |
276 | 277 | self.mock_task_store.get.assert_not_called() |
277 | 278 |
|
278 | 279 |
|
| 280 | + async def test_build_with_custom_id_generators(self) -> None: |
| 281 | + mock_task_id_generator = AsyncMock(spec=IDGenerator) |
| 282 | + mock_context_id_generator = AsyncMock(spec=IDGenerator) |
| 283 | + |
| 284 | + builder = SimpleRequestContextBuilder( |
| 285 | + should_populate_referred_tasks=False, |
| 286 | + task_store=self.mock_task_store, |
| 287 | + task_id_generator=mock_task_id_generator, |
| 288 | + context_id_generator=mock_context_id_generator, |
| 289 | + ) |
| 290 | + params = MessageSendParams(message=create_sample_message()) |
| 291 | + server_call_context = ServerCallContext(user=UnauthenticatedUser()) |
| 292 | + |
| 293 | + request_context = await builder.build( |
| 294 | + params=params, |
| 295 | + task_id=None, |
| 296 | + context_id=None, |
| 297 | + task=None, |
| 298 | + context=server_call_context, |
| 299 | + ) |
| 300 | + |
| 301 | + mock_task_id_generator.generate.assert_called_once() |
| 302 | + mock_context_id_generator.generate.assert_called_once() |
| 303 | + |
| 304 | + async def test_build_with_provided_ids_and_custom_id_generators(self) -> None: |
| 305 | + mock_task_id_generator = AsyncMock(spec=IDGenerator) |
| 306 | + mock_context_id_generator = AsyncMock(spec=IDGenerator) |
| 307 | + |
| 308 | + builder = SimpleRequestContextBuilder( |
| 309 | + should_populate_referred_tasks=False, |
| 310 | + task_store=self.mock_task_store, |
| 311 | + task_id_generator=mock_task_id_generator, |
| 312 | + context_id_generator=mock_context_id_generator, |
| 313 | + ) |
| 314 | + params = MessageSendParams(message=create_sample_message()) |
| 315 | + server_call_context = ServerCallContext(user=UnauthenticatedUser()) |
| 316 | + |
| 317 | + provided_task_id = 'provided_task_id' |
| 318 | + provided_context_id = 'provided_context_id' |
| 319 | + |
| 320 | + request_context = await builder.build( |
| 321 | + params=params, |
| 322 | + task_id=provided_task_id, |
| 323 | + context_id=provided_context_id, |
| 324 | + task=None, |
| 325 | + context=server_call_context, |
| 326 | + ) |
| 327 | + |
| 328 | + mock_task_id_generator.generate.assert_not_called() |
| 329 | + mock_context_id_generator.generate.assert_not_called() |
| 330 | + self.assertEqual(request_context.task_id, provided_task_id) |
| 331 | + self.assertEqual(request_context.context_id, provided_context_id) |
| 332 | + |
279 | 333 | if __name__ == '__main__': |
280 | 334 | unittest.main() |
0 commit comments