Skip to content

Commit caaef7c

Browse files
committed
split resource server and auth server
Signed-off-by: Jesse Sanford <108698+jessesanford@users.noreply.github.com>
1 parent cb646af commit caaef7c

File tree

2 files changed

+322
-0
lines changed

2 files changed

+322
-0
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# pyright: reportMissingImports=false
2+
import logging
3+
import os
4+
import time
5+
6+
from dotenv import load_dotenv # type: ignore
7+
from mcp.server.auth.provider import AccessToken, OAuthToken
8+
from mcp.server.auth.providers.transparent_proxy import (
9+
ProxySettings, # type: ignore
10+
TransparentOAuthProxyProvider,
11+
ProxyTokenHandler,
12+
)
13+
from mcp.server.auth.routes import cors_middleware, create_auth_routes
14+
from mcp.server.auth.settings import ClientRegistrationOptions
15+
from pydantic import AnyHttpUrl
16+
from starlette.applications import Starlette
17+
from starlette.requests import Request # type: ignore
18+
from starlette.responses import JSONResponse, Response
19+
from starlette.routing import Route
20+
from uvicorn import Config, Server
21+
22+
# Load environment variables from .env if present
23+
load_dotenv()
24+
25+
# Configure logging after .env so LOG_LEVEL can come from environment
26+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
27+
28+
logging.basicConfig(
29+
level=LOG_LEVEL,
30+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
31+
datefmt="%Y-%m-%d %H:%M:%S",
32+
)
33+
34+
# Dedicated logger for this server module
35+
logger = logging.getLogger("proxy_oauth.auth_server")
36+
37+
# Suppress noisy INFO messages from the FastMCP low-level server unless we are
38+
# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type
39+
# ListToolsRequest") are helpful for debugging but clutter normal output.
40+
41+
_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server")
42+
if LOG_LEVEL == "DEBUG":
43+
# In full debug mode, allow the library to emit its detailed logs
44+
_mcp_lowlevel_logger.setLevel(logging.DEBUG)
45+
else:
46+
# Otherwise, only warnings and above
47+
_mcp_lowlevel_logger.setLevel(logging.WARNING)
48+
49+
# ----------------------------------------------------------------------------
50+
# Environment configuration
51+
# ----------------------------------------------------------------------------
52+
# Load and validate settings from the environment (uses .env automatically)
53+
settings = ProxySettings.load()
54+
55+
# Upstream endpoints (fully-qualified URLs)
56+
UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize)
57+
UPSTREAM_TOKEN: str = str(settings.upstream_token)
58+
UPSTREAM_JWKS_URI = settings.jwks_uri
59+
# Derive base URL from the authorize endpoint for convenience / tests
60+
UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0]
61+
62+
# Client credentials & defaults
63+
CLIENT_ID: str = settings.client_id or "demo-client-id"
64+
CLIENT_SECRET = settings.client_secret
65+
DEFAULT_SCOPE: str = settings.default_scope
66+
67+
# Optional audience passthrough (not part of ProxySettings yet)
68+
AUDIENCE = os.getenv("PROXY_AUDIENCE")
69+
70+
# Metadata URL (only used if we need to fetch from upstream)
71+
UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server"
72+
73+
# ---------------------------------------------------------------------------
74+
# Logging helpers
75+
# ---------------------------------------------------------------------------
76+
77+
78+
def _mask_secret(secret: str | None) -> str | None: # noqa: D401
79+
"""Return a masked version of the given secret.
80+
81+
The first and last four characters are preserved (if available) and the
82+
middle section is replaced by asterisks. If the secret is shorter than
83+
eight characters, the entire value is replaced by ``*``.
84+
"""
85+
86+
if not secret:
87+
return None
88+
89+
if len(secret) <= 8:
90+
return "*" * len(secret)
91+
92+
return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}"
93+
94+
95+
# Consolidated configuration (with sensitive data redacted)
96+
_masked_settings = settings.model_dump(exclude_none=True).copy()
97+
98+
if "client_secret" in _masked_settings:
99+
_masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"])
100+
101+
# Log configuration at *debug* level only so it can be enabled when needed
102+
logger.debug("[Auth Proxy Config] %s", _masked_settings)
103+
104+
# Server host/port
105+
AUTH_SERVER_PORT = int(os.getenv("AUTH_SERVER_PORT", "9000"))
106+
AUTH_SERVER_HOST = os.getenv("AUTH_SERVER_HOST", "localhost")
107+
AUTH_SERVER_URL = os.getenv(
108+
"AUTH_SERVER_URL", f"http://{AUTH_SERVER_HOST}:{AUTH_SERVER_PORT}"
109+
)
110+
111+
# ----------------------------------------------------------------------------
112+
# Auth Server
113+
# ----------------------------------------------------------------------------
114+
115+
# Create auth provider
116+
oauth_provider = TransparentOAuthProxyProvider(settings=settings)
117+
118+
# Enable client registration
119+
client_registration_options = ClientRegistrationOptions(
120+
enabled=True,
121+
valid_scopes=["openid"],
122+
default_scopes=["openid"],
123+
)
124+
125+
# Create auth routes
126+
routes = create_auth_routes(
127+
provider=oauth_provider,
128+
issuer_url=AnyHttpUrl(AUTH_SERVER_URL),
129+
service_documentation_url=None,
130+
client_registration_options=client_registration_options,
131+
revocation_options=None,
132+
)
133+
134+
# Add token endpoint handler
135+
# We need to replace any existing token endpoint route
136+
routes = [r for r in routes if not (hasattr(r, "path") and r.path == "/token")]
137+
138+
# Create token handler and add it to routes
139+
proxy_token_handler = ProxyTokenHandler(oauth_provider)
140+
routes.append(Route("/token", endpoint=proxy_token_handler.handle, methods=["POST"]))
141+
142+
# Add token introspection endpoint for Resource Servers
143+
async def introspect_handler(request: Request) -> Response:
144+
"""
145+
Token introspection endpoint for Resource Servers.
146+
147+
Resource Servers call this endpoint to validate tokens without
148+
needing direct access to token storage.
149+
"""
150+
form = await request.form()
151+
token = form.get("token")
152+
if not token or not isinstance(token, str):
153+
return JSONResponse({"active": False}, status_code=400)
154+
155+
# For the transparent proxy, we don't actually validate tokens
156+
# Just create a dummy AccessToken like the provider does
157+
access_token = AccessToken(
158+
token=token, client_id=str(CLIENT_ID), scopes=[DEFAULT_SCOPE], expires_at=None
159+
)
160+
161+
return JSONResponse(
162+
{
163+
"active": True,
164+
"client_id": access_token.client_id,
165+
"scope": " ".join(access_token.scopes),
166+
"exp": access_token.expires_at,
167+
"iat": int(time.time()),
168+
"token_type": "Bearer",
169+
"aud": access_token.resource, # RFC 8707 audience claim
170+
}
171+
)
172+
173+
174+
routes.append(
175+
Route(
176+
"/introspect",
177+
endpoint=cors_middleware(introspect_handler, ["POST", "OPTIONS"]),
178+
methods=["POST", "OPTIONS"],
179+
)
180+
)
181+
182+
# Create Starlette app with routes
183+
auth_app = Starlette(routes=routes)
184+
185+
186+
async def run_server():
187+
"""Run the Authorization Server."""
188+
config = Config(
189+
auth_app,
190+
host=AUTH_SERVER_HOST,
191+
port=AUTH_SERVER_PORT,
192+
log_level="info",
193+
)
194+
server = Server(config)
195+
196+
logger.info(f"🚀 MCP Authorization Server running on {AUTH_SERVER_URL}")
197+
198+
await server.serve()
199+
200+
201+
if __name__ == "__main__":
202+
import asyncio
203+
204+
asyncio.run(run_server())
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Example token verifier implementation using OAuth 2.0 Token Introspection."""
2+
3+
import logging
4+
from typing import Any
5+
6+
from mcp.server.auth.provider import AccessToken, TokenVerifier
7+
from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class IntrospectionTokenVerifier(TokenVerifier):
13+
"""Example token verifier that uses OAuth 2.0 Token Introspection (RFC 7662).
14+
15+
This is a simple example implementation for demonstration purposes.
16+
Production implementations should consider:
17+
- Connection pooling and reuse
18+
- More sophisticated error handling
19+
- Rate limiting and retry logic
20+
- Comprehensive configuration options
21+
"""
22+
23+
def __init__(
24+
self,
25+
introspection_endpoint: str,
26+
server_url: str,
27+
validate_resource: bool = False,
28+
):
29+
self.introspection_endpoint = introspection_endpoint
30+
self.server_url = server_url
31+
self.validate_resource = validate_resource
32+
self.resource_url = resource_url_from_server_url(server_url)
33+
34+
async def verify_token(self, token: str) -> AccessToken | None:
35+
"""Verify token via introspection endpoint."""
36+
import httpx
37+
38+
# Validate URL to prevent SSRF attacks
39+
if not self.introspection_endpoint.startswith(
40+
("https://", "http://localhost", "http://127.0.0.1")
41+
):
42+
logger.warning(
43+
f"Rejecting introspection endpoint with unsafe scheme: "
44+
f"{self.introspection_endpoint}"
45+
)
46+
return None
47+
48+
# Configure secure HTTP client
49+
timeout = httpx.Timeout(10.0, connect=5.0)
50+
limits = httpx.Limits(max_connections=10, max_keepalive_connections=5)
51+
52+
async with httpx.AsyncClient(
53+
timeout=timeout,
54+
limits=limits,
55+
verify=True, # Enforce SSL verification
56+
) as client:
57+
try:
58+
response = await client.post(
59+
self.introspection_endpoint,
60+
data={"token": token},
61+
headers={"Content-Type": "application/x-www-form-urlencoded"},
62+
)
63+
64+
if response.status_code != 200:
65+
logger.debug(
66+
f"Token introspection returned status {response.status_code}"
67+
)
68+
return None
69+
70+
data = response.json()
71+
if not data.get("active", False):
72+
return None
73+
74+
# RFC 8707 resource validation (only when --oauth-strict is set)
75+
if self.validate_resource and not self._validate_resource(data):
76+
logger.warning(
77+
f"Token resource validation failed. Expected: "
78+
f"{self.resource_url}"
79+
)
80+
return None
81+
82+
return AccessToken(
83+
token=token,
84+
client_id=data.get("client_id", "unknown"),
85+
scopes=data.get("scope", "").split() if data.get("scope") else [],
86+
expires_at=data.get("exp"),
87+
resource=data.get("aud"), # Include resource in token
88+
)
89+
except Exception as e:
90+
logger.warning(f"Token introspection failed: {e}")
91+
return None
92+
93+
def _validate_resource(self, token_data: dict[str, Any]) -> bool:
94+
"""Validate token was issued for this resource server."""
95+
if not self.server_url or not self.resource_url:
96+
return False # Fail if strict validation requested but URLs missing
97+
98+
# Check 'aud' claim first (standard JWT audience)
99+
aud = token_data.get("aud")
100+
if isinstance(aud, list):
101+
for audience in aud:
102+
if self._is_valid_resource(audience):
103+
return True
104+
return False
105+
elif aud:
106+
return self._is_valid_resource(aud)
107+
108+
# No resource binding - invalid per RFC 8707
109+
return False
110+
111+
def _is_valid_resource(self, resource: str) -> bool:
112+
"""Check if resource matches this server using hierarchical matching."""
113+
if not self.resource_url:
114+
return False
115+
116+
return check_resource_allowed(
117+
requested_resource=self.resource_url, configured_resource=resource
118+
)

0 commit comments

Comments
 (0)