Skip to content

Commit aeab631

Browse files
committed
Clean up
1 parent 29a5e3a commit aeab631

File tree

4 files changed

+8
-187
lines changed

4 files changed

+8
-187
lines changed

src/mcp/server/transport_security.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,14 @@ class TransportSecuritySettings(BaseModel):
2323

2424
allowed_hosts: list[str] = Field(
2525
default=[],
26-
description="List of allowed Host header values. If None, all hosts "
27-
"are allowed when protection is disabled, or only localhost/127.0.0.1 "
28-
"when enabled."
26+
description="List of allowed Host header values. Only applies when " +
27+
"enable_dns_rebinding_protection is True."
2928
)
3029

3130
allowed_origins: list[str] = Field(
3231
default=[],
33-
description="List of allowed Origin header values. If None, all "
34-
"origins are allowed when protection is disabled, or only localhost "
35-
"origins when enabled."
32+
description="List of allowed Origin header values. Only applies when " +
33+
"enable_dns_rebinding_protection is True."
3634
)
3735

3836

tests/server/fastmcp/test_integration.py

Lines changed: 0 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -475,179 +475,6 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
475475
assert tool_result.content[0].text == "Echo: hello"
476476

477477

478-
def make_fastmcp_with_context_app():
479-
"""Create a FastMCP server that can access request context."""
480-
from mcp.server.transport_security import TransportSecuritySettings
481-
482-
transport_security = TransportSecuritySettings(
483-
allowed_hosts=["127.0.0.1:*", "localhost:*"],
484-
allowed_origins=["http://127.0.0.1:*", "http://localhost:*"]
485-
)
486-
mcp = FastMCP(name="ContextServer", transport_security=transport_security)
487-
488-
# Tool that echoes request headers
489-
@mcp.tool(description="Echo request headers from context")
490-
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
491-
"""Returns the request headers as JSON."""
492-
headers_info = {}
493-
if ctx.request_context.request:
494-
# Now the type system knows request is a Starlette Request object
495-
headers_info = dict(ctx.request_context.request.headers)
496-
return json.dumps(headers_info)
497-
498-
# Tool that returns full request context
499-
@mcp.tool(description="Echo request context with custom data")
500-
def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str:
501-
"""Returns request context including headers and custom data."""
502-
context_data = {
503-
"custom_request_id": custom_request_id,
504-
"headers": {},
505-
"method": None,
506-
"path": None,
507-
}
508-
if ctx.request_context.request:
509-
request = ctx.request_context.request
510-
context_data["headers"] = dict(request.headers)
511-
context_data["method"] = request.method
512-
context_data["path"] = request.url.path
513-
return json.dumps(context_data)
514-
515-
# Create the SSE app
516-
app = mcp.sse_app()
517-
return mcp, app
518-
519-
520-
def run_context_server(server_port: int) -> None:
521-
"""Run the context-aware FastMCP server."""
522-
_, app = make_fastmcp_with_context_app()
523-
server = uvicorn.Server(
524-
config=uvicorn.Config(
525-
app=app, host="127.0.0.1", port=server_port, log_level="error"
526-
)
527-
)
528-
print(f"Starting context server on port {server_port}")
529-
server.run()
530-
531-
532-
@pytest.fixture()
533-
def context_aware_server(server_port: int) -> Generator[None, None, None]:
534-
"""Start the context-aware server in a separate process."""
535-
proc = multiprocessing.Process(
536-
target=run_context_server, args=(server_port,), daemon=True
537-
)
538-
print("Starting context-aware server process")
539-
proc.start()
540-
541-
# Wait for server to be running
542-
max_attempts = 20
543-
attempt = 0
544-
print("Waiting for context-aware server to start")
545-
while attempt < max_attempts:
546-
try:
547-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
548-
s.connect(("127.0.0.1", server_port))
549-
break
550-
except ConnectionRefusedError:
551-
time.sleep(0.1)
552-
attempt += 1
553-
else:
554-
raise RuntimeError(
555-
f"Context server failed to start after {max_attempts} attempts"
556-
)
557-
558-
yield
559-
560-
print("Killing context-aware server")
561-
proc.kill()
562-
proc.join(timeout=2)
563-
if proc.is_alive():
564-
print("Context server process failed to terminate")
565-
566-
567-
@pytest.mark.anyio
568-
async def test_fast_mcp_with_request_context(
569-
context_aware_server: None, server_url: str
570-
) -> None:
571-
"""Test that FastMCP properly propagates request context to tools."""
572-
# Test with custom headers
573-
custom_headers = {
574-
"Authorization": "Bearer fastmcp-test-token",
575-
"X-Custom-Header": "fastmcp-value",
576-
"X-Request-Id": "req-123",
577-
}
578-
579-
async with sse_client(server_url + "/sse", headers=custom_headers) as streams:
580-
async with ClientSession(*streams) as session:
581-
# Initialize the session
582-
result = await session.initialize()
583-
assert isinstance(result, InitializeResult)
584-
assert result.serverInfo.name == "ContextServer"
585-
586-
# Test 1: Call tool that echoes headers
587-
headers_result = await session.call_tool("echo_headers", {})
588-
assert len(headers_result.content) == 1
589-
assert isinstance(headers_result.content[0], TextContent)
590-
591-
headers_data = json.loads(headers_result.content[0].text)
592-
assert headers_data.get("authorization") == "Bearer fastmcp-test-token"
593-
assert headers_data.get("x-custom-header") == "fastmcp-value"
594-
assert headers_data.get("x-request-id") == "req-123"
595-
596-
# Test 2: Call tool that returns full context
597-
context_result = await session.call_tool(
598-
"echo_context", {"custom_request_id": "test-123"}
599-
)
600-
assert len(context_result.content) == 1
601-
assert isinstance(context_result.content[0], TextContent)
602-
603-
context_data = json.loads(context_result.content[0].text)
604-
assert context_data["custom_request_id"] == "test-123"
605-
assert (
606-
context_data["headers"].get("authorization")
607-
== "Bearer fastmcp-test-token"
608-
)
609-
assert context_data["method"] == "POST" #
610-
611-
612-
@pytest.mark.anyio
613-
async def test_fast_mcp_request_context_isolation(
614-
context_aware_server: None, server_url: str
615-
) -> None:
616-
"""Test that request contexts are isolated between different FastMCP clients."""
617-
contexts = []
618-
619-
# Create multiple clients with different headers
620-
for i in range(3):
621-
headers = {
622-
"Authorization": f"Bearer token-{i}",
623-
"X-Request-Id": f"fastmcp-req-{i}",
624-
"X-Custom-Value": f"value-{i}",
625-
}
626-
627-
async with sse_client(server_url + "/sse", headers=headers) as streams:
628-
async with ClientSession(*streams) as session:
629-
await session.initialize()
630-
631-
# Call the tool that returns context
632-
tool_result = await session.call_tool(
633-
"echo_context", {"custom_request_id": f"test-req-{i}"}
634-
)
635-
636-
# Parse and store the result
637-
assert len(tool_result.content) == 1
638-
assert isinstance(tool_result.content[0], TextContent)
639-
context_data = json.loads(tool_result.content[0].text)
640-
contexts.append(context_data)
641-
642-
# Verify each request had its own isolated context
643-
assert len(contexts) == 3
644-
for i, ctx in enumerate(contexts):
645-
assert ctx["custom_request_id"] == f"test-req-{i}"
646-
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"
647-
assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}"
648-
assert ctx["headers"].get("x-custom-value") == f"value-{i}"
649-
650-
651478
@pytest.mark.anyio
652479
async def test_fastmcp_streamable_http(
653480
streamable_http_server: None, http_server_url: str

tests/shared/test_sse.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
8383
# Test fixtures
8484
def make_server_app() -> Starlette:
8585
"""Create test Starlette app with SSE transport"""
86-
from mcp.server.transport_security import TransportSecuritySettings
8786
# Configure security with allowed hosts/origins for testing
8887
security_settings = TransportSecuritySettings(
8988
allowed_hosts=["127.0.0.1:*", "localhost:*"],

tests/shared/test_streamable_http.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
StreamId,
3737
)
3838
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
39+
from mcp.server.transport_security import TransportSecuritySettings
3940
from mcp.shared.context import RequestContext
4041
from mcp.shared.exceptions import McpError
4142
from mcp.shared.message import (
@@ -227,7 +228,6 @@ def create_app(
227228
server = ServerTest()
228229

229230
# Create the session manager
230-
from mcp.server.transport_security import TransportSecuritySettings
231231
security_settings = TransportSecuritySettings(
232232
allowed_hosts=["127.0.0.1:*", "localhost:*"],
233233
allowed_origins=["http://127.0.0.1:*", "http://localhost:*"]
@@ -446,12 +446,9 @@ def test_content_type_validation(basic_server, basic_server_url):
446446
},
447447
data="This is not JSON",
448448
)
449-
# May return 400 (security middleware) or 415 (transport validation)
450-
assert response.status_code in (400, 415)
451-
assert any(
452-
msg in response.text
453-
for msg in ["Invalid Content-Type", "Unsupported Media Type"]
454-
)
449+
450+
assert response.status_code == 400
451+
assert "Invalid Content-Type" in response.text
455452

456453

457454
def test_json_validation(basic_server, basic_server_url):

0 commit comments

Comments
 (0)