Skip to content

Commit 4b30e0a

Browse files
committed
fix listening sse stream close blocking
1 parent e87387b commit 4b30e0a

File tree

4 files changed

+54
-39
lines changed

4 files changed

+54
-39
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import io.modelcontextprotocol.spec.McpError;
1313
import io.modelcontextprotocol.spec.McpSchema;
1414
import io.modelcontextprotocol.spec.McpStreamableServerSession;
15+
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1516
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
1617
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
1718
import io.modelcontextprotocol.spec.ProtocolVersions;
@@ -199,9 +200,16 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
199200
.body(Flux.<ServerSentEvent<?>>create(sink -> {
200201
WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport(
201202
sink);
202-
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
203-
.listeningStream(sessionTransport);
204-
sink.onDispose(listeningStream::close);
203+
session.listeningStream(sessionTransport)
204+
.doOnNext(serverSessionStream -> sink
205+
.onDispose(() -> serverSessionStream.closeGracefully().subscribe(v -> {
206+
}, error -> logger.warn("Failed to close listening stream gracefully", error))))
207+
.doOnError(error -> {
208+
logger.error("Failed to create listening stream", error);
209+
sink.error(error);
210+
})
211+
.subscribe(serverSessionStream -> logger.debug("Listening stream created successfully"),
212+
sink::error);
205213
}), ServerSentEvent.class);
206214

207215
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.concurrent.ConcurrentHashMap;
1111
import java.util.concurrent.locks.ReentrantLock;
1212

13+
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1314
import org.slf4j.Logger;
1415
import org.slf4j.LoggerFactory;
1516
import org.springframework.http.HttpStatus;
@@ -289,13 +290,15 @@ private ServerResponse handleGet(ServerRequest request) {
289290
}
290291
else {
291292
// Establish new listening stream
292-
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
293+
Mono<McpStreamableServerSession.McpStreamableServerSessionStream> listeningStream = session
293294
.listeningStream(sessionTransport);
294-
295-
sseBuilder.onComplete(() -> {
295+
listeningStream.subscribe(serverSessionStream -> sseBuilder.onComplete(() -> {
296296
logger.debug("SSE connection completed for session: {}", sessionId);
297-
listeningStream.close();
298-
});
297+
serverSessionStream.close();
298+
}), error -> {
299+
sseBuilder.error(error);
300+
logger.error("Failed to create listening stream", error);
301+
}, () -> logger.debug("Listening stream created successfully"));
299302
}
300303
}, Duration.ZERO);
301304
}

mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.concurrent.ConcurrentHashMap;
1414
import java.util.concurrent.locks.ReentrantLock;
1515

16+
import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream;
1617
import org.slf4j.Logger;
1718
import org.slf4j.LoggerFactory;
1819

@@ -316,33 +317,36 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
316317
}
317318
else {
318319
// Establish new listening stream
319-
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
320-
.listeningStream(sessionTransport);
321-
322-
asyncContext.addListener(new jakarta.servlet.AsyncListener() {
323-
@Override
324-
public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException {
325-
logger.debug("SSE connection completed for session: {}", sessionId);
326-
listeningStream.close();
327-
}
328-
329-
@Override
330-
public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException {
331-
logger.debug("SSE connection timed out for session: {}", sessionId);
332-
listeningStream.close();
333-
}
320+
session.listeningStream(sessionTransport)
321+
.doOnNext(serverSessionStream -> asyncContext.addListener(new jakarta.servlet.AsyncListener() {
322+
@Override
323+
public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException {
324+
logger.debug("SSE connection completed for session: {}", sessionId);
325+
serverSessionStream.close();
326+
}
327+
328+
@Override
329+
public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException {
330+
logger.debug("SSE connection timed out for session: {}", sessionId);
331+
serverSessionStream.close();
332+
}
333+
334+
@Override
335+
public void onError(jakarta.servlet.AsyncEvent event) throws IOException {
336+
logger.debug("SSE connection error for session: {}", sessionId);
337+
serverSessionStream.close();
338+
}
339+
340+
@Override
341+
public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException {
342+
// No action needed
343+
}
344+
}))
345+
.doOnError(error -> {
346+
logger.error("Failed to create listening stream", error);
347+
})
348+
.subscribe(serverSessionStream -> logger.debug("Listening stream created successfully"));
334349

335-
@Override
336-
public void onError(jakarta.servlet.AsyncEvent event) throws IOException {
337-
logger.debug("SSE connection error for session: {}", sessionId);
338-
listeningStream.close();
339-
}
340-
341-
@Override
342-
public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException {
343-
// No action needed
344-
}
345-
});
346350
}
347351
}
348352
catch (Exception e) {

mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,16 @@ public Mono<Void> delete() {
134134
* @param transport The dedicated SSE transport stream
135135
* @return a stream representation
136136
*/
137-
public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) {
137+
public Mono<McpStreamableServerSessionStream> listeningStream(McpStreamableServerTransport transport) {
138138
McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport);
139-
McpLoggableSession listenedStream = this.listeningStreamRef.getAndSet(listeningStream);
140-
if (listenedStream != null) {
139+
McpLoggableSession oldStream = this.listeningStreamRef.getAndSet(listeningStream);
140+
if (oldStream != null) {
141141
logger.debug(
142142
"Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream",
143143
this.id);
144-
listenedStream.closeGracefully().block();
144+
return oldStream.closeGracefully().thenReturn(listeningStream);
145145
}
146-
return listeningStream;
146+
return Mono.just(listeningStream);
147147
}
148148

149149
// TODO: keep track of history by keeping a map from eventId to stream and then

0 commit comments

Comments
 (0)