Skip to content

Commit 40818c7

Browse files
test(sessions): add regression test for event pagination client closure
Add test_get_session_pagination_keeps_client_open to verify that get_session() keeps the API client open while iterating through paginated events. The test uses a mock client that tracks context state and raises RuntimeError if iteration occurs outside the async with block, matching real httpx behavior.
1 parent 51e66e4 commit 40818c7

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,102 @@ async def _append_event(
396396
else:
397397
self.event_dict[session_id] = ([event_json], None)
398398

399+
class MockAsyncClientWithPagination:
400+
"""Mock client that simulates pagination requiring an open client connection.
401+
402+
This mock tracks whether the client context is active and raises RuntimeError
403+
if iteration occurs outside the context, simulating the real httpx behavior.
404+
"""
405+
406+
def __init__(self, session_data: dict, events_pages: list[list[dict]]):
407+
self._session_data = session_data
408+
self._events_pages = events_pages
409+
self._context_active = False
410+
self.agent_engines = mock.AsyncMock()
411+
self.agent_engines.sessions.get.side_effect = self._get_session
412+
self.agent_engines.sessions.events.list.side_effect = self._list_events
413+
414+
async def __aenter__(self):
415+
self._context_active = True
416+
return self
417+
418+
async def __aexit__(self, exc_type, exc_val, exc_tb):
419+
self._context_active = False
420+
421+
async def _get_session(self, name: str):
422+
return _convert_to_object(self._session_data)
423+
424+
async def _list_events(self, name: str, **kwargs):
425+
return self._paginated_events_iterator()
426+
427+
async def _paginated_events_iterator(self):
428+
for page in self._events_pages:
429+
for event in page:
430+
if not self._context_active:
431+
raise RuntimeError(
432+
'Cannot send a request, as the client has been closed.'
433+
)
434+
yield _convert_to_object(event)
435+
436+
437+
def _generate_events_for_page(session_id: str, start_idx: int, count: int):
438+
events = []
439+
start_time = isoparse('2024-12-12T12:12:12.123456Z')
440+
for i in range(count):
441+
idx = start_idx + i
442+
event_time = start_time + datetime.timedelta(microseconds=idx * 1000)
443+
events.append({
444+
'name': (
445+
'projects/test-project/locations/test-location/'
446+
f'reasoningEngines/123/sessions/{session_id}/events/{idx}'
447+
),
448+
'invocation_id': f'invocation_{idx}',
449+
'author': 'pagination_user',
450+
'timestamp': event_time.isoformat().replace('+00:00', 'Z'),
451+
})
452+
return events
453+
454+
455+
@pytest.mark.asyncio
456+
async def test_get_session_pagination_keeps_client_open():
457+
"""Regression test: event iteration must occur inside the api_client context.
458+
459+
This test verifies that get_session() keeps the API client open while
460+
iterating through paginated events. Before the fix, the events_iterator
461+
was consumed outside the async with block, causing RuntimeError when
462+
fetching subsequent pages.
463+
"""
464+
session_data = {
465+
'name': (
466+
'projects/test-project/locations/test-location/'
467+
'reasoningEngines/123/sessions/pagination_test'
468+
),
469+
'update_time': '2024-12-12T12:12:12.123456Z',
470+
'user_id': 'pagination_user',
471+
}
472+
page1_events = _generate_events_for_page('pagination_test', 0, 100)
473+
page2_events = _generate_events_for_page('pagination_test', 100, 100)
474+
page3_events = _generate_events_for_page('pagination_test', 200, 50)
475+
476+
mock_client = MockAsyncClientWithPagination(
477+
session_data=session_data,
478+
events_pages=[page1_events, page2_events, page3_events],
479+
)
480+
481+
session_service = mock_vertex_ai_session_service()
482+
483+
with mock.patch.object(
484+
session_service, '_get_api_client', return_value=mock_client
485+
):
486+
session = await session_service.get_session(
487+
app_name='123', user_id='pagination_user', session_id='pagination_test'
488+
)
489+
490+
assert session is not None
491+
assert len(session.events) == 250
492+
assert session.events[0].invocation_id == 'invocation_0'
493+
assert session.events[249].invocation_id == 'invocation_249'
494+
399495

400496
def mock_vertex_ai_session_service(
401497
project: Optional[str] = 'test-project',

0 commit comments

Comments
 (0)