99import java .util .concurrent .atomic .AtomicInteger ;
1010import java .util .function .Function ;
1111
12+ import com .fasterxml .jackson .core .type .TypeReference ;
1213import com .fasterxml .jackson .databind .ObjectMapper ;
14+ import io .modelcontextprotocol .spec .McpClientSession ;
15+ import io .modelcontextprotocol .spec .McpClientTransport ;
1316import io .modelcontextprotocol .spec .McpSchema ;
1417import io .modelcontextprotocol .spec .McpSchema .JSONRPCRequest ;
1518import org .junit .jupiter .api .AfterEach ;
3134import static org .assertj .core .api .Assertions .assertThatThrownBy ;
3235
3336/**
34- * Tests for the {@link WebFluxSseClientTransport} class.
35- *
3637 * @author Christian Tzolov
3738 */
3839@ Timeout (15 )
@@ -46,20 +47,22 @@ class WebFluxSseClientTransportTests {
4647 .withExposedPorts (3001 )
4748 .waitingFor (Wait .forHttp ("/" ).forStatusCode (404 ));
4849
49- private TestSseClientTransport transport ;
50+ private TestSseClientTransportProvider transportProvider ;
51+
52+ private McpClientTransport transport ;
5053
5154 private WebClient .Builder webClientBuilder ;
5255
5356 private ObjectMapper objectMapper ;
5457
5558 // Test class to access protected methods
56- static class TestSseClientTransport extends WebFluxSseClientTransport {
59+ static class TestSseClientTransportProvider extends WebFluxSseClientTransportProvider {
5760
5861 private final AtomicInteger inboundMessageCount = new AtomicInteger (0 );
5962
6063 private Sinks .Many <ServerSentEvent <String >> events = Sinks .many ().unicast ().onBackpressureBuffer ();
6164
62- public TestSseClientTransport (WebClient .Builder webClientBuilder , ObjectMapper objectMapper ) {
65+ public TestSseClientTransportProvider (WebClient .Builder webClientBuilder , ObjectMapper objectMapper ) {
6366 super (webClientBuilder , objectMapper );
6467 }
6568
@@ -69,7 +72,7 @@ protected Flux<ServerSentEvent<String>> eventStream() {
6972 }
7073
7174 public String getLastEndpoint () {
72- return messageEndpointSink .asMono ().block ();
75+ return (( WebFluxSseClientTransport ) getSession (). getTransport ()). messageEndpointSink .asMono ().block ();
7376 }
7477
7578 public int getInboundMessageCount () {
@@ -99,7 +102,10 @@ void setUp() {
99102 startContainer ();
100103 webClientBuilder = WebClient .builder ().baseUrl (host );
101104 objectMapper = new ObjectMapper ();
102- transport = new TestSseClientTransport (webClientBuilder , objectMapper );
105+ transportProvider = new TestSseClientTransportProvider (webClientBuilder , objectMapper );
106+ transportProvider .setSessionFactory (
107+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
108+ transport = transportProvider .getSession ().getTransport ();
103109 transport .connect (Function .identity ()).block ();
104110 }
105111
@@ -117,44 +123,62 @@ void cleanup() {
117123
118124 @ Test
119125 void testEndpointEventHandling () {
120- assertThat (transport .getLastEndpoint ()).startsWith ("/message?" );
126+ assertThat (transportProvider .getLastEndpoint ()).startsWith ("/message?" );
121127 }
122128
123129 @ Test
124130 void constructorValidation () {
125- assertThatThrownBy (() -> new WebFluxSseClientTransport (null )).isInstanceOf (IllegalArgumentException .class )
131+ assertThatThrownBy (() -> new WebFluxSseClientTransportProvider (null ))
132+ .isInstanceOf (IllegalArgumentException .class )
126133 .hasMessageContaining ("WebClient.Builder must not be null" );
127134
128- assertThatThrownBy (() -> new WebFluxSseClientTransport (webClientBuilder , null ))
135+ assertThatThrownBy (() -> new WebFluxSseClientTransportProvider (webClientBuilder , null ))
129136 .isInstanceOf (IllegalArgumentException .class )
130137 .hasMessageContaining ("ObjectMapper must not be null" );
131138 }
132139
133140 @ Test
134141 void testBuilderPattern () {
135142 // Test default builder
136- WebFluxSseClientTransport transport1 = WebFluxSseClientTransport .builder (webClientBuilder ).build ();
137- assertThatCode (() -> transport1 .closeGracefully ().block ()).doesNotThrowAnyException ();
143+ WebFluxSseClientTransportProvider transportProvider1 = WebFluxSseClientTransportProvider
144+ .builder (webClientBuilder )
145+ .build ();
146+ transportProvider1 .setSessionFactory (
147+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
148+ transportProvider1 .getSession ();
149+ assertThatCode (() -> transportProvider1 .closeGracefully ().block ()).doesNotThrowAnyException ();
138150
139151 // Test builder with custom ObjectMapper
140152 ObjectMapper customMapper = new ObjectMapper ();
141- WebFluxSseClientTransport transport2 = WebFluxSseClientTransport .builder (webClientBuilder )
153+ WebFluxSseClientTransportProvider transportProvider2 = WebFluxSseClientTransportProvider
154+ .builder (webClientBuilder )
142155 .objectMapper (customMapper )
143156 .build ();
144- assertThatCode (() -> transport2 .closeGracefully ().block ()).doesNotThrowAnyException ();
157+ transportProvider2 .setSessionFactory (
158+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
159+ transportProvider2 .getSession ();
160+ assertThatCode (() -> transportProvider2 .closeGracefully ().block ()).doesNotThrowAnyException ();
145161
146162 // Test builder with custom SSE endpoint
147- WebFluxSseClientTransport transport3 = WebFluxSseClientTransport .builder (webClientBuilder )
163+ WebFluxSseClientTransportProvider transportProvider3 = WebFluxSseClientTransportProvider
164+ .builder (webClientBuilder )
148165 .sseEndpoint ("/custom-sse" )
149166 .build ();
150- assertThatCode (() -> transport3 .closeGracefully ().block ()).doesNotThrowAnyException ();
167+ transportProvider3 .setSessionFactory (
168+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
169+ transportProvider3 .getSession ();
170+ assertThatCode (() -> transportProvider3 .closeGracefully ().block ()).doesNotThrowAnyException ();
151171
152172 // Test builder with all custom parameters
153- WebFluxSseClientTransport transport4 = WebFluxSseClientTransport .builder (webClientBuilder )
173+ WebFluxSseClientTransportProvider transportProvider4 = WebFluxSseClientTransportProvider
174+ .builder (webClientBuilder )
154175 .objectMapper (customMapper )
155176 .sseEndpoint ("/custom-sse" )
156177 .build ();
157- assertThatCode (() -> transport4 .closeGracefully ().block ()).doesNotThrowAnyException ();
178+ transportProvider4 .setSessionFactory (
179+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
180+ transportProvider4 .getSession ();
181+ assertThatCode (() -> transportProvider4 .closeGracefully ().block ()).doesNotThrowAnyException ();
158182 }
159183
160184 @ Test
@@ -164,7 +188,7 @@ void testMessageProcessing() {
164188 Map .of ("key" , "value" ));
165189
166190 // Simulate receiving the message
167- transport .simulateMessageEvent ("""
191+ transportProvider .simulateMessageEvent ("""
168192 {
169193 "jsonrpc": "2.0",
170194 "method": "test-method",
@@ -176,13 +200,13 @@ void testMessageProcessing() {
176200 // Subscribe to messages and verify
177201 StepVerifier .create (transport .sendMessage (testMessage )).verifyComplete ();
178202
179- assertThat (transport .getInboundMessageCount ()).isEqualTo (1 );
203+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (1 );
180204 }
181205
182206 @ Test
183207 void testResponseMessageProcessing () {
184208 // Simulate receiving a response message
185- transport .simulateMessageEvent ("""
209+ transportProvider .simulateMessageEvent ("""
186210 {
187211 "jsonrpc": "2.0",
188212 "id": "test-id",
@@ -197,13 +221,13 @@ void testResponseMessageProcessing() {
197221 // Verify message handling
198222 StepVerifier .create (transport .sendMessage (testMessage )).verifyComplete ();
199223
200- assertThat (transport .getInboundMessageCount ()).isEqualTo (1 );
224+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (1 );
201225 }
202226
203227 @ Test
204228 void testErrorMessageProcessing () {
205229 // Simulate receiving an error message
206- transport .simulateMessageEvent ("""
230+ transportProvider .simulateMessageEvent ("""
207231 {
208232 "jsonrpc": "2.0",
209233 "id": "test-id",
@@ -221,13 +245,13 @@ void testErrorMessageProcessing() {
221245 // Verify message handling
222246 StepVerifier .create (transport .sendMessage (testMessage )).verifyComplete ();
223247
224- assertThat (transport .getInboundMessageCount ()).isEqualTo (1 );
248+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (1 );
225249 }
226250
227251 @ Test
228252 void testNotificationMessageProcessing () {
229253 // Simulate receiving a notification message (no id)
230- transport .simulateMessageEvent ("""
254+ transportProvider .simulateMessageEvent ("""
231255 {
232256 "jsonrpc": "2.0",
233257 "method": "update",
@@ -236,7 +260,7 @@ void testNotificationMessageProcessing() {
236260 """ );
237261
238262 // Verify the notification was processed
239- assertThat (transport .getInboundMessageCount ()).isEqualTo (1 );
263+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (1 );
240264 }
241265
242266 @ Test
@@ -252,27 +276,31 @@ void testGracefulShutdown() {
252276 StepVerifier .create (transport .sendMessage (testMessage )).verifyComplete ();
253277
254278 // Message count should remain 0 after shutdown
255- assertThat (transport .getInboundMessageCount ()).isEqualTo (0 );
279+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (0 );
256280 }
257281
258282 @ Test
259283 void testRetryBehavior () {
260284 // Create a WebClient that simulates connection failures
261285 WebClient .Builder failingWebClientBuilder = WebClient .builder ().baseUrl ("http://non-existent-host" );
262286
263- WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport .builder (failingWebClientBuilder ).build ();
287+ WebFluxSseClientTransportProvider failingTransportProvider = WebFluxSseClientTransportProvider
288+ .builder (failingWebClientBuilder )
289+ .build ();
290+ failingTransportProvider .setSessionFactory (
291+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
264292
265293 // Verify that the transport attempts to reconnect
266294 StepVerifier .create (Mono .delay (Duration .ofSeconds (2 ))).expectNextCount (1 ).verifyComplete ();
267295
268296 // Clean up
269- failingTransport .closeGracefully ().block ();
297+ failingTransportProvider . getSession (). getTransport () .closeGracefully ().block ();
270298 }
271299
272300 @ Test
273301 void testMultipleMessageProcessing () {
274302 // Simulate receiving multiple messages in sequence
275- transport .simulateMessageEvent ("""
303+ transportProvider .simulateMessageEvent ("""
276304 {
277305 "jsonrpc": "2.0",
278306 "method": "method1",
@@ -281,7 +309,7 @@ void testMultipleMessageProcessing() {
281309 }
282310 """ );
283311
284- transport .simulateMessageEvent ("""
312+ transportProvider .simulateMessageEvent ("""
285313 {
286314 "jsonrpc": "2.0",
287315 "method": "method2",
@@ -301,13 +329,13 @@ void testMultipleMessageProcessing() {
301329 StepVerifier .create (transport .sendMessage (message1 ).then (transport .sendMessage (message2 ))).verifyComplete ();
302330
303331 // Verify message count
304- assertThat (transport .getInboundMessageCount ()).isEqualTo (2 );
332+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (2 );
305333 }
306334
307335 @ Test
308336 void testMessageOrderPreservation () {
309337 // Simulate receiving messages in a specific order
310- transport .simulateMessageEvent ("""
338+ transportProvider .simulateMessageEvent ("""
311339 {
312340 "jsonrpc": "2.0",
313341 "method": "first",
@@ -316,7 +344,7 @@ void testMessageOrderPreservation() {
316344 }
317345 """ );
318346
319- transport .simulateMessageEvent ("""
347+ transportProvider .simulateMessageEvent ("""
320348 {
321349 "jsonrpc": "2.0",
322350 "method": "second",
@@ -325,7 +353,7 @@ void testMessageOrderPreservation() {
325353 }
326354 """ );
327355
328- transport .simulateMessageEvent ("""
356+ transportProvider .simulateMessageEvent ("""
329357 {
330358 "jsonrpc": "2.0",
331359 "method": "third",
@@ -335,7 +363,7 @@ void testMessageOrderPreservation() {
335363 """ );
336364
337365 // Verify message count and order
338- assertThat (transport .getInboundMessageCount ()).isEqualTo (3 );
366+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (3 );
339367 }
340368
341369}
0 commit comments