Skip to content

Commit c9ca956

Browse files
committed
fix: Make context path in transportprovider work.
1 parent 1a829d0 commit c9ca956

File tree

6 files changed

+188
-29
lines changed

6 files changed

+188
-29
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
8282
*/
8383
public static final String DEFAULT_SSE_ENDPOINT = "/sse";
8484

85+
public static final String DEFAULT_CONTEXT_PATH = "";
86+
8587
public static final String DEFAULT_BASE_URL = "";
8688

8789
private final ObjectMapper objectMapper;
@@ -92,6 +94,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
9294
*/
9395
private final String baseUrl;
9496

97+
private final String contextPath;
98+
9599
private final String messageEndpoint;
96100

97101
private final String sseEndpoint;
@@ -134,7 +138,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
134138
* @throws IllegalArgumentException if either parameter is null
135139
*/
136140
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
137-
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
141+
this(objectMapper, DEFAULT_CONTEXT_PATH, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
138142
}
139143

140144
/**
@@ -147,24 +151,29 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
147151
* setup. Must not be null.
148152
* @throws IllegalArgumentException if either parameter is null
149153
*/
150-
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
151-
String sseEndpoint) {
154+
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String contextPath, String baseUrl,
155+
String messageEndpoint, String sseEndpoint) {
152156
Assert.notNull(objectMapper, "ObjectMapper must not be null");
157+
Assert.notNull(contextPath, "Context path must not be null");
153158
Assert.notNull(baseUrl, "Message base path must not be null");
154159
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
155160
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
156161

157162
if (baseUrl.endsWith("/")) {
158163
baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
159164
}
165+
if (contextPath.endsWith("/")) {
166+
contextPath = contextPath.substring(0, contextPath.length() - 1);
167+
}
160168

161169
this.objectMapper = objectMapper;
170+
this.contextPath = contextPath;
162171
this.baseUrl = baseUrl;
163172
this.messageEndpoint = messageEndpoint;
164173
this.sseEndpoint = sseEndpoint;
165174
this.routerFunction = RouterFunctions.route()
166-
.GET(this.sseEndpoint, this::handleSseConnection)
167-
.POST(this.messageEndpoint, this::handleMessage)
175+
.GET(this.baseUrl + this.sseEndpoint, this::handleSseConnection)
176+
.POST(this.baseUrl + this.messageEndpoint, this::handleMessage)
168177
.build();
169178
}
170179

@@ -275,7 +284,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
275284
logger.debug("Sending initial endpoint event to session: {}", sessionId);
276285
sink.next(ServerSentEvent.builder()
277286
.event(ENDPOINT_EVENT_TYPE)
278-
.data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId)
287+
.data(this.contextPath + this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId)
279288
.build());
280289
sink.onCancel(() -> {
281290
logger.debug("Session {} cancelled", sessionId);
@@ -395,6 +404,8 @@ public static class Builder {
395404

396405
private ObjectMapper objectMapper;
397406

407+
private String contextPath = DEFAULT_CONTEXT_PATH;
408+
398409
private String baseUrl = DEFAULT_BASE_URL;
399410

400411
private String messageEndpoint;
@@ -461,7 +472,8 @@ public WebFluxSseServerTransportProvider build() {
461472
Assert.notNull(objectMapper, "ObjectMapper must be set");
462473
Assert.notNull(messageEndpoint, "Message endpoint must be set");
463474

464-
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
475+
return new WebFluxSseServerTransportProvider(objectMapper, contextPath, baseUrl, messageEndpoint,
476+
sseEndpoint);
465477
}
466478

467479
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package io.modelcontextprotocol.server;
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper;
4+
import io.modelcontextprotocol.client.McpClient;
5+
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
6+
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
7+
import io.modelcontextprotocol.spec.McpSchema;
8+
import org.junit.jupiter.api.AfterEach;
9+
import org.junit.jupiter.params.ParameterizedTest;
10+
import org.junit.jupiter.params.provider.Arguments;
11+
import org.junit.jupiter.params.provider.MethodSource;
12+
import org.springframework.http.server.reactive.HttpHandler;
13+
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
14+
import org.springframework.web.reactive.function.client.WebClient;
15+
import org.springframework.web.reactive.function.server.RequestPredicates;
16+
import org.springframework.web.reactive.function.server.RouterFunction;
17+
import org.springframework.web.reactive.function.server.RouterFunctions;
18+
import org.springframework.web.reactive.function.server.ServerResponse;
19+
import reactor.core.publisher.Mono;
20+
import reactor.netty.DisposableServer;
21+
import reactor.netty.http.server.HttpServer;
22+
23+
import java.util.List;
24+
import java.util.Map;
25+
import java.util.function.Supplier;
26+
import java.util.stream.Stream;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
import static org.springframework.web.reactive.function.server.RequestPredicates.path;
30+
import static org.springframework.web.reactive.function.server.RouterFunctions.nest;
31+
import static org.springframework.web.reactive.function.server.RouterFunctions.route;
32+
33+
public class WebFluxSseCustomPathIntegrationTests {
34+
35+
private static final int PORT = TestUtil.findAvailablePort();
36+
37+
private DisposableServer httpServer;
38+
39+
private WebFluxSseServerTransportProvider mcpServerTransportProvider;
40+
41+
String emptyJsonSchema = """
42+
{
43+
"$schema": "http://json-schema.org/draft-07/schema#",
44+
"type": "object",
45+
"properties": {}
46+
}
47+
""";
48+
49+
@ParameterizedTest(
50+
name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" contextPath = \"{3}\" : {displayName} ")
51+
@MethodSource("provideCustomEndpoints")
52+
public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint,
53+
String contextPath) {
54+
55+
this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), contextPath,
56+
baseUrl, messageEndpoint, sseEndpoint);
57+
58+
RouterFunction<?> router = this.mcpServerTransportProvider.getRouterFunction();
59+
RouterFunction<ServerResponse> nestedRouter = (RouterFunction<ServerResponse>) nest(path(contextPath), router);
60+
HttpHandler httpHandler = RouterFunctions.toHttpHandler(nestedRouter);
61+
ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler);
62+
63+
this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow();
64+
65+
var c = contextPath;
66+
var b = baseUrl;
67+
var s = sseEndpoint;
68+
if (baseUrl.endsWith("/")) {
69+
b = b.substring(0, b.length() - 1);
70+
}
71+
if (contextPath.endsWith("/")) {
72+
c = c.substring(0, c.length() - 1);
73+
}
74+
75+
var clientBuilder = McpClient
76+
.sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
77+
.sseEndpoint(c + b + s)
78+
.build());
79+
80+
McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult(
81+
List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
82+
83+
McpServerFeatures.AsyncToolSpecification tool1 = new McpServerFeatures.AsyncToolSpecification(
84+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema),
85+
(exchange, request) -> Mono.just(callResponse));
86+
87+
var server = McpServer.async(mcpServerTransportProvider)
88+
.serverInfo("test-server", "1.0.0")
89+
.capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
90+
.tools(tool1)
91+
.build();
92+
93+
try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) {
94+
assertThat(client.initialize()).isNotNull();
95+
assertThat(client.listTools().tools()).contains(tool1.tool());
96+
97+
McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
98+
assertThat(response).isNotNull().isEqualTo(callResponse);
99+
}
100+
101+
server.close();
102+
103+
}
104+
105+
private static Stream<Arguments> provideCustomEndpoints() {
106+
String[] baseUrls = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" };
107+
String[] messageEndpoints = { "/message", "/another/sse", "/" };
108+
String[] sseEndpoints = { "/sse", "/another/sse", "/" };
109+
String[] contextPaths = { "", "/mcp", "/root/mcp", "/", "/mcp/", "/root/mcp/" };
110+
111+
return Stream.of(baseUrls)
112+
.flatMap(baseUrl -> Stream.of(messageEndpoints)
113+
.flatMap(messageEndpoint -> Stream.of(sseEndpoints)
114+
.flatMap(sseEndpoint -> Stream.of(contextPaths)
115+
.map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath)))));
116+
}
117+
118+
@AfterEach
119+
public void after() {
120+
if (httpServer != null) {
121+
httpServer.disposeNow();
122+
}
123+
}
124+
125+
}

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
9494

9595
private final String baseUrl;
9696

97+
private final String contextPath;
98+
9799
private final RouterFunction<ServerResponse> routerFunction;
98100

99101
private McpServerSession.Factory sessionFactory;
@@ -133,13 +135,14 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
133135
* @throws IllegalArgumentException if any parameter is null
134136
*/
135137
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
136-
this(objectMapper, "", messageEndpoint, sseEndpoint);
138+
this(objectMapper, "", "", messageEndpoint, sseEndpoint);
137139
}
138140

139141
/**
140142
* Constructs a new WebMvcSseServerTransportProvider instance.
141143
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
142144
* of messages.
145+
* @param contextPath The context path under which the server runs.
143146
* @param baseUrl The base URL for the message endpoint, used to construct the full
144147
* endpoint URL for clients.
145148
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
@@ -148,9 +151,10 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
148151
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
149152
* @throws IllegalArgumentException if any parameter is null
150153
*/
151-
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
152-
String sseEndpoint) {
154+
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String contextPath, String baseUrl,
155+
String messageEndpoint, String sseEndpoint) {
153156
Assert.notNull(objectMapper, "ObjectMapper must not be null");
157+
Assert.notNull(contextPath, "Context path must not be null");
154158
Assert.notNull(baseUrl, "Message base URL must not be null");
155159
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
156160
Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
@@ -161,13 +165,18 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
161165
baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
162166
}
163167

168+
if (contextPath.endsWith("/")) {
169+
contextPath = contextPath.substring(0, contextPath.length() - 1);
170+
}
171+
164172
this.objectMapper = objectMapper;
165173
this.baseUrl = baseUrl;
174+
this.contextPath = contextPath;
166175
this.messageEndpoint = messageEndpoint;
167176
this.sseEndpoint = sseEndpoint;
168177
this.routerFunction = RouterFunctions.route()
169-
.GET(this.sseEndpoint, this::handleSseConnection)
170-
.POST(this.messageEndpoint, this::handleMessage)
178+
.GET(this.baseUrl + this.sseEndpoint, this::handleSseConnection)
179+
.POST(this.baseUrl + this.messageEndpoint, this::handleMessage)
171180
.build();
172181
}
173182

@@ -276,7 +285,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
276285
try {
277286
sseBuilder.id(sessionId)
278287
.event(ENDPOINT_EVENT_TYPE)
279-
.data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
288+
.data(this.contextPath + this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
280289
}
281290
catch (Exception e) {
282291
logger.error("Failed to send initial endpoint event: {}", e.getMessage());

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ public void before() {
4949
throw new RuntimeException("Failed to start Tomcat", e);
5050
}
5151

52-
var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT + CUSTOM_CONTEXT_PATH)
53-
.sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
52+
var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT)
53+
.sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
5454
.build();
5555

5656
clientBuilder = McpClient.sync(clientTransport);
@@ -91,7 +91,7 @@ static class TestConfig {
9191
@Bean
9292
public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
9393

94-
return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT,
94+
return new WebMvcSseServerTransportProvider(new ObjectMapper(), "", CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT,
9595
WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT);
9696
}
9797

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ public WebMvcSseServerTransportProvider transportProvider(org.springframework.co
5151
String baseUrl = env.getProperty("test.baseUrl");
5252
String messageEndpoint = env.getProperty("test.messageEndpoint");
5353
String sseEndpoint = env.getProperty("test.sseEndpoint");
54+
String contextPath = env.getProperty("test.contextPath");
5455

55-
return new WebMvcSseServerTransportProvider(new ObjectMapper(), baseUrl, messageEndpoint, sseEndpoint);
56+
return new WebMvcSseServerTransportProvider(new ObjectMapper(), contextPath, baseUrl, messageEndpoint,
57+
sseEndpoint);
5658
}
5759

5860
@Bean
@@ -62,14 +64,17 @@ public RouterFunction<ServerResponse> routerFunction(WebMvcSseServerTransportPro
6264

6365
}
6466

65-
@ParameterizedTest(name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" : {displayName} ")
67+
@ParameterizedTest(
68+
name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" contextPath = \"{3}\" : {displayName} ")
6669
@MethodSource("provideCustomEndpoints")
67-
public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint) {
70+
public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint,
71+
String contextPath) {
6872
System.setProperty("test.baseUrl", baseUrl);
6973
System.setProperty("test.messageEndpoint", messageEndpoint);
7074
System.setProperty("test.sseEndpoint", sseEndpoint);
75+
System.setProperty("test.contextPath", contextPath);
7176

72-
tomcatServer = TomcatTestUtil.createTomcatServer(baseUrl, PORT, TestConfig.class);
77+
tomcatServer = TomcatTestUtil.createTomcatServer(contextPath, PORT, TestConfig.class);
7378

7479
try {
7580
tomcatServer.tomcat().start();
@@ -79,9 +84,18 @@ public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, Stri
7984
throw new RuntimeException("Failed to start Tomcat", e);
8085
}
8186

82-
clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT + baseUrl)
83-
.sseEndpoint(sseEndpoint)
84-
.build());
87+
var c = contextPath;
88+
var b = baseUrl;
89+
var s = sseEndpoint;
90+
if (baseUrl.endsWith("/")) {
91+
b = b.substring(0, b.length() - 1);
92+
}
93+
if (contextPath.endsWith("/")) {
94+
c = c.substring(0, c.length() - 1);
95+
}
96+
97+
clientBuilder = McpClient
98+
.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(c + b + s).build());
8599

86100
McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult(
87101
List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
@@ -113,14 +127,13 @@ private static Stream<Arguments> provideCustomEndpoints() {
113127
String[] baseUrls = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" };
114128
String[] messageEndpoints = { "/message", "/another/sse", "/" };
115129
String[] sseEndpoints = { "/sse", "/another/sse", "/" };
116-
String[] contextPath = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" };
130+
String[] contextPaths = { "", "/mcp", "/root/mcp", "/", "/mcp/", "/root/mcp/" };
117131

118132
return Stream.of(baseUrls)
119133
.flatMap(baseUrl -> Stream.of(messageEndpoints)
120134
.flatMap(messageEndpoint -> Stream.of(sseEndpoints)
121-
.map(sseEndpoint -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint))
122-
123-
));
135+
.flatMap(sseEndpoint -> Stream.of(contextPaths)
136+
.map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath)))));
124137
}
125138

126139
@AfterEach

mcp/src/main/java/io/modelcontextprotocol/util/Utils.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ public static boolean isEmpty(@Nullable Map<?, ?> map) {
7070
* base URL or URI is malformed
7171
*/
7272
public static URI resolveUri(URI baseUrl, String endpointUrl) {
73-
URI endpointUri = URI.create(endpointUrl);
73+
URI endpointUri = URI.create(endpointUrl.startsWith("/") ? endpointUrl.substring(1) : endpointUrl);
7474
if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) {
7575
throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL.");
7676
}
7777
else {
78-
return baseUrl.resolve(endpointUri);
78+
return ensureTrailingSlash(baseUrl).resolve(endpointUri);
7979
}
8080
}
8181

0 commit comments

Comments
 (0)