Skip to content

Commit cdbdeb3

Browse files
committed
Add example proxy OAuth server implementation
Signed-off-by: Jesse Sanford <108698+jessesanford@users.noreply.github.com>
1 parent 85039f8 commit cdbdeb3

File tree

7 files changed

+895
-0
lines changed

7 files changed

+895
-0
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# pyright: reportMissingImports=false
2+
import argparse
3+
import asyncio
4+
import logging
5+
import os
6+
import time
7+
8+
from dotenv import load_dotenv # type: ignore
9+
from mcp.server.auth.provider import AccessToken
10+
from mcp.server.auth.providers.transparent_proxy import (
11+
ProxySettings, # type: ignore
12+
ProxyTokenHandler,
13+
TransparentOAuthProxyProvider,
14+
)
15+
from mcp.server.auth.routes import cors_middleware, create_auth_routes
16+
from mcp.server.auth.settings import ClientRegistrationOptions
17+
from pydantic import AnyHttpUrl
18+
from starlette.applications import Starlette
19+
from starlette.requests import Request # type: ignore
20+
from starlette.responses import JSONResponse, Response
21+
from starlette.routing import Route
22+
from uvicorn import Config, Server
23+
24+
# Load environment variables from .env if present
25+
load_dotenv()
26+
27+
# Configure logging after .env so LOG_LEVEL can come from environment
28+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
29+
30+
logging.basicConfig(
31+
level=LOG_LEVEL,
32+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
33+
datefmt="%Y-%m-%d %H:%M:%S",
34+
)
35+
36+
# Dedicated logger for this server module
37+
logger = logging.getLogger("proxy_auth.auth_server")
38+
39+
# Suppress noisy INFO messages from the FastMCP low-level server unless we are
40+
# explicitly running in DEBUG mode. These logs (e.g. "Processing request of type
41+
# ListToolsRequest") are helpful for debugging but clutter normal output.
42+
43+
_mcp_lowlevel_logger = logging.getLogger("mcp.server.lowlevel.server")
44+
if LOG_LEVEL == "DEBUG":
45+
# In full debug mode, allow the library to emit its detailed logs
46+
_mcp_lowlevel_logger.setLevel(logging.DEBUG)
47+
else:
48+
# Otherwise, only warnings and above
49+
_mcp_lowlevel_logger.setLevel(logging.WARNING)
50+
51+
# ----------------------------------------------------------------------------
52+
# Environment configuration
53+
# ----------------------------------------------------------------------------
54+
# Load and validate settings from the environment (uses .env automatically)
55+
settings = ProxySettings.load()
56+
57+
# Upstream endpoints (fully-qualified URLs)
58+
UPSTREAM_AUTHORIZE: str = str(settings.upstream_authorize)
59+
UPSTREAM_TOKEN: str = str(settings.upstream_token)
60+
UPSTREAM_JWKS_URI = settings.jwks_uri
61+
# Derive base URL from the authorize endpoint for convenience / tests
62+
UPSTREAM_BASE: str = UPSTREAM_AUTHORIZE.rsplit("/", 1)[0]
63+
64+
# Client credentials & defaults
65+
CLIENT_ID: str = settings.client_id or "demo-client-id"
66+
CLIENT_SECRET = settings.client_secret
67+
DEFAULT_SCOPE: str = settings.default_scope
68+
69+
# Metadata URL (only used if we need to fetch from upstream)
70+
UPSTREAM_METADATA = f"{UPSTREAM_BASE}/.well-known/oauth-authorization-server"
71+
72+
# ---------------------------------------------------------------------------
73+
# Logging helpers
74+
# ---------------------------------------------------------------------------
75+
76+
77+
def _mask_secret(secret: str | None) -> str | None: # noqa: D401
78+
"""Return a masked version of the given secret.
79+
80+
The first and last four characters are preserved (if available) and the
81+
middle section is replaced by asterisks. If the secret is shorter than
82+
eight characters, the entire value is replaced by ``*``.
83+
"""
84+
85+
if not secret:
86+
return None
87+
88+
if len(secret) <= 8:
89+
return "*" * len(secret)
90+
91+
return f"{secret[:4]}{'*' * (len(secret) - 8)}{secret[-4:]}"
92+
93+
94+
# Consolidated configuration (with sensitive data redacted)
95+
_masked_settings = settings.model_dump(exclude_none=True).copy()
96+
97+
if "client_secret" in _masked_settings:
98+
_masked_settings["client_secret"] = _mask_secret(_masked_settings["client_secret"])
99+
100+
# Log configuration at *debug* level only so it can be enabled when needed
101+
logger.debug("[Auth Proxy Config] %s", _masked_settings)
102+
103+
# Server host/port
104+
AUTH_SERVER_PORT = int(os.getenv("AUTH_SERVER_PORT", "9000"))
105+
AUTH_SERVER_HOST = os.getenv("AUTH_SERVER_HOST", "localhost")
106+
AUTH_SERVER_URL = os.getenv(
107+
"AUTH_SERVER_URL", f"http://{AUTH_SERVER_HOST}:{AUTH_SERVER_PORT}"
108+
)
109+
110+
# ----------------------------------------------------------------------------
111+
# Auth Server
112+
# ----------------------------------------------------------------------------
113+
114+
# Create auth provider
115+
oauth_provider = TransparentOAuthProxyProvider(settings=settings)
116+
117+
# Enable client registration
118+
client_registration_options = ClientRegistrationOptions(
119+
enabled=True,
120+
valid_scopes=["openid"],
121+
default_scopes=["openid"],
122+
)
123+
124+
# Create auth routes
125+
routes = create_auth_routes(
126+
provider=oauth_provider,
127+
issuer_url=AnyHttpUrl(AUTH_SERVER_URL),
128+
service_documentation_url=None,
129+
client_registration_options=client_registration_options,
130+
revocation_options=None,
131+
)
132+
133+
# Add token endpoint handler
134+
# We need to replace any existing token endpoint route
135+
routes = [r for r in routes if not (hasattr(r, "path") and r.path == "/token")]
136+
137+
# Create token handler and add it to routes
138+
proxy_token_handler = ProxyTokenHandler(oauth_provider)
139+
routes.append(Route("/token", endpoint=proxy_token_handler.handle, methods=["POST"]))
140+
141+
142+
# Add token introspection endpoint for Resource Servers
143+
async def introspect_handler(request: Request) -> Response:
144+
"""
145+
Token introspection endpoint for Resource Servers.
146+
147+
Resource Servers call this endpoint to validate tokens without
148+
needing direct access to token storage.
149+
"""
150+
form = await request.form()
151+
token = form.get("token")
152+
if not token or not isinstance(token, str):
153+
return JSONResponse({"active": False}, status_code=400)
154+
155+
# For the transparent proxy, we don't actually validate tokens
156+
# Just create a dummy AccessToken like the provider does
157+
access_token = AccessToken(
158+
token=token, client_id=str(CLIENT_ID), scopes=[DEFAULT_SCOPE], expires_at=None
159+
)
160+
161+
return JSONResponse(
162+
{
163+
"active": True,
164+
"client_id": access_token.client_id,
165+
"scope": " ".join(access_token.scopes),
166+
"exp": access_token.expires_at,
167+
"iat": int(time.time()),
168+
"token_type": "Bearer",
169+
"aud": access_token.resource, # RFC 8707 audience claim
170+
}
171+
)
172+
173+
174+
routes.append(
175+
Route(
176+
"/introspect",
177+
endpoint=cors_middleware(introspect_handler, ["POST", "OPTIONS"]),
178+
methods=["POST", "OPTIONS"],
179+
)
180+
)
181+
182+
# Create Starlette app with routes
183+
auth_app = Starlette(routes=routes)
184+
185+
186+
async def run_server(host: str = AUTH_SERVER_HOST, port: int = AUTH_SERVER_PORT):
187+
"""Run the Authorization Server."""
188+
config = Config(
189+
auth_app,
190+
host=host,
191+
port=port,
192+
log_level="info",
193+
)
194+
server = Server(config)
195+
196+
logger.info(f"🚀 MCP Authorization Server running on http://{host}:{port}")
197+
198+
await server.serve()
199+
200+
201+
def main():
202+
"""Command-line entry point for the Authorization Server."""
203+
parser = argparse.ArgumentParser(description="MCP OAuth Proxy Authorization Server")
204+
parser.add_argument(
205+
"--host",
206+
default=None,
207+
help="Host to bind to (overrides AUTH_SERVER_HOST env var)",
208+
)
209+
parser.add_argument(
210+
"--port",
211+
type=int,
212+
default=None,
213+
help="Port to bind to (overrides AUTH_SERVER_PORT env var)",
214+
)
215+
216+
args = parser.parse_args()
217+
218+
# Use command-line arguments only if provided, otherwise use environment variables
219+
host = args.host or AUTH_SERVER_HOST
220+
port = args.port or AUTH_SERVER_PORT
221+
222+
# Log the configuration being used
223+
logger.info(f"Starting Authorization Server with host={host}, port={port}")
224+
logger.info("Using environment variables from .env file if present")
225+
226+
asyncio.run(run_server(host=host, port=port))
227+
228+
229+
if __name__ == "__main__":
230+
main()

0 commit comments

Comments
 (0)