Skip to content

Commit 63a8eba

Browse files
wuliang229copybara-github
authored andcommitted
fix: Ensure database sessions are always rolled back on errors
Fixes issue#3328 Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 863314939
1 parent 47221cd commit 63a8eba

File tree

2 files changed

+238
-6
lines changed

2 files changed

+238
-6
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
from __future__ import annotations
1515

1616
import asyncio
17+
from contextlib import asynccontextmanager
1718
import copy
1819
from datetime import datetime
1920
from datetime import timezone
2021
import logging
2122
from typing import Any
23+
from typing import AsyncIterator
2224
from typing import Optional
2325

2426
from sqlalchemy import delete
@@ -156,6 +158,23 @@ def __init__(self, db_url: str, **kwargs: Any):
156158
def _get_schema_classes(self) -> _SchemaClasses:
157159
return _SchemaClasses(self._db_schema_version)
158160

161+
@asynccontextmanager
162+
async def _rollback_on_exception_session(
163+
self,
164+
) -> AsyncIterator[DatabaseSessionFactory]:
165+
"""Yields a database session with guaranteed rollback on errors.
166+
167+
On normal exit the caller is responsible for committing; on any exception
168+
the transaction is explicitly rolled back before the error propagates,
169+
preventing connection-pool exhaustion from lingering invalid transactions.
170+
"""
171+
async with self.database_session_factory() as sql_session:
172+
try:
173+
yield sql_session
174+
except BaseException:
175+
await sql_session.rollback()
176+
raise
177+
159178
async def _prepare_tables(self):
160179
"""Ensure database tables are ready for use.
161180
@@ -204,7 +223,7 @@ async def _prepare_tables(self):
204223
self._tables_created = True
205224

206225
if self._db_schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION:
207-
async with self.database_session_factory() as sql_session:
226+
async with self._rollback_on_exception_session() as sql_session:
208227
# Check if schema version is set, if not, set it to the latest
209228
# version
210229
stmt = select(StorageMetadata).where(
@@ -236,7 +255,7 @@ async def create_session(
236255
# 5. Return the session
237256
await self._prepare_tables()
238257
schema = self._get_schema_classes()
239-
async with self.database_session_factory() as sql_session:
258+
async with self._rollback_on_exception_session() as sql_session:
240259
if session_id and await sql_session.get(
241260
schema.StorageSession, (app_name, user_id, session_id)
242261
):
@@ -313,7 +332,7 @@ async def get_session(
313332
# 2. Get all the events based on session id and filtering config
314333
# 3. Convert and return the session
315334
schema = self._get_schema_classes()
316-
async with self.database_session_factory() as sql_session:
335+
async with self._rollback_on_exception_session() as sql_session:
317336
storage_session = await sql_session.get(
318337
schema.StorageSession, (app_name, user_id, session_id)
319338
)
@@ -368,7 +387,7 @@ async def list_sessions(
368387
) -> ListSessionsResponse:
369388
await self._prepare_tables()
370389
schema = self._get_schema_classes()
371-
async with self.database_session_factory() as sql_session:
390+
async with self._rollback_on_exception_session() as sql_session:
372391
stmt = select(schema.StorageSession).filter(
373392
schema.StorageSession.app_name == app_name
374393
)
@@ -418,7 +437,7 @@ async def delete_session(
418437
) -> None:
419438
await self._prepare_tables()
420439
schema = self._get_schema_classes()
421-
async with self.database_session_factory() as sql_session:
440+
async with self._rollback_on_exception_session() as sql_session:
422441
stmt = delete(schema.StorageSession).where(
423442
schema.StorageSession.app_name == app_name,
424443
schema.StorageSession.user_id == user_id,
@@ -440,7 +459,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
440459
# 2. Update session attributes based on event config
441460
# 3. Store event to table
442461
schema = self._get_schema_classes()
443-
async with self.database_session_factory() as sql_session:
462+
async with self._rollback_on_exception_session() as sql_session:
444463
storage_session = await sql_session.get(
445464
schema.StorageSession, (session.app_name, session.user_id, session.id)
446465
)

tests/unittests/sessions/test_session_service.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,3 +643,216 @@ async def test_partial_events_are_not_persisted(session_service):
643643
app_name=app_name, user_id=user_id, session_id=session.id
644644
)
645645
assert len(session_got.events) == 0
646+
647+
648+
# ---------------------------------------------------------------------------
649+
# Rollback tests – verify _rollback_on_exception_session explicitly rolls back
650+
# on errors
651+
# ---------------------------------------------------------------------------
652+
class _RollbackSpySession:
653+
"""Wraps an AsyncSession to spy on rollback() and optionally fail commit()."""
654+
655+
def __init__(self, real_session, *, fail_commit=False):
656+
self._real = real_session
657+
self._fail_commit = fail_commit
658+
self.rollback_called = False
659+
660+
async def __aenter__(self):
661+
self._real = await self._real.__aenter__()
662+
return self
663+
664+
async def __aexit__(self, *args):
665+
return await self._real.__aexit__(*args)
666+
667+
async def commit(self):
668+
if self._fail_commit:
669+
raise RuntimeError('simulated commit failure')
670+
return await self._real.commit()
671+
672+
async def rollback(self):
673+
self.rollback_called = True
674+
return await self._real.rollback()
675+
676+
def __getattr__(self, name):
677+
return getattr(self._real, name)
678+
679+
680+
@pytest.mark.asyncio
681+
async def test_create_session_calls_rollback_on_commit_failure():
682+
"""Verifies that a commit failure during create_session triggers an explicit
683+
rollback() call via _rollback_on_exception_session, not just a close()."""
684+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
685+
try:
686+
# Ensure tables are initialized.
687+
await service.create_session(
688+
app_name='app', user_id='user', session_id='good'
689+
)
690+
691+
original_factory = service.database_session_factory
692+
spy_sessions = []
693+
694+
def _spy_factory():
695+
spy = _RollbackSpySession(original_factory(), fail_commit=True)
696+
spy_sessions.append(spy)
697+
return spy
698+
699+
service.database_session_factory = _spy_factory
700+
701+
with pytest.raises(RuntimeError, match='simulated commit failure'):
702+
await service.create_session(
703+
app_name='app', user_id='user', session_id='should_fail'
704+
)
705+
706+
# The key assertion: rollback() must have been called explicitly.
707+
assert len(spy_sessions) == 1
708+
assert spy_sessions[0].rollback_called, (
709+
'rollback() was not called – _rollback_on_exception_session is not'
710+
' protecting this path'
711+
)
712+
713+
# Restore and verify the failed session was not persisted.
714+
service.database_session_factory = original_factory
715+
assert (
716+
await service.get_session(
717+
app_name='app', user_id='user', session_id='should_fail'
718+
)
719+
is None
720+
)
721+
finally:
722+
await service.close()
723+
724+
725+
@pytest.mark.asyncio
726+
async def test_append_event_calls_rollback_on_commit_failure():
727+
"""Verifies that a commit failure during append_event triggers an explicit
728+
rollback() call via _rollback_on_exception_session."""
729+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
730+
try:
731+
session = await service.create_session(
732+
app_name='app', user_id='user', session_id='s1'
733+
)
734+
735+
# Successfully append one event first.
736+
event1 = Event(
737+
invocation_id='inv1',
738+
author='user',
739+
actions=EventActions(state_delta={'key1': 'value1'}),
740+
)
741+
await service.append_event(session, event1)
742+
743+
original_factory = service.database_session_factory
744+
spy_sessions = []
745+
746+
def _spy_factory():
747+
spy = _RollbackSpySession(original_factory(), fail_commit=True)
748+
spy_sessions.append(spy)
749+
return spy
750+
751+
service.database_session_factory = _spy_factory
752+
753+
event2 = Event(
754+
invocation_id='inv2',
755+
author='user',
756+
actions=EventActions(state_delta={'key2': 'value2'}),
757+
)
758+
with pytest.raises(RuntimeError, match='simulated commit failure'):
759+
await service.append_event(session, event2)
760+
761+
assert len(spy_sessions) == 1
762+
assert spy_sessions[0].rollback_called, (
763+
'rollback() was not called – _rollback_on_exception_session is not'
764+
' protecting this path'
765+
)
766+
767+
# Restore and verify only the first event was persisted.
768+
service.database_session_factory = original_factory
769+
got = await service.get_session(
770+
app_name='app', user_id='user', session_id='s1'
771+
)
772+
assert len(got.events) == 1
773+
assert got.events[0].invocation_id == 'inv1'
774+
finally:
775+
await service.close()
776+
777+
778+
@pytest.mark.asyncio
779+
async def test_delete_session_calls_rollback_on_commit_failure():
780+
"""Verifies that a commit failure during delete_session triggers an explicit
781+
rollback() call via _rollback_on_exception_session."""
782+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
783+
try:
784+
await service.create_session(
785+
app_name='app', user_id='user', session_id='s1'
786+
)
787+
788+
original_factory = service.database_session_factory
789+
spy_sessions = []
790+
791+
def _spy_factory():
792+
spy = _RollbackSpySession(original_factory(), fail_commit=True)
793+
spy_sessions.append(spy)
794+
return spy
795+
796+
service.database_session_factory = _spy_factory
797+
798+
with pytest.raises(RuntimeError, match='simulated commit failure'):
799+
await service.delete_session(
800+
app_name='app', user_id='user', session_id='s1'
801+
)
802+
803+
assert len(spy_sessions) == 1
804+
assert spy_sessions[0].rollback_called, (
805+
'rollback() was not called – _rollback_on_exception_session is not'
806+
' protecting this path'
807+
)
808+
809+
# Restore and verify the session still exists (delete was rolled back).
810+
service.database_session_factory = original_factory
811+
got = await service.get_session(
812+
app_name='app', user_id='user', session_id='s1'
813+
)
814+
assert got is not None
815+
finally:
816+
await service.close()
817+
818+
819+
@pytest.mark.asyncio
820+
async def test_service_recovers_after_multiple_failures():
821+
"""After several consecutive commit failures, every single one must trigger
822+
a rollback() call and the service must remain functional afterward."""
823+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
824+
try:
825+
await service.create_session(
826+
app_name='app', user_id='user', session_id='seed'
827+
)
828+
829+
original_factory = service.database_session_factory
830+
spy_sessions = []
831+
832+
def _spy_factory():
833+
spy = _RollbackSpySession(original_factory(), fail_commit=True)
834+
spy_sessions.append(spy)
835+
return spy
836+
837+
service.database_session_factory = _spy_factory
838+
839+
num_failures = 5
840+
for i in range(num_failures):
841+
with pytest.raises(RuntimeError, match='simulated commit failure'):
842+
await service.create_session(
843+
app_name='app', user_id='user', session_id=f'fail_{i}'
844+
)
845+
846+
# Every failure must have triggered a rollback.
847+
assert len(spy_sessions) == num_failures
848+
for i, spy in enumerate(spy_sessions):
849+
assert spy.rollback_called, f'rollback() was not called on failure #{i}'
850+
851+
# Restore and verify the service is still healthy.
852+
service.database_session_factory = original_factory
853+
session = await service.create_session(
854+
app_name='app', user_id='user', session_id='recovered'
855+
)
856+
assert session.id == 'recovered'
857+
finally:
858+
await service.close()

0 commit comments

Comments
 (0)