88import java .util .List ;
99import java .util .Map ;
1010import java .util .concurrent .ConcurrentHashMap ;
11+ import java .util .concurrent .TimeUnit ;
1112import java .util .concurrent .atomic .AtomicReference ;
1213import java .util .function .Function ;
1314import java .util .stream .Collectors ;
1819import io .modelcontextprotocol .client .transport .WebFluxSseClientTransport ;
1920import io .modelcontextprotocol .server .McpServer ;
2021import io .modelcontextprotocol .server .McpServerFeatures ;
22+ import io .modelcontextprotocol .server .TestUtil ;
2123import io .modelcontextprotocol .server .transport .WebFluxSseServerTransportProvider ;
2224import io .modelcontextprotocol .spec .McpError ;
2325import io .modelcontextprotocol .spec .McpSchema ;
3537import org .junit .jupiter .api .BeforeEach ;
3638import org .junit .jupiter .params .ParameterizedTest ;
3739import org .junit .jupiter .params .provider .ValueSource ;
38- import reactor .core .publisher .Mono ;
3940import reactor .netty .DisposableServer ;
4041import reactor .netty .http .server .HttpServer ;
41- import reactor .test .StepVerifier ;
4242
4343import org .springframework .http .server .reactive .HttpHandler ;
4444import org .springframework .http .server .reactive .ReactorHttpHandlerAdapter ;
4747import org .springframework .web .reactive .function .server .RouterFunctions ;
4848
4949import static org .assertj .core .api .Assertions .assertThat ;
50+ import static org .assertj .core .api .Assertions .assertThatExceptionOfType ;
51+ import static org .assertj .core .api .Assertions .assertWith ;
5052import static org .awaitility .Awaitility .await ;
5153import static org .mockito .Mockito .mock ;
5254
53- public class WebFluxSseIntegrationTests {
55+ class WebFluxSseIntegrationTests {
5456
55- private static final int PORT = 8182 ;
57+ private static final int PORT = TestUtil . findAvailablePort () ;
5658
5759 private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse" ;
5860
@@ -106,12 +108,9 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
106108 var clientBuilder = clientBuilders .get (clientType );
107109
108110 McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
109- new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
110-
111- exchange .createMessage (mock (McpSchema .CreateMessageRequest .class )).block ();
112-
113- return Mono .just (mock (CallToolResult .class ));
114- });
111+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ),
112+ (exchange , request ) -> exchange .createMessage (mock (CreateMessageRequest .class ))
113+ .thenReturn (mock (CallToolResult .class )));
115114
116115 var server = McpServer .async (mcpServerTransportProvider ).serverInfo ("test-server" , "1.0.0" ).tools (tool ).build ();
117116
@@ -133,7 +132,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
133132
134133 @ ParameterizedTest (name = "{0} : {displayName} " )
135134 @ ValueSource (strings = { "httpclient" , "webflux" })
136- void testCreateMessageSuccess (String clientType ) throws InterruptedException {
135+ void testCreateMessageSuccess (String clientType ) {
137136
138137 var clientBuilder = clientBuilders .get (clientType );
139138
@@ -148,10 +147,12 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
148147 CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
149148 null );
150149
150+ AtomicReference <CreateMessageResult > samplingResult = new AtomicReference <>();
151+
151152 McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
152153 new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
153154
154- var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
155+ var createMessageRequest = McpSchema .CreateMessageRequest .builder ()
155156 .messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
156157 new McpSchema .TextContent ("Test message" ))))
157158 .modelPreferences (ModelPreferences .builder ()
@@ -162,19 +163,89 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
162163 .build ())
163164 .build ();
164165
165- StepVerifier .create (exchange .createMessage (craeteMessageRequest )).consumeNextWith (result -> {
166- assertThat (result ).isNotNull ();
167- assertThat (result .role ()).isEqualTo (Role .USER );
168- assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
169- assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
170- assertThat (result .model ()).isEqualTo ("MockModelName" );
171- assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
172- }).verifyComplete ();
166+ return exchange .createMessage (createMessageRequest )
167+ .doOnNext (samplingResult ::set )
168+ .thenReturn (callResponse );
169+ });
170+
171+ var mcpServer = McpServer .async (mcpServerTransportProvider )
172+ .serverInfo ("test-server" , "1.0.0" )
173+ .tools (tool )
174+ .build ();
175+
176+ try (var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
177+ .capabilities (ClientCapabilities .builder ().sampling ().build ())
178+ .sampling (samplingHandler )
179+ .build ()) {
180+
181+ InitializeResult initResult = mcpClient .initialize ();
182+ assertThat (initResult ).isNotNull ();
183+
184+ CallToolResult response = mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
185+
186+ assertThat (response ).isNotNull ();
187+ assertThat (response ).isEqualTo (callResponse );
188+
189+ assertWith (samplingResult .get (), result -> {
190+ assertThat (result ).isNotNull ();
191+ assertThat (result .role ()).isEqualTo (Role .USER );
192+ assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
193+ assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
194+ assertThat (result .model ()).isEqualTo ("MockModelName" );
195+ assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
196+ });
197+ }
198+ mcpServer .closeGracefully ().block ();
199+ }
200+
201+ @ ParameterizedTest (name = "{0} : {displayName} " )
202+ @ ValueSource (strings = { "httpclient" , "webflux" })
203+ void testCreateMessageWithRequestTimeoutSuccess (String clientType ) throws InterruptedException {
204+
205+ // Client
206+ var clientBuilder = clientBuilders .get (clientType );
207+
208+ Function <CreateMessageRequest , CreateMessageResult > samplingHandler = request -> {
209+ assertThat (request .messages ()).hasSize (1 );
210+ assertThat (request .messages ().get (0 ).content ()).isInstanceOf (McpSchema .TextContent .class );
211+ try {
212+ TimeUnit .SECONDS .sleep (2 );
213+ }
214+ catch (InterruptedException e ) {
215+ throw new RuntimeException (e );
216+ }
217+ return new CreateMessageResult (Role .USER , new McpSchema .TextContent ("Test message" ), "MockModelName" ,
218+ CreateMessageResult .StopReason .STOP_SEQUENCE );
219+ };
220+
221+ // Server
222+
223+ CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
224+ null );
225+
226+ AtomicReference <CreateMessageResult > samplingResult = new AtomicReference <>();
227+
228+ McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
229+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
230+
231+ var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
232+ .messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
233+ new McpSchema .TextContent ("Test message" ))))
234+ .modelPreferences (ModelPreferences .builder ()
235+ .hints (List .of ())
236+ .costPriority (1.0 )
237+ .speedPriority (1.0 )
238+ .intelligencePriority (1.0 )
239+ .build ())
240+ .build ();
173241
174- return Mono .just (callResponse );
242+ return exchange .createMessage (craeteMessageRequest )
243+ .doOnNext (samplingResult ::set )
244+ .thenReturn (callResponse );
175245 });
176246
177247 var mcpServer = McpServer .async (mcpServerTransportProvider )
248+ .requestTimeout (Duration .ofSeconds (4 ))
178249 .serverInfo ("test-server" , "1.0.0" )
179250 .tools (tool )
180251 .build ();
@@ -191,8 +262,77 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
191262
192263 assertThat (response ).isNotNull ();
193264 assertThat (response ).isEqualTo (callResponse );
265+
266+ assertWith (samplingResult .get (), result -> {
267+ assertThat (result ).isNotNull ();
268+ assertThat (result .role ()).isEqualTo (Role .USER );
269+ assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
270+ assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
271+ assertThat (result .model ()).isEqualTo ("MockModelName" );
272+ assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
273+ });
194274 }
195- mcpServer .close ();
275+
276+ mcpServer .closeGracefully ().block ();
277+ }
278+
279+ @ ParameterizedTest (name = "{0} : {displayName} " )
280+ @ ValueSource (strings = { "httpclient" , "webflux" })
281+ void testCreateMessageWithRequestTimeoutFail (String clientType ) throws InterruptedException {
282+
283+ // Client
284+ var clientBuilder = clientBuilders .get (clientType );
285+
286+ Function <CreateMessageRequest , CreateMessageResult > samplingHandler = request -> {
287+ assertThat (request .messages ()).hasSize (1 );
288+ assertThat (request .messages ().get (0 ).content ()).isInstanceOf (McpSchema .TextContent .class );
289+ try {
290+ TimeUnit .SECONDS .sleep (2 );
291+ }
292+ catch (InterruptedException e ) {
293+ throw new RuntimeException (e );
294+ }
295+ return new CreateMessageResult (Role .USER , new McpSchema .TextContent ("Test message" ), "MockModelName" ,
296+ CreateMessageResult .StopReason .STOP_SEQUENCE );
297+ };
298+
299+ // Server
300+
301+ CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
302+ null );
303+
304+ McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
305+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
306+
307+ var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
308+ .messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
309+ new McpSchema .TextContent ("Test message" ))))
310+ .build ();
311+
312+ return exchange .createMessage (craeteMessageRequest ).thenReturn (callResponse );
313+ });
314+
315+ var mcpServer = McpServer .async (mcpServerTransportProvider )
316+ .requestTimeout (Duration .ofSeconds (1 ))
317+ .serverInfo ("test-server" , "1.0.0" )
318+ .tools (tool )
319+ .build ();
320+
321+ try (var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
322+ .capabilities (ClientCapabilities .builder ().sampling ().build ())
323+ .sampling (samplingHandler )
324+ .build ()) {
325+
326+ InitializeResult initResult = mcpClient .initialize ();
327+ assertThat (initResult ).isNotNull ();
328+
329+ assertThatExceptionOfType (McpError .class ).isThrownBy (() -> {
330+ mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
331+ }).withMessageContaining ("within 1000ms" );
332+
333+ }
334+
335+ mcpServer .closeGracefully ().block ();
196336 }
197337
198338 // ---------------------------------------
@@ -262,9 +402,8 @@ void testRootsWithoutCapability(String clientType) {
262402 var mcpServer = McpServer .sync (mcpServerTransportProvider ).rootsChangeHandler ((exchange , rootsUpdate ) -> {
263403 }).tools (tool ).build ();
264404
265- try (
266- // Create client without roots capability
267- var mcpClient = clientBuilder .capabilities (ClientCapabilities .builder ().build ()).build ()) {
405+ // Create client without roots capability
406+ try (var mcpClient = clientBuilder .capabilities (ClientCapabilities .builder ().build ()).build ()) {
268407
269408 assertThat (mcpClient .initialize ()).isNotNull ();
270409
@@ -282,7 +421,7 @@ void testRootsWithoutCapability(String clientType) {
282421
283422 @ ParameterizedTest (name = "{0} : {displayName} " )
284423 @ ValueSource (strings = { "httpclient" , "webflux" })
285- void testRootsNotifciationWithEmptyRootsList (String clientType ) {
424+ void testRootsNotificationWithEmptyRootsList (String clientType ) {
286425 var clientBuilder = clientBuilders .get (clientType );
287426
288427 AtomicReference <List <Root >> rootsRef = new AtomicReference <>();
0 commit comments