Skip to content

Commit 49e0502

Browse files
committed
fix: use common parsing/context module between
NorthAuthenticationMiddleware and FastMCPNorthMiddleware
1 parent 7135119 commit 49e0502

File tree

7 files changed

+612
-440
lines changed

7 files changed

+612
-440
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def echo(_: dict) -> dict:
4949

5050
The middleware reads the `X-North-ID-Token` header (if present) and parses Base64-encoded JSON from `X-North-Connector-Tokens`. It never returns a 401—it simply exposes these values through a context variable and `request.state.north_context` for downstream handlers.
5151

52+
When you use `NorthMCPServer`, the authentication stack now populates the same request context automatically, so utilities built against `get_north_request_context()` can be shared across FastMCP apps and fully authenticated servers.
53+
5254
## Examples
5355

5456
This repository contains example servers that you can use as a quickstart. You can find them in the [examples directory](https://github.com/cohere-ai/north-mcp-python-sdk/tree/main/examples).

src/north_mcp_python_sdk/auth.py

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818
from starlette.responses import JSONResponse
1919
from starlette.types import ASGIApp, Receive, Scope, Send
2020

21+
from .north_context import (
22+
DEFAULT_CONNECTOR_TOKENS_HEADER,
23+
DEFAULT_SERVER_SECRET_HEADER,
24+
DEFAULT_USER_ID_TOKEN_HEADER,
25+
NORTH_CONTEXT_SCOPE_KEY,
26+
NorthRequestContext,
27+
decode_connector_tokens,
28+
reset_north_request_context,
29+
set_north_request_context,
30+
)
2131

2232
class AuthHeaderTokens(BaseModel):
2333
server_secret: str | None
@@ -30,9 +40,15 @@ def __init__(
3040
self,
3141
connector_access_tokens: dict[str, str],
3242
email: str | None = None,
43+
user_id_token: str | None = None,
3344
):
3445
self.connector_access_tokens = connector_access_tokens
3546
self.email = email
47+
self.user_id_token = user_id_token
48+
self.north_context = NorthRequestContext(
49+
user_id_token=user_id_token,
50+
connector_tokens=connector_access_tokens,
51+
)
3652

3753

3854
class NorthAuthenticationMiddleware(AuthenticationMiddleware):
@@ -153,16 +169,32 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
153169
return await self.app(scope, receive, send)
154170

155171
user = scope.get("user")
172+
existing_context = scope.get(NORTH_CONTEXT_SCOPE_KEY)
173+
174+
def store_context(context: NorthRequestContext) -> None:
175+
scope[NORTH_CONTEXT_SCOPE_KEY] = context
176+
state = scope.get("state")
177+
if state is None:
178+
scope["state"] = {"north_context": context}
179+
elif isinstance(state, dict):
180+
state["north_context"] = context
181+
else:
182+
setattr(state, "north_context", context)
156183

157184
# For custom routes that don't require auth, user will be None
158185
if user is None:
159186
self.logger.debug(
160187
"Custom route accessed without authentication (operational endpoint)"
161188
)
189+
context = existing_context or NorthRequestContext()
190+
store_context(context)
191+
192+
context_token = set_north_request_context(context)
162193
token = auth_context_var.set(None)
163194
try:
164195
await self.app(scope, receive, send)
165196
finally:
197+
reset_north_request_context(context_token)
166198
auth_context_var.reset(token)
167199
return
168200

@@ -179,10 +211,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
179211
list(user.connector_access_tokens.keys()),
180212
)
181213

214+
context = existing_context or user.north_context
215+
store_context(context)
216+
217+
context_token = set_north_request_context(context)
182218
token = auth_context_var.set(user)
183219
try:
184220
await self.app(scope, receive, send)
185221
finally:
222+
reset_north_request_context(context_token)
186223
auth_context_var.reset(token)
187224

188225

@@ -211,26 +248,12 @@ def _has_x_north_headers(self, conn: HTTPConnection) -> bool:
211248
return any(
212249
conn.headers.get(header)
213250
for header in [
214-
"X-North-ID-Token",
215-
"X-North-Connector-Tokens",
216-
"X-North-Server-Secret",
251+
DEFAULT_USER_ID_TOKEN_HEADER,
252+
DEFAULT_CONNECTOR_TOKENS_HEADER,
253+
DEFAULT_SERVER_SECRET_HEADER,
217254
]
218255
)
219256

220-
def _parse_connector_tokens(self, header_value: str) -> dict[str, str]:
221-
"""Parse Base64 URL-safe encoded JSON connector tokens."""
222-
try:
223-
# Add padding if needed for Base64 decoding
224-
padded = header_value + "=" * (4 - len(header_value) % 4)
225-
decoded_json = base64.urlsafe_b64decode(padded).decode()
226-
tokens = json.loads(decoded_json)
227-
if not isinstance(tokens, dict):
228-
raise ValueError("Connector tokens must be a JSON object")
229-
return tokens
230-
except Exception as e:
231-
self.logger.debug("Failed to parse connector tokens: %s", e)
232-
raise AuthenticationError("invalid connector tokens format")
233-
234257
def _validate_server_secret(self, provided_secret: str | None) -> None:
235258
"""Validate server secret matches expected value."""
236259
if self._server_secret and self._server_secret != provided_secret:
@@ -271,11 +294,16 @@ def _process_user_id_token(self, user_id_token: str | None) -> str | None:
271294
raise AuthenticationError("invalid user id token")
272295

273296
def _create_authenticated_user(
274-
self, email: str | None, connector_access_tokens: dict[str, str]
297+
self,
298+
email: str | None,
299+
connector_access_tokens: dict[str, str],
300+
user_id_token: str | None,
275301
) -> tuple[AuthCredentials, AuthenticatedNorthUser]:
276302
"""Create authenticated user from validated tokens."""
277303
return AuthCredentials(), AuthenticatedNorthUser(
278-
connector_access_tokens=connector_access_tokens, email=email
304+
connector_access_tokens=connector_access_tokens,
305+
email=email,
306+
user_id_token=user_id_token,
279307
)
280308

281309
async def _authenticate_x_north_headers(
@@ -285,16 +313,25 @@ async def _authenticate_x_north_headers(
285313
self.logger.debug("Using X-North headers for authentication")
286314

287315
# Extract headers
288-
user_id_token = conn.headers.get("X-North-ID-Token")
289-
connector_tokens_header = conn.headers.get("X-North-Connector-Tokens")
290-
server_secret = conn.headers.get("X-North-Server-Secret")
316+
user_id_token = conn.headers.get(DEFAULT_USER_ID_TOKEN_HEADER)
317+
connector_tokens_header = conn.headers.get(
318+
DEFAULT_CONNECTOR_TOKENS_HEADER
319+
)
320+
server_secret = conn.headers.get(DEFAULT_SERVER_SECRET_HEADER)
291321

292322
# Parse connector tokens (Base64 URL-safe encoded JSON)
293323
connector_access_tokens = {}
294324
if connector_tokens_header:
295-
connector_access_tokens = self._parse_connector_tokens(
296-
connector_tokens_header
297-
)
325+
try:
326+
connector_access_tokens = decode_connector_tokens(
327+
connector_tokens_header,
328+
logger=self.logger,
329+
raise_on_error=True,
330+
)
331+
except ValueError as exc:
332+
raise AuthenticationError(
333+
"invalid connector tokens format"
334+
) from exc
298335

299336
self.logger.debug(
300337
"X-North headers parsed. Has server_secret: %s, Has user_id_token: %s, Connector count: %d",
@@ -309,8 +346,16 @@ async def _authenticate_x_north_headers(
309346
self._validate_server_secret(server_secret)
310347
email = self._process_user_id_token(user_id_token)
311348

349+
context = NorthRequestContext(
350+
user_id_token=user_id_token,
351+
connector_tokens=connector_access_tokens,
352+
)
353+
conn.scope[NORTH_CONTEXT_SCOPE_KEY] = context
354+
312355
self.logger.debug("X-North authentication successful")
313-
return self._create_authenticated_user(email, connector_access_tokens)
356+
return self._create_authenticated_user(
357+
email, connector_access_tokens, user_id_token
358+
)
314359

315360
async def _authenticate_legacy_bearer(
316361
self, conn: HTTPConnection
@@ -358,9 +403,15 @@ async def _authenticate_legacy_bearer(
358403
self._validate_server_secret(tokens.server_secret)
359404
email = self._process_user_id_token(tokens.user_id_token)
360405

406+
context = NorthRequestContext(
407+
user_id_token=tokens.user_id_token,
408+
connector_tokens=tokens.connector_access_tokens,
409+
)
410+
conn.scope[NORTH_CONTEXT_SCOPE_KEY] = context
411+
361412
self.logger.debug("Legacy authentication successful")
362413
return self._create_authenticated_user(
363-
email, tokens.connector_access_tokens
414+
email, tokens.connector_access_tokens, tokens.user_id_token
364415
)
365416

366417
async def authenticate(

src/north_mcp_python_sdk/middleware.py

Lines changed: 15 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,21 @@
1-
import base64
2-
import binascii
3-
import contextvars
4-
import json
51
import logging
6-
from dataclasses import dataclass, field
7-
from typing import Dict, Optional
82

93
from starlette.middleware.base import BaseHTTPMiddleware
104
from starlette.requests import Request
115
from starlette.types import ASGIApp
126

13-
14-
_DEFAULT_USER_ID_TOKEN_HEADER = "X-North-ID-Token"
15-
_DEFAULT_CONNECTOR_TOKENS_HEADER = "X-North-Connector-Tokens"
16-
17-
18-
@dataclass(frozen=True)
19-
class NorthRequestContext:
20-
"""Holds North-specific request context extracted from headers."""
21-
22-
user_id_token: Optional[str] = None
23-
connector_tokens: Dict[str, str] = field(default_factory=dict)
24-
25-
26-
north_request_context_var = contextvars.ContextVar[NorthRequestContext](
27-
"north_request_context",
28-
default=NorthRequestContext(),
7+
from .north_context import (
8+
DEFAULT_CONNECTOR_TOKENS_HEADER,
9+
DEFAULT_USER_ID_TOKEN_HEADER,
10+
NorthRequestContext,
11+
decode_connector_tokens,
12+
get_north_request_context,
13+
north_request_context_var,
14+
reset_north_request_context,
15+
set_north_request_context,
2916
)
3017

3118

32-
def get_north_request_context() -> NorthRequestContext:
33-
"""
34-
Retrieve the North request context for the current request.
35-
36-
Returns:
37-
NorthRequestContext: The context extracted by FastMCPNorthMiddleware.
38-
"""
39-
return north_request_context_var.get()
40-
41-
4219
class FastMCPNorthMiddleware(BaseHTTPMiddleware):
4320
"""
4421
Lightweight middleware that extracts North request metadata from headers.
@@ -52,8 +29,8 @@ def __init__(
5229
self,
5330
app: ASGIApp,
5431
*,
55-
user_id_token_header: str = _DEFAULT_USER_ID_TOKEN_HEADER,
56-
connector_tokens_header: str = _DEFAULT_CONNECTOR_TOKENS_HEADER,
32+
user_id_token_header: str = DEFAULT_USER_ID_TOKEN_HEADER,
33+
connector_tokens_header: str = DEFAULT_CONNECTOR_TOKENS_HEADER,
5734
debug: bool = False,
5835
) -> None:
5936
super().__init__(app)
@@ -64,43 +41,13 @@ def __init__(
6441
self._logger.setLevel(logging.DEBUG)
6542
self._debug = debug
6643

67-
def _parse_connector_tokens(self, raw_header: str) -> Dict[str, str]:
68-
"""
69-
Parse the connector tokens header, expected to be Base64-encoded JSON.
70-
71-
Returns an empty dict when the header cannot be decoded or does not
72-
resolve to a JSON object of string keys and values.
73-
"""
74-
if not raw_header:
75-
return {}
76-
77-
padding = (-len(raw_header)) % 4
78-
padded_value = raw_header + ("=" * padding)
79-
80-
try:
81-
decoded_bytes = base64.urlsafe_b64decode(padded_value)
82-
decoded_json = decoded_bytes.decode()
83-
parsed = json.loads(decoded_json)
84-
if isinstance(parsed, dict):
85-
return {
86-
str(key): str(value)
87-
for key, value in parsed.items()
88-
if isinstance(key, str) and isinstance(value, str)
89-
}
90-
except (ValueError, json.JSONDecodeError, binascii.Error) as exc:
91-
self._logger.debug(
92-
"Failed to decode connector tokens header: %s", exc
93-
)
94-
95-
return {}
96-
9744
async def dispatch(self, request: Request, call_next):
9845
user_id_token = request.headers.get(self._user_id_token_header)
9946
connector_tokens_header = request.headers.get(
10047
self._connector_tokens_header
10148
)
102-
connector_tokens = self._parse_connector_tokens(
103-
connector_tokens_header or ""
49+
connector_tokens = decode_connector_tokens(
50+
connector_tokens_header or "", logger=self._logger
10451
)
10552

10653
if self._debug:
@@ -116,12 +63,12 @@ async def dispatch(self, request: Request, call_next):
11663
)
11764

11865
request.state.north_context = context
119-
token = north_request_context_var.set(context)
66+
token = set_north_request_context(context)
12067

12168
try:
12269
response = await call_next(request)
12370
finally:
124-
north_request_context_var.reset(token)
71+
reset_north_request_context(token)
12572
# Request.state lives for the lifetime of the request; no cleanup needed.
12673

12774
return response

0 commit comments

Comments
 (0)