@@ -643,3 +643,216 @@ async def test_partial_events_are_not_persisted(session_service):
643643 app_name = app_name , user_id = user_id , session_id = session .id
644644 )
645645 assert len (session_got .events ) == 0
646+
647+
648+ # ---------------------------------------------------------------------------
649+ # Rollback tests – verify _rollback_on_exception_session explicitly rolls back
650+ # on errors
651+ # ---------------------------------------------------------------------------
652+ class _RollbackSpySession :
653+ """Wraps an AsyncSession to spy on rollback() and optionally fail commit()."""
654+
655+ def __init__ (self , real_session , * , fail_commit = False ):
656+ self ._real = real_session
657+ self ._fail_commit = fail_commit
658+ self .rollback_called = False
659+
660+ async def __aenter__ (self ):
661+ self ._real = await self ._real .__aenter__ ()
662+ return self
663+
664+ async def __aexit__ (self , * args ):
665+ return await self ._real .__aexit__ (* args )
666+
667+ async def commit (self ):
668+ if self ._fail_commit :
669+ raise RuntimeError ('simulated commit failure' )
670+ return await self ._real .commit ()
671+
672+ async def rollback (self ):
673+ self .rollback_called = True
674+ return await self ._real .rollback ()
675+
676+ def __getattr__ (self , name ):
677+ return getattr (self ._real , name )
678+
679+
680+ @pytest .mark .asyncio
681+ async def test_create_session_calls_rollback_on_commit_failure ():
682+ """Verifies that a commit failure during create_session triggers an explicit
683+ rollback() call via _rollback_on_exception_session, not just a close()."""
684+ service = DatabaseSessionService ('sqlite+aiosqlite:///:memory:' )
685+ try :
686+ # Ensure tables are initialized.
687+ await service .create_session (
688+ app_name = 'app' , user_id = 'user' , session_id = 'good'
689+ )
690+
691+ original_factory = service .database_session_factory
692+ spy_sessions = []
693+
694+ def _spy_factory ():
695+ spy = _RollbackSpySession (original_factory (), fail_commit = True )
696+ spy_sessions .append (spy )
697+ return spy
698+
699+ service .database_session_factory = _spy_factory
700+
701+ with pytest .raises (RuntimeError , match = 'simulated commit failure' ):
702+ await service .create_session (
703+ app_name = 'app' , user_id = 'user' , session_id = 'should_fail'
704+ )
705+
706+ # The key assertion: rollback() must have been called explicitly.
707+ assert len (spy_sessions ) == 1
708+ assert spy_sessions [0 ].rollback_called , (
709+ 'rollback() was not called – _rollback_on_exception_session is not'
710+ ' protecting this path'
711+ )
712+
713+ # Restore and verify the failed session was not persisted.
714+ service .database_session_factory = original_factory
715+ assert (
716+ await service .get_session (
717+ app_name = 'app' , user_id = 'user' , session_id = 'should_fail'
718+ )
719+ is None
720+ )
721+ finally :
722+ await service .close ()
723+
724+
725+ @pytest .mark .asyncio
726+ async def test_append_event_calls_rollback_on_commit_failure ():
727+ """Verifies that a commit failure during append_event triggers an explicit
728+ rollback() call via _rollback_on_exception_session."""
729+ service = DatabaseSessionService ('sqlite+aiosqlite:///:memory:' )
730+ try :
731+ session = await service .create_session (
732+ app_name = 'app' , user_id = 'user' , session_id = 's1'
733+ )
734+
735+ # Successfully append one event first.
736+ event1 = Event (
737+ invocation_id = 'inv1' ,
738+ author = 'user' ,
739+ actions = EventActions (state_delta = {'key1' : 'value1' }),
740+ )
741+ await service .append_event (session , event1 )
742+
743+ original_factory = service .database_session_factory
744+ spy_sessions = []
745+
746+ def _spy_factory ():
747+ spy = _RollbackSpySession (original_factory (), fail_commit = True )
748+ spy_sessions .append (spy )
749+ return spy
750+
751+ service .database_session_factory = _spy_factory
752+
753+ event2 = Event (
754+ invocation_id = 'inv2' ,
755+ author = 'user' ,
756+ actions = EventActions (state_delta = {'key2' : 'value2' }),
757+ )
758+ with pytest .raises (RuntimeError , match = 'simulated commit failure' ):
759+ await service .append_event (session , event2 )
760+
761+ assert len (spy_sessions ) == 1
762+ assert spy_sessions [0 ].rollback_called , (
763+ 'rollback() was not called – _rollback_on_exception_session is not'
764+ ' protecting this path'
765+ )
766+
767+ # Restore and verify only the first event was persisted.
768+ service .database_session_factory = original_factory
769+ got = await service .get_session (
770+ app_name = 'app' , user_id = 'user' , session_id = 's1'
771+ )
772+ assert len (got .events ) == 1
773+ assert got .events [0 ].invocation_id == 'inv1'
774+ finally :
775+ await service .close ()
776+
777+
778+ @pytest .mark .asyncio
779+ async def test_delete_session_calls_rollback_on_commit_failure ():
780+ """Verifies that a commit failure during delete_session triggers an explicit
781+ rollback() call via _rollback_on_exception_session."""
782+ service = DatabaseSessionService ('sqlite+aiosqlite:///:memory:' )
783+ try :
784+ await service .create_session (
785+ app_name = 'app' , user_id = 'user' , session_id = 's1'
786+ )
787+
788+ original_factory = service .database_session_factory
789+ spy_sessions = []
790+
791+ def _spy_factory ():
792+ spy = _RollbackSpySession (original_factory (), fail_commit = True )
793+ spy_sessions .append (spy )
794+ return spy
795+
796+ service .database_session_factory = _spy_factory
797+
798+ with pytest .raises (RuntimeError , match = 'simulated commit failure' ):
799+ await service .delete_session (
800+ app_name = 'app' , user_id = 'user' , session_id = 's1'
801+ )
802+
803+ assert len (spy_sessions ) == 1
804+ assert spy_sessions [0 ].rollback_called , (
805+ 'rollback() was not called – _rollback_on_exception_session is not'
806+ ' protecting this path'
807+ )
808+
809+ # Restore and verify the session still exists (delete was rolled back).
810+ service .database_session_factory = original_factory
811+ got = await service .get_session (
812+ app_name = 'app' , user_id = 'user' , session_id = 's1'
813+ )
814+ assert got is not None
815+ finally :
816+ await service .close ()
817+
818+
819+ @pytest .mark .asyncio
820+ async def test_service_recovers_after_multiple_failures ():
821+ """After several consecutive commit failures, every single one must trigger
822+ a rollback() call and the service must remain functional afterward."""
823+ service = DatabaseSessionService ('sqlite+aiosqlite:///:memory:' )
824+ try :
825+ await service .create_session (
826+ app_name = 'app' , user_id = 'user' , session_id = 'seed'
827+ )
828+
829+ original_factory = service .database_session_factory
830+ spy_sessions = []
831+
832+ def _spy_factory ():
833+ spy = _RollbackSpySession (original_factory (), fail_commit = True )
834+ spy_sessions .append (spy )
835+ return spy
836+
837+ service .database_session_factory = _spy_factory
838+
839+ num_failures = 5
840+ for i in range (num_failures ):
841+ with pytest .raises (RuntimeError , match = 'simulated commit failure' ):
842+ await service .create_session (
843+ app_name = 'app' , user_id = 'user' , session_id = f'fail_{ i } '
844+ )
845+
846+ # Every failure must have triggered a rollback.
847+ assert len (spy_sessions ) == num_failures
848+ for i , spy in enumerate (spy_sessions ):
849+ assert spy .rollback_called , f'rollback() was not called on failure #{ i } '
850+
851+ # Restore and verify the service is still healthy.
852+ service .database_session_factory = original_factory
853+ session = await service .create_session (
854+ app_name = 'app' , user_id = 'user' , session_id = 'recovered'
855+ )
856+ assert session .id == 'recovered'
857+ finally :
858+ await service .close ()
0 commit comments