11/*
2- * Copyright 2024 - 2024 the original author or authors.
2+ * Copyright 2024 - 2026 the original author or authors.
33 */
44
55package io .modelcontextprotocol .server .transport ;
88import java .io .IOException ;
99import java .io .PrintWriter ;
1010import java .time .Duration ;
11+ import java .util .Collections ;
12+ import java .util .Enumeration ;
13+ import java .util .HashMap ;
1114import java .util .List ;
1215import java .util .Map ;
1316import java .util .UUID ;
@@ -142,6 +145,11 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
142145 */
143146 private KeepAliveScheduler keepAliveScheduler ;
144147
148+ /**
149+ * Security validator for validating HTTP requests.
150+ */
151+ private final ServerTransportSecurityValidator securityValidator ;
152+
145153 /**
146154 * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE
147155 * endpoint.
@@ -153,23 +161,25 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
153161 * @param keepAliveInterval The interval for keep-alive pings, or null to disable
154162 * keep-alive functionality
155163 * @param contextExtractor The extractor for transport context from the request.
156- * @deprecated Use the builder {@link #builder()} instead for better configuration
157- * options.
164+ * @param securityValidator The security validator for validating HTTP requests.
158165 */
159166 private HttpServletSseServerTransportProvider (McpJsonMapper jsonMapper , String baseUrl , String messageEndpoint ,
160167 String sseEndpoint , Duration keepAliveInterval ,
161- McpTransportContextExtractor <HttpServletRequest > contextExtractor ) {
168+ McpTransportContextExtractor <HttpServletRequest > contextExtractor ,
169+ ServerTransportSecurityValidator securityValidator ) {
162170
163171 Assert .notNull (jsonMapper , "JsonMapper must not be null" );
164172 Assert .notNull (messageEndpoint , "messageEndpoint must not be null" );
165173 Assert .notNull (sseEndpoint , "sseEndpoint must not be null" );
166174 Assert .notNull (contextExtractor , "Context extractor must not be null" );
175+ Assert .notNull (securityValidator , "Security validator must not be null" );
167176
168177 this .jsonMapper = jsonMapper ;
169178 this .baseUrl = baseUrl ;
170179 this .messageEndpoint = messageEndpoint ;
171180 this .sseEndpoint = sseEndpoint ;
172181 this .contextExtractor = contextExtractor ;
182+ this .securityValidator = securityValidator ;
173183
174184 if (keepAliveInterval != null ) {
175185
@@ -246,6 +256,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
246256 return ;
247257 }
248258
259+ try {
260+ Map <String , List <String >> headers = extractHeaders (request );
261+ this .securityValidator .validateHeaders (headers );
262+ }
263+ catch (ServerTransportSecurityException e ) {
264+ response .sendError (e .getStatusCode (), e .getMessage ());
265+ return ;
266+ }
267+
249268 response .setContentType ("text/event-stream" );
250269 response .setCharacterEncoding (UTF_8 );
251270 response .setHeader ("Cache-Control" , "no-cache" );
@@ -311,6 +330,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
311330 return ;
312331 }
313332
333+ try {
334+ Map <String , List <String >> headers = extractHeaders (request );
335+ this .securityValidator .validateHeaders (headers );
336+ }
337+ catch (ServerTransportSecurityException e ) {
338+ response .sendError (e .getStatusCode (), e .getMessage ());
339+ return ;
340+ }
341+
314342 // Get the session ID from the request parameter
315343 String sessionId = request .getParameter ("sessionId" );
316344 if (sessionId == null ) {
@@ -411,6 +439,21 @@ private void sendEvent(PrintWriter writer, String eventType, String data) throws
411439 }
412440 }
413441
442+ /**
443+ * Extracts all headers from the HTTP servlet request into a map.
444+ * @param request The HTTP servlet request
445+ * @return A map of header names to their values
446+ */
447+ private Map <String , List <String >> extractHeaders (HttpServletRequest request ) {
448+ Map <String , List <String >> headers = new HashMap <>();
449+ Enumeration <String > names = request .getHeaderNames ();
450+ while (names .hasMoreElements ()) {
451+ String name = names .nextElement ();
452+ headers .put (name , Collections .list (request .getHeaders (name )));
453+ }
454+ return headers ;
455+ }
456+
414457 /**
415458 * Cleans up resources when the servlet is being destroyed.
416459 * <p>
@@ -547,6 +590,8 @@ public static class Builder {
547590
548591 private Duration keepAliveInterval ;
549592
593+ private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator .NOOP ;
594+
550595 /**
551596 * Sets the JsonMapper implementation to use for serialization/deserialization. If
552597 * not specified, a JacksonJsonMapper will be created from the configured
@@ -621,6 +666,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
621666 return this ;
622667 }
623668
669+ /**
670+ * Sets the security validator for validating HTTP requests.
671+ * @param securityValidator The security validator to use. Must not be null.
672+ * @return This builder instance
673+ * @throws IllegalArgumentException if securityValidator is null
674+ */
675+ public Builder securityValidator (ServerTransportSecurityValidator securityValidator ) {
676+ Assert .notNull (securityValidator , "Security validator must not be null" );
677+ this .securityValidator = securityValidator ;
678+ return this ;
679+ }
680+
624681 /**
625682 * Builds a new instance of HttpServletSseServerTransportProvider with the
626683 * configured settings.
@@ -633,7 +690,7 @@ public HttpServletSseServerTransportProvider build() {
633690 }
634691 return new HttpServletSseServerTransportProvider (
635692 jsonMapper == null ? McpJsonMapper .getDefault () : jsonMapper , baseUrl , messageEndpoint , sseEndpoint ,
636- keepAliveInterval , contextExtractor );
693+ keepAliveInterval , contextExtractor , securityValidator );
637694 }
638695
639696 }
0 commit comments