Skip to content

Commit d79be8f

Browse files
committed
Fixups while integrating new auth capabilities
1 parent 9fee929 commit d79be8f

File tree

5 files changed

+91
-66
lines changed

5 files changed

+91
-66
lines changed

src/mcp/server/auth/provider.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Protocol
1+
from typing import Generic, Protocol, TypeVar
22
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
33

44
from pydantic import AnyHttpUrl, BaseModel
@@ -62,7 +62,16 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None
6262
...
6363

6464

65-
class OAuthServerProvider(Protocol):
65+
# NOTE: FastMCP doesn't render any of these types in the user response, so it's
66+
# OK to add fields to subclasses which should not be exposed externally.
67+
AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode)
68+
RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
69+
AuthInfoT = TypeVar("AuthInfoT", bound=AuthInfo)
70+
71+
72+
class OAuthServerProvider(
73+
Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AuthInfoT]
74+
):
6675
@property
6776
def clients_store(self) -> OAuthRegisteredClientsStore:
6877
"""
@@ -107,7 +116,7 @@ async def authorize(
107116

108117
async def load_authorization_code(
109118
self, client: OAuthClientInformationFull, authorization_code: str
110-
) -> AuthorizationCode | None:
119+
) -> AuthorizationCodeT | None:
111120
"""
112121
Loads metadata for the authorization code challenge.
113122
@@ -121,7 +130,7 @@ async def load_authorization_code(
121130
...
122131

123132
async def exchange_authorization_code(
124-
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
133+
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT
125134
) -> OAuthToken:
126135
"""
127136
Exchanges an authorization code for an access token.
@@ -137,12 +146,12 @@ async def exchange_authorization_code(
137146

138147
async def load_refresh_token(
139148
self, client: OAuthClientInformationFull, refresh_token: str
140-
) -> RefreshToken | None: ...
149+
) -> RefreshTokenT | None: ...
141150

142151
async def exchange_refresh_token(
143152
self,
144153
client: OAuthClientInformationFull,
145-
refresh_token: RefreshToken,
154+
refresh_token: RefreshTokenT,
146155
scopes: list[str],
147156
) -> OAuthToken:
148157
"""
@@ -158,7 +167,7 @@ async def exchange_refresh_token(
158167
"""
159168
...
160169

161-
async def load_access_token(self, token: str) -> AuthInfo | None:
170+
async def load_access_token(self, token: str) -> AuthInfoT | None:
162171
"""
163172
Verifies an access token and returns information about it.
164173
@@ -172,7 +181,7 @@ async def load_access_token(self, token: str) -> AuthInfo | None:
172181

173182
async def revoke_token(
174183
self,
175-
token: AuthInfo | RefreshToken,
184+
token: AuthInfoT | RefreshTokenT,
176185
) -> None:
177186
"""
178187
Revokes an access or refresh token.

src/mcp/server/auth/router.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Callable
33

44
from pydantic import AnyHttpUrl
5-
from starlette.routing import Route, Router
5+
from starlette.routing import Route
66

77
from mcp.server.auth.handlers.authorize import AuthorizationHandler
88
from mcp.server.auth.handlers.metadata import MetadataHandler
@@ -57,27 +57,13 @@ def validate_issuer_url(url: AnyHttpUrl):
5757
REVOCATION_PATH = "/revoke"
5858

5959

60-
def create_auth_router(
60+
def create_auth_routes(
6161
provider: OAuthServerProvider,
6262
issuer_url: AnyHttpUrl,
6363
service_documentation_url: AnyHttpUrl | None = None,
6464
client_registration_options: ClientRegistrationOptions | None = None,
6565
revocation_options: RevocationOptions | None = None,
66-
) -> Router:
67-
"""
68-
Create a Starlette router with standard MCP authorization endpoints.
69-
70-
Args:
71-
provider: OAuth server provider
72-
issuer_url: Issuer URL for the authorization server
73-
service_documentation_url: Optional URL for service documentation
74-
client_registration_options: Options for client registration
75-
revocation_options: Options for token revocation
76-
77-
Returns:
78-
Starlette router with authorization endpoints
79-
"""
80-
66+
) -> list[Route]:
8167
validate_issuer_url(issuer_url)
8268

8369
client_registration_options = (
@@ -93,32 +79,30 @@ def create_auth_router(
9379
client_authenticator = ClientAuthenticator(provider.clients_store)
9480

9581
# Create routes
96-
auth_router = Router(
97-
routes=[
98-
Route(
99-
"/.well-known/oauth-authorization-server",
100-
endpoint=MetadataHandler(metadata).handle,
101-
methods=["GET"],
102-
),
103-
Route(
104-
AUTHORIZATION_PATH,
105-
endpoint=AuthorizationHandler(provider).handle,
106-
methods=["GET", "POST"],
107-
),
108-
Route(
109-
TOKEN_PATH,
110-
endpoint=TokenHandler(provider, client_authenticator).handle,
111-
methods=["POST"],
112-
),
113-
]
114-
)
82+
routes = [
83+
Route(
84+
"/.well-known/oauth-authorization-server",
85+
endpoint=MetadataHandler(metadata).handle,
86+
methods=["GET"],
87+
),
88+
Route(
89+
AUTHORIZATION_PATH,
90+
endpoint=AuthorizationHandler(provider).handle,
91+
methods=["GET", "POST"],
92+
),
93+
Route(
94+
TOKEN_PATH,
95+
endpoint=TokenHandler(provider, client_authenticator).handle,
96+
methods=["POST"],
97+
),
98+
]
11599

116100
if client_registration_options.enabled:
117101
registration_handler = RegistrationHandler(
118102
provider.clients_store,
119103
client_secret_expiry_seconds=client_registration_options.client_secret_expiry_seconds,
120104
)
121-
auth_router.routes.append(
105+
routes.append(
122106
Route(
123107
REGISTRATION_PATH,
124108
endpoint=registration_handler.handle,
@@ -128,11 +112,11 @@ def create_auth_router(
128112

129113
if revocation_options.enabled:
130114
revocation_handler = RevocationHandler(provider, client_authenticator)
131-
auth_router.routes.append(
115+
routes.append(
132116
Route(REVOCATION_PATH, endpoint=revocation_handler.handle, methods=["POST"])
133117
)
134118

135-
return auth_router
119+
return routes
136120

137121

138122
def modify_url_path(url: AnyHttpUrl, path_mapper: Callable[[str], str]) -> AnyHttpUrl:

src/mcp/server/fastmcp/server.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
asynccontextmanager,
1212
)
1313
from itertools import chain
14-
from typing import Any, Callable, Generic, Literal, Sequence
14+
from typing import Any, Awaitable, Callable, Generic, Literal, Sequence
1515

1616
import anyio
1717
import pydantic_core
@@ -24,6 +24,7 @@
2424
from starlette.authentication import requires
2525
from starlette.middleware.authentication import AuthenticationMiddleware
2626

27+
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
2728
from mcp.server.auth.middleware.bearer_auth import (
2829
BearerAuthBackend,
2930
RequireAuthMiddleware,
@@ -151,6 +152,7 @@ def __init__(
151152
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
152153
)
153154
self._auth_provider = auth_provider
155+
self._custom_starlette_routes = []
154156
self.dependencies = self.settings.dependencies
155157

156158
# Set up MCP protocol handlers
@@ -477,6 +479,33 @@ def decorator(func: AnyFunction) -> AnyFunction:
477479

478480
return decorator
479481

482+
def custom_route(
483+
self,
484+
path: str,
485+
methods: list[str],
486+
name: str | None = None,
487+
include_in_schema: bool = True,
488+
):
489+
from starlette.requests import Request
490+
from starlette.responses import Response
491+
from starlette.routing import Route
492+
493+
def decorator(
494+
func: Callable[[Request], Awaitable[Response]],
495+
) -> Callable[[Request], Awaitable[Response]]:
496+
self._custom_starlette_routes.append(
497+
Route(
498+
path,
499+
endpoint=func,
500+
methods=methods,
501+
name=name,
502+
include_in_schema=include_in_schema,
503+
)
504+
)
505+
return func
506+
507+
return decorator
508+
480509
async def run_stdio_async(self) -> None:
481510
"""Run the server using stdio transport."""
482511
async with stdio_server() as (read_stream, write_stream):
@@ -513,31 +542,33 @@ async def handle_sse(request) -> EventSourceResponse:
513542
routes = []
514543
middleware = []
515544
required_scopes = self.settings.auth_required_scopes or []
516-
auth_router = None
517545

518546
# Add auth endpoints if auth provider is configured
519547
if self._auth_provider and self.settings.auth_issuer_url:
520-
from mcp.server.auth.router import create_auth_router
548+
from mcp.server.auth.router import create_auth_routes
521549

522-
# Set up bearer auth middleware if auth is required
523550
middleware = [
551+
# extract auth info from request (but do not require it)
524552
Middleware(
525553
AuthenticationMiddleware,
526554
backend=BearerAuthBackend(
527555
provider=self._auth_provider,
528556
),
529-
)
557+
),
558+
# Add the auth context middleware to store
559+
# authenticated user in a contextvar
560+
Middleware(AuthContextMiddleware),
530561
]
531-
auth_router = create_auth_router(
532-
provider=self._auth_provider,
533-
issuer_url=self.settings.auth_issuer_url,
534-
service_documentation_url=self.settings.auth_service_documentation_url,
535-
client_registration_options=self.settings.auth_client_registration_options,
536-
revocation_options=self.settings.auth_revocation_options,
562+
routes.extend(
563+
create_auth_routes(
564+
provider=self._auth_provider,
565+
issuer_url=self.settings.auth_issuer_url,
566+
service_documentation_url=self.settings.auth_service_documentation_url,
567+
client_registration_options=self.settings.auth_client_registration_options,
568+
revocation_options=self.settings.auth_revocation_options,
569+
)
537570
)
538571

539-
# Add the auth router as a mount
540-
541572
routes.append(
542573
Route(
543574
"/sse", endpoint=requires(required_scopes)(handle_sse), methods=["GET"]
@@ -549,8 +580,8 @@ async def handle_sse(request) -> EventSourceResponse:
549580
app=RequireAuthMiddleware(sse.handle_post_message, required_scopes),
550581
)
551582
)
552-
if auth_router:
553-
routes.append(Mount("/", app=auth_router))
583+
# mount these routes last, so they have the lowest route matching precedence
584+
routes.extend(self._custom_starlette_routes)
554585

555586
# Create Starlette app with routes and middleware
556587
return Starlette(

src/mcp/shared/auth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class OAuthClientMetadata(BaseModel):
4141
)
4242
# grant_types: this implementation only supports authorization_code & refresh_token
4343
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
44-
"authorization_code"
44+
"authorization_code",
45+
"refresh_token",
4546
]
4647
# this implementation only supports code; ie: it does not support implicit grants
4748
response_types: list[Literal["code"]] = ["code"]

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from mcp.server.auth.router import (
3131
ClientRegistrationOptions,
3232
RevocationOptions,
33-
create_auth_router,
33+
create_auth_routes,
3434
)
3535
from mcp.server.fastmcp import FastMCP
3636
from mcp.shared.auth import (
@@ -222,7 +222,7 @@ def mock_oauth_provider():
222222
@pytest.fixture
223223
def auth_app(mock_oauth_provider):
224224
# Create auth router
225-
auth_router = create_auth_router(
225+
auth_routes = create_auth_routes(
226226
mock_oauth_provider,
227227
AnyHttpUrl("https://auth.example.com"),
228228
AnyHttpUrl("https://docs.example.com"),
@@ -231,7 +231,7 @@ def auth_app(mock_oauth_provider):
231231
)
232232

233233
# Create Starlette app
234-
app = Starlette(routes=[Mount("/", app=auth_router)])
234+
app = Starlette(routes=auth_routes)
235235

236236
return app
237237

0 commit comments

Comments
 (0)