Skip to content

Commit 8a091e3

Browse files
authored
Merge branch 'main' into test-card-resolver
2 parents de32cb5 + e12ca42 commit 8a091e3

File tree

3 files changed

+80
-1
lines changed

3 files changed

+80
-1
lines changed

src/a2a/server/agent_execution/simple_request_context_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from a2a.server.agent_execution import RequestContext, RequestContextBuilder
44
from a2a.server.context import ServerCallContext
5+
from a2a.server.id_generator import IDGenerator
56
from a2a.server.tasks import TaskStore
67
from a2a.types import MessageSendParams, Task
78

@@ -13,6 +14,8 @@ def __init__(
1314
self,
1415
should_populate_referred_tasks: bool = False,
1516
task_store: TaskStore | None = None,
17+
task_id_generator: IDGenerator | None = None,
18+
context_id_generator: IDGenerator | None = None,
1619
) -> None:
1720
"""Initializes the SimpleRequestContextBuilder.
1821
@@ -22,9 +25,13 @@ def __init__(
2225
`related_tasks` field in the RequestContext. Defaults to False.
2326
task_store: The TaskStore instance to use for fetching referred tasks.
2427
Required if `should_populate_referred_tasks` is True.
28+
task_id_generator: ID generator for new task IDs. Defaults to None.
29+
context_id_generator: ID generator for new context IDs. Defaults to None.
2530
"""
2631
self._task_store = task_store
2732
self._should_populate_referred_tasks = should_populate_referred_tasks
33+
self._task_id_generator = task_id_generator
34+
self._context_id_generator = context_id_generator
2835

2936
async def build(
3037
self,
@@ -74,4 +81,6 @@ async def build(
7481
task=task,
7582
related_tasks=related_tasks,
7683
call_context=context,
84+
task_id_generator=self._task_id_generator,
85+
context_id_generator=self._context_id_generator,
7786
)

tests/auth/test_user.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
import unittest
22

3-
from a2a.auth.user import UnauthenticatedUser
3+
from inspect import isabstract
4+
5+
from a2a.auth.user import UnauthenticatedUser, User
6+
7+
8+
class TestUser(unittest.TestCase):
9+
def test_is_abstract(self):
10+
self.assertTrue(isabstract(User))
411

512

613
class TestUnauthenticatedUser(unittest.TestCase):
14+
def test_is_user_subclass(self):
15+
self.assertTrue(issubclass(UnauthenticatedUser, User))
16+
717
def test_is_authenticated_returns_false(self):
818
user = UnauthenticatedUser()
919
self.assertFalse(user.is_authenticated)

tests/server/agent_execution/test_simple_request_context_builder.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SimpleRequestContextBuilder,
1111
)
1212
from a2a.server.context import ServerCallContext
13+
from a2a.server.id_generator import IDGenerator
1314
from a2a.server.tasks.task_store import TaskStore
1415
from a2a.types import (
1516
Message,
@@ -275,6 +276,65 @@ async def test_build_populate_false_with_reference_task_ids(self) -> None:
275276
self.assertEqual(request_context.related_tasks, [])
276277
self.mock_task_store.get.assert_not_called()
277278

279+
async def test_build_with_custom_id_generators(self) -> None:
280+
mock_task_id_generator = AsyncMock(spec=IDGenerator)
281+
mock_context_id_generator = AsyncMock(spec=IDGenerator)
282+
mock_task_id_generator.generate.return_value = 'custom_task_id'
283+
mock_context_id_generator.generate.return_value = 'custom_context_id'
284+
285+
builder = SimpleRequestContextBuilder(
286+
should_populate_referred_tasks=False,
287+
task_store=self.mock_task_store,
288+
task_id_generator=mock_task_id_generator,
289+
context_id_generator=mock_context_id_generator,
290+
)
291+
params = MessageSendParams(message=create_sample_message())
292+
server_call_context = ServerCallContext(user=UnauthenticatedUser())
293+
294+
request_context = await builder.build(
295+
params=params,
296+
task_id=None,
297+
context_id=None,
298+
task=None,
299+
context=server_call_context,
300+
)
301+
302+
mock_task_id_generator.generate.assert_called_once()
303+
mock_context_id_generator.generate.assert_called_once()
304+
self.assertEqual(request_context.task_id, 'custom_task_id')
305+
self.assertEqual(request_context.context_id, 'custom_context_id')
306+
307+
async def test_build_with_provided_ids_and_custom_id_generators(
308+
self,
309+
) -> None:
310+
mock_task_id_generator = AsyncMock(spec=IDGenerator)
311+
mock_context_id_generator = AsyncMock(spec=IDGenerator)
312+
313+
builder = SimpleRequestContextBuilder(
314+
should_populate_referred_tasks=False,
315+
task_store=self.mock_task_store,
316+
task_id_generator=mock_task_id_generator,
317+
context_id_generator=mock_context_id_generator,
318+
)
319+
params = MessageSendParams(message=create_sample_message())
320+
server_call_context = ServerCallContext(user=UnauthenticatedUser())
321+
322+
provided_task_id = 'provided_task_id'
323+
provided_context_id = 'provided_context_id'
324+
325+
request_context = await builder.build(
326+
params=params,
327+
task_id=provided_task_id,
328+
context_id=provided_context_id,
329+
task=None,
330+
context=server_call_context,
331+
)
332+
333+
mock_task_id_generator.generate.assert_not_called()
334+
mock_context_id_generator.generate.assert_not_called()
335+
self.assertEqual(request_context.task_id, provided_task_id)
336+
self.assertEqual(request_context.context_id, provided_context_id)
337+
278338

279339
if __name__ == '__main__':
280340
unittest.main()

0 commit comments

Comments
 (0)