2020import io .modelcontextprotocol .spec .McpSchema ;
2121import io .modelcontextprotocol .spec .McpTransportSessionNotFoundException ;
2222import reactor .core .publisher .Mono ;
23+ import reactor .core .scheduler .Schedulers ;
2324import reactor .test .StepVerifier ;
25+ import reactor .test .scheduler .VirtualTimeScheduler ;
2426import reactor .util .context .Context ;
2527import reactor .util .context .ContextView ;
2628
3436import static org .mockito .Mockito .when ;
3537
3638/**
37- * Tests for {@link LifecyleInitializer }.
39+ * Tests for {@link LifecycleInitializer }.
3840 */
39- class LifecyleInitializerTests {
41+ class LifecycleInitializerTests {
4042
4143 private static final Duration INITIALIZATION_TIMEOUT = Duration .ofSeconds (5 );
4244
43- private static final Duration SHORT_TIMEOUT = Duration .ofMillis (100 );
44-
4545 private static final McpSchema .ClientCapabilities CLIENT_CAPABILITIES = McpSchema .ClientCapabilities .builder ()
4646 .build ();
4747
@@ -54,56 +54,56 @@ class LifecyleInitializerTests {
5454 "Test instructions" );
5555
5656 @ Mock
57- private McpClientSession mockSession ;
57+ private McpClientSession mockClientSession ;
5858
5959 @ Mock
6060 private Function <ContextView , McpClientSession > mockSessionSupplier ;
6161
62- private LifecyleInitializer initializer ;
62+ private LifecycleInitializer initializer ;
6363
6464 @ BeforeEach
6565 void setUp () {
6666 MockitoAnnotations .openMocks (this );
6767
68- when (mockSessionSupplier .apply (any (ContextView .class ))).thenReturn (mockSession );
69- when (mockSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ()))
68+ when (mockSessionSupplier .apply (any (ContextView .class ))).thenReturn (mockClientSession );
69+ when (mockClientSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ()))
7070 .thenReturn (Mono .just (MOCK_INIT_RESULT ));
71- when (mockSession .sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), any ()))
71+ when (mockClientSession .sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), any ()))
7272 .thenReturn (Mono .empty ());
73- when (mockSession .closeGracefully ()).thenReturn (Mono .empty ());
73+ when (mockClientSession .closeGracefully ()).thenReturn (Mono .empty ());
7474
75- initializer = new LifecyleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , PROTOCOL_VERSIONS ,
75+ initializer = new LifecycleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , PROTOCOL_VERSIONS ,
7676 INITIALIZATION_TIMEOUT , mockSessionSupplier );
7777 }
7878
7979 @ Test
8080 void constructorShouldValidateParameters () {
81- assertThatThrownBy (() -> new LifecyleInitializer (null , CLIENT_INFO , PROTOCOL_VERSIONS , INITIALIZATION_TIMEOUT ,
81+ assertThatThrownBy (() -> new LifecycleInitializer (null , CLIENT_INFO , PROTOCOL_VERSIONS , INITIALIZATION_TIMEOUT ,
8282 mockSessionSupplier ))
8383 .isInstanceOf (IllegalArgumentException .class )
8484 .hasMessageContaining ("Client capabilities must not be null" );
8585
86- assertThatThrownBy (() -> new LifecyleInitializer (CLIENT_CAPABILITIES , null , PROTOCOL_VERSIONS ,
86+ assertThatThrownBy (() -> new LifecycleInitializer (CLIENT_CAPABILITIES , null , PROTOCOL_VERSIONS ,
8787 INITIALIZATION_TIMEOUT , mockSessionSupplier ))
8888 .isInstanceOf (IllegalArgumentException .class )
8989 .hasMessageContaining ("Client info must not be null" );
9090
91- assertThatThrownBy (() -> new LifecyleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , null , INITIALIZATION_TIMEOUT ,
92- mockSessionSupplier ))
91+ assertThatThrownBy (() -> new LifecycleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , null ,
92+ INITIALIZATION_TIMEOUT , mockSessionSupplier ))
9393 .isInstanceOf (IllegalArgumentException .class )
9494 .hasMessageContaining ("Protocol versions must not be empty" );
9595
96- assertThatThrownBy (() -> new LifecyleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , List .of (),
96+ assertThatThrownBy (() -> new LifecycleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , List .of (),
9797 INITIALIZATION_TIMEOUT , mockSessionSupplier ))
9898 .isInstanceOf (IllegalArgumentException .class )
9999 .hasMessageContaining ("Protocol versions must not be empty" );
100100
101- assertThatThrownBy (() -> new LifecyleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , PROTOCOL_VERSIONS , null ,
101+ assertThatThrownBy (() -> new LifecycleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , PROTOCOL_VERSIONS , null ,
102102 mockSessionSupplier ))
103103 .isInstanceOf (IllegalArgumentException .class )
104104 .hasMessageContaining ("Initialization timeout must not be null" );
105105
106- assertThatThrownBy (() -> new LifecyleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , PROTOCOL_VERSIONS ,
106+ assertThatThrownBy (() -> new LifecycleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO , PROTOCOL_VERSIONS ,
107107 INITIALIZATION_TIMEOUT , null ))
108108 .isInstanceOf (IllegalArgumentException .class )
109109 .hasMessageContaining ("Session supplier must not be null" );
@@ -119,15 +119,16 @@ void shouldInitializeSuccessfully() {
119119 })
120120 .verifyComplete ();
121121
122- verify (mockSession ).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (McpSchema .InitializeRequest .class ), any ());
123- verify (mockSession ).sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), eq (null ));
122+ verify (mockClientSession ).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (McpSchema .InitializeRequest .class ),
123+ any ());
124+ verify (mockClientSession ).sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), eq (null ));
124125 }
125126
126127 @ Test
127128 void shouldUseLatestProtocolVersionInInitializeRequest () {
128129 AtomicReference <McpSchema .InitializeRequest > capturedRequest = new AtomicReference <>();
129130
130- when (mockSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ())).thenAnswer (invocation -> {
131+ when (mockClientSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ())).thenAnswer (invocation -> {
131132 capturedRequest .set ((McpSchema .InitializeRequest ) invocation .getArgument (1 ));
132133 return Mono .just (MOCK_INIT_RESULT );
133134 });
@@ -149,26 +150,34 @@ void shouldFailForUnsupportedProtocolVersion() {
149150 McpSchema .ServerCapabilities .builder ().build (), new McpSchema .Implementation ("test-server" , "1.0.0" ),
150151 "Test instructions" );
151152
152- when (mockSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ()))
153+ when (mockClientSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ()))
153154 .thenReturn (Mono .just (unsupportedResult ));
154155
155156 StepVerifier .create (initializer .withIntitialization ("test" , init -> Mono .just (init .initializeResult ())))
156157 .expectError (McpError .class )
157158 .verify ();
158159
159- verify (mockSession , never ()).sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), any ());
160+ verify (mockClientSession , never ()).sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), any ());
160161 }
161162
162163 @ Test
163164 void shouldTimeoutOnSlowInitialization () {
164- LifecyleInitializer shortTimeoutInitializer = new LifecyleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO ,
165- PROTOCOL_VERSIONS , SHORT_TIMEOUT , mockSessionSupplier );
165+ VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler .getOrSet ();
166166
167- when (mockSession .<McpSchema .InitializeResult >sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ()))
168- .thenReturn (Mono .just (MOCK_INIT_RESULT ).delayElement (Duration .ofSeconds (1 )));
167+ Duration INITIALIZE_TIMEOUT = Duration .ofSeconds (1 );
168+ Duration SLOW_RESPONSE_DELAY = Duration .ofSeconds (5 );
169+
170+ LifecycleInitializer shortTimeoutInitializer = new LifecycleInitializer (CLIENT_CAPABILITIES , CLIENT_INFO ,
171+ PROTOCOL_VERSIONS , INITIALIZE_TIMEOUT , mockSessionSupplier );
172+
173+ when (mockClientSession .<McpSchema .InitializeResult >sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ()))
174+ .thenReturn (Mono .just (MOCK_INIT_RESULT ).delayElement (SLOW_RESPONSE_DELAY , virtualTimeScheduler ));
169175
170176 StepVerifier
171- .create (shortTimeoutInitializer .withIntitialization ("test" , init -> Mono .just (init .initializeResult ())))
177+ .withVirtualTime (() -> shortTimeoutInitializer .withIntitialization ("test" ,
178+ init -> Mono .just (init .initializeResult ())), () -> virtualTimeScheduler , Long .MAX_VALUE )
179+ .expectSubscription ()
180+ .expectNoEvent (INITIALIZE_TIMEOUT )
172181 .expectError (McpError .class )
173182 .verify ();
174183 }
@@ -187,7 +196,7 @@ void shouldReuseExistingInitialization() {
187196
188197 // Verify session was created only once
189198 verify (mockSessionSupplier , times (1 )).apply (any (ContextView .class ));
190- verify (mockSession , times (1 )).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ());
199+ verify (mockClientSession , times (1 )).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ());
191200 }
192201
193202 @ Test
@@ -196,13 +205,17 @@ void shouldHandleConcurrentInitializationRequests() {
196205
197206 when (mockSessionSupplier .apply (any (ContextView .class ))).thenAnswer (invocation -> {
198207 sessionCreationCount .incrementAndGet ();
199- return mockSession ;
208+ return mockClientSession ;
200209 });
201210
202- // Start multiple concurrent initializations
203- Mono <String > init1 = initializer .withIntitialization ("test1" , init -> Mono .just ("result1" ));
204- Mono <String > init2 = initializer .withIntitialization ("test2" , init -> Mono .just ("result2" ));
205- Mono <String > init3 = initializer .withIntitialization ("test3" , init -> Mono .just ("result3" ));
211+ // Start multiple concurrent initializations using subscribeOn with parallel
212+ // scheduler
213+ Mono <String > init1 = initializer .withIntitialization ("test1" , init -> Mono .just ("result1" ))
214+ .subscribeOn (Schedulers .parallel ());
215+ Mono <String > init2 = initializer .withIntitialization ("test2" , init -> Mono .just ("result2" ))
216+ .subscribeOn (Schedulers .parallel ());
217+ Mono <String > init3 = initializer .withIntitialization ("test3" , init -> Mono .just ("result3" ))
218+ .subscribeOn (Schedulers .parallel ());
206219
207220 StepVerifier .create (Mono .zip (init1 , init2 , init3 )).assertNext (tuple -> {
208221 assertThat (tuple .getT1 ()).isEqualTo ("result1" );
@@ -212,12 +225,12 @@ void shouldHandleConcurrentInitializationRequests() {
212225
213226 // Should only create one session despite concurrent requests
214227 assertThat (sessionCreationCount .get ()).isEqualTo (1 );
215- verify (mockSession , times (1 )).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ());
228+ verify (mockClientSession , times (1 )).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ());
216229 }
217230
218231 @ Test
219232 void shouldHandleInitializationFailure () {
220- when (mockSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ()))
233+ when (mockClientSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ()))
221234 .thenReturn (Mono .error (new RuntimeException ("Connection failed" )));
222235
223236 StepVerifier .create (initializer .withIntitialization ("test" , init -> Mono .just (init .initializeResult ())))
@@ -230,27 +243,24 @@ void shouldHandleInitializationFailure() {
230243
231244 @ Test
232245 void shouldHandleTransportSessionNotFoundException () {
233- // Simulate a successful initialization first
246+ // successful initialization first
234247 StepVerifier .create (initializer .withIntitialization ("test" , init -> Mono .just (init .initializeResult ())))
235248 .expectNext (MOCK_INIT_RESULT )
236249 .verifyComplete ();
237250
238251 assertThat (initializer .isInitialized ()).isTrue ();
239252
240- // Simulate transport session not found exception
253+ // Simulate transport session not found
241254 initializer .handleException (new McpTransportSessionNotFoundException ("Session not found" ));
242255
243- // The exception handling resets the initialization state and triggers
244- // re-initialization
245- // We need to wait a bit for the async re-initialization to start
246- try {
247- Thread .sleep (10 ); // Small delay to allow async processing
248- }
249- catch (InterruptedException e ) {
250- Thread .currentThread ().interrupt ();
251- }
252-
253- verify (mockSession ).close ();
256+ assertThat (initializer .isInitialized ()).isTrue ();
257+
258+ // Verify that the session was closed and re-initialized
259+ verify (mockClientSession ).close ();
260+
261+ // Verify session was created 2 times (once for initial and once for
262+ // re-initialization)
263+ verify (mockSessionSupplier , times (2 )).apply (any (ContextView .class ));
254264 }
255265
256266 @ Test
@@ -267,34 +277,33 @@ void shouldHandleOtherExceptions() {
267277
268278 // Should still be initialized
269279 assertThat (initializer .isInitialized ()).isTrue ();
270- verify (mockSession , never ()).close ();
280+ verify (mockClientSession , never ()).close ();
281+ // Verify that the session was not re-created
282+ verify (mockSessionSupplier , times (1 )).apply (any (ContextView .class ));
271283 }
272284
273285 @ Test
274286 void shouldCloseGracefully () {
275- // Initialize first
276287 StepVerifier .create (initializer .withIntitialization ("test" , init -> Mono .just (init .initializeResult ())))
277288 .expectNext (MOCK_INIT_RESULT )
278289 .verifyComplete ();
279290
280- // Close gracefully
281291 StepVerifier .create (initializer .closeGracefully ()).verifyComplete ();
282292
283- verify (mockSession ).closeGracefully ();
293+ verify (mockClientSession ).closeGracefully ();
284294 assertThat (initializer .isInitialized ()).isFalse ();
285295 }
286296
287297 @ Test
288298 void shouldCloseImmediately () {
289- // Initialize first
290299 StepVerifier .create (initializer .withIntitialization ("test" , init -> Mono .just (init .initializeResult ())))
291300 .expectNext (MOCK_INIT_RESULT )
292301 .verifyComplete ();
293302
294303 // Close immediately
295304 initializer .close ();
296305
297- verify (mockSession ).close ();
306+ verify (mockClientSession ).close ();
298307 assertThat (initializer .isInitialized ()).isFalse ();
299308 }
300309
@@ -305,8 +314,8 @@ void shouldHandleCloseWithoutInitialization() {
305314
306315 StepVerifier .create (initializer .closeGracefully ()).verifyComplete ();
307316
308- verify (mockSession , never ()).close ();
309- verify (mockSession , never ()).closeGracefully ();
317+ verify (mockClientSession , never ()).close ();
318+ verify (mockClientSession , never ()).closeGracefully ();
310319 }
311320
312321 @ Test
@@ -316,7 +325,7 @@ void shouldSetProtocolVersionsForTesting() {
316325
317326 AtomicReference <McpSchema .InitializeRequest > capturedRequest = new AtomicReference <>();
318327
319- when (mockSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ())).thenAnswer (invocation -> {
328+ when (mockClientSession .sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ())).thenAnswer (invocation -> {
320329 capturedRequest .set ((McpSchema .InitializeRequest ) invocation .getArgument (1 ));
321330 return Mono .just (new McpSchema .InitializeResult ("4.0.0" , McpSchema .ServerCapabilities .builder ().build (),
322331 new McpSchema .Implementation ("test-server" , "1.0.0" ), "Test instructions" ));
@@ -339,7 +348,7 @@ void shouldPassContextToSessionSupplier() {
339348
340349 when (mockSessionSupplier .apply (any (ContextView .class ))).thenAnswer (invocation -> {
341350 capturedContext .set (invocation .getArgument (0 ));
342- return mockSession ;
351+ return mockClientSession ;
343352 });
344353
345354 StepVerifier
@@ -355,23 +364,23 @@ void shouldPassContextToSessionSupplier() {
355364 @ Test
356365 void shouldProvideAccessToMcpSessionAndInitializeResult () {
357366 StepVerifier .create (initializer .withIntitialization ("test" , init -> {
358- assertThat (init .mcpSession ()).isEqualTo (mockSession );
367+ assertThat (init .mcpSession ()).isEqualTo (mockClientSession );
359368 assertThat (init .initializeResult ()).isEqualTo (MOCK_INIT_RESULT );
360369 return Mono .just ("success" );
361370 })).expectNext ("success" ).verifyComplete ();
362371 }
363372
364373 @ Test
365374 void shouldHandleNotificationFailure () {
366- when (mockSession .sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), any ()))
375+ when (mockClientSession .sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), any ()))
367376 .thenReturn (Mono .error (new RuntimeException ("Notification failed" )));
368377
369378 StepVerifier .create (initializer .withIntitialization ("test" , init -> Mono .just (init .initializeResult ())))
370379 .expectError (RuntimeException .class )
371380 .verify ();
372381
373- verify (mockSession ).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ());
374- verify (mockSession ).sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), eq (null ));
382+ verify (mockClientSession ).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ());
383+ verify (mockClientSession ).sendNotification (eq (McpSchema .METHOD_NOTIFICATION_INITIALIZED ), eq (null ));
375384 }
376385
377386 @ Test
@@ -397,7 +406,7 @@ void shouldReinitializeAfterTransportSessionException() {
397406
398407 // Verify two separate initializations occurred
399408 verify (mockSessionSupplier , times (2 )).apply (any (ContextView .class ));
400- verify (mockSession , times (2 )).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ());
409+ verify (mockClientSession , times (2 )).sendRequest (eq (McpSchema .METHOD_INITIALIZE ), any (), any ());
401410 }
402411
403412}
0 commit comments