Skip to content

Commit 00d8467

Browse files
committed
Address review comments
Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent ef140fc commit 00d8467

File tree

3 files changed

+79
-70
lines changed

3 files changed

+79
-70
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/LifecyleInitializer.java renamed to mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@
7272
* the initialized notification</li>
7373
* </ul>
7474
*/
75-
public class LifecyleInitializer {
75+
class LifecycleInitializer {
7676

77-
private static final Logger logger = LoggerFactory.getLogger(LifecyleInitializer.class);
77+
private static final Logger logger = LoggerFactory.getLogger(LifecycleInitializer.class);
7878

7979
/**
8080
* The MCP session supplier that manages bidirectional JSON-RPC communication between
@@ -95,7 +95,7 @@ public class LifecyleInitializer {
9595
*/
9696
private final Duration initializationTimeout;
9797

98-
public LifecyleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo,
98+
public LifecycleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo,
9999
List<String> protocolVersions, Duration initializationTimeout,
100100
Function<ContextView, McpClientSession> sessionSupplier) {
101101

@@ -233,7 +233,7 @@ private Mono<Void> closeGracefully() {
233233
}
234234

235235
public boolean isInitialized() {
236-
return currentInitializationResult() != null;
236+
return this.currentInitializationResult() != null;
237237
}
238238

239239
public McpSchema.InitializeResult currentInitializationResult() {

mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ public class McpAsyncClient {
145145
/**
146146
* The lifecycle initializer that manages the client-server connection initialization.
147147
*/
148-
private LifecyleInitializer initializer;
148+
private final LifecycleInitializer initializer;
149149

150150
/**
151151
* Create a new McpAsyncClient with the given transport and session request-response
@@ -253,7 +253,7 @@ public class McpAsyncClient {
253253
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE,
254254
asyncLoggingNotificationHandler(loggingConsumersFinal));
255255

256-
this.initializer = new LifecyleInitializer(clientCapabilities, clientInfo,
256+
this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo,
257257
List.of(McpSchema.LATEST_PROTOCOL_VERSION), initializationTimeout,
258258
ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers,
259259
con -> con.contextWrite(ctx)));

mcp/src/test/java/io/modelcontextprotocol/client/LifecyleInitializerTests.java renamed to mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java

Lines changed: 73 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import io.modelcontextprotocol.spec.McpSchema;
2121
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
2222
import reactor.core.publisher.Mono;
23+
import reactor.core.scheduler.Schedulers;
2324
import reactor.test.StepVerifier;
25+
import reactor.test.scheduler.VirtualTimeScheduler;
2426
import reactor.util.context.Context;
2527
import reactor.util.context.ContextView;
2628

@@ -34,14 +36,12 @@
3436
import static org.mockito.Mockito.when;
3537

3638
/**
37-
* Tests for {@link LifecyleInitializer}.
39+
* Tests for {@link LifecycleInitializer}.
3840
*/
39-
class LifecyleInitializerTests {
41+
class LifecycleInitializerTests {
4042

4143
private static final Duration INITIALIZATION_TIMEOUT = Duration.ofSeconds(5);
4244

43-
private static final Duration SHORT_TIMEOUT = Duration.ofMillis(100);
44-
4545
private static final McpSchema.ClientCapabilities CLIENT_CAPABILITIES = McpSchema.ClientCapabilities.builder()
4646
.build();
4747

@@ -54,56 +54,56 @@ class LifecyleInitializerTests {
5454
"Test instructions");
5555

5656
@Mock
57-
private McpClientSession mockSession;
57+
private McpClientSession mockClientSession;
5858

5959
@Mock
6060
private Function<ContextView, McpClientSession> mockSessionSupplier;
6161

62-
private LifecyleInitializer initializer;
62+
private LifecycleInitializer initializer;
6363

6464
@BeforeEach
6565
void setUp() {
6666
MockitoAnnotations.openMocks(this);
6767

68-
when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockSession);
69-
when(mockSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
68+
when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession);
69+
when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
7070
.thenReturn(Mono.just(MOCK_INIT_RESULT));
71-
when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()))
71+
when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()))
7272
.thenReturn(Mono.empty());
73-
when(mockSession.closeGracefully()).thenReturn(Mono.empty());
73+
when(mockClientSession.closeGracefully()).thenReturn(Mono.empty());
7474

75-
initializer = new LifecyleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS,
75+
initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS,
7676
INITIALIZATION_TIMEOUT, mockSessionSupplier);
7777
}
7878

7979
@Test
8080
void constructorShouldValidateParameters() {
81-
assertThatThrownBy(() -> new LifecyleInitializer(null, CLIENT_INFO, PROTOCOL_VERSIONS, INITIALIZATION_TIMEOUT,
81+
assertThatThrownBy(() -> new LifecycleInitializer(null, CLIENT_INFO, PROTOCOL_VERSIONS, INITIALIZATION_TIMEOUT,
8282
mockSessionSupplier))
8383
.isInstanceOf(IllegalArgumentException.class)
8484
.hasMessageContaining("Client capabilities must not be null");
8585

86-
assertThatThrownBy(() -> new LifecyleInitializer(CLIENT_CAPABILITIES, null, PROTOCOL_VERSIONS,
86+
assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, null, PROTOCOL_VERSIONS,
8787
INITIALIZATION_TIMEOUT, mockSessionSupplier))
8888
.isInstanceOf(IllegalArgumentException.class)
8989
.hasMessageContaining("Client info must not be null");
9090

91-
assertThatThrownBy(() -> new LifecyleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, null, INITIALIZATION_TIMEOUT,
92-
mockSessionSupplier))
91+
assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, null,
92+
INITIALIZATION_TIMEOUT, mockSessionSupplier))
9393
.isInstanceOf(IllegalArgumentException.class)
9494
.hasMessageContaining("Protocol versions must not be empty");
9595

96-
assertThatThrownBy(() -> new LifecyleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, List.of(),
96+
assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, List.of(),
9797
INITIALIZATION_TIMEOUT, mockSessionSupplier))
9898
.isInstanceOf(IllegalArgumentException.class)
9999
.hasMessageContaining("Protocol versions must not be empty");
100100

101-
assertThatThrownBy(() -> new LifecyleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, null,
101+
assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, null,
102102
mockSessionSupplier))
103103
.isInstanceOf(IllegalArgumentException.class)
104104
.hasMessageContaining("Initialization timeout must not be null");
105105

106-
assertThatThrownBy(() -> new LifecyleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS,
106+
assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS,
107107
INITIALIZATION_TIMEOUT, null))
108108
.isInstanceOf(IllegalArgumentException.class)
109109
.hasMessageContaining("Session supplier must not be null");
@@ -119,15 +119,16 @@ void shouldInitializeSuccessfully() {
119119
})
120120
.verifyComplete();
121121

122-
verify(mockSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(McpSchema.InitializeRequest.class), any());
123-
verify(mockSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null));
122+
verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(McpSchema.InitializeRequest.class),
123+
any());
124+
verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null));
124125
}
125126

126127
@Test
127128
void shouldUseLatestProtocolVersionInInitializeRequest() {
128129
AtomicReference<McpSchema.InitializeRequest> capturedRequest = new AtomicReference<>();
129130

130-
when(mockSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> {
131+
when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> {
131132
capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1));
132133
return Mono.just(MOCK_INIT_RESULT);
133134
});
@@ -149,26 +150,34 @@ void shouldFailForUnsupportedProtocolVersion() {
149150
McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"),
150151
"Test instructions");
151152

152-
when(mockSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
153+
when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
153154
.thenReturn(Mono.just(unsupportedResult));
154155

155156
StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
156157
.expectError(McpError.class)
157158
.verify();
158159

159-
verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any());
160+
verify(mockClientSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any());
160161
}
161162

162163
@Test
163164
void shouldTimeoutOnSlowInitialization() {
164-
LifecyleInitializer shortTimeoutInitializer = new LifecyleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO,
165-
PROTOCOL_VERSIONS, SHORT_TIMEOUT, mockSessionSupplier);
165+
VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet();
166166

167-
when(mockSession.<McpSchema.InitializeResult>sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
168-
.thenReturn(Mono.just(MOCK_INIT_RESULT).delayElement(Duration.ofSeconds(1)));
167+
Duration INITIALIZE_TIMEOUT = Duration.ofSeconds(1);
168+
Duration SLOW_RESPONSE_DELAY = Duration.ofSeconds(5);
169+
170+
LifecycleInitializer shortTimeoutInitializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO,
171+
PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier);
172+
173+
when(mockClientSession.<McpSchema.InitializeResult>sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
174+
.thenReturn(Mono.just(MOCK_INIT_RESULT).delayElement(SLOW_RESPONSE_DELAY, virtualTimeScheduler));
169175

170176
StepVerifier
171-
.create(shortTimeoutInitializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
177+
.withVirtualTime(() -> shortTimeoutInitializer.withIntitialization("test",
178+
init -> Mono.just(init.initializeResult())), () -> virtualTimeScheduler, Long.MAX_VALUE)
179+
.expectSubscription()
180+
.expectNoEvent(INITIALIZE_TIMEOUT)
172181
.expectError(McpError.class)
173182
.verify();
174183
}
@@ -187,7 +196,7 @@ void shouldReuseExistingInitialization() {
187196

188197
// Verify session was created only once
189198
verify(mockSessionSupplier, times(1)).apply(any(ContextView.class));
190-
verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
199+
verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
191200
}
192201

193202
@Test
@@ -196,13 +205,17 @@ void shouldHandleConcurrentInitializationRequests() {
196205

197206
when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> {
198207
sessionCreationCount.incrementAndGet();
199-
return mockSession;
208+
return mockClientSession;
200209
});
201210

202-
// Start multiple concurrent initializations
203-
Mono<String> init1 = initializer.withIntitialization("test1", init -> Mono.just("result1"));
204-
Mono<String> init2 = initializer.withIntitialization("test2", init -> Mono.just("result2"));
205-
Mono<String> init3 = initializer.withIntitialization("test3", init -> Mono.just("result3"));
211+
// Start multiple concurrent initializations using subscribeOn with parallel
212+
// scheduler
213+
Mono<String> init1 = initializer.withIntitialization("test1", init -> Mono.just("result1"))
214+
.subscribeOn(Schedulers.parallel());
215+
Mono<String> init2 = initializer.withIntitialization("test2", init -> Mono.just("result2"))
216+
.subscribeOn(Schedulers.parallel());
217+
Mono<String> init3 = initializer.withIntitialization("test3", init -> Mono.just("result3"))
218+
.subscribeOn(Schedulers.parallel());
206219

207220
StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> {
208221
assertThat(tuple.getT1()).isEqualTo("result1");
@@ -212,12 +225,12 @@ void shouldHandleConcurrentInitializationRequests() {
212225

213226
// Should only create one session despite concurrent requests
214227
assertThat(sessionCreationCount.get()).isEqualTo(1);
215-
verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
228+
verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
216229
}
217230

218231
@Test
219232
void shouldHandleInitializationFailure() {
220-
when(mockSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
233+
when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
221234
.thenReturn(Mono.error(new RuntimeException("Connection failed")));
222235

223236
StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
@@ -230,27 +243,24 @@ void shouldHandleInitializationFailure() {
230243

231244
@Test
232245
void shouldHandleTransportSessionNotFoundException() {
233-
// Simulate a successful initialization first
246+
// successful initialization first
234247
StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
235248
.expectNext(MOCK_INIT_RESULT)
236249
.verifyComplete();
237250

238251
assertThat(initializer.isInitialized()).isTrue();
239252

240-
// Simulate transport session not found exception
253+
// Simulate transport session not found
241254
initializer.handleException(new McpTransportSessionNotFoundException("Session not found"));
242255

243-
// The exception handling resets the initialization state and triggers
244-
// re-initialization
245-
// We need to wait a bit for the async re-initialization to start
246-
try {
247-
Thread.sleep(10); // Small delay to allow async processing
248-
}
249-
catch (InterruptedException e) {
250-
Thread.currentThread().interrupt();
251-
}
252-
253-
verify(mockSession).close();
256+
assertThat(initializer.isInitialized()).isTrue();
257+
258+
// Verify that the session was closed and re-initialized
259+
verify(mockClientSession).close();
260+
261+
// Verify session was created 2 times (once for initial and once for
262+
// re-initialization)
263+
verify(mockSessionSupplier, times(2)).apply(any(ContextView.class));
254264
}
255265

256266
@Test
@@ -267,34 +277,33 @@ void shouldHandleOtherExceptions() {
267277

268278
// Should still be initialized
269279
assertThat(initializer.isInitialized()).isTrue();
270-
verify(mockSession, never()).close();
280+
verify(mockClientSession, never()).close();
281+
// Verify that the session was not re-created
282+
verify(mockSessionSupplier, times(1)).apply(any(ContextView.class));
271283
}
272284

273285
@Test
274286
void shouldCloseGracefully() {
275-
// Initialize first
276287
StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
277288
.expectNext(MOCK_INIT_RESULT)
278289
.verifyComplete();
279290

280-
// Close gracefully
281291
StepVerifier.create(initializer.closeGracefully()).verifyComplete();
282292

283-
verify(mockSession).closeGracefully();
293+
verify(mockClientSession).closeGracefully();
284294
assertThat(initializer.isInitialized()).isFalse();
285295
}
286296

287297
@Test
288298
void shouldCloseImmediately() {
289-
// Initialize first
290299
StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
291300
.expectNext(MOCK_INIT_RESULT)
292301
.verifyComplete();
293302

294303
// Close immediately
295304
initializer.close();
296305

297-
verify(mockSession).close();
306+
verify(mockClientSession).close();
298307
assertThat(initializer.isInitialized()).isFalse();
299308
}
300309

@@ -305,8 +314,8 @@ void shouldHandleCloseWithoutInitialization() {
305314

306315
StepVerifier.create(initializer.closeGracefully()).verifyComplete();
307316

308-
verify(mockSession, never()).close();
309-
verify(mockSession, never()).closeGracefully();
317+
verify(mockClientSession, never()).close();
318+
verify(mockClientSession, never()).closeGracefully();
310319
}
311320

312321
@Test
@@ -316,7 +325,7 @@ void shouldSetProtocolVersionsForTesting() {
316325

317326
AtomicReference<McpSchema.InitializeRequest> capturedRequest = new AtomicReference<>();
318327

319-
when(mockSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> {
328+
when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> {
320329
capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1));
321330
return Mono.just(new McpSchema.InitializeResult("4.0.0", McpSchema.ServerCapabilities.builder().build(),
322331
new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions"));
@@ -339,7 +348,7 @@ void shouldPassContextToSessionSupplier() {
339348

340349
when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> {
341350
capturedContext.set(invocation.getArgument(0));
342-
return mockSession;
351+
return mockClientSession;
343352
});
344353

345354
StepVerifier
@@ -355,23 +364,23 @@ void shouldPassContextToSessionSupplier() {
355364
@Test
356365
void shouldProvideAccessToMcpSessionAndInitializeResult() {
357366
StepVerifier.create(initializer.withIntitialization("test", init -> {
358-
assertThat(init.mcpSession()).isEqualTo(mockSession);
367+
assertThat(init.mcpSession()).isEqualTo(mockClientSession);
359368
assertThat(init.initializeResult()).isEqualTo(MOCK_INIT_RESULT);
360369
return Mono.just("success");
361370
})).expectNext("success").verifyComplete();
362371
}
363372

364373
@Test
365374
void shouldHandleNotificationFailure() {
366-
when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()))
375+
when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()))
367376
.thenReturn(Mono.error(new RuntimeException("Notification failed")));
368377

369378
StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
370379
.expectError(RuntimeException.class)
371380
.verify();
372381

373-
verify(mockSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
374-
verify(mockSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null));
382+
verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
383+
verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null));
375384
}
376385

377386
@Test
@@ -397,7 +406,7 @@ void shouldReinitializeAfterTransportSessionException() {
397406

398407
// Verify two separate initializations occurred
399408
verify(mockSessionSupplier, times(2)).apply(any(ContextView.class));
400-
verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
409+
verify(mockClientSession, times(2)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
401410
}
402411

403412
}

0 commit comments

Comments
 (0)