44
55package io .modelcontextprotocol .server ;
66
7- import static net .javacrumbs .jsonunit .assertj .JsonAssertions .assertThatJson ;
8- import static net .javacrumbs .jsonunit .assertj .JsonAssertions .json ;
9- import static org .assertj .core .api .Assertions .assertThat ;
10- import static org .awaitility .Awaitility .await ;
11-
12- import java .time .Duration ;
13- import java .util .List ;
14- import java .util .Map ;
15- import java .util .concurrent .ConcurrentHashMap ;
16- import java .util .concurrent .atomic .AtomicReference ;
17- import java .util .function .BiFunction ;
18-
19- import org .apache .catalina .LifecycleException ;
20- import org .apache .catalina .LifecycleState ;
21- import org .apache .catalina .startup .Tomcat ;
22- import org .junit .jupiter .api .AfterEach ;
23- import org .junit .jupiter .api .BeforeEach ;
24- import org .junit .jupiter .params .ParameterizedTest ;
25- import org .junit .jupiter .params .provider .ValueSource ;
26- import org .springframework .web .client .RestClient ;
27-
287import com .fasterxml .jackson .databind .ObjectMapper ;
29-
308import io .modelcontextprotocol .client .McpClient ;
319import io .modelcontextprotocol .client .transport .HttpClientStreamableHttpTransport ;
3210import io .modelcontextprotocol .server .transport .HttpServletStatelessServerTransport ;
4220import io .modelcontextprotocol .spec .McpSchema .ServerCapabilities ;
4321import io .modelcontextprotocol .spec .McpSchema .Tool ;
4422import net .javacrumbs .jsonunit .core .Option ;
23+ import org .apache .catalina .LifecycleException ;
24+ import org .apache .catalina .LifecycleState ;
25+ import org .apache .catalina .startup .Tomcat ;
26+ import org .junit .jupiter .api .AfterEach ;
27+ import org .junit .jupiter .api .BeforeEach ;
28+ import org .junit .jupiter .api .Test ;
29+ import org .junit .jupiter .params .ParameterizedTest ;
30+ import org .junit .jupiter .params .provider .ValueSource ;
31+ import org .springframework .web .client .RestClient ;
32+
33+ import java .net .URI ;
34+ import java .net .http .HttpClient ;
35+ import java .net .http .HttpRequest ;
36+ import java .net .http .HttpResponse ;
37+ import java .time .Duration ;
38+ import java .util .Iterator ;
39+ import java .util .List ;
40+ import java .util .Map ;
41+ import java .util .UUID ;
42+ import java .util .concurrent .ConcurrentHashMap ;
43+ import java .util .concurrent .atomic .AtomicReference ;
44+ import java .util .function .BiFunction ;
45+ import java .util .stream .Stream ;
46+
47+ import static io .modelcontextprotocol .server .transport .HttpServletStatelessServerTransport .APPLICATION_JSON ;
48+ import static io .modelcontextprotocol .server .transport .HttpServletStatelessServerTransport .TEXT_EVENT_STREAM ;
49+ import static net .javacrumbs .jsonunit .assertj .JsonAssertions .assertThatJson ;
50+ import static net .javacrumbs .jsonunit .assertj .JsonAssertions .json ;
51+ import static org .assertj .core .api .Assertions .assertThat ;
52+ import static org .awaitility .Awaitility .await ;
4553
4654class HttpServletStatelessIntegrationTests {
4755
@@ -55,10 +63,13 @@ class HttpServletStatelessIntegrationTests {
5563
5664 private Tomcat tomcat ;
5765
66+ private ObjectMapper objectMapper ;
67+
5868 @ BeforeEach
5969 public void before () {
70+ objectMapper = new ObjectMapper ();
6071 this .mcpStatelessServerTransport = HttpServletStatelessServerTransport .builder ()
61- .objectMapper (new ObjectMapper () )
72+ .objectMapper (objectMapper )
6273 .messageEndpoint (CUSTOM_MESSAGE_ENDPOINT )
6374 .build ();
6475
@@ -213,6 +224,87 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
213224 mcpServer .close ();
214225 }
215226
227+ @ Test
228+ void testNotifications () throws Exception {
229+
230+ Tool tool = Tool .builder ().name ("test" ).build ();
231+
232+ final int PROGRESS_QTY = 1000 ;
233+ final String progressMessage = "We're working on it..." ;
234+
235+ var progressToken = UUID .randomUUID ().toString ();
236+ var callResponse = new CallToolResult (List .of (), null , null , Map .of ("progressToken" , progressToken ));
237+ McpStatelessServerFeatures .SyncToolSpecification toolSpecification = new McpStatelessServerFeatures .SyncToolSpecification (
238+ tool , (transportContext , request ) -> {
239+ // Simulate sending progress notifications - send enough to ensure
240+ // that cunked transfer encoding is used
241+ for (int i = 0 ; i < PROGRESS_QTY ; i ++) {
242+ transportContext .sendNotification (McpSchema .METHOD_NOTIFICATION_PROGRESS ,
243+ new McpSchema .ProgressNotification (progressToken , i , 5.0 , progressMessage ));
244+ }
245+ return callResponse ;
246+ });
247+
248+ var mcpServer = McpServer .sync (mcpStatelessServerTransport )
249+ .capabilities (ServerCapabilities .builder ().tools (true ).build ())
250+ .tools (toolSpecification )
251+ .build ();
252+
253+ HttpClient client = HttpClient .newBuilder ().version (HttpClient .Version .HTTP_1_1 ).build ();
254+ HttpRequest request = HttpRequest .newBuilder ()
255+ .method ("POST" ,
256+ HttpRequest .BodyPublishers .ofString (
257+ objectMapper .writeValueAsString (new McpSchema .JSONRPCRequest (McpSchema .JSONRPC_VERSION ,
258+ "tools/call" , "1" , new McpSchema .CallToolRequest ("test" , Map .of ())))))
259+ .header ("Content-Type" , APPLICATION_JSON )
260+ .header ("Accept" , APPLICATION_JSON + "," + TEXT_EVENT_STREAM )
261+ .uri (URI .create ("http://localhost:" + PORT + CUSTOM_MESSAGE_ENDPOINT ))
262+ .build ();
263+
264+ HttpResponse <Stream <String >> response = client .send (request , HttpResponse .BodyHandlers .ofLines ());
265+ assertThat (response .headers ().firstValue ("Transfer-Encoding" )).contains ("chunked" );
266+
267+ List <String > responseBody = response .body ().toList ();
268+
269+ assertThat (responseBody ).hasSize ((PROGRESS_QTY + 1 ) * 4 ); // 4 lines per progress
270+ // notification + 4
271+ // for
272+ // the call result
273+
274+ Iterator <String > iterator = responseBody .iterator ();
275+ for (int i = 0 ; i < PROGRESS_QTY ; ++i ) {
276+ String eventLine = iterator .next ();
277+ String idLine = iterator .next ();
278+ String dataLine = iterator .next ();
279+ String blankLine = iterator .next ();
280+
281+ McpSchema .ProgressNotification expectedNotification = new McpSchema .ProgressNotification (progressToken , i ,
282+ 5.0 , progressMessage );
283+ McpSchema .JSONRPCNotification expectedJsonRpcNotification = new McpSchema .JSONRPCNotification (
284+ McpSchema .JSONRPC_VERSION , McpSchema .METHOD_NOTIFICATION_PROGRESS , expectedNotification );
285+
286+ assertThat (eventLine ).isEqualTo ("event: notification" );
287+ assertThat (idLine ).isEqualTo ("id: " + i );
288+ assertThat (dataLine ).isEqualTo ("data: " + objectMapper .writeValueAsString (expectedJsonRpcNotification ));
289+ assertThat (blankLine ).isBlank ();
290+ }
291+
292+ String eventLine = iterator .next ();
293+ String idLine = iterator .next ();
294+ String dataLine = iterator .next ();
295+ String blankLine = iterator .next ();
296+
297+ assertThat (eventLine ).isEqualTo ("event: result" );
298+ assertThat (idLine ).isEqualTo ("id: " + PROGRESS_QTY );
299+ assertThat (dataLine ).isEqualTo ("data: " + objectMapper
300+ .writeValueAsString (new McpSchema .JSONRPCResponse (McpSchema .JSONRPC_VERSION , "1" , callResponse , null )));
301+ assertThat (blankLine ).isBlank ();
302+
303+ assertThat (iterator .hasNext ()).isFalse ();
304+
305+ mcpServer .close ();
306+ }
307+
216308 // ---------------------------------------
217309 // Tool Structured Output Schema Tests
218310 // ---------------------------------------
0 commit comments