Skip to content

Commit be61de3

Browse files
committed
Support specifying new handlers in McpServer spec
Signed-off-by: Dariusz Jędrzejczyk <dariusz.jedrzejczyk@broadcom.com>
1 parent 09d5b3a commit be61de3

File tree

12 files changed

+788
-175
lines changed

12 files changed

+788
-175
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,9 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
293293
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
294294
return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> {
295295
logger.error("Error processing message: {}", error.getMessage());
296+
// TODO: instead of signalling the error, just respond with 200 OK
297+
// - the error is signalled on the SSE connection
298+
// return ServerResponse.ok().build();
296299
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
297300
.bodyValue(new McpError(error.getMessage()));
298301
});

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 152 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io.modelcontextprotocol.server.McpServer;
1818
import io.modelcontextprotocol.server.McpServerFeatures;
1919
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport;
20+
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
2021
import io.modelcontextprotocol.spec.McpError;
2122
import io.modelcontextprotocol.spec.McpSchema;
2223
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
@@ -55,16 +56,16 @@ public class WebFluxSseIntegrationTests {
5556

5657
private DisposableServer httpServer;
5758

58-
private WebFluxSseServerTransport mcpServerTransport;
59+
private WebFluxSseServerTransportProvider mcpServerTransportProvider;
5960

6061
ConcurrentHashMap<String, McpClient.SyncSpec> clientBulders = new ConcurrentHashMap<>();
6162

6263
@BeforeEach
6364
public void before() {
6465

65-
this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT);
66+
this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT);
6667

67-
HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction());
68+
HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction());
6869
ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler);
6970
this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow();
7071

@@ -84,89 +85,109 @@ public void after() {
8485
// ---------------------------------------
8586
// Sampling Tests
8687
// ---------------------------------------
87-
@Test
88-
void testCreateMessageWithoutInitialization() {
89-
var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build();
90-
91-
var messages = List
92-
.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")));
93-
var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
94-
95-
var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
96-
McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of());
97-
98-
StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> {
99-
assertThat(error).isInstanceOf(McpError.class)
100-
.hasMessage("Client must be initialized. Call the initialize method first!");
101-
});
102-
}
103-
104-
@ParameterizedTest(name = "{0} : {displayName} ")
105-
@ValueSource(strings = { "httpclient", "webflux" })
106-
void testCreateMessageWithoutSamplingCapabilities(String clientType) {
107-
108-
var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build();
109-
110-
var clientBuilder = clientBulders.get(clientType);
111-
112-
var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build();
113-
114-
InitializeResult initResult = client.initialize();
115-
assertThat(initResult).isNotNull();
116-
117-
var messages = List
118-
.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")));
119-
var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
120-
121-
var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
122-
McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of());
123-
124-
StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> {
125-
assertThat(error).isInstanceOf(McpError.class)
126-
.hasMessage("Client must be configured with sampling capabilities");
127-
});
128-
}
129-
130-
@ParameterizedTest(name = "{0} : {displayName} ")
131-
@ValueSource(strings = { "httpclient", "webflux" })
132-
void testCreateMessageSuccess(String clientType) throws InterruptedException {
133-
134-
var clientBuilder = clientBulders.get(clientType);
135-
136-
var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build();
137-
138-
Function<CreateMessageRequest, CreateMessageResult> samplingHandler = request -> {
139-
assertThat(request.messages()).hasSize(1);
140-
assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
141-
142-
return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
143-
CreateMessageResult.StopReason.STOP_SEQUENCE);
144-
};
145-
146-
var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
147-
.capabilities(ClientCapabilities.builder().sampling().build())
148-
.sampling(samplingHandler)
149-
.build();
150-
151-
InitializeResult initResult = client.initialize();
152-
assertThat(initResult).isNotNull();
153-
154-
var messages = List
155-
.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")));
156-
var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
157-
158-
var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
159-
McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of());
160-
161-
StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> {
162-
assertThat(result).isNotNull();
163-
assertThat(result.role()).isEqualTo(Role.USER);
164-
assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
165-
assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
166-
assertThat(result.model()).isEqualTo("MockModelName");
167-
assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
168-
}).verifyComplete();
169-
}
88+
// TODO implement within a tool execution
89+
// @Test
90+
// void testCreateMessageWithoutInitialization() {
91+
// var mcpAsyncServer =
92+
// McpServer.async(mcpServerTransportProvider).serverInfo("test-server",
93+
// "1.0.0").build();
94+
//
95+
// var messages = List
96+
// .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new
97+
// McpSchema.TextContent("Test message")));
98+
// var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
99+
//
100+
// var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
101+
// McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
102+
// Map.of());
103+
//
104+
// StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error
105+
// -> {
106+
// assertThat(error).isInstanceOf(McpError.class)
107+
// .hasMessage("Client must be initialized. Call the initialize method first!");
108+
// });
109+
// }
110+
//
111+
// @ParameterizedTest(name = "{0} : {displayName} ")
112+
// @ValueSource(strings = { "httpclient", "webflux" })
113+
// void testCreateMessageWithoutSamplingCapabilities(String clientType) {
114+
//
115+
// var mcpAsyncServer =
116+
// McpServer.async(mcpServerTransportProvider).serverInfo("test-server",
117+
// "1.0.0").build();
118+
//
119+
// var clientBuilder = clientBulders.get(clientType);
120+
//
121+
// var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client",
122+
// "0.0.0")).build();
123+
//
124+
// InitializeResult initResult = client.initialize();
125+
// assertThat(initResult).isNotNull();
126+
//
127+
// var messages = List
128+
// .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new
129+
// McpSchema.TextContent("Test message")));
130+
// var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
131+
//
132+
// var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
133+
// McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
134+
// Map.of());
135+
//
136+
// StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error
137+
// -> {
138+
// assertThat(error).isInstanceOf(McpError.class)
139+
// .hasMessage("Client must be configured with sampling capabilities");
140+
// });
141+
// }
142+
//
143+
// @ParameterizedTest(name = "{0} : {displayName} ")
144+
// @ValueSource(strings = { "httpclient", "webflux" })
145+
// void testCreateMessageSuccess(String clientType) throws InterruptedException {
146+
//
147+
// var clientBuilder = clientBulders.get(clientType);
148+
//
149+
// var mcpAsyncServer =
150+
// McpServer.async(mcpServerTransportProvider).serverInfo("test-server",
151+
// "1.0.0").build();
152+
//
153+
// Function<CreateMessageRequest, CreateMessageResult> samplingHandler = request -> {
154+
// assertThat(request.messages()).hasSize(1);
155+
// assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
156+
//
157+
// return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test
158+
// message"), "MockModelName",
159+
// CreateMessageResult.StopReason.STOP_SEQUENCE);
160+
// };
161+
//
162+
// var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client",
163+
// "0.0.0"))
164+
// .capabilities(ClientCapabilities.builder().sampling().build())
165+
// .sampling(samplingHandler)
166+
// .build();
167+
//
168+
// InitializeResult initResult = client.initialize();
169+
// assertThat(initResult).isNotNull();
170+
//
171+
// var messages = List
172+
// .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new
173+
// McpSchema.TextContent("Test message")));
174+
// var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
175+
//
176+
// var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
177+
// McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
178+
// Map.of());
179+
//
180+
// StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result
181+
// -> {
182+
// assertThat(result).isNotNull();
183+
// assertThat(result.role()).isEqualTo(Role.USER);
184+
// assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
185+
// assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test
186+
// message");
187+
// assertThat(result.model()).isEqualTo("MockModelName");
188+
// assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
189+
// }).verifyComplete();
190+
// }
170191

171192
// ---------------------------------------
172193
// Roots Tests
@@ -179,8 +200,8 @@ void testRootsSuccess(String clientType) {
179200
List<Root> roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2"));
180201

181202
AtomicReference<List<Root>> rootsRef = new AtomicReference<>();
182-
var mcpServer = McpServer.sync(mcpServerTransport)
183-
.rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate))
203+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
204+
.rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
184205
.build();
185206

186207
var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
@@ -192,8 +213,6 @@ void testRootsSuccess(String clientType) {
192213

193214
assertThat(rootsRef.get()).isNull();
194215

195-
assertThat(mcpServer.listRoots().roots()).containsAll(roots);
196-
197216
mcpClient.rootsListChangedNotification();
198217

199218
await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
@@ -219,39 +238,48 @@ void testRootsSuccess(String clientType) {
219238
mcpServer.close();
220239
}
221240

222-
@ParameterizedTest(name = "{0} : {displayName} ")
223-
@ValueSource(strings = { "httpclient", "webflux" })
224-
void testRootsWithoutCapability(String clientType) {
225-
var clientBuilder = clientBulders.get(clientType);
226-
227-
var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> {
228-
}).build();
229-
230-
// Create client without roots capability
231-
var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // No
232-
// roots
233-
// capability
234-
.build();
235-
236-
InitializeResult initResult = mcpClient.initialize();
237-
assertThat(initResult).isNotNull();
238-
239-
// Attempt to list roots should fail
240-
assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class)
241-
.hasMessage("Roots not supported");
242-
243-
mcpClient.close();
244-
mcpServer.close();
245-
}
241+
// @ParameterizedTest(name = "{0} : {displayName} ")
242+
// @ValueSource(strings = { "httpclient", "webflux" })
243+
// void testRootsWithoutCapability(String clientType) {
244+
// var clientBuilder = clientBulders.get(clientType);
245+
// AtomicReference<Exception> errorRef = new AtomicReference<>();
246+
//
247+
// var mcpServer =
248+
// McpServer.sync(mcpServerTransportProvider)
249+
// // TODO: implement tool handling and try to list roots
250+
// .tool(tool, (exchange, args) -> {
251+
// try {
252+
// exchange.listRoots();
253+
// } catch (Exception e) {
254+
// errorRef.set(e);
255+
// }
256+
// }).build();
257+
//
258+
// // Create client without roots capability
259+
// var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) //
260+
// No
261+
// // roots
262+
// // capability
263+
// .build();
264+
//
265+
// InitializeResult initResult = mcpClient.initialize();
266+
// assertThat(initResult).isNotNull();
267+
//
268+
// assertThat(errorRef.get()).isInstanceOf(McpError.class).hasMessage("Roots not
269+
// supported");
270+
//
271+
// mcpClient.close();
272+
// mcpServer.close();
273+
// }
246274

247275
@ParameterizedTest(name = "{0} : {displayName} ")
248276
@ValueSource(strings = { "httpclient", "webflux" })
249277
void testRootsWithEmptyRootsList(String clientType) {
250278
var clientBuilder = clientBulders.get(clientType);
251279

252280
AtomicReference<List<Root>> rootsRef = new AtomicReference<>();
253-
var mcpServer = McpServer.sync(mcpServerTransport)
254-
.rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate))
281+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
282+
.rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
255283
.build();
256284

257285
var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
@@ -273,17 +301,17 @@ void testRootsWithEmptyRootsList(String clientType) {
273301

274302
@ParameterizedTest(name = "{0} : {displayName} ")
275303
@ValueSource(strings = { "httpclient", "webflux" })
276-
void testRootsWithMultipleConsumers(String clientType) {
304+
void testRootsWithMultipleHandlers(String clientType) {
277305
var clientBuilder = clientBulders.get(clientType);
278306

279307
List<Root> roots = List.of(new Root("uri1://", "root1"));
280308

281309
AtomicReference<List<Root>> rootsRef1 = new AtomicReference<>();
282310
AtomicReference<List<Root>> rootsRef2 = new AtomicReference<>();
283311

284-
var mcpServer = McpServer.sync(mcpServerTransport)
285-
.rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate))
286-
.rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate))
312+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
313+
.rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate))
314+
.rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate))
287315
.build();
288316

289317
var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
@@ -313,8 +341,8 @@ void testRootsServerCloseWithActiveSubscription(String clientType) {
313341
List<Root> roots = List.of(new Root("uri1://", "root1"));
314342

315343
AtomicReference<List<Root>> rootsRef = new AtomicReference<>();
316-
var mcpServer = McpServer.sync(mcpServerTransport)
317-
.rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate))
344+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
345+
.rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
318346
.build();
319347

320348
var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
@@ -368,7 +396,7 @@ void testToolCallSuccess(String clientType) {
368396
return callResponse;
369397
});
370398

371-
var mcpServer = McpServer.sync(mcpServerTransport)
399+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
372400
.capabilities(ServerCapabilities.builder().tools(true).build())
373401
.tools(tool1)
374402
.build();
@@ -408,7 +436,7 @@ void testToolListChangeHandlingSuccess(String clientType) {
408436
return callResponse;
409437
});
410438

411-
var mcpServer = McpServer.sync(mcpServerTransport)
439+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
412440
.capabilities(ServerCapabilities.builder().tools(true).build())
413441
.tools(tool1)
414442
.build();

0 commit comments

Comments
 (0)