|
10 | 10 | SimpleRequestContextBuilder, |
11 | 11 | ) |
12 | 12 | from a2a.server.context import ServerCallContext |
| 13 | +from a2a.server.id_generator import IDGenerator |
13 | 14 | from a2a.server.tasks.task_store import TaskStore |
14 | 15 | from a2a.types import ( |
15 | 16 | Message, |
@@ -275,6 +276,65 @@ async def test_build_populate_false_with_reference_task_ids(self) -> None: |
275 | 276 | self.assertEqual(request_context.related_tasks, []) |
276 | 277 | self.mock_task_store.get.assert_not_called() |
277 | 278 |
|
| 279 | + async def test_build_with_custom_id_generators(self) -> None: |
| 280 | + mock_task_id_generator = AsyncMock(spec=IDGenerator) |
| 281 | + mock_context_id_generator = AsyncMock(spec=IDGenerator) |
| 282 | + mock_task_id_generator.generate.return_value = 'custom_task_id' |
| 283 | + mock_context_id_generator.generate.return_value = 'custom_context_id' |
| 284 | + |
| 285 | + builder = SimpleRequestContextBuilder( |
| 286 | + should_populate_referred_tasks=False, |
| 287 | + task_store=self.mock_task_store, |
| 288 | + task_id_generator=mock_task_id_generator, |
| 289 | + context_id_generator=mock_context_id_generator, |
| 290 | + ) |
| 291 | + params = MessageSendParams(message=create_sample_message()) |
| 292 | + server_call_context = ServerCallContext(user=UnauthenticatedUser()) |
| 293 | + |
| 294 | + request_context = await builder.build( |
| 295 | + params=params, |
| 296 | + task_id=None, |
| 297 | + context_id=None, |
| 298 | + task=None, |
| 299 | + context=server_call_context, |
| 300 | + ) |
| 301 | + |
| 302 | + mock_task_id_generator.generate.assert_called_once() |
| 303 | + mock_context_id_generator.generate.assert_called_once() |
| 304 | + self.assertEqual(request_context.task_id, 'custom_task_id') |
| 305 | + self.assertEqual(request_context.context_id, 'custom_context_id') |
| 306 | + |
| 307 | + async def test_build_with_provided_ids_and_custom_id_generators( |
| 308 | + self, |
| 309 | + ) -> None: |
| 310 | + mock_task_id_generator = AsyncMock(spec=IDGenerator) |
| 311 | + mock_context_id_generator = AsyncMock(spec=IDGenerator) |
| 312 | + |
| 313 | + builder = SimpleRequestContextBuilder( |
| 314 | + should_populate_referred_tasks=False, |
| 315 | + task_store=self.mock_task_store, |
| 316 | + task_id_generator=mock_task_id_generator, |
| 317 | + context_id_generator=mock_context_id_generator, |
| 318 | + ) |
| 319 | + params = MessageSendParams(message=create_sample_message()) |
| 320 | + server_call_context = ServerCallContext(user=UnauthenticatedUser()) |
| 321 | + |
| 322 | + provided_task_id = 'provided_task_id' |
| 323 | + provided_context_id = 'provided_context_id' |
| 324 | + |
| 325 | + request_context = await builder.build( |
| 326 | + params=params, |
| 327 | + task_id=provided_task_id, |
| 328 | + context_id=provided_context_id, |
| 329 | + task=None, |
| 330 | + context=server_call_context, |
| 331 | + ) |
| 332 | + |
| 333 | + mock_task_id_generator.generate.assert_not_called() |
| 334 | + mock_context_id_generator.generate.assert_not_called() |
| 335 | + self.assertEqual(request_context.task_id, provided_task_id) |
| 336 | + self.assertEqual(request_context.context_id, provided_context_id) |
| 337 | + |
278 | 338 |
|
279 | 339 | if __name__ == '__main__': |
280 | 340 | unittest.main() |
0 commit comments