Skip to content

Commit 90ecbb4

Browse files
committed
resolve conflicts
2 parents 4b30e0a + 629464b commit 90ecbb4

File tree

53 files changed

+1251
-262
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1251
-262
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
import com.fasterxml.jackson.core.type.TypeReference;
1313
import com.fasterxml.jackson.databind.ObjectMapper;
14+
15+
import io.modelcontextprotocol.common.McpTransportContext;
16+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1417
import io.modelcontextprotocol.spec.McpError;
1518
import io.modelcontextprotocol.spec.McpSchema;
1619
import io.modelcontextprotocol.spec.McpServerSession;
@@ -115,6 +118,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
115118
*/
116119
private final ConcurrentHashMap<String, McpServerSession> sessions = new ConcurrentHashMap<>();
117120

121+
private McpTransportContextExtractor<ServerRequest> contextExtractor;
122+
118123
/**
119124
* Flag indicating if the transport is shutting down.
120125
*/
@@ -194,15 +199,38 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
194199
@Deprecated
195200
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
196201
String sseEndpoint, Duration keepAliveInterval) {
202+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval,
203+
(serverRequest) -> McpTransportContext.EMPTY);
204+
}
205+
206+
/**
207+
* Constructs a new WebFlux SSE server transport provider instance.
208+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
209+
* of MCP messages. Must not be null.
210+
* @param baseUrl webflux message base path
211+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
212+
* messages. This endpoint will be communicated to clients during SSE connection
213+
* setup. Must not be null.
214+
* @param sseEndpoint The SSE endpoint path. Must not be null.
215+
* @param keepAliveInterval The interval for sending keep-alive pings to clients.
216+
* @param contextExtractor The context extractor to use for extracting MCP transport
217+
* context from HTTP requests. Must not be null.
218+
* @throws IllegalArgumentException if either parameter is null
219+
*/
220+
private WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
221+
String sseEndpoint, Duration keepAliveInterval,
222+
McpTransportContextExtractor<ServerRequest> contextExtractor) {
197223
Assert.notNull(objectMapper, "ObjectMapper must not be null");
198224
Assert.notNull(baseUrl, "Message base path must not be null");
199225
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
200226
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
227+
Assert.notNull(contextExtractor, "Context extractor must not be null");
201228

202229
this.objectMapper = objectMapper;
203230
this.baseUrl = baseUrl;
204231
this.messageEndpoint = messageEndpoint;
205232
this.sseEndpoint = sseEndpoint;
233+
this.contextExtractor = contextExtractor;
206234
this.routerFunction = RouterFunctions.route()
207235
.GET(this.sseEndpoint, this::handleSseConnection)
208236
.POST(this.messageEndpoint, this::handleMessage)
@@ -315,6 +343,8 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
315343
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
316344
}
317345

346+
McpTransportContext transportContext = this.contextExtractor.extract(request);
347+
318348
return ServerResponse.ok()
319349
.contentType(MediaType.TEXT_EVENT_STREAM)
320350
.body(Flux.<ServerSentEvent<?>>create(sink -> {
@@ -336,7 +366,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
336366
logger.debug("Session {} cancelled", sessionId);
337367
sessions.remove(sessionId);
338368
});
339-
}), ServerSentEvent.class);
369+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class);
340370
}
341371

342372
/**
@@ -370,6 +400,8 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
370400
.bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get()));
371401
}
372402

403+
McpTransportContext transportContext = this.contextExtractor.extract(request);
404+
373405
return request.bodyToMono(String.class).flatMap(body -> {
374406
try {
375407
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
@@ -386,7 +418,7 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
386418
logger.error("Failed to deserialize message: {}", e.getMessage());
387419
return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
388420
}
389-
});
421+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
390422
}
391423

392424
private class WebFluxMcpSessionTransport implements McpServerTransport {
@@ -458,6 +490,9 @@ public static class Builder {
458490

459491
private Duration keepAliveInterval;
460492

493+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
494+
serverRequest) -> McpTransportContext.EMPTY;
495+
461496
/**
462497
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
463498
* messages.
@@ -519,6 +554,22 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
519554
return this;
520555
}
521556

557+
/**
558+
* Sets the context extractor that allows providing the MCP feature
559+
* implementations to inspect HTTP transport level metadata that was present at
560+
* HTTP request processing time. This allows to extract custom headers and other
561+
* useful data for use during execution later on in the process.
562+
* @param contextExtractor The contextExtractor to fill in a
563+
* {@link McpTransportContext}.
564+
* @return this builder instance
565+
* @throws IllegalArgumentException if contextExtractor is null
566+
*/
567+
public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> contextExtractor) {
568+
Assert.notNull(contextExtractor, "contextExtractor must not be null");
569+
this.contextExtractor = contextExtractor;
570+
return this;
571+
}
572+
522573
/**
523574
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
524575
* configured settings.
@@ -530,7 +581,7 @@ public WebFluxSseServerTransportProvider build() {
530581
Assert.notNull(messageEndpoint, "Message endpoint must be set");
531582

532583
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
533-
keepAliveInterval);
584+
keepAliveInterval, contextExtractor);
534585
}
535586

536587
}

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
package io.modelcontextprotocol.server.transport;
66

77
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import io.modelcontextprotocol.common.McpTransportContext;
89
import io.modelcontextprotocol.server.McpStatelessServerHandler;
9-
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
1010
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1111
import io.modelcontextprotocol.spec.McpError;
1212
import io.modelcontextprotocol.spec.McpSchema;
1313
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
14-
import io.modelcontextprotocol.server.McpTransportContext;
1514
import io.modelcontextprotocol.util.Assert;
1615
import org.slf4j.Logger;
1716
import org.slf4j.LoggerFactory;
@@ -97,7 +96,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
9796
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
9897
}
9998

100-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
99+
McpTransportContext transportContext = this.contextExtractor.extract(request);
101100

102101
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
103102
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
@@ -151,7 +150,8 @@ public static class Builder {
151150

152151
private String mcpEndpoint = "/mcp";
153152

154-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
153+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
154+
serverRequest) -> McpTransportContext.EMPTY;
155155

156156
private Builder() {
157157
// used by a static method

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,15 @@
66

77
import com.fasterxml.jackson.core.type.TypeReference;
88
import com.fasterxml.jackson.databind.ObjectMapper;
9-
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
9+
import io.modelcontextprotocol.common.McpTransportContext;
1010
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1111
import io.modelcontextprotocol.spec.HttpHeaders;
1212
import io.modelcontextprotocol.spec.McpError;
1313
import io.modelcontextprotocol.spec.McpSchema;
1414
import io.modelcontextprotocol.spec.McpStreamableServerSession;
15-
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1615
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
1716
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
1817
import io.modelcontextprotocol.spec.ProtocolVersions;
19-
import io.modelcontextprotocol.server.McpTransportContext;
2018
import io.modelcontextprotocol.util.Assert;
2119
import io.modelcontextprotocol.util.KeepAliveScheduler;
2220

@@ -167,7 +165,7 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
167165
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
168166
}
169167

170-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
168+
McpTransportContext transportContext = this.contextExtractor.extract(request);
171169

172170
return Mono.defer(() -> {
173171
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
@@ -192,7 +190,9 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
192190
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
193191
return ServerResponse.ok()
194192
.contentType(MediaType.TEXT_EVENT_STREAM)
195-
.body(session.replay(lastId), ServerSentEvent.class);
193+
.body(session.replay(lastId)
194+
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)),
195+
ServerSentEvent.class);
196196
}
197197

198198
return ServerResponse.ok()
@@ -210,7 +210,9 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
210210
})
211211
.subscribe(serverSessionStream -> logger.debug("Listening stream created successfully"),
212212
sink::error);
213-
}), ServerSentEvent.class);
213+
// TODO Clarify why the outer context is not present in the
214+
// Flux.create sink?
215+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class);
214216

215217
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
216218
}
@@ -225,7 +227,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
225227
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
226228
}
227229

228-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
230+
McpTransportContext transportContext = this.contextExtractor.extract(request);
229231

230232
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
231233
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
@@ -290,7 +292,10 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {
290292
return true;
291293
}).contextWrite(sink.contextView()).subscribe();
292294
sink.onCancel(streamSubscription);
293-
}), ServerSentEvent.class);
295+
// TODO Clarify why the outer context is not present in the
296+
// Flux.create sink?
297+
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)),
298+
ServerSentEvent.class);
294299
}
295300
else {
296301
return ServerResponse.badRequest().bodyValue(new McpError("Unknown message type"));
@@ -310,7 +315,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
310315
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
311316
}
312317

313-
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
318+
McpTransportContext transportContext = this.contextExtractor.extract(request);
314319

315320
return Mono.defer(() -> {
316321
if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
@@ -403,7 +408,8 @@ public static class Builder {
403408

404409
private String mcpEndpoint = "/mcp";
405410

406-
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
411+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
412+
serverRequest) -> McpTransportContext.EMPTY;
407413

408414
private boolean disallowDelete;
409415

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package io.modelcontextprotocol;
66

77
import java.time.Duration;
8+
import java.util.Map;
89

910
import org.junit.jupiter.api.AfterEach;
1011
import org.junit.jupiter.api.BeforeEach;
@@ -13,15 +14,18 @@
1314
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
1415
import org.springframework.web.reactive.function.client.WebClient;
1516
import org.springframework.web.reactive.function.server.RouterFunctions;
17+
import org.springframework.web.reactive.function.server.ServerRequest;
1618

1719
import com.fasterxml.jackson.databind.ObjectMapper;
1820

1921
import io.modelcontextprotocol.client.McpClient;
2022
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
2123
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
24+
import io.modelcontextprotocol.common.McpTransportContext;
2225
import io.modelcontextprotocol.server.McpServer;
2326
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
2427
import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification;
28+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2529
import io.modelcontextprotocol.server.TestUtil;
2630
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
2731
import reactor.netty.DisposableServer;
@@ -40,6 +44,9 @@ class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests
4044

4145
private WebFluxSseServerTransportProvider mcpServerTransportProvider;
4246

47+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
48+
.create(Map.of("important", "value"));
49+
4350
@Override
4451
protected void prepareClients(int port, String mcpEndpoint) {
4552

@@ -75,6 +82,7 @@ public void before() {
7582
.objectMapper(new ObjectMapper())
7683
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
7784
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
85+
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
7886
.build();
7987

8088
HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction());

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package io.modelcontextprotocol;
66

77
import java.time.Duration;
8+
import java.util.Map;
89

910
import org.junit.jupiter.api.AfterEach;
1011
import org.junit.jupiter.api.BeforeEach;
@@ -13,15 +14,18 @@
1314
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
1415
import org.springframework.web.reactive.function.client.WebClient;
1516
import org.springframework.web.reactive.function.server.RouterFunctions;
17+
import org.springframework.web.reactive.function.server.ServerRequest;
1618

1719
import com.fasterxml.jackson.databind.ObjectMapper;
1820

1921
import io.modelcontextprotocol.client.McpClient;
2022
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
2123
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
24+
import io.modelcontextprotocol.common.McpTransportContext;
2225
import io.modelcontextprotocol.server.McpServer;
2326
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
2427
import io.modelcontextprotocol.server.McpServer.SyncSpecification;
28+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2529
import io.modelcontextprotocol.server.TestUtil;
2630
import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider;
2731
import reactor.netty.DisposableServer;
@@ -38,6 +42,9 @@ class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrati
3842

3943
private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider;
4044

45+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
46+
.create(Map.of("important", "value"));
47+
4148
@Override
4249
protected void prepareClients(int port, String mcpEndpoint) {
4350

@@ -71,6 +78,7 @@ public void before() {
7178
this.mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider.builder()
7279
.objectMapper(new ObjectMapper())
7380
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
81+
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
7482
.build();
7583

7684
HttpHandler httpHandler = RouterFunctions

0 commit comments

Comments
 (0)