Skip to content

Commit a2e96d9

Browse files
committed
refactor: pull out oauth helper functions
1 parent b7b0f8e commit a2e96d9

File tree

3 files changed

+195
-175
lines changed

3 files changed

+195
-175
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 21 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import base64
88
import hashlib
99
import logging
10-
import re
1110
import secrets
1211
import string
1312
import time
@@ -20,6 +19,14 @@
2019
import httpx
2120
from pydantic import BaseModel, Field, ValidationError
2221

22+
from mcp.client.auth.utils import (
23+
build_protected_resource_discovery_urls,
24+
extract_field_from_www_auth,
25+
extract_resource_metadata_from_www_auth,
26+
extract_scope_from_www_auth,
27+
get_client_metadata_scopes,
28+
get_discovery_urls,
29+
)
2330
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
2431
from mcp.shared.auth import (
2532
OAuthClientInformationFull,
@@ -200,85 +207,6 @@ def __init__(
200207
)
201208
self._initialized = False
202209

203-
def _build_protected_resource_discovery_urls(self, init_response: httpx.Response) -> list[str]:
204-
"""
205-
Build ordered list of URLs to try for protected resource metadata discovery.
206-
207-
Per SEP-985, the client MUST:
208-
1. Try resource_metadata from WWW-Authenticate header (if present)
209-
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
210-
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
211-
212-
Args:
213-
init_response: The initial 401 response from the server
214-
215-
Returns:
216-
Ordered list of URLs to try for discovery
217-
"""
218-
urls: list[str] = []
219-
220-
# Priority 1: WWW-Authenticate header with resource_metadata parameter
221-
www_auth_url = self._extract_resource_metadata_from_www_auth(init_response)
222-
if www_auth_url:
223-
urls.append(www_auth_url)
224-
225-
# Priority 2-3: Well-known URIs (RFC 9728)
226-
parsed = urlparse(self.context.server_url)
227-
base_url = f"{parsed.scheme}://{parsed.netloc}"
228-
229-
# Priority 2: Path-based well-known URI (if server has a path component)
230-
if parsed.path and parsed.path != "/":
231-
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
232-
urls.append(path_based_url)
233-
234-
# Priority 3: Root-based well-known URI
235-
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
236-
urls.append(root_based_url)
237-
238-
return urls
239-
240-
def _extract_field_from_www_auth(self, init_response: httpx.Response, field_name: str) -> str | None:
241-
"""
242-
Extract field from WWW-Authenticate header.
243-
244-
Returns:
245-
Field value if found in WWW-Authenticate header, None otherwise
246-
"""
247-
www_auth_header = init_response.headers.get("WWW-Authenticate")
248-
if not www_auth_header:
249-
return None
250-
251-
# Pattern matches: field_name="value" or field_name=value (unquoted)
252-
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
253-
match = re.search(pattern, www_auth_header)
254-
255-
if match:
256-
# Return quoted value if present, otherwise unquoted value
257-
return match.group(1) or match.group(2)
258-
259-
return None
260-
261-
def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None:
262-
"""
263-
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
264-
265-
Returns:
266-
Resource metadata URL if found in WWW-Authenticate header, None otherwise
267-
"""
268-
if not init_response or init_response.status_code != 401:
269-
return None
270-
271-
return self._extract_field_from_www_auth(init_response, "resource_metadata")
272-
273-
def _extract_scope_from_www_auth(self, init_response: httpx.Response) -> str | None:
274-
"""
275-
Extract scope parameter from WWW-Authenticate header as per RFC6750.
276-
277-
Returns:
278-
Scope string if found in WWW-Authenticate header, None otherwise
279-
"""
280-
return self._extract_field_from_www_auth(init_response, "scope")
281-
282210
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
283211
"""
284212
Handle protected resource metadata discovery response.
@@ -309,54 +237,6 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
309237
# Other error - fail immediately
310238
raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}")
311239

312-
def _select_scopes(self, init_response: httpx.Response) -> None:
313-
"""Select scopes as outlined in the 'Scope Selection Strategy in the MCP spec."""
314-
# Per MCP spec, scope selection priority order:
315-
# 1. Use scope from WWW-Authenticate header (if provided)
316-
# 2. Use all scopes from PRM scopes_supported (if available)
317-
# 3. Omit scope parameter if neither is available
318-
#
319-
www_authenticate_scope = self._extract_scope_from_www_auth(init_response)
320-
if www_authenticate_scope is not None:
321-
# Priority 1: WWW-Authenticate header scope
322-
self.context.client_metadata.scope = www_authenticate_scope
323-
elif (
324-
self.context.protected_resource_metadata is not None
325-
and self.context.protected_resource_metadata.scopes_supported is not None
326-
):
327-
# Priority 2: PRM scopes_supported
328-
self.context.client_metadata.scope = " ".join(self.context.protected_resource_metadata.scopes_supported)
329-
else:
330-
# Priority 3: Omit scope parameter
331-
self.context.client_metadata.scope = None
332-
333-
def _get_discovery_urls(self) -> list[str]:
334-
"""Generate ordered list of (url, type) tuples for discovery attempts."""
335-
urls: list[str] = []
336-
auth_server_url = self.context.auth_server_url or self.context.server_url
337-
parsed = urlparse(auth_server_url)
338-
base_url = f"{parsed.scheme}://{parsed.netloc}"
339-
340-
# RFC 8414: Path-aware OAuth discovery
341-
if parsed.path and parsed.path != "/":
342-
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
343-
urls.append(urljoin(base_url, oauth_path))
344-
345-
# OAuth root fallback
346-
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
347-
348-
# RFC 8414 section 5: Path-aware OIDC discovery
349-
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
350-
if parsed.path and parsed.path != "/":
351-
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
352-
urls.append(urljoin(base_url, oidc_path))
353-
354-
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
355-
oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration"
356-
urls.append(oidc_fallback)
357-
358-
return urls
359-
360240
async def _register_client(self) -> httpx.Request | None:
361241
"""Build registration request or skip if already registered."""
362242
if self.context.client_info:
@@ -610,8 +490,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
610490
# Perform full OAuth flow
611491
try:
612492
# OAuth flow must be inline due to generator constraints
493+
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
494+
613495
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
614-
discovery_urls = self._build_protected_resource_discovery_urls(response)
496+
discovery_urls = build_protected_resource_discovery_urls(
497+
www_auth_resource_metadata_url, self.context.server_url
498+
)
615499
discovery_success = False
616500
for url in discovery_urls:
617501
discovery_request = httpx.Request(
@@ -626,10 +510,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
626510
raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found")
627511

628512
# Step 2: Apply scope selection strategy
629-
self._select_scopes(response)
513+
self.context.client_metadata.scope = get_client_metadata_scopes(
514+
www_auth_resource_metadata_url, self.context.protected_resource_metadata
515+
)
630516

631517
# Step 3: Discover OAuth metadata (with fallback for legacy servers)
632-
discovery_urls = self._get_discovery_urls()
518+
discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url)
633519
for url in discovery_urls:
634520
oauth_metadata_request = self._create_oauth_metadata_request(url)
635521
oauth_metadata_response = yield oauth_metadata_request
@@ -661,13 +547,15 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
661547
yield request
662548
elif response.status_code == 403:
663549
# Step 1: Extract error field from WWW-Authenticate header
664-
error = self._extract_field_from_www_auth(response, "error")
550+
error = extract_field_from_www_auth(response, "error")
665551

666552
# Step 2: Check if we need to step-up authorization
667553
if error == "insufficient_scope":
668554
try:
669555
# Step 2a: Update the required scopes
670-
self._select_scopes(response)
556+
self.context.client_metadata.scope = get_client_metadata_scopes(
557+
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
558+
)
671559

672560
# Step 2b: Perform (re-)authorization and token exchange
673561
token_response = yield await self._perform_authorization()

src/mcp/client/auth/utils.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import re
2+
from urllib.parse import urljoin, urlparse
3+
4+
from httpx import Response
5+
6+
from mcp.shared.auth import ProtectedResourceMetadata
7+
8+
9+
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
10+
"""
11+
Extract field from WWW-Authenticate header.
12+
13+
Returns:
14+
Field value if found in WWW-Authenticate header, None otherwise
15+
"""
16+
www_auth_header = response.headers.get("WWW-Authenticate")
17+
if not www_auth_header:
18+
return None
19+
20+
# Pattern matches: field_name="value" or field_name=value (unquoted)
21+
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
22+
match = re.search(pattern, www_auth_header)
23+
24+
if match:
25+
# Return quoted value if present, otherwise unquoted value
26+
return match.group(1) or match.group(2)
27+
28+
return None
29+
30+
31+
def extract_scope_from_www_auth(response: Response) -> str | None:
32+
"""
33+
Extract scope parameter from WWW-Authenticate header as per RFC6750.
34+
35+
Returns:
36+
Scope string if found in WWW-Authenticate header, None otherwise
37+
"""
38+
return extract_field_from_www_auth(response, "scope")
39+
40+
41+
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
42+
"""
43+
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
44+
45+
Returns:
46+
Resource metadata URL if found in WWW-Authenticate header, None otherwise
47+
"""
48+
if not response or response.status_code != 401:
49+
return None
50+
51+
return extract_field_from_www_auth(response, "resource_metadata")
52+
53+
54+
def build_protected_resource_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
55+
"""
56+
Build ordered list of URLs to try for protected resource metadata discovery.
57+
58+
Per SEP-985, the client MUST:
59+
1. Try resource_metadata from WWW-Authenticate header (if present)
60+
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
61+
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
62+
63+
Args:
64+
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
65+
server_url: server url
66+
67+
Returns:
68+
Ordered list of URLs to try for discovery
69+
"""
70+
urls: list[str] = []
71+
72+
# Priority 1: WWW-Authenticate header with resource_metadata parameter
73+
if www_auth_url:
74+
urls.append(www_auth_url)
75+
76+
# Priority 2-3: Well-known URIs (RFC 9728)
77+
parsed = urlparse(server_url)
78+
base_url = f"{parsed.scheme}://{parsed.netloc}"
79+
80+
# Priority 2: Path-based well-known URI (if server has a path component)
81+
if parsed.path and parsed.path != "/":
82+
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
83+
urls.append(path_based_url)
84+
85+
# Priority 3: Root-based well-known URI
86+
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
87+
urls.append(root_based_url)
88+
89+
return urls
90+
91+
92+
def get_client_metadata_scopes(
93+
www_authenticate_scope: str | None, protected_resource_metadata: ProtectedResourceMetadata | None
94+
) -> str | None:
95+
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
96+
# Per MCP spec, scope selection priority order:
97+
# 1. Use scope from WWW-Authenticate header (if provided)
98+
# 2. Use all scopes from PRM scopes_supported (if available)
99+
# 3. Omit scope parameter if neither is available
100+
101+
if www_authenticate_scope is not None:
102+
# Priority 1: WWW-Authenticate header scope
103+
return www_authenticate_scope
104+
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
105+
# Priority 2: PRM scopes_supported
106+
return " ".join(protected_resource_metadata.scopes_supported)
107+
else:
108+
# Priority 3: Omit scope parameter
109+
return None
110+
111+
112+
def get_discovery_urls(auth_server_url: str) -> list[str]:
113+
"""Generate ordered list of (url, type) tuples for discovery attempts."""
114+
urls: list[str] = []
115+
parsed = urlparse(auth_server_url)
116+
base_url = f"{parsed.scheme}://{parsed.netloc}"
117+
118+
# RFC 8414: Path-aware OAuth discovery
119+
if parsed.path and parsed.path != "/":
120+
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
121+
urls.append(urljoin(base_url, oauth_path))
122+
123+
# OAuth root fallback
124+
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
125+
126+
# RFC 8414 section 5: Path-aware OIDC discovery
127+
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
128+
if parsed.path and parsed.path != "/":
129+
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
130+
urls.append(urljoin(base_url, oidc_path))
131+
132+
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
133+
oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration"
134+
urls.append(oidc_fallback)
135+
136+
return urls

0 commit comments

Comments
 (0)