Skip to content

Commit 9903b56

Browse files
committed
Add Origin header validation
- Fixes #695 - Does not implement Host header validation yet Signed-off-by: Daniel Garnier-Moiroux <git@garnier.wf>
1 parent a47920c commit 9903b56

File tree

17 files changed

+1804
-30
lines changed

17 files changed

+1804
-30
lines changed

conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,34 @@
88
import io.modelcontextprotocol.server.McpServer;
99
import io.modelcontextprotocol.server.McpServerFeatures;
1010
import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
11-
import io.modelcontextprotocol.spec.McpSchema.*;
11+
import io.modelcontextprotocol.spec.McpSchema.AudioContent;
12+
import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents;
13+
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
14+
import io.modelcontextprotocol.spec.McpSchema.CompleteResult;
15+
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
16+
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
17+
import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
18+
import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
19+
import io.modelcontextprotocol.spec.McpSchema.EmbeddedResource;
20+
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
21+
import io.modelcontextprotocol.spec.McpSchema.ImageContent;
22+
import io.modelcontextprotocol.spec.McpSchema.JsonSchema;
23+
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
24+
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
25+
import io.modelcontextprotocol.spec.McpSchema.ProgressNotification;
26+
import io.modelcontextprotocol.spec.McpSchema.Prompt;
27+
import io.modelcontextprotocol.spec.McpSchema.PromptArgument;
28+
import io.modelcontextprotocol.spec.McpSchema.PromptMessage;
29+
import io.modelcontextprotocol.spec.McpSchema.PromptReference;
30+
import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult;
31+
import io.modelcontextprotocol.spec.McpSchema.Resource;
32+
import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate;
33+
import io.modelcontextprotocol.spec.McpSchema.Role;
34+
import io.modelcontextprotocol.spec.McpSchema.SamplingMessage;
35+
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
36+
import io.modelcontextprotocol.spec.McpSchema.TextContent;
37+
import io.modelcontextprotocol.spec.McpSchema.TextResourceContents;
38+
import io.modelcontextprotocol.spec.McpSchema.Tool;
1239
import org.apache.catalina.Context;
1340
import org.apache.catalina.LifecycleException;
1441
import org.apache.catalina.startup.Tomcat;
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright 2026-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server.transport;
6+
7+
import java.util.ArrayList;
8+
import java.util.List;
9+
import java.util.Map;
10+
11+
import io.modelcontextprotocol.util.Assert;
12+
13+
/**
14+
* Default implementation of {@link ServerTransportSecurityValidator} that validates the
15+
* Origin header against a list of allowed origins.
16+
*
17+
* <p>
18+
* Supports exact matches and wildcard port patterns (e.g., "http://example.com:*").
19+
*
20+
* @author Daniel Garnier-Moiroux
21+
* @see ServerTransportSecurityValidator
22+
* @see ServerTransportSecurityException
23+
*/
24+
public class DefaultServerTransportSecurityValidator implements ServerTransportSecurityValidator {
25+
26+
private static final String ORIGIN_HEADER = "Origin";
27+
28+
private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403,
29+
"Invalid Origin header");
30+
31+
private final List<String> allowedOrigins;
32+
33+
/**
34+
* Creates a new validator with the specified allowed origins.
35+
* @param allowedOrigins List of allowed origin patterns. Supports exact matches
36+
* (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*")
37+
*/
38+
public DefaultServerTransportSecurityValidator(List<String> allowedOrigins) {
39+
Assert.notNull(allowedOrigins, "allowedOrigins must not be null");
40+
this.allowedOrigins = allowedOrigins;
41+
}
42+
43+
@Override
44+
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
45+
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
46+
if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
47+
List<String> values = entry.getValue();
48+
if (values != null && !values.isEmpty()) {
49+
validateOrigin(values.get(0));
50+
}
51+
break;
52+
}
53+
}
54+
}
55+
56+
/**
57+
* Validates a single origin value against the allowed origins. Subclasses can
58+
* override this method to customize origin validation logic.
59+
* @param origin The origin header value, or null if not present
60+
* @throws ServerTransportSecurityException if the origin is not allowed
61+
*/
62+
protected void validateOrigin(String origin) throws ServerTransportSecurityException {
63+
// Origin absent = no validation needed (same-origin request)
64+
if (origin == null || origin.isBlank()) {
65+
return;
66+
}
67+
68+
for (String allowed : allowedOrigins) {
69+
if (allowed.equals(origin)) {
70+
return;
71+
}
72+
else if (allowed.endsWith(":*")) {
73+
// Wildcard port pattern: "http://example.com:*"
74+
String baseOrigin = allowed.substring(0, allowed.length() - 2);
75+
if (origin.equals(baseOrigin) || origin.startsWith(baseOrigin + ":")) {
76+
return;
77+
}
78+
}
79+
80+
}
81+
82+
throw INVALID_ORIGIN;
83+
}
84+
85+
/**
86+
* Creates a new builder for constructing a DefaultServerTransportSecurityValidator.
87+
* @return A new builder instance
88+
*/
89+
public static Builder builder() {
90+
return new Builder();
91+
}
92+
93+
/**
94+
* Builder for creating instances of {@link DefaultServerTransportSecurityValidator}.
95+
*/
96+
public static class Builder {
97+
98+
private final List<String> allowedOrigins = new ArrayList<>();
99+
100+
/**
101+
* Adds an allowed origin pattern.
102+
* @param origin The origin to allow (e.g., "http://localhost:8080" or
103+
* "http://example.com:*")
104+
* @return this builder instance
105+
*/
106+
public Builder allowedOrigin(String origin) {
107+
this.allowedOrigins.add(origin);
108+
return this;
109+
}
110+
111+
/**
112+
* Adds multiple allowed origin patterns.
113+
* @param origins The origins to allow
114+
* @return this builder instance
115+
*/
116+
public Builder allowedOrigins(List<String> origins) {
117+
Assert.notNull(origins, "origins must not be null");
118+
this.allowedOrigins.addAll(origins);
119+
return this;
120+
}
121+
122+
/**
123+
* Builds the validator instance.
124+
* @return A new DefaultServerTransportSecurityValidator
125+
*/
126+
public DefaultServerTransportSecurityValidator build() {
127+
return new DefaultServerTransportSecurityValidator(allowedOrigins);
128+
}
129+
130+
}
131+
132+
}

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2024 - 2024 the original author or authors.
2+
* Copyright 2024 - 2026 the original author or authors.
33
*/
44

55
package io.modelcontextprotocol.server.transport;
@@ -8,6 +8,9 @@
88
import java.io.IOException;
99
import java.io.PrintWriter;
1010
import java.time.Duration;
11+
import java.util.Collections;
12+
import java.util.Enumeration;
13+
import java.util.HashMap;
1114
import java.util.List;
1215
import java.util.Map;
1316
import java.util.UUID;
@@ -142,6 +145,11 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
142145
*/
143146
private KeepAliveScheduler keepAliveScheduler;
144147

148+
/**
149+
* Security validator for validating HTTP requests.
150+
*/
151+
private final ServerTransportSecurityValidator securityValidator;
152+
145153
/**
146154
* Creates a new HttpServletSseServerTransportProvider instance with a custom SSE
147155
* endpoint.
@@ -153,23 +161,25 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
153161
* @param keepAliveInterval The interval for keep-alive pings, or null to disable
154162
* keep-alive functionality
155163
* @param contextExtractor The extractor for transport context from the request.
156-
* @deprecated Use the builder {@link #builder()} instead for better configuration
157-
* options.
164+
* @param securityValidator The security validator for validating HTTP requests.
158165
*/
159166
private HttpServletSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint,
160167
String sseEndpoint, Duration keepAliveInterval,
161-
McpTransportContextExtractor<HttpServletRequest> contextExtractor) {
168+
McpTransportContextExtractor<HttpServletRequest> contextExtractor,
169+
ServerTransportSecurityValidator securityValidator) {
162170

163171
Assert.notNull(jsonMapper, "JsonMapper must not be null");
164172
Assert.notNull(messageEndpoint, "messageEndpoint must not be null");
165173
Assert.notNull(sseEndpoint, "sseEndpoint must not be null");
166174
Assert.notNull(contextExtractor, "Context extractor must not be null");
175+
Assert.notNull(securityValidator, "Security validator must not be null");
167176

168177
this.jsonMapper = jsonMapper;
169178
this.baseUrl = baseUrl;
170179
this.messageEndpoint = messageEndpoint;
171180
this.sseEndpoint = sseEndpoint;
172181
this.contextExtractor = contextExtractor;
182+
this.securityValidator = securityValidator;
173183

174184
if (keepAliveInterval != null) {
175185

@@ -246,6 +256,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
246256
return;
247257
}
248258

259+
try {
260+
Map<String, List<String>> headers = extractHeaders(request);
261+
this.securityValidator.validateHeaders(headers);
262+
}
263+
catch (ServerTransportSecurityException e) {
264+
response.sendError(e.getStatusCode(), e.getMessage());
265+
return;
266+
}
267+
249268
response.setContentType("text/event-stream");
250269
response.setCharacterEncoding(UTF_8);
251270
response.setHeader("Cache-Control", "no-cache");
@@ -311,6 +330,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
311330
return;
312331
}
313332

333+
try {
334+
Map<String, List<String>> headers = extractHeaders(request);
335+
this.securityValidator.validateHeaders(headers);
336+
}
337+
catch (ServerTransportSecurityException e) {
338+
response.sendError(e.getStatusCode(), e.getMessage());
339+
return;
340+
}
341+
314342
// Get the session ID from the request parameter
315343
String sessionId = request.getParameter("sessionId");
316344
if (sessionId == null) {
@@ -411,6 +439,21 @@ private void sendEvent(PrintWriter writer, String eventType, String data) throws
411439
}
412440
}
413441

442+
/**
443+
* Extracts all headers from the HTTP servlet request into a map.
444+
* @param request The HTTP servlet request
445+
* @return A map of header names to their values
446+
*/
447+
private Map<String, List<String>> extractHeaders(HttpServletRequest request) {
448+
Map<String, List<String>> headers = new HashMap<>();
449+
Enumeration<String> names = request.getHeaderNames();
450+
while (names.hasMoreElements()) {
451+
String name = names.nextElement();
452+
headers.put(name, Collections.list(request.getHeaders(name)));
453+
}
454+
return headers;
455+
}
456+
414457
/**
415458
* Cleans up resources when the servlet is being destroyed.
416459
* <p>
@@ -547,6 +590,8 @@ public static class Builder {
547590

548591
private Duration keepAliveInterval;
549592

593+
private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
594+
550595
/**
551596
* Sets the JsonMapper implementation to use for serialization/deserialization. If
552597
* not specified, a JacksonJsonMapper will be created from the configured
@@ -621,6 +666,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) {
621666
return this;
622667
}
623668

669+
/**
670+
* Sets the security validator for validating HTTP requests.
671+
* @param securityValidator The security validator to use. Must not be null.
672+
* @return This builder instance
673+
* @throws IllegalArgumentException if securityValidator is null
674+
*/
675+
public Builder securityValidator(ServerTransportSecurityValidator securityValidator) {
676+
Assert.notNull(securityValidator, "Security validator must not be null");
677+
this.securityValidator = securityValidator;
678+
return this;
679+
}
680+
624681
/**
625682
* Builds a new instance of HttpServletSseServerTransportProvider with the
626683
* configured settings.
@@ -633,7 +690,7 @@ public HttpServletSseServerTransportProvider build() {
633690
}
634691
return new HttpServletSseServerTransportProvider(
635692
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, baseUrl, messageEndpoint, sseEndpoint,
636-
keepAliveInterval, contextExtractor);
693+
keepAliveInterval, contextExtractor, securityValidator);
637694
}
638695

639696
}

0 commit comments

Comments
 (0)