Skip to content

Commit 9cc92aa

Browse files
committed
fix: adds StreamableHttpServerTransportProvide
1 parent 028ad6d commit 9cc92aa

File tree

3 files changed

+376
-0
lines changed

3 files changed

+376
-0
lines changed
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
package io.modelcontextprotocol.server.transport;
2+
3+
import com.fasterxml.jackson.core.type.TypeReference;
4+
import com.fasterxml.jackson.databind.JsonNode;
5+
import com.fasterxml.jackson.databind.ObjectMapper;
6+
import io.modelcontextprotocol.spec.McpSchema;
7+
import io.modelcontextprotocol.spec.McpServerSession;
8+
import io.modelcontextprotocol.spec.McpServerTransport;
9+
import io.modelcontextprotocol.spec.McpServerTransportProvider;
10+
import io.modelcontextprotocol.spec.McpSession;
11+
import io.modelcontextprotocol.spec.StatelessMcpSession;
12+
import jakarta.servlet.AsyncContext;
13+
import jakarta.servlet.ServletException;
14+
import jakarta.servlet.http.HttpServlet;
15+
import jakarta.servlet.http.HttpServletRequest;
16+
import jakarta.servlet.http.HttpServletResponse;
17+
import org.slf4j.Logger;
18+
import org.slf4j.LoggerFactory;
19+
import reactor.core.publisher.Flux;
20+
import reactor.core.publisher.Mono;
21+
22+
import java.io.IOException;
23+
import java.io.InputStream;
24+
import java.io.OutputStream;
25+
import java.nio.charset.StandardCharsets;
26+
import java.time.Duration;
27+
import java.util.ArrayList;
28+
import java.util.Arrays;
29+
import java.util.List;
30+
import java.util.Map;
31+
import java.util.Optional;
32+
import java.util.Set;
33+
import java.util.UUID;
34+
import java.util.concurrent.ConcurrentHashMap;
35+
36+
/**
37+
* @author Aliaksei_Darafeyeu
38+
*/
39+
public class StreamableHttpServerTransportProvider extends HttpServlet implements McpServerTransportProvider {
40+
/**
41+
* Logger for this class
42+
*/
43+
private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class);
44+
45+
private static final String MCP_SESSION_ID = "Mcp-Session-Id";
46+
private static final String APPLICATION_JSON = "application/json";
47+
private static final String TEXT_EVENT_STREAM = "text/event-stream";
48+
49+
private McpServerSession.Factory sessionFactory;
50+
51+
private final ObjectMapper objectMapper;
52+
53+
private final McpServerTransportProvider legacyTransportProvider;
54+
55+
private final Set<String> allowedOrigins;
56+
57+
/**
58+
* Map of active client sessions, keyed by session ID
59+
*/
60+
private final Map<String, McpSession> sessions = new ConcurrentHashMap<>();
61+
62+
public StreamableHttpServerTransportProvider(final ObjectMapper objectMapper, final McpServerTransportProvider legacyTransportProvider, final Set<String> allowedOrigins) {
63+
this.objectMapper = objectMapper;
64+
this.legacyTransportProvider = legacyTransportProvider;
65+
this.allowedOrigins = allowedOrigins;
66+
}
67+
68+
@Override
69+
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
70+
this.sessionFactory = sessionFactory;
71+
}
72+
73+
@Override
74+
public Mono<Void> notifyClients(String method, Object params) {
75+
if (sessions.isEmpty()) {
76+
logger.debug("No active sessions to broadcast message to");
77+
return Mono.empty();
78+
}
79+
80+
logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
81+
return Flux.fromIterable(sessions.values())
82+
.flatMap(session -> session.sendNotification(method, params)
83+
.doOnError(e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()))
84+
.onErrorComplete())
85+
.then();
86+
}
87+
88+
@Override
89+
public Mono<Void> closeGracefully() {
90+
logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size());
91+
return Flux.fromIterable(sessions.values()).flatMap(McpSession::closeGracefully).then();
92+
}
93+
94+
@Override
95+
public void destroy() {
96+
closeGracefully().block();
97+
super.destroy();
98+
}
99+
100+
@Override
101+
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
102+
// 1. Origin header check
103+
String origin = req.getHeader("Origin");
104+
if (origin != null && !allowedOrigins.contains(origin)) {
105+
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Origin not allowed");
106+
return;
107+
}
108+
109+
// 2. Accept header routing
110+
final String accept = Optional.ofNullable(req.getHeader("Accept")).orElse("");
111+
final List<String> acceptTypes = Arrays.stream(accept.split(","))
112+
.map(String::trim)
113+
.toList();
114+
115+
// todo!!!!
116+
if (!acceptTypes.contains(APPLICATION_JSON) && !acceptTypes.contains(TEXT_EVENT_STREAM)) {
117+
if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) {
118+
legacy.doPost(req, resp);
119+
} else {
120+
resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Legacy transport not available");
121+
}
122+
return;
123+
}
124+
125+
// 3. Enable async
126+
final AsyncContext asyncContext = req.startAsync();
127+
asyncContext.setTimeout(0);
128+
129+
// resp
130+
resp.setStatus(HttpServletResponse.SC_OK);
131+
resp.setCharacterEncoding("UTF-8");
132+
133+
final McpServerTransport transport = new StreamableHttpServerTransport(resp.getOutputStream(), objectMapper);
134+
final McpSession session = getOrCreateSession(req.getHeader(MCP_SESSION_ID), transport);
135+
if (!"stateless".equals(session.getId())) {
136+
resp.setHeader(MCP_SESSION_ID, session.getId());
137+
}
138+
final Flux<McpSchema.JSONRPCMessage> messages = parseRequestBodyAsStream(req);
139+
140+
if (accept.contains(TEXT_EVENT_STREAM)) {
141+
// TODO: Handle streaming JSON-RPC over HTTP
142+
resp.setContentType(TEXT_EVENT_STREAM);
143+
resp.setHeader("Connection", "keep-alive");
144+
145+
messages.flatMap(session::handle)
146+
.doOnError(e -> sendError(resp, 500, "Streaming failed: " + e.getMessage()))
147+
.then(transport.closeGracefully())
148+
.subscribe();
149+
} else if (accept.contains(APPLICATION_JSON)) {
150+
// TODO: Handle traditional JSON-RPC response
151+
resp.setContentType(APPLICATION_JSON);
152+
153+
messages.flatMap(session::handle)
154+
.collectList()
155+
.flatMap(responses -> {
156+
try {
157+
String json = new ObjectMapper().writeValueAsString(
158+
responses.size() == 1 ? responses.get(0) : responses
159+
);
160+
resp.getWriter().write(json);
161+
return transport.closeGracefully();
162+
} catch (IOException e) {
163+
return Mono.error(e);
164+
}
165+
})
166+
.doOnError(e -> sendError(resp, 500, "JSON response failed: " + e.getMessage()))
167+
.subscribe();
168+
169+
} else {
170+
resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Unsupported Accept header");
171+
}
172+
}
173+
174+
@Override
175+
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
176+
if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) {
177+
legacy.doGet(req, resp);
178+
} else {
179+
resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Legacy transport not available");
180+
}
181+
}
182+
183+
protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
184+
final String sessionId = req.getHeader("mcp-session-id");
185+
if (sessionId == null || !sessions.containsKey(sessionId)) {
186+
resp.sendError(HttpServletResponse.SC_NOT_FOUND, "Session not found");
187+
return;
188+
}
189+
190+
final McpSession session = sessions.remove(sessionId);
191+
session.closeGracefully().subscribe();
192+
resp.setStatus(HttpServletResponse.SC_NO_CONTENT);
193+
}
194+
195+
// todo:!!!
196+
private Flux<McpSchema.JSONRPCMessage> parseRequestBodyAsStream(final HttpServletRequest req) {
197+
return Mono.fromCallable(() -> {
198+
try (final InputStream inputStream = req.getInputStream()) {
199+
final JsonNode node = objectMapper.readTree(inputStream);
200+
if (node.isArray()) {
201+
final List<McpSchema.JSONRPCMessage> messages = new ArrayList<>();
202+
for (final JsonNode item : node) {
203+
messages.add(objectMapper.treeToValue(item, McpSchema.JSONRPCMessage.class));
204+
}
205+
return messages;
206+
} else if (node.isObject()) {
207+
return List.of(objectMapper.treeToValue(node, McpSchema.JSONRPCMessage.class));
208+
} else {
209+
throw new IllegalArgumentException("Invalid JSON-RPC request: not object or array");
210+
}
211+
}
212+
}).flatMapMany(Flux::fromIterable);
213+
}
214+
215+
private McpSession getOrCreateSession(final String sessionId, final McpServerTransport transport) {
216+
if (sessionId != null && sessionFactory != null) {
217+
// Reuse or track sessions if you support that; for now, we just create new ones
218+
return sessions.get(sessionId);
219+
} else if (sessionFactory != null) {
220+
final String newSessionId = UUID.randomUUID().toString();
221+
return sessions.put(newSessionId, sessionFactory.create(transport));
222+
} else {
223+
return new StatelessMcpSession(transport);
224+
}
225+
}
226+
227+
private void sendError(final HttpServletResponse resp, final int code, final String msg) {
228+
try {
229+
resp.sendError(code, msg);
230+
} catch (IOException ignored) {
231+
logger.debug("Exception during send error");
232+
}
233+
}
234+
235+
public static class StreamableHttpServerTransport implements McpServerTransport {
236+
private final ObjectMapper objectMapper;
237+
private final OutputStream outputStream;
238+
239+
public StreamableHttpServerTransport(final OutputStream outputStream, final ObjectMapper objectMapper) {
240+
this.objectMapper = objectMapper;
241+
this.outputStream = outputStream;
242+
}
243+
244+
@Override
245+
public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
246+
return Mono.fromRunnable(() -> {
247+
try {
248+
String json = objectMapper.writeValueAsString(message);
249+
outputStream.write(json.getBytes(StandardCharsets.UTF_8));
250+
outputStream.write('\n');
251+
outputStream.flush();
252+
} catch (IOException e) {
253+
throw new RuntimeException("Failed to send message", e);
254+
}
255+
});
256+
}
257+
258+
@Override
259+
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
260+
return objectMapper.convertValue(data, typeRef);
261+
}
262+
263+
@Override
264+
public Mono<Void> closeGracefully() {
265+
return Mono.fromRunnable(() -> {
266+
try {
267+
outputStream.flush();
268+
outputStream.close();
269+
} catch (IOException e) {
270+
// ignore or log
271+
}
272+
});
273+
}
274+
}
275+
276+
}

mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@
2525
*/
2626
public interface McpSession {
2727

28+
/**
29+
* Retrieve the session id.
30+
* @return session id
31+
*/
32+
String getId();
33+
34+
/**
35+
* Called by the {@link McpServerTransportProvider} once the session is determined.
36+
* The purpose of this method is to dispatch the message to an appropriate handler as
37+
* specified by the MCP server implementation
38+
* ({@link io.modelcontextprotocol.server.McpAsyncServer} or
39+
* {@link io.modelcontextprotocol.server.McpSyncServer}) via
40+
* {@link McpServerSession.Factory} that the server creates.
41+
* @param message the incoming JSON-RPC message
42+
* @return a Mono that completes when the message is processed
43+
*/
44+
Mono<Void> handle(McpSchema.JSONRPCMessage message);
45+
2846
/**
2947
* Sends a request to the model counterparty and expects a response of type T.
3048
*
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package io.modelcontextprotocol.spec;
2+
3+
import com.fasterxml.jackson.core.type.TypeReference;
4+
import reactor.core.publisher.Mono;
5+
6+
import java.util.UUID;
7+
8+
/**
9+
* @author Aliaksei_Darafeyeu
10+
*/
11+
public class StatelessMcpSession implements McpSession {
12+
13+
private final McpTransport transport;
14+
15+
public StatelessMcpSession(final McpTransport transport) {
16+
this.transport = transport;
17+
}
18+
19+
@Override
20+
public String getId() {
21+
return "stateless";
22+
}
23+
24+
@Override
25+
public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
26+
if (message instanceof McpSchema.JSONRPCRequest request) {
27+
// Stateless sessions do not support incoming requests
28+
McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse(
29+
McpSchema.JSONRPC_VERSION,
30+
request.id(),
31+
null,
32+
new McpSchema.JSONRPCResponse.JSONRPCError(
33+
McpSchema.ErrorCodes.METHOD_NOT_FOUND,
34+
"Stateless session does not handle requests",
35+
null
36+
)
37+
);
38+
return transport.sendMessage(errorResponse);
39+
}
40+
else if (message instanceof McpSchema.JSONRPCNotification notification) {
41+
// Stateless session ignores incoming notifications
42+
return Mono.empty();
43+
}
44+
else if (message instanceof McpSchema.JSONRPCResponse response) {
45+
// No request/response correlation in stateless mode
46+
return Mono.empty();
47+
}
48+
else {
49+
return Mono.empty();
50+
}
51+
}
52+
53+
54+
@Override
55+
public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
56+
// Stateless = no request/response correlation
57+
String requestId = UUID.randomUUID().toString();
58+
McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(
59+
McpSchema.JSONRPC_VERSION, method, requestId, requestParams
60+
);
61+
62+
return Mono.defer(() -> Mono.from(this.transport.sendMessage(request)).then(Mono.error(new IllegalStateException("Stateless session cannot receive responses")))
63+
);
64+
}
65+
66+
@Override
67+
public Mono<Void> sendNotification(String method, Object params) {
68+
McpSchema.JSONRPCNotification notification =
69+
new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params);
70+
return Mono.from(this.transport.sendMessage(notification));
71+
}
72+
73+
@Override
74+
public Mono<Void> closeGracefully() {
75+
return this.transport.closeGracefully();
76+
}
77+
78+
@Override
79+
public void close() {
80+
this.closeGracefully().subscribe();
81+
}
82+
}

0 commit comments

Comments
 (0)