Skip to content

Commit e87387b

Browse files
committed
Atomically close the existing listening stream and switch to the new one.
1 parent c7cbe98 commit e87387b

File tree

4 files changed

+20
-44
lines changed

4 files changed

+20
-44
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1111
import io.modelcontextprotocol.spec.HttpHeaders;
1212
import io.modelcontextprotocol.spec.McpError;
13-
import io.modelcontextprotocol.spec.McpLoggableSession;
1413
import io.modelcontextprotocol.spec.McpSchema;
1514
import io.modelcontextprotocol.spec.McpStreamableServerSession;
16-
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1715
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
1816
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
1917
import io.modelcontextprotocol.spec.ProtocolVersions;
@@ -189,19 +187,12 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
189187
return ServerResponse.notFound().build();
190188
}
191189

192-
McpLoggableSession listenedStream = session.getListeningStream();
193190
if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
194191
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
195192
return ServerResponse.ok()
196193
.contentType(MediaType.TEXT_EVENT_STREAM)
197194
.body(session.replay(lastId), ServerSentEvent.class);
198195
}
199-
if (listenedStream instanceof McpStreamableServerSessionStream) {
200-
logger.debug(
201-
"Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream",
202-
sessionId);
203-
listenedStream.close();
204-
}
205196

206197
return ServerResponse.ok()
207198
.contentType(MediaType.TEXT_EVENT_STREAM)

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import java.util.concurrent.ConcurrentHashMap;
1111
import java.util.concurrent.locks.ReentrantLock;
1212

13-
import io.modelcontextprotocol.spec.McpLoggableSession;
14-
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1513
import org.slf4j.Logger;
1614
import org.slf4j.LoggerFactory;
1715
import org.springframework.http.HttpStatus;
@@ -254,6 +252,7 @@ private ServerResponse handleGet(ServerRequest request) {
254252
}
255253

256254
logger.debug("Handling GET request for session: {}", sessionId);
255+
257256
try {
258257
return ServerResponse.sse(sseBuilder -> {
259258
sseBuilder.onTimeout(() -> {
@@ -266,6 +265,7 @@ private ServerResponse handleGet(ServerRequest request) {
266265
// Check if this is a replay request
267266
if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
268267
String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
268+
269269
try {
270270
session.replay(lastId)
271271
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
@@ -288,13 +288,6 @@ private ServerResponse handleGet(ServerRequest request) {
288288
}
289289
}
290290
else {
291-
McpLoggableSession listenedStream = session.getListeningStream();
292-
if (listenedStream instanceof McpStreamableServerSessionStream) {
293-
logger.debug(
294-
"Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream",
295-
sessionId);
296-
listenedStream.close();
297-
}
298291
// Establish new listening stream
299292
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
300293
.listeningStream(sessionTransport);

mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
import java.util.concurrent.ConcurrentHashMap;
1414
import java.util.concurrent.locks.ReentrantLock;
1515

16-
import io.modelcontextprotocol.spec.McpLoggableSession;
17-
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1816
import org.slf4j.Logger;
1917
import org.slf4j.LoggerFactory;
2018

@@ -275,6 +273,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
275273
}
276274

277275
logger.debug("Handling GET request for session: {}", sessionId);
276+
278277
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
279278

280279
try {
@@ -316,13 +315,6 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
316315
}
317316
}
318317
else {
319-
McpLoggableSession listenedStream = session.getListeningStream();
320-
if (listenedStream instanceof McpStreamableServerSessionStream) {
321-
logger.debug(
322-
"Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream",
323-
sessionId);
324-
listenedStream.close();
325-
}
326318
// Establish new listening stream
327319
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
328320
.listeningStream(sessionTransport);

mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,26 @@
44

55
package io.modelcontextprotocol.spec;
66

7-
import java.time.Duration;
8-
import java.util.Map;
9-
import java.util.UUID;
10-
import java.util.concurrent.ConcurrentHashMap;
11-
import java.util.concurrent.atomic.AtomicLong;
12-
import java.util.concurrent.atomic.AtomicReference;
13-
import java.util.function.Supplier;
14-
15-
import org.slf4j.Logger;
16-
import org.slf4j.LoggerFactory;
17-
187
import com.fasterxml.jackson.core.type.TypeReference;
19-
208
import io.modelcontextprotocol.server.McpAsyncServerExchange;
219
import io.modelcontextprotocol.server.McpNotificationHandler;
2210
import io.modelcontextprotocol.server.McpRequestHandler;
2311
import io.modelcontextprotocol.server.McpTransportContext;
2412
import io.modelcontextprotocol.util.Assert;
13+
import org.slf4j.Logger;
14+
import org.slf4j.LoggerFactory;
2515
import reactor.core.publisher.Flux;
2616
import reactor.core.publisher.Mono;
2717
import reactor.core.publisher.MonoSink;
2818

19+
import java.time.Duration;
20+
import java.util.Map;
21+
import java.util.UUID;
22+
import java.util.concurrent.ConcurrentHashMap;
23+
import java.util.concurrent.atomic.AtomicLong;
24+
import java.util.concurrent.atomic.AtomicReference;
25+
import java.util.function.Supplier;
26+
2927
/**
3028
* Representation of a Streamable HTTP server session that keeps track of mapping
3129
* server-initiated requests to the client and mapping arriving responses. It also allows
@@ -138,14 +136,16 @@ public Mono<Void> delete() {
138136
*/
139137
public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) {
140138
McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport);
141-
this.listeningStreamRef.set(listeningStream);
139+
McpLoggableSession listenedStream = this.listeningStreamRef.getAndSet(listeningStream);
140+
if (listenedStream != null) {
141+
logger.debug(
142+
"Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream",
143+
this.id);
144+
listenedStream.closeGracefully().block();
145+
}
142146
return listeningStream;
143147
}
144148

145-
public McpLoggableSession getListeningStream() {
146-
return this.listeningStreamRef.get();
147-
}
148-
149149
// TODO: keep track of history by keeping a map from eventId to stream and then
150150
// iterate over the events using the lastEventId
151151
public Flux<McpSchema.JSONRPCMessage> replay(Object lastEventId) {

0 commit comments

Comments
 (0)