Skip to content

Commit 5da9560

Browse files
didier-durandmartimfasantos
authored andcommitted
test: adding 12 tests for InMemoryQueueManager (a2aproject#57)
1 parent 946c0f4 commit 5da9560

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import asyncio
2+
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
7+
from a2a.server.events import InMemoryQueueManager
8+
from a2a.server.events.event_queue import EventQueue
9+
from a2a.server.events.queue_manager import (
10+
NoTaskQueue,
11+
TaskQueueExists,
12+
)
13+
14+
15+
class TestInMemoryQueueManager:
16+
@pytest.fixture
17+
def queue_manager(self):
18+
"""Fixture to create a fresh InMemoryQueueManager for each test."""
19+
manager = InMemoryQueueManager()
20+
return manager
21+
22+
@pytest.fixture
23+
def event_queue(self):
24+
"""Fixture to create a mock EventQueue."""
25+
queue = MagicMock(spec=EventQueue)
26+
# Mock the tap method to return itself
27+
queue.tap.return_value = queue
28+
return queue
29+
30+
@pytest.mark.asyncio
31+
async def test_init(self, queue_manager):
32+
"""Test that the InMemoryQueueManager initializes with empty task queue and a lock."""
33+
assert queue_manager._task_queue == {}
34+
assert isinstance(queue_manager._lock, asyncio.Lock)
35+
36+
@pytest.mark.asyncio
37+
async def test_add_new_queue(self, queue_manager, event_queue):
38+
"""Test adding a new queue to the manager."""
39+
task_id = 'test_task_id'
40+
await queue_manager.add(task_id, event_queue)
41+
assert queue_manager._task_queue[task_id] == event_queue
42+
43+
@pytest.mark.asyncio
44+
async def test_add_existing_queue(self, queue_manager, event_queue):
45+
"""Test adding a queue with an existing task_id raises TaskQueueExists."""
46+
task_id = 'test_task_id'
47+
await queue_manager.add(task_id, event_queue)
48+
49+
with pytest.raises(TaskQueueExists):
50+
await queue_manager.add(task_id, event_queue)
51+
52+
@pytest.mark.asyncio
53+
async def test_get_existing_queue(self, queue_manager, event_queue):
54+
"""Test getting an existing queue returns the queue."""
55+
task_id = 'test_task_id'
56+
await queue_manager.add(task_id, event_queue)
57+
58+
result = await queue_manager.get(task_id)
59+
assert result == event_queue
60+
61+
@pytest.mark.asyncio
62+
async def test_get_nonexistent_queue(self, queue_manager):
63+
"""Test getting a nonexistent queue returns None."""
64+
result = await queue_manager.get('nonexistent_task_id')
65+
assert result is None
66+
67+
@pytest.mark.asyncio
68+
async def test_tap_existing_queue(self, queue_manager, event_queue):
69+
"""Test tapping an existing queue returns the tapped queue."""
70+
task_id = 'test_task_id'
71+
await queue_manager.add(task_id, event_queue)
72+
73+
result = await queue_manager.tap(task_id)
74+
assert result == event_queue
75+
event_queue.tap.assert_called_once()
76+
77+
@pytest.mark.asyncio
78+
async def test_tap_nonexistent_queue(self, queue_manager):
79+
"""Test tapping a nonexistent queue returns None."""
80+
result = await queue_manager.tap('nonexistent_task_id')
81+
assert result is None
82+
83+
@pytest.mark.asyncio
84+
async def test_close_existing_queue(self, queue_manager, event_queue):
85+
"""Test closing an existing queue removes it from the manager."""
86+
task_id = 'test_task_id'
87+
await queue_manager.add(task_id, event_queue)
88+
89+
await queue_manager.close(task_id)
90+
assert task_id not in queue_manager._task_queue
91+
92+
@pytest.mark.asyncio
93+
async def test_close_nonexistent_queue(self, queue_manager):
94+
"""Test closing a nonexistent queue raises NoTaskQueue."""
95+
with pytest.raises(NoTaskQueue):
96+
await queue_manager.close('nonexistent_task_id')
97+
98+
@pytest.mark.asyncio
99+
async def test_create_or_tap_new_queue(self, queue_manager):
100+
"""Test create_or_tap with a new task_id creates and returns a new queue."""
101+
task_id = 'test_task_id'
102+
103+
result = await queue_manager.create_or_tap(task_id)
104+
assert isinstance(result, EventQueue)
105+
assert queue_manager._task_queue[task_id] == result
106+
107+
@pytest.mark.asyncio
108+
async def test_create_or_tap_existing_queue(
109+
self, queue_manager, event_queue
110+
):
111+
"""Test create_or_tap with an existing task_id taps and returns the existing queue."""
112+
task_id = 'test_task_id'
113+
await queue_manager.add(task_id, event_queue)
114+
115+
result = await queue_manager.create_or_tap(task_id)
116+
117+
assert result == event_queue
118+
event_queue.tap.assert_called_once()
119+
120+
@pytest.mark.asyncio
121+
async def test_concurrency(self, queue_manager):
122+
"""Test concurrent access to the queue manager."""
123+
124+
async def add_task(task_id):
125+
queue = EventQueue()
126+
await queue_manager.add(task_id, queue)
127+
return task_id
128+
129+
async def get_task(task_id):
130+
return await queue_manager.get(task_id)
131+
132+
# Create 10 different task IDs
133+
task_ids = [f'task_{i}' for i in range(10)]
134+
135+
# Add tasks concurrently
136+
add_tasks = [add_task(task_id) for task_id in task_ids]
137+
added_task_ids = await asyncio.gather(*add_tasks)
138+
139+
# Verify all tasks were added
140+
assert set(added_task_ids) == set(task_ids)
141+
142+
# Get tasks concurrently
143+
get_tasks = [get_task(task_id) for task_id in task_ids]
144+
queues = await asyncio.gather(*get_tasks)
145+
146+
# Verify all queues are not None
147+
assert all(queue is not None for queue in queues)
148+
149+
# Verify all tasks are in the manager
150+
for task_id in task_ids:
151+
assert task_id in queue_manager._task_queue

0 commit comments

Comments
 (0)