@@ -15,48 +15,45 @@ class TransportSecuritySettings(BaseModel):
1515 These settings help protect against DNS rebinding attacks by validating
1616 incoming request headers.
1717 """
18-
18+
1919 enable_dns_rebinding_protection : bool = Field (
2020 default = True ,
21- description = "Enable DNS rebinding protection (recommended for production)"
21+ description = "Enable DNS rebinding protection (recommended for production)" ,
2222 )
23-
23+
2424 allowed_hosts : list [str ] = Field (
2525 default = [],
26- description = "List of allowed Host header values. Only applies when " +
27- "enable_dns_rebinding_protection is True."
26+ description = "List of allowed Host header values. Only applies when "
27+ + "enable_dns_rebinding_protection is True." ,
2828 )
29-
29+
3030 allowed_origins : list [str ] = Field (
3131 default = [],
32- description = "List of allowed Origin header values. Only applies when " +
33- "enable_dns_rebinding_protection is True."
32+ description = "List of allowed Origin header values. Only applies when "
33+ + "enable_dns_rebinding_protection is True." ,
3434 )
3535
3636
3737class TransportSecurityMiddleware :
3838 """Middleware to enforce DNS rebinding protection for MCP transport endpoints."""
39-
39+
4040 def __init__ (self , settings : TransportSecuritySettings | None = None ):
4141 # If not specified, disable DNS rebinding protection by default
4242 # for backwards compatibility
4343 self .settings = settings or TransportSecuritySettings (
4444 enable_dns_rebinding_protection = False
4545 )
46-
46+
4747 def _validate_host (self , host : str | None ) -> bool :
4848 """Validate the Host header against allowed values."""
49- if not self .settings .enable_dns_rebinding_protection :
50- return True
51-
5249 if not host :
5350 logger .warning ("Missing Host header in request" )
5451 return False
55-
52+
5653 # Check exact match first
5754 if host in self .settings .allowed_hosts :
5855 return True
59-
56+
6057 # Check wildcard port patterns
6158 for allowed in self .settings .allowed_hosts :
6259 if allowed .endswith (":*" ):
@@ -65,23 +62,20 @@ def _validate_host(self, host: str | None) -> bool:
6562 # Check if the actual host starts with base host and has a port
6663 if host .startswith (base_host + ":" ):
6764 return True
68-
65+
6966 logger .warning (f"Invalid Host header: { host } " )
7067 return False
71-
68+
7269 def _validate_origin (self , origin : str | None ) -> bool :
7370 """Validate the Origin header against allowed values."""
74- if not self .settings .enable_dns_rebinding_protection :
75- return True
76-
7771 # Origin can be absent for same-origin requests
7872 if not origin :
7973 return True
80-
74+
8175 # Check exact match first
8276 if origin in self .settings .allowed_origins :
8377 return True
84-
78+
8579 # Check wildcard port patterns
8680 for allowed in self .settings .allowed_origins :
8781 if allowed .endswith (":*" ):
@@ -90,44 +84,48 @@ def _validate_origin(self, origin: str | None) -> bool:
9084 # Check if the actual origin starts with base origin and has a port
9185 if origin .startswith (base_origin + ":" ):
9286 return True
93-
87+
9488 logger .warning (f"Invalid Origin header: { origin } " )
9589 return False
96-
90+
9791 def _validate_content_type (self , content_type : str | None ) -> bool :
9892 """Validate the Content-Type header for POST requests."""
9993 if not content_type :
10094 logger .warning ("Missing Content-Type header in POST request" )
10195 return False
102-
96+
10397 # Content-Type must start with application/json
10498 if not content_type .lower ().startswith ("application/json" ):
10599 logger .warning (f"Invalid Content-Type header: { content_type } " )
106100 return False
107-
101+
108102 return True
109-
103+
110104 async def validate_request (
111105 self , request : Request , is_post : bool = False
112106 ) -> Response | None :
113107 """Validate request headers for DNS rebinding protection.
114-
108+
115109 Returns None if validation passes, or an error Response if validation fails.
116110 """
111+ # Always validate Content-Type for POST requests
112+ if is_post :
113+ content_type = request .headers .get ("content-type" )
114+ if not self ._validate_content_type (content_type ):
115+ return Response ("Invalid Content-Type header" , status_code = 400 )
116+
117+ # Skip remaining validation if DNS rebinding protection is disabled
118+ if not self .settings .enable_dns_rebinding_protection :
119+ return None
120+
117121 # Validate Host header
118122 host = request .headers .get ("host" )
119123 if not self ._validate_host (host ):
120124 return Response ("Invalid Host header" , status_code = 400 )
121-
125+
122126 # Validate Origin header
123127 origin = request .headers .get ("origin" )
124128 if not self ._validate_origin (origin ):
125129 return Response ("Invalid Origin header" , status_code = 400 )
126-
127- # Validate Content-Type for POST requests
128- if is_post :
129- content_type = request .headers .get ("content-type" )
130- if not self ._validate_content_type (content_type ):
131- return Response ("Invalid Content-Type header" , status_code = 400 )
132-
133- return None
130+
131+ return None
0 commit comments