Skip to content

Commit 1a829d0

Browse files
committed
fix: Resolve URIs
1 parent 07e7b8f commit 1a829d0

File tree

7 files changed

+191
-7
lines changed

7 files changed

+191
-7
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
154154
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
155155
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
156156

157+
if (baseUrl.endsWith("/")) {
158+
baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
159+
}
160+
157161
this.objectMapper = objectMapper;
158162
this.baseUrl = baseUrl;
159163
this.messageEndpoint = messageEndpoint;

mcp-spring/mcp-spring-webmvc/pom.xml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@
7171
<version>${junit.version}</version>
7272
<scope>test</scope>
7373
</dependency>
74+
<dependency>
75+
<groupId>org.junit.jupiter</groupId>
76+
<artifactId>junit-jupiter-params</artifactId>
77+
<version>${junit.version}</version>
78+
<scope>test</scope>
79+
</dependency>
7480
<dependency>
7581
<groupId>org.mockito</groupId>
7682
<artifactId>mockito-core</artifactId>
@@ -128,7 +134,7 @@
128134
<scope>test</scope>
129135
</dependency>
130136

131-
</dependencies>
137+
</dependencies>
132138

133139

134140
</project>

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import java.io.IOException;
88
import java.time.Duration;
9+
import java.util.Collections;
910
import java.util.Map;
1011
import java.util.UUID;
1112
import java.util.concurrent.ConcurrentHashMap;
@@ -152,7 +153,13 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
152153
Assert.notNull(objectMapper, "ObjectMapper must not be null");
153154
Assert.notNull(baseUrl, "Message base URL must not be null");
154155
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
156+
Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
155157
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
158+
Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");
159+
160+
if (baseUrl.endsWith("/")) {
161+
baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
162+
}
156163

157164
this.objectMapper = objectMapper;
158165
this.baseUrl = baseUrl;

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ public void before() {
4949
throw new RuntimeException("Failed to start Tomcat", e);
5050
}
5151

52-
var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT)
53-
.sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
52+
var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT + CUSTOM_CONTEXT_PATH)
53+
.sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
5454
.build();
5555

5656
clientBuilder = McpClient.sync(clientTransport);
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package io.modelcontextprotocol.server;
2+
3+
import io.modelcontextprotocol.client.McpClient;
4+
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
5+
import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
6+
import io.modelcontextprotocol.spec.McpSchema;
7+
import java.util.List;
8+
import java.util.Map;
9+
import java.util.stream.Stream;
10+
import reactor.core.publisher.Mono;
11+
import static org.assertj.core.api.Assertions.assertThat;
12+
13+
import org.apache.catalina.LifecycleException;
14+
import org.apache.catalina.LifecycleState;
15+
import org.junit.jupiter.api.AfterEach;
16+
import org.junit.jupiter.params.ParameterizedTest;
17+
import org.junit.jupiter.params.provider.Arguments;
18+
import org.junit.jupiter.params.provider.MethodSource;
19+
import org.springframework.context.annotation.Bean;
20+
import org.springframework.context.annotation.Configuration;
21+
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
22+
import org.springframework.web.servlet.function.RouterFunction;
23+
import org.springframework.web.servlet.function.ServerResponse;
24+
25+
import com.fasterxml.jackson.databind.ObjectMapper;
26+
27+
public class WebMvcSseCustomPathIntegrationTests {
28+
29+
private static final int PORT = TestUtil.findAvailablePort();
30+
31+
private WebMvcSseServerTransportProvider mcpServerTransportProvider;
32+
33+
McpClient.SyncSpec clientBuilder;
34+
35+
private TomcatTestUtil.TomcatServer tomcatServer;
36+
37+
String emptyJsonSchema = """
38+
{
39+
"$schema": "http://json-schema.org/draft-07/schema#",
40+
"type": "object",
41+
"properties": {}
42+
}
43+
""";
44+
45+
@Configuration
46+
@EnableWebMvc
47+
static class TestConfig {
48+
49+
@Bean
50+
public WebMvcSseServerTransportProvider transportProvider(org.springframework.core.env.Environment env) {
51+
String baseUrl = env.getProperty("test.baseUrl");
52+
String messageEndpoint = env.getProperty("test.messageEndpoint");
53+
String sseEndpoint = env.getProperty("test.sseEndpoint");
54+
55+
return new WebMvcSseServerTransportProvider(new ObjectMapper(), baseUrl, messageEndpoint, sseEndpoint);
56+
}
57+
58+
@Bean
59+
public RouterFunction<ServerResponse> routerFunction(WebMvcSseServerTransportProvider transportProvider) {
60+
return transportProvider.getRouterFunction();
61+
}
62+
63+
}
64+
65+
@ParameterizedTest(name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" : {displayName} ")
66+
@MethodSource("provideCustomEndpoints")
67+
public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint) {
68+
System.setProperty("test.baseUrl", baseUrl);
69+
System.setProperty("test.messageEndpoint", messageEndpoint);
70+
System.setProperty("test.sseEndpoint", sseEndpoint);
71+
72+
tomcatServer = TomcatTestUtil.createTomcatServer(baseUrl, PORT, TestConfig.class);
73+
74+
try {
75+
tomcatServer.tomcat().start();
76+
assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
77+
}
78+
catch (Exception e) {
79+
throw new RuntimeException("Failed to start Tomcat", e);
80+
}
81+
82+
clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT + baseUrl)
83+
.sseEndpoint(sseEndpoint)
84+
.build());
85+
86+
McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult(
87+
List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
88+
89+
McpServerFeatures.AsyncToolSpecification tool1 = new McpServerFeatures.AsyncToolSpecification(
90+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema),
91+
(exchange, request) -> Mono.just(callResponse));
92+
93+
mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
94+
95+
var server = McpServer.async(mcpServerTransportProvider)
96+
.serverInfo("test-server", "1.0.0")
97+
.capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
98+
.tools(tool1)
99+
.build();
100+
101+
try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) {
102+
assertThat(client.initialize()).isNotNull();
103+
assertThat(client.listTools().tools()).contains(tool1.tool());
104+
105+
McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
106+
assertThat(response).isNotNull().isEqualTo(callResponse);
107+
}
108+
109+
server.close();
110+
}
111+
112+
private static Stream<Arguments> provideCustomEndpoints() {
113+
String[] baseUrls = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" };
114+
String[] messageEndpoints = { "/message", "/another/sse", "/" };
115+
String[] sseEndpoints = { "/sse", "/another/sse", "/" };
116+
String[] contextPath = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" };
117+
118+
return Stream.of(baseUrls)
119+
.flatMap(baseUrl -> Stream.of(messageEndpoints)
120+
.flatMap(messageEndpoint -> Stream.of(sseEndpoints)
121+
.map(sseEndpoint -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint))
122+
123+
));
124+
}
125+
126+
@AfterEach
127+
public void after() {
128+
if (mcpServerTransportProvider != null) {
129+
mcpServerTransportProvider.closeGracefully().block();
130+
}
131+
if (tomcatServer.appContext() != null) {
132+
tomcatServer.appContext().close();
133+
}
134+
if (tomcatServer.tomcat() != null) {
135+
try {
136+
tomcatServer.tomcat().stop();
137+
tomcatServer.tomcat().destroy();
138+
}
139+
catch (LifecycleException e) {
140+
throw new RuntimeException("Failed to stop Tomcat", e);
141+
}
142+
}
143+
}
144+
145+
}

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques
175175
HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
176176
String sseEndpoint, ObjectMapper objectMapper) {
177177
Assert.notNull(objectMapper, "ObjectMapper must not be null");
178-
Assert.hasText(baseUri, "baseUri must not be empty");
179-
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
180-
Assert.notNull(httpClient, "httpClient must not be null");
178+
Assert.notNull(baseUri, "baseUri must not be null");
179+
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
180+
Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");
181181
Assert.notNull(requestBuilder, "requestBuilder must not be null");
182182
this.baseUri = URI.create(baseUri);
183183
this.sseEndpoint = sseEndpoint;
@@ -341,7 +341,8 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
341341
CompletableFuture<Void> future = new CompletableFuture<>();
342342
connectionFuture.set(future);
343343

344-
URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
344+
URI clientUri = Utils.resolveSseUri(this.baseUri, this.sseEndpoint);
345+
logger.debug("Subscribing to {}", clientUri);
345346
sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() {
346347
@Override
347348
public void onEvent(SseEvent event) {

mcp/src/main/java/io/modelcontextprotocol/util/Utils.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package io.modelcontextprotocol.util;
66

7+
import java.net.URL;
78
import reactor.util.annotation.Nullable;
89

910
import java.net.URI;
@@ -78,6 +79,26 @@ public static URI resolveUri(URI baseUrl, String endpointUrl) {
7879
}
7980
}
8081

82+
public static URI resolveSseUri(URI baseUrl, String endpointUrl) {
83+
String sanitizedEndpoint = stripLeadingSlash(endpointUrl);
84+
URI endpointUri = URI.create(sanitizedEndpoint);
85+
if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) {
86+
throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL.");
87+
}
88+
89+
URI res = ensureTrailingSlash(baseUrl).resolve(endpointUri);
90+
return res;
91+
}
92+
93+
private static String stripLeadingSlash(String url) {
94+
return url.startsWith("/") ? url.substring(1) : url;
95+
}
96+
97+
private static URI ensureTrailingSlash(URI uri) {
98+
String uriString = uri.toString();
99+
return !uriString.endsWith("/") ? URI.create(uriString.concat("/")) : uri;
100+
}
101+
81102
/**
82103
* Checks if the given absolute endpoint URI falls under the base URI. It validates
83104
* the scheme, authority (host and port), and ensures that the base path is a prefix

0 commit comments

Comments
 (0)