Skip to content

Commit 31618c1

Browse files
committed
Add auth context middleware
1 parent 8c86bce commit 31618c1

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import contextvars
2+
3+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
4+
from starlette.requests import Request
5+
from starlette.responses import Response
6+
7+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
8+
from mcp.server.auth.provider import AuthInfo
9+
10+
# Create a contextvar to store the authenticated user
11+
# The default is None, indicating no authenticated user is present
12+
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None](
13+
"auth_context", default=None
14+
)
15+
16+
17+
def get_current_auth_info() -> AuthInfo | None:
18+
"""
19+
Get the auth info from the current context.
20+
21+
Returns:
22+
The auth info if an authenticated user is available, None otherwise.
23+
"""
24+
auth_user = auth_context_var.get()
25+
return auth_user.auth_info if auth_user else None
26+
27+
28+
class AuthContextMiddleware(BaseHTTPMiddleware):
29+
"""
30+
Middleware that extracts the authenticated user from the request
31+
and sets it in a contextvar for easy access throughout the request lifecycle.
32+
33+
This middleware should be added after the AuthenticationMiddleware in the
34+
middleware stack to ensure that the user is properly authenticated before
35+
being stored in the context.
36+
"""
37+
38+
async def dispatch(
39+
self, request: Request, call_next: RequestResponseEndpoint
40+
) -> Response:
41+
# Get the authenticated user from the request if it exists
42+
user = getattr(request, "user", None)
43+
44+
# Only set the context var if the user is an AuthenticatedUser
45+
if isinstance(user, AuthenticatedUser):
46+
# Set the authenticated user in the contextvar
47+
token = auth_context_var.set(user)
48+
try:
49+
# Process the request
50+
response = await call_next(request)
51+
return response
52+
finally:
53+
# Reset the contextvar after the request is processed
54+
auth_context_var.reset(token)
55+
else:
56+
# No authenticated user, just process the request
57+
return await call_next(request)

0 commit comments

Comments
 (0)