Skip to content

Commit b2bbcd1

Browse files
committed
Move gate to validate_request to avoid calling functions unnecessarily
1 parent fb3ce68 commit b2bbcd1

File tree

1 file changed

+35
-37
lines changed

1 file changed

+35
-37
lines changed

src/mcp/server/transport_security.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,48 +15,45 @@ class TransportSecuritySettings(BaseModel):
1515
These settings help protect against DNS rebinding attacks by validating
1616
incoming request headers.
1717
"""
18-
18+
1919
enable_dns_rebinding_protection: bool = Field(
2020
default=True,
21-
description="Enable DNS rebinding protection (recommended for production)"
21+
description="Enable DNS rebinding protection (recommended for production)",
2222
)
23-
23+
2424
allowed_hosts: list[str] = Field(
2525
default=[],
26-
description="List of allowed Host header values. Only applies when " +
27-
"enable_dns_rebinding_protection is True."
26+
description="List of allowed Host header values. Only applies when "
27+
+ "enable_dns_rebinding_protection is True.",
2828
)
29-
29+
3030
allowed_origins: list[str] = Field(
3131
default=[],
32-
description="List of allowed Origin header values. Only applies when " +
33-
"enable_dns_rebinding_protection is True."
32+
description="List of allowed Origin header values. Only applies when "
33+
+ "enable_dns_rebinding_protection is True.",
3434
)
3535

3636

3737
class TransportSecurityMiddleware:
3838
"""Middleware to enforce DNS rebinding protection for MCP transport endpoints."""
39-
39+
4040
def __init__(self, settings: TransportSecuritySettings | None = None):
4141
# If not specified, disable DNS rebinding protection by default
4242
# for backwards compatibility
4343
self.settings = settings or TransportSecuritySettings(
4444
enable_dns_rebinding_protection=False
4545
)
46-
46+
4747
def _validate_host(self, host: str | None) -> bool:
4848
"""Validate the Host header against allowed values."""
49-
if not self.settings.enable_dns_rebinding_protection:
50-
return True
51-
5249
if not host:
5350
logger.warning("Missing Host header in request")
5451
return False
55-
52+
5653
# Check exact match first
5754
if host in self.settings.allowed_hosts:
5855
return True
59-
56+
6057
# Check wildcard port patterns
6158
for allowed in self.settings.allowed_hosts:
6259
if allowed.endswith(":*"):
@@ -65,23 +62,20 @@ def _validate_host(self, host: str | None) -> bool:
6562
# Check if the actual host starts with base host and has a port
6663
if host.startswith(base_host + ":"):
6764
return True
68-
65+
6966
logger.warning(f"Invalid Host header: {host}")
7067
return False
71-
68+
7269
def _validate_origin(self, origin: str | None) -> bool:
7370
"""Validate the Origin header against allowed values."""
74-
if not self.settings.enable_dns_rebinding_protection:
75-
return True
76-
7771
# Origin can be absent for same-origin requests
7872
if not origin:
7973
return True
80-
74+
8175
# Check exact match first
8276
if origin in self.settings.allowed_origins:
8377
return True
84-
78+
8579
# Check wildcard port patterns
8680
for allowed in self.settings.allowed_origins:
8781
if allowed.endswith(":*"):
@@ -90,44 +84,48 @@ def _validate_origin(self, origin: str | None) -> bool:
9084
# Check if the actual origin starts with base origin and has a port
9185
if origin.startswith(base_origin + ":"):
9286
return True
93-
87+
9488
logger.warning(f"Invalid Origin header: {origin}")
9589
return False
96-
90+
9791
def _validate_content_type(self, content_type: str | None) -> bool:
9892
"""Validate the Content-Type header for POST requests."""
9993
if not content_type:
10094
logger.warning("Missing Content-Type header in POST request")
10195
return False
102-
96+
10397
# Content-Type must start with application/json
10498
if not content_type.lower().startswith("application/json"):
10599
logger.warning(f"Invalid Content-Type header: {content_type}")
106100
return False
107-
101+
108102
return True
109-
103+
110104
async def validate_request(
111105
self, request: Request, is_post: bool = False
112106
) -> Response | None:
113107
"""Validate request headers for DNS rebinding protection.
114-
108+
115109
Returns None if validation passes, or an error Response if validation fails.
116110
"""
111+
# Always validate Content-Type for POST requests
112+
if is_post:
113+
content_type = request.headers.get("content-type")
114+
if not self._validate_content_type(content_type):
115+
return Response("Invalid Content-Type header", status_code=400)
116+
117+
# Skip remaining validation if DNS rebinding protection is disabled
118+
if not self.settings.enable_dns_rebinding_protection:
119+
return None
120+
117121
# Validate Host header
118122
host = request.headers.get("host")
119123
if not self._validate_host(host):
120124
return Response("Invalid Host header", status_code=400)
121-
125+
122126
# Validate Origin header
123127
origin = request.headers.get("origin")
124128
if not self._validate_origin(origin):
125129
return Response("Invalid Origin header", status_code=400)
126-
127-
# Validate Content-Type for POST requests
128-
if is_post:
129-
content_type = request.headers.get("content-type")
130-
if not self._validate_content_type(content_type):
131-
return Response("Invalid Content-Type header", status_code=400)
132-
133-
return None
130+
131+
return None

0 commit comments

Comments
 (0)