Skip to content

Commit 366b3c4

Browse files
committed
Add support for DNS rebinding protections
1 parent 9dad266 commit 366b3c4

File tree

10 files changed

+1276
-15
lines changed

10 files changed

+1276
-15
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from mcp.server.stdio import stdio_server
5050
from mcp.server.streamable_http import EventStore
5151
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
52+
from mcp.server.transport_security import TransportSecuritySettings
5253
from mcp.shared.context import LifespanContextT, RequestContext
5354
from mcp.types import (
5455
AnyFunction,
@@ -119,6 +120,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
119120
) = Field(None, description="Lifespan context manager")
120121

121122
auth: AuthSettings | None = None
123+
124+
# Transport security settings (DNS rebinding protection)
125+
transport_security: TransportSecuritySettings | None = None
122126

123127

124128
def lifespan_wrapper(
@@ -670,6 +674,7 @@ def sse_app(self, mount_path: str | None = None) -> Starlette:
670674

671675
sse = SseServerTransport(
672676
normalized_message_endpoint,
677+
security_settings=self.settings.transport_security,
673678
)
674679

675680
async def handle_sse(scope: Scope, receive: Receive, send: Send):
@@ -777,6 +782,7 @@ def streamable_http_app(self) -> Starlette:
777782
event_store=self._event_store,
778783
json_response=self.settings.json_response,
779784
stateless=self.settings.stateless_http, # Use the stateless setting
785+
security_settings=self.settings.transport_security,
780786
)
781787

782788
# Create the ASGI handler

src/mcp/server/sse.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ async def handle_sse(request):
5252
from starlette.types import Receive, Scope, Send
5353

5454
import mcp.types as types
55+
from mcp.server.transport_security import (
56+
TransportSecurityMiddleware,
57+
TransportSecuritySettings,
58+
)
5559
from mcp.shared.message import SessionMessage
5660

5761
logger = logging.getLogger(__name__)
@@ -71,16 +75,24 @@ class SseServerTransport:
7175

7276
_endpoint: str
7377
_read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]]
78+
_security: TransportSecurityMiddleware
7479

75-
def __init__(self, endpoint: str) -> None:
80+
def __init__(
81+
self, endpoint: str, security_settings: TransportSecuritySettings | None = None
82+
) -> None:
7683
"""
7784
Creates a new SSE server transport, which will direct the client to POST
7885
messages to the relative or absolute URL given.
86+
87+
Args:
88+
endpoint: The relative or absolute URL for POST messages.
89+
security_settings: Optional security settings for DNS rebinding protection.
7990
"""
8091

8192
super().__init__()
8293
self._endpoint = endpoint
8394
self._read_stream_writers = {}
95+
self._security = TransportSecurityMiddleware(security_settings)
8496
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
8597

8698
@asynccontextmanager
@@ -89,6 +101,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
89101
logger.error("connect_sse received non-HTTP request")
90102
raise ValueError("connect_sse can only handle HTTP requests")
91103

104+
# Validate request headers for DNS rebinding protection
105+
request = Request(scope, receive)
106+
error_response = await self._security.validate_request(request, is_post=False)
107+
if error_response:
108+
await error_response(scope, receive, send)
109+
raise ValueError("Request validation failed")
110+
92111
logger.debug("Setting up SSE connection")
93112
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
94113
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
@@ -169,6 +188,11 @@ async def handle_post_message(
169188
) -> None:
170189
logger.debug("Handling POST message")
171190
request = Request(scope, receive)
191+
192+
# Validate request headers for DNS rebinding protection
193+
error_response = await self._security.validate_request(request, is_post=True)
194+
if error_response:
195+
return await error_response(scope, receive, send)
172196

173197
session_id_param = request.query_params.get("session_id")
174198
if session_id_param is None:

src/mcp/server/streamable_http.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from starlette.responses import Response
2525
from starlette.types import Receive, Scope, Send
2626

27+
from mcp.server.transport_security import (
28+
TransportSecurityMiddleware,
29+
TransportSecuritySettings,
30+
)
2731
from mcp.shared.message import ServerMessageMetadata, SessionMessage
2832
from mcp.types import (
2933
INTERNAL_ERROR,
@@ -131,12 +135,14 @@ class StreamableHTTPServerTransport:
131135
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
132136
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
133137
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
138+
_security: TransportSecurityMiddleware
134139

135140
def __init__(
136141
self,
137142
mcp_session_id: str | None,
138143
is_json_response_enabled: bool = False,
139144
event_store: EventStore | None = None,
145+
security_settings: TransportSecuritySettings | None = None,
140146
) -> None:
141147
"""
142148
Initialize a new StreamableHTTP server transport.
@@ -149,6 +155,7 @@ def __init__(
149155
event_store: Event store for resumability support. If provided,
150156
resumability will be enabled, allowing clients to
151157
reconnect and resume messages.
158+
security_settings: Optional security settings for DNS rebinding protection.
152159
153160
Raises:
154161
ValueError: If the session ID contains invalid characters.
@@ -163,6 +170,7 @@ def __init__(
163170
self.mcp_session_id = mcp_session_id
164171
self.is_json_response_enabled = is_json_response_enabled
165172
self._event_store = event_store
173+
self._security = TransportSecurityMiddleware(security_settings)
166174
self._request_streams: dict[
167175
RequestId,
168176
tuple[
@@ -260,6 +268,14 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
260268
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
261269
"""Application entry point that handles all HTTP requests"""
262270
request = Request(scope, receive)
271+
272+
# Validate request headers for DNS rebinding protection
273+
is_post = request.method == "POST"
274+
error_response = await self._security.validate_request(request, is_post=is_post)
275+
if error_response:
276+
await error_response(scope, receive, send)
277+
return
278+
263279
if self._terminated:
264280
# If the session has been terminated, return 404 Not Found
265281
response = self._create_error_response(

src/mcp/server/streamable_http_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
EventStore,
2323
StreamableHTTPServerTransport,
2424
)
25+
from mcp.server.transport_security import TransportSecuritySettings
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -60,11 +61,13 @@ def __init__(
6061
event_store: EventStore | None = None,
6162
json_response: bool = False,
6263
stateless: bool = False,
64+
security_settings: TransportSecuritySettings | None = None,
6365
):
6466
self.app = app
6567
self.event_store = event_store
6668
self.json_response = json_response
6769
self.stateless = stateless
70+
self.security_settings = security_settings
6871

6972
# Session tracking (only used if not stateless)
7073
self._session_creation_lock = anyio.Lock()
@@ -162,6 +165,7 @@ async def _handle_stateless_request(
162165
mcp_session_id=None, # No session tracking in stateless mode
163166
is_json_response_enabled=self.json_response,
164167
event_store=None, # No event store in stateless mode
168+
security_settings=self.security_settings,
165169
)
166170

167171
# Start server in a new task
@@ -222,6 +226,7 @@ async def _handle_stateful_request(
222226
mcp_session_id=new_session_id,
223227
is_json_response_enabled=self.json_response,
224228
event_store=self.event_store, # May be None (no resumability)
229+
security_settings=self.security_settings,
225230
)
226231

227232
assert http_transport.mcp_session_id is not None
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""DNS rebinding protection for MCP server transports."""
2+
3+
import logging
4+
5+
from pydantic import BaseModel, Field
6+
from starlette.requests import Request
7+
from starlette.responses import Response
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class TransportSecuritySettings(BaseModel):
13+
"""Settings for MCP transport security features.
14+
15+
These settings help protect against DNS rebinding attacks by validating
16+
incoming request headers.
17+
"""
18+
19+
enable_dns_rebinding_protection: bool = Field(
20+
default=True,
21+
description="Enable DNS rebinding protection (recommended for production)"
22+
)
23+
24+
allowed_hosts: list[str] = Field(
25+
default=[],
26+
description="List of allowed Host header values. If None, all hosts "
27+
"are allowed when protection is disabled, or only localhost/127.0.0.1 "
28+
"when enabled."
29+
)
30+
31+
allowed_origins: list[str] = Field(
32+
default=[],
33+
description="List of allowed Origin header values. If None, all "
34+
"origins are allowed when protection is disabled, or only localhost "
35+
"origins when enabled."
36+
)
37+
38+
39+
class TransportSecurityMiddleware:
40+
"""Middleware to enforce DNS rebinding protection for MCP transport endpoints."""
41+
42+
def __init__(self, settings: TransportSecuritySettings | None = None):
43+
# If not specified, disable DNS rebinding protection by default
44+
# for backwards compatibility
45+
self.settings = settings or TransportSecuritySettings(
46+
enable_dns_rebinding_protection=False
47+
)
48+
49+
def _validate_host(self, host: str | None) -> bool:
50+
"""Validate the Host header against allowed values."""
51+
if not self.settings.enable_dns_rebinding_protection:
52+
return True
53+
54+
if not host:
55+
logger.warning("Missing Host header in request")
56+
return False
57+
58+
# Check exact match first
59+
if host in self.settings.allowed_hosts:
60+
return True
61+
62+
# Check wildcard port patterns
63+
for allowed in self.settings.allowed_hosts:
64+
if allowed.endswith(":*"):
65+
# Extract base host from pattern
66+
base_host = allowed[:-2]
67+
# Check if the actual host starts with base host and has a port
68+
if host.startswith(base_host + ":"):
69+
return True
70+
71+
logger.warning(f"Invalid Host header: {host}")
72+
return False
73+
74+
def _validate_origin(self, origin: str | None) -> bool:
75+
"""Validate the Origin header against allowed values."""
76+
if not self.settings.enable_dns_rebinding_protection:
77+
return True
78+
79+
# Origin can be absent for same-origin requests
80+
if not origin:
81+
return True
82+
83+
# Check exact match first
84+
if origin in self.settings.allowed_origins:
85+
return True
86+
87+
# Check wildcard port patterns
88+
for allowed in self.settings.allowed_origins:
89+
if allowed.endswith(":*"):
90+
# Extract base origin from pattern
91+
base_origin = allowed[:-2]
92+
# Check if the actual origin starts with base origin and has a port
93+
if origin.startswith(base_origin + ":"):
94+
return True
95+
96+
logger.warning(f"Invalid Origin header: {origin}")
97+
return False
98+
99+
def _validate_content_type(self, content_type: str | None) -> bool:
100+
"""Validate the Content-Type header for POST requests."""
101+
if not content_type:
102+
logger.warning("Missing Content-Type header in POST request")
103+
return False
104+
105+
# Content-Type must start with application/json
106+
if not content_type.lower().startswith("application/json"):
107+
logger.warning(f"Invalid Content-Type header: {content_type}")
108+
return False
109+
110+
return True
111+
112+
async def validate_request(
113+
self, request: Request, is_post: bool = False
114+
) -> Response | None:
115+
"""Validate request headers for DNS rebinding protection.
116+
117+
Returns None if validation passes, or an error Response if validation fails.
118+
"""
119+
# Validate Host header
120+
host = request.headers.get("host")
121+
if not self._validate_host(host):
122+
return Response("Invalid Host header", status_code=400)
123+
124+
# Validate Origin header
125+
origin = request.headers.get("origin")
126+
if not self._validate_origin(origin):
127+
return Response("Invalid Origin header", status_code=400)
128+
129+
# Validate Content-Type for POST requests
130+
if is_post:
131+
content_type = request.headers.get("content-type")
132+
if not self._validate_content_type(content_type):
133+
return Response("Invalid Content-Type header", status_code=400)
134+
135+
return None

0 commit comments

Comments
 (0)