Skip to content

Commit b8832da

Browse files
committed
DNS rebinding protection: check host header
Signed-off-by: Daniel Garnier-Moiroux <git@garnier.wf>
1 parent 5ed6063 commit b8832da

File tree

5 files changed

+554
-119
lines changed

5 files changed

+554
-119
lines changed

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212

1313
/**
1414
* Default implementation of {@link ServerTransportSecurityValidator} that validates the
15-
* Origin header against a list of allowed origins.
15+
* Origin and Host headers against lists of allowed values.
1616
*
1717
* <p>
18-
* Supports exact matches and wildcard port patterns (e.g., "http://example.com:*").
18+
* Supports exact matches and wildcard port patterns (e.g., "http://example.com:*" for
19+
* origins, "example.com:*" for hosts).
1920
*
2021
* @author Daniel Garnier-Moiroux
2122
* @see ServerTransportSecurityValidator
@@ -25,32 +26,55 @@ public class DefaultServerTransportSecurityValidator implements ServerTransportS
2526

2627
private static final String ORIGIN_HEADER = "Origin";
2728

29+
private static final String HOST_HEADER = "Host";
30+
2831
private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403,
2932
"Invalid Origin header");
3033

34+
private static final ServerTransportSecurityException INVALID_HOST = new ServerTransportSecurityException(421,
35+
"Invalid Host header");
36+
3137
private final List<String> allowedOrigins;
3238

39+
private final List<String> allowedHosts;
40+
3341
/**
34-
* Creates a new validator with the specified allowed origins.
42+
* Creates a new validator with the specified allowed origins and hosts.
3543
* @param allowedOrigins List of allowed origin patterns. Supports exact matches
3644
* (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*")
45+
* @param allowedHosts List of allowed host patterns. Supports exact matches (e.g.,
46+
* "example.com:8080") and wildcard ports (e.g., "example.com:*")
3747
*/
38-
public DefaultServerTransportSecurityValidator(List<String> allowedOrigins) {
48+
public DefaultServerTransportSecurityValidator(List<String> allowedOrigins, List<String> allowedHosts) {
3949
Assert.notNull(allowedOrigins, "allowedOrigins must not be null");
50+
Assert.notNull(allowedHosts, "allowedHosts must not be null");
4051
this.allowedOrigins = allowedOrigins;
52+
this.allowedHosts = allowedHosts;
4153
}
4254

4355
@Override
4456
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
57+
boolean missingHost = true;
4558
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
4659
if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
4760
List<String> values = entry.getValue();
48-
if (values != null && !values.isEmpty()) {
49-
validateOrigin(values.get(0));
61+
if (values == null || values.isEmpty()) {
62+
throw INVALID_ORIGIN;
63+
}
64+
validateOrigin(values.get(0));
65+
}
66+
else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) {
67+
missingHost = false;
68+
List<String> values = entry.getValue();
69+
if (values == null || values.isEmpty()) {
70+
throw INVALID_HOST;
5071
}
51-
break;
72+
validateHost(values.get(0));
5273
}
5374
}
75+
if (!allowedHosts.isEmpty() && missingHost) {
76+
throw INVALID_HOST;
77+
}
5478
}
5579

5680
/**
@@ -82,6 +106,37 @@ else if (allowed.endsWith(":*")) {
82106
throw INVALID_ORIGIN;
83107
}
84108

109+
/**
110+
* Validates a single host value against the allowed hosts.
111+
* @param host The host header value, or null if not present
112+
* @throws ServerTransportSecurityException if the host is not allowed
113+
*/
114+
private void validateHost(String host) throws ServerTransportSecurityException {
115+
if (allowedHosts.isEmpty()) {
116+
return;
117+
}
118+
119+
// Host is required
120+
if (host == null || host.isBlank()) {
121+
throw INVALID_HOST;
122+
}
123+
124+
for (String allowed : allowedHosts) {
125+
if (allowed.equals(host)) {
126+
return;
127+
}
128+
else if (allowed.endsWith(":*")) {
129+
// Wildcard port pattern: "example.com:*"
130+
String baseHost = allowed.substring(0, allowed.length() - 2);
131+
if (host.equals(baseHost) || host.startsWith(baseHost + ":")) {
132+
return;
133+
}
134+
}
135+
}
136+
137+
throw INVALID_HOST;
138+
}
139+
85140
/**
86141
* Creates a new builder for constructing a DefaultServerTransportSecurityValidator.
87142
* @return A new builder instance
@@ -97,6 +152,8 @@ public static class Builder {
97152

98153
private final List<String> allowedOrigins = new ArrayList<>();
99154

155+
private final List<String> allowedHosts = new ArrayList<>();
156+
100157
/**
101158
* Adds an allowed origin pattern.
102159
* @param origin The origin to allow (e.g., "http://localhost:8080" or
@@ -119,12 +176,33 @@ public Builder allowedOrigins(List<String> origins) {
119176
return this;
120177
}
121178

179+
/**
180+
* Adds an allowed host pattern.
181+
* @param host The host to allow (e.g., "localhost:8080" or "example.com:*")
182+
* @return this builder instance
183+
*/
184+
public Builder allowedHost(String host) {
185+
this.allowedHosts.add(host);
186+
return this;
187+
}
188+
189+
/**
190+
* Adds multiple allowed host patterns.
191+
* @param hosts The hosts to allow
192+
* @return this builder instance
193+
*/
194+
public Builder allowedHosts(List<String> hosts) {
195+
Assert.notNull(hosts, "hosts must not be null");
196+
this.allowedHosts.addAll(hosts);
197+
return this;
198+
}
199+
122200
/**
123201
* Builds the validator instance.
124202
* @return A new DefaultServerTransportSecurityValidator
125203
*/
126204
public DefaultServerTransportSecurityValidator build() {
127-
return new DefaultServerTransportSecurityValidator(allowedOrigins);
205+
return new DefaultServerTransportSecurityValidator(allowedOrigins, allowedHosts);
128206
}
129207

130208
}

0 commit comments

Comments
 (0)