Skip to content

Commit e1e5c96

Browse files
committed
feat: Add request with context support.
Signed-off-by: He-Pin <hepin1989@gmail.com>
1 parent e610d85 commit e1e5c96

File tree

15 files changed

+250
-112
lines changed

15 files changed

+250
-112
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
package io.modelcontextprotocol.server.transport;
22

33
import java.io.IOException;
4-
import java.util.Map;
54
import java.util.concurrent.ConcurrentHashMap;
65

76
import com.fasterxml.jackson.core.type.TypeReference;
87
import com.fasterxml.jackson.databind.ObjectMapper;
9-
import io.modelcontextprotocol.spec.McpError;
10-
import io.modelcontextprotocol.spec.McpSchema;
11-
import io.modelcontextprotocol.spec.McpServerSession;
12-
import io.modelcontextprotocol.spec.McpServerTransport;
13-
import io.modelcontextprotocol.spec.McpServerTransportProvider;
8+
import io.modelcontextprotocol.spec.*;
149
import io.modelcontextprotocol.util.Assert;
1510
import org.slf4j.Logger;
1611
import org.slf4j.LoggerFactory;
@@ -100,6 +95,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
10095

10196
private McpServerSession.Factory sessionFactory;
10297

98+
private McpContextFactory mcpContextFactory;
99+
103100
/**
104101
* Map of active client sessions, keyed by session ID.
105102
*/
@@ -169,6 +166,11 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) {
169166
this.sessionFactory = sessionFactory;
170167
}
171168

169+
@Override
170+
public void setMcpContextFactory(final McpContextFactory mcpContextFactory) {
171+
this.mcpContextFactory = mcpContextFactory;
172+
}
173+
172174
/**
173175
* Broadcasts a JSON-RPC message to all connected clients through their SSE
174176
* connections. The message is serialized to JSON and sent as a server-sent event to
@@ -261,7 +263,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
261263
.body(Flux.<ServerSentEvent<?>>create(sink -> {
262264
WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink);
263265

264-
McpServerSession session = sessionFactory.create(sessionTransport);
266+
McpServerSession session = sessionFactory.create(sessionTransport, createContext(request));
265267
String sessionId = session.getId();
266268

267269
logger.debug("Created new SSE connection for session: {}", sessionId);
@@ -280,6 +282,18 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
280282
}), ServerSentEvent.class);
281283
}
282284

285+
private McpContext createContext(final ServerRequest request) {
286+
// create a context form the request
287+
McpContext context;
288+
if (mcpContextFactory != null) {
289+
context = mcpContextFactory.create(request);
290+
}
291+
else {
292+
context = McpContext.empty();
293+
}
294+
return context;
295+
}
296+
283297
/**
284298
* Handles incoming JSON-RPC messages from clients. Deserializes the message and
285299
* processes it through the configured message handler.
@@ -314,14 +328,16 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
314328
return request.bodyToMono(String.class).flatMap(body -> {
315329
try {
316330
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
317-
return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> {
318-
logger.error("Error processing message: {}", error.getMessage());
319-
// TODO: instead of signalling the error, just respond with 200 OK
320-
// - the error is signalled on the SSE connection
321-
// return ServerResponse.ok().build();
322-
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
323-
.bodyValue(new McpError(error.getMessage()));
324-
});
331+
return session.handle(message, createContext(request))
332+
.flatMap(response -> ServerResponse.ok().build())
333+
.onErrorResume(error -> {
334+
logger.error("Error processing message: {}", error.getMessage());
335+
// TODO: instead of signalling the error, just respond with 200 OK
336+
// - the error is signalled on the SSE connection
337+
// return ServerResponse.ok().build();
338+
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
339+
.bodyValue(new McpError(error.getMessage()));
340+
});
325341
}
326342
catch (IllegalArgumentException | IOException e) {
327343
logger.error("Failed to deserialize message: {}", e.getMessage());

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import io.modelcontextprotocol.server.TestUtil;
2424
import io.modelcontextprotocol.server.McpSyncServerExchange;
2525
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
26+
import io.modelcontextprotocol.spec.McpContext;
2627
import io.modelcontextprotocol.spec.McpError;
2728
import io.modelcontextprotocol.spec.McpSchema;
2829
import io.modelcontextprotocol.spec.McpSchema.*;
29-
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities;
3030
import org.junit.jupiter.api.AfterEach;
3131
import org.junit.jupiter.api.BeforeEach;
3232
import org.junit.jupiter.params.ParameterizedTest;
@@ -767,9 +767,11 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
767767
));
768768

769769
AtomicReference<CompleteRequest> samplingRequest = new AtomicReference<>();
770-
BiFunction<McpSyncServerExchange, CompleteRequest, CompleteResult> completionHandler = (mcpSyncServerExchange,
771-
request) -> {
772-
samplingRequest.set(request);
770+
AtomicReference<McpContext> mcpContext = new AtomicReference<>();
771+
BiFunction<McpSyncServerExchange, McpServerFeatures.RequestWithContext<CompleteRequest>, CompleteResult> completionHandler = (
772+
mcpSyncServerExchange, reqWithContext) -> {
773+
samplingRequest.set(reqWithContext.request());
774+
mcpContext.set(reqWithContext.mcpContext());
773775
return completionResponse;
774776
};
775777

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,12 @@
66

77
import java.io.IOException;
88
import java.time.Duration;
9-
import java.util.Map;
109
import java.util.UUID;
1110
import java.util.concurrent.ConcurrentHashMap;
1211

1312
import com.fasterxml.jackson.core.type.TypeReference;
1413
import com.fasterxml.jackson.databind.ObjectMapper;
15-
import io.modelcontextprotocol.spec.McpError;
16-
import io.modelcontextprotocol.spec.McpSchema;
17-
import io.modelcontextprotocol.spec.McpServerTransport;
18-
import io.modelcontextprotocol.spec.McpServerTransportProvider;
19-
import io.modelcontextprotocol.spec.McpServerSession;
14+
import io.modelcontextprotocol.spec.*;
2015
import io.modelcontextprotocol.util.Assert;
2116
import org.slf4j.Logger;
2217
import org.slf4j.LoggerFactory;
@@ -97,6 +92,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
9792

9893
private McpServerSession.Factory sessionFactory;
9994

95+
private McpContextFactory mcpContextFactory;
96+
10097
/**
10198
* Map of active client sessions, keyed by session ID.
10299
*/
@@ -169,6 +166,11 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) {
169166
this.sessionFactory = sessionFactory;
170167
}
171168

169+
@Override
170+
public void setMcpContextFactory(McpContextFactory mcpContextFactory) {
171+
this.mcpContextFactory = mcpContextFactory;
172+
}
173+
172174
/**
173175
* Broadcasts a notification to all connected clients through their SSE connections.
174176
* The message is serialized to JSON and sent as an SSE event with type "message". If
@@ -263,7 +265,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
263265
});
264266

265267
WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder);
266-
McpServerSession session = sessionFactory.create(sessionTransport);
268+
McpServerSession session = sessionFactory.create(sessionTransport, createContext(request));
267269
this.sessions.put(sessionId, session);
268270

269271
try {
@@ -284,6 +286,18 @@ private ServerResponse handleSseConnection(ServerRequest request) {
284286
}
285287
}
286288

289+
private McpContext createContext(final ServerRequest request) {
290+
// create a context form the request
291+
McpContext context;
292+
if (mcpContextFactory != null) {
293+
context = mcpContextFactory.create(request);
294+
}
295+
else {
296+
context = McpContext.empty();
297+
}
298+
return context;
299+
}
300+
287301
/**
288302
* Handles incoming JSON-RPC messages from clients. This method:
289303
* <ul>
@@ -316,7 +330,8 @@ private ServerResponse handleMessage(ServerRequest request) {
316330
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
317331

318332
// Process the message through the session's handle method
319-
session.handle(message).block(); // Block for WebMVC compatibility
333+
session.handle(message, createContext(request)).block(); // Block for WebMVC
334+
// compatibility
320335

321336
return ServerResponse.ok().build();
322337
}

0 commit comments

Comments
 (0)