Skip to content

Commit a98c092

Browse files
committed
mvc sse session support session timeou configuration
1 parent 082444e commit a98c092

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
118118

119119
private KeepAliveScheduler keepAliveScheduler;
120120

121+
/**
122+
* sse session timeout
123+
*/
124+
private final Duration sessionTimeout;
125+
121126
/**
122127
* Constructs a new WebMvcSseServerTransportProvider instance.
123128
* @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization
@@ -135,18 +140,20 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
135140
*/
136141
private WebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint,
137142
String sseEndpoint, Duration keepAliveInterval,
138-
McpTransportContextExtractor<ServerRequest> contextExtractor) {
143+
McpTransportContextExtractor<ServerRequest> contextExtractor, Duration sessionTimeout) {
139144
Assert.notNull(jsonMapper, "McpJsonMapper must not be null");
140145
Assert.notNull(baseUrl, "Message base URL must not be null");
141146
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
142147
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
143148
Assert.notNull(contextExtractor, "Context extractor must not be null");
149+
Assert.notNull(sessionTimeout, "Session timeout must not be null");
144150

145151
this.jsonMapper = jsonMapper;
146152
this.baseUrl = baseUrl;
147153
this.messageEndpoint = messageEndpoint;
148154
this.sseEndpoint = sseEndpoint;
149155
this.contextExtractor = contextExtractor;
156+
this.sessionTimeout = sessionTimeout;
150157
this.routerFunction = RouterFunctions.route()
151158
.GET(this.sseEndpoint, this::handleSseConnection)
152159
.POST(this.messageEndpoint, this::handleMessage)
@@ -279,7 +286,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
279286
this.sessions.remove(sessionId);
280287
sseBuilder.error(e);
281288
}
282-
}, Duration.ZERO);
289+
}, this.sessionTimeout);
283290
}
284291

285292
/**
@@ -471,6 +478,8 @@ public static class Builder {
471478

472479
private Duration keepAliveInterval;
473480

481+
private Duration sessionTimeout = Duration.ZERO;
482+
474483
private McpTransportContextExtractor<ServerRequest> contextExtractor = (
475484
serverRequest) -> McpTransportContext.EMPTY;
476485

@@ -549,6 +558,11 @@ public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> cont
549558
return this;
550559
}
551560

561+
public Builder sessionTimeout(Duration sessionTimeout) {
562+
this.sessionTimeout = sessionTimeout;
563+
return this;
564+
}
565+
552566
/**
553567
* Builds a new instance of WebMvcSseServerTransportProvider with the configured
554568
* settings.
@@ -560,7 +574,7 @@ public WebMvcSseServerTransportProvider build() {
560574
throw new IllegalStateException("MessageEndpoint must be set");
561575
}
562576
return new WebMvcSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
563-
baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor);
577+
baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor, sessionTimeout);
564578
}
565579

566580
}

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
import org.springframework.web.servlet.function.RouterFunction;
2525
import org.springframework.web.servlet.function.ServerResponse;
2626

27+
import java.time.Duration;
28+
import java.util.concurrent.TimeUnit;
29+
2730
import static org.assertj.core.api.Assertions.assertThat;
2831

2932
/**
@@ -66,7 +69,7 @@ public void before() {
6669
}
6770

6871
@Test
69-
void validBaseUrl() {
72+
void validBaseUrl() throws InterruptedException {
7073
McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build();
7174
try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
7275
.build()) {
@@ -106,6 +109,7 @@ public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
106109
.sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
107110
.jsonMapper(McpJsonMapper.getDefault())
108111
.contextExtractor(req -> McpTransportContext.EMPTY)
112+
.sessionTimeout(Duration.ofSeconds(1))
109113
.build();
110114
}
111115

0 commit comments

Comments
 (0)