1+ """DNS rebinding protection for MCP server transports."""
2+
3+ import logging
4+
5+ from pydantic import BaseModel , Field
6+ from starlette .requests import Request
7+ from starlette .responses import Response
8+
9+ logger = logging .getLogger (__name__ )
10+
11+
12+ class TransportSecuritySettings (BaseModel ):
13+ """Settings for MCP transport security features.
14+
15+ These settings help protect against DNS rebinding attacks by validating
16+ incoming request headers.
17+ """
18+
19+ enable_dns_rebinding_protection : bool = Field (
20+ default = True ,
21+ description = "Enable DNS rebinding protection (recommended for production)"
22+ )
23+
24+ allowed_hosts : list [str ] = Field (
25+ default = [],
26+ description = "List of allowed Host header values. If None, all hosts "
27+ "are allowed when protection is disabled, or only localhost/127.0.0.1 "
28+ "when enabled."
29+ )
30+
31+ allowed_origins : list [str ] = Field (
32+ default = [],
33+ description = "List of allowed Origin header values. If None, all "
34+ "origins are allowed when protection is disabled, or only localhost "
35+ "origins when enabled."
36+ )
37+
38+
39+ class TransportSecurityMiddleware :
40+ """Middleware to enforce DNS rebinding protection for MCP transport endpoints."""
41+
42+ def __init__ (self , settings : TransportSecuritySettings | None = None ):
43+ # If not specified, disable DNS rebinding protection by default
44+ # for backwards compatibility
45+ self .settings = settings or TransportSecuritySettings (
46+ enable_dns_rebinding_protection = False
47+ )
48+
49+ def _validate_host (self , host : str | None ) -> bool :
50+ """Validate the Host header against allowed values."""
51+ if not self .settings .enable_dns_rebinding_protection :
52+ return True
53+
54+ if not host :
55+ logger .warning ("Missing Host header in request" )
56+ return False
57+
58+ # Check exact match first
59+ if host in self .settings .allowed_hosts :
60+ return True
61+
62+ # Check wildcard port patterns
63+ for allowed in self .settings .allowed_hosts :
64+ if allowed .endswith (":*" ):
65+ # Extract base host from pattern
66+ base_host = allowed [:- 2 ]
67+ # Check if the actual host starts with base host and has a port
68+ if host .startswith (base_host + ":" ):
69+ return True
70+
71+ logger .warning (f"Invalid Host header: { host } " )
72+ return False
73+
74+ def _validate_origin (self , origin : str | None ) -> bool :
75+ """Validate the Origin header against allowed values."""
76+ if not self .settings .enable_dns_rebinding_protection :
77+ return True
78+
79+ # Origin can be absent for same-origin requests
80+ if not origin :
81+ return True
82+
83+ # Check exact match first
84+ if origin in self .settings .allowed_origins :
85+ return True
86+
87+ # Check wildcard port patterns
88+ for allowed in self .settings .allowed_origins :
89+ if allowed .endswith (":*" ):
90+ # Extract base origin from pattern
91+ base_origin = allowed [:- 2 ]
92+ # Check if the actual origin starts with base origin and has a port
93+ if origin .startswith (base_origin + ":" ):
94+ return True
95+
96+ logger .warning (f"Invalid Origin header: { origin } " )
97+ return False
98+
99+ def _validate_content_type (self , content_type : str | None ) -> bool :
100+ """Validate the Content-Type header for POST requests."""
101+ if not content_type :
102+ logger .warning ("Missing Content-Type header in POST request" )
103+ return False
104+
105+ # Content-Type must start with application/json
106+ if not content_type .lower ().startswith ("application/json" ):
107+ logger .warning (f"Invalid Content-Type header: { content_type } " )
108+ return False
109+
110+ return True
111+
112+ async def validate_request (
113+ self , request : Request , is_post : bool = False
114+ ) -> Response | None :
115+ """Validate request headers for DNS rebinding protection.
116+
117+ Returns None if validation passes, or an error Response if validation fails.
118+ """
119+ # Validate Host header
120+ host = request .headers .get ("host" )
121+ if not self ._validate_host (host ):
122+ return Response ("Invalid Host header" , status_code = 400 )
123+
124+ # Validate Origin header
125+ origin = request .headers .get ("origin" )
126+ if not self ._validate_origin (origin ):
127+ return Response ("Invalid Origin header" , status_code = 400 )
128+
129+ # Validate Content-Type for POST requests
130+ if is_post :
131+ content_type = request .headers .get ("content-type" )
132+ if not self ._validate_content_type (content_type ):
133+ return Response ("Invalid Content-Type header" , status_code = 400 )
134+
135+ return None
0 commit comments