Skip to content

Commit 6901553

Browse files
committed
refactor: pull more oauth helpers out
1 parent 3ff6257 commit 6901553

File tree

5 files changed

+175
-87
lines changed

5 files changed

+175
-87
lines changed

src/mcp/client/auth/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
Implements authorization code flow with PKCE and automatic token refresh.
55
"""
66

7+
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
78
from mcp.client.auth.oauth2 import (
89
OAuthClientProvider,
9-
OAuthFlowError,
10-
OAuthRegistrationError,
11-
OAuthTokenError,
1210
PKCEParameters,
1311
TokenStorage,
1412
)

src/mcp/client/auth/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class OAuthFlowError(Exception):
2+
"""Base exception for OAuth flow errors."""
3+
4+
5+
class OAuthTokenError(OAuthFlowError):
6+
"""Raised when token operations fail."""
7+
8+
9+
class OAuthRegistrationError(OAuthFlowError):
10+
"""Raised when client registration fails."""

src/mcp/client/auth/oauth2.py

Lines changed: 63 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,19 @@
1919
import httpx
2020
from pydantic import BaseModel, Field, ValidationError
2121

22+
from mcp.client.auth import OAuthFlowError, OAuthTokenError
2223
from mcp.client.auth.utils import (
2324
build_protected_resource_discovery_urls,
25+
create_client_registration_request,
26+
create_oauth_metadata_request,
2427
extract_field_from_www_auth,
2528
extract_resource_metadata_from_www_auth,
2629
extract_scope_from_www_auth,
2730
get_client_metadata_scopes,
2831
get_discovery_urls,
32+
handle_auth_metadata_response,
2933
handle_protected_resource_response,
34+
handle_registration_response,
3035
)
3136
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
3237
from mcp.shared.auth import (
@@ -37,23 +42,10 @@
3742
ProtectedResourceMetadata,
3843
)
3944
from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url
40-
from mcp.types import LATEST_PROTOCOL_VERSION
4145

4246
logger = logging.getLogger(__name__)
4347

4448

45-
class OAuthFlowError(Exception):
46-
"""Base exception for OAuth flow errors."""
47-
48-
49-
class OAuthTokenError(OAuthFlowError):
50-
"""Raised when token operations fail."""
51-
52-
53-
class OAuthRegistrationError(OAuthFlowError):
54-
"""Raised when client registration fails."""
55-
56-
5749
class PKCEParameters(BaseModel):
5850
"""PKCE (Proof Key for Code Exchange) parameters."""
5951

@@ -255,19 +247,19 @@ async def _register_client(self) -> httpx.Request | None:
255247
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
256248
)
257249

258-
async def _handle_registration_response(self, response: httpx.Response) -> None:
259-
"""Handle registration response."""
260-
if response.status_code not in (200, 201):
261-
await response.aread()
262-
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
263-
264-
try:
265-
content = await response.aread()
266-
client_info = OAuthClientInformationFull.model_validate_json(content)
267-
self.context.client_info = client_info
268-
await self.context.storage.set_client_info(client_info)
269-
except ValidationError as e:
270-
raise OAuthRegistrationError(f"Invalid registration response: {e}")
250+
# async def _handle_registration_response(self, response: httpx.Response) -> None:
251+
# """Handle registration response."""
252+
# if response.status_code not in (200, 201):
253+
# await response.aread()
254+
# raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
255+
#
256+
# try:
257+
# content = await response.aread()
258+
# client_info = OAuthClientInformationFull.model_validate_json(content)
259+
# self.context.client_info = client_info
260+
# await self.context.storage.set_client_info(client_info)
261+
# except ValidationError as e:
262+
# raise OAuthRegistrationError(f"Invalid registration response: {e}")
271263

272264
async def _perform_authorization(self) -> httpx.Request:
273265
"""Perform the authorization flow."""
@@ -456,9 +448,6 @@ def _add_auth_header(self, request: httpx.Request) -> None:
456448
if self.context.current_tokens and self.context.current_tokens.access_token:
457449
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
458450

459-
def _create_oauth_metadata_request(self, url: str) -> httpx.Request:
460-
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
461-
462451
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
463452
content = await response.aread()
464453
metadata = OAuthMetadata.model_validate_json(content)
@@ -494,54 +483,63 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
494483
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
495484

496485
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
497-
discovery_urls = build_protected_resource_discovery_urls(
486+
prm_discovery_urls = build_protected_resource_discovery_urls(
498487
www_auth_resource_metadata_url, self.context.server_url
499488
)
500-
discovery_success = False
501-
for url in discovery_urls:
502-
discovery_request = httpx.Request(
503-
"GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
504-
)
505-
discovery_response = yield discovery_request
506-
discovery_success, prm, auth_server_url = await handle_protected_resource_response(
507-
discovery_response
508-
)
509-
if discovery_success:
510-
assert prm is not None
511-
assert auth_server_url is not None
489+
prm_discovery_success = False
490+
for url in prm_discovery_urls:
491+
discovery_request = create_oauth_metadata_request(url)
492+
493+
discovery_response = yield discovery_request # sending request
494+
495+
prm = await handle_protected_resource_response(discovery_response)
496+
if prm:
497+
prm_discovery_success = True
498+
499+
# saving the response metadata
512500
self.context.protected_resource_metadata = prm
513501
if prm.authorization_servers:
514-
self.context.auth_server_url = auth_server_url
515-
break
502+
self.context.auth_server_url = str(prm.authorization_servers[0])
516503

517-
if not discovery_success:
504+
break
505+
else:
506+
logger.debug(f"Protected resource metadata discovery failed: {url}")
507+
if not prm_discovery_success:
518508
raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found")
519509

520-
# Step 2: Apply scope selection strategy
521-
self.context.client_metadata.scope = get_client_metadata_scopes(
522-
www_auth_resource_metadata_url, self.context.protected_resource_metadata
523-
)
524-
525-
# Step 3: Discover OAuth metadata (with fallback for legacy servers)
526-
discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url)
527-
for url in discovery_urls:
528-
oauth_metadata_request = self._create_oauth_metadata_request(url)
510+
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
511+
asm_discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url)
512+
for url in asm_discovery_urls:
513+
oauth_metadata_request = create_oauth_metadata_request(url)
529514
oauth_metadata_response = yield oauth_metadata_request
530515

531-
if oauth_metadata_response.status_code == 200:
532-
try:
533-
await self._handle_oauth_metadata_response(oauth_metadata_response)
534-
break
535-
except ValidationError:
536-
continue
537-
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
538-
break # Non-4XX error, stop trying
516+
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
517+
if not ok:
518+
break
519+
if ok and asm:
520+
self.context.oauth_metadata = asm
521+
break
522+
else:
523+
logger.debug(f"OAuth metadata discovery failed: {url}")
524+
525+
# Step 3: Apply scope selection strategy
526+
self.context.client_metadata.scope = get_client_metadata_scopes(
527+
www_auth_resource_metadata_url,
528+
self.context.protected_resource_metadata,
529+
self.context.oauth_metadata,
530+
)
539531

540532
# Step 4: Register client if needed
541-
registration_request = await self._register_client()
542-
if registration_request:
533+
registration_request = create_client_registration_request(
534+
self.context.oauth_metadata,
535+
self.context.client_metadata,
536+
self.context.get_authorization_base_url(self.context.server_url),
537+
)
538+
if not self.context.client_info:
543539
registration_response = yield registration_request
544-
await self._handle_registration_response(registration_response)
540+
client_information = await handle_registration_response(registration_response)
541+
self.context.client_info = client_information
542+
await self.context.storage.set_client_info(client_information)
545543

546544
# Step 5: Perform authorization and complete token exchange
547545
token_response = yield await self._perform_authorization()

src/mcp/client/auth/utils.py

Lines changed: 96 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import re
33
from urllib.parse import urljoin, urlparse
44

5-
from httpx import Response
5+
from httpx import Request, Response
66
from pydantic import ValidationError
77

8-
from mcp.client.auth import OAuthFlowError
9-
from mcp.shared.auth import ProtectedResourceMetadata
8+
from mcp.client.auth import OAuthRegistrationError
9+
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
10+
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, ProtectedResourceMetadata
11+
from mcp.types import LATEST_PROTOCOL_VERSION
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -95,7 +97,9 @@ def build_protected_resource_discovery_urls(www_auth_url: str | None, server_url
9597

9698

9799
def get_client_metadata_scopes(
98-
www_authenticate_scope: str | None, protected_resource_metadata: ProtectedResourceMetadata | None
100+
www_authenticate_scope: str | None,
101+
protected_resource_metadata: ProtectedResourceMetadata | None,
102+
authorization_server_metadata: OAuthMetadata | None = None,
99103
) -> str | None:
100104
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
101105
# Per MCP spec, scope selection priority order:
@@ -109,6 +113,8 @@ def get_client_metadata_scopes(
109113
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
110114
# Priority 2: PRM scopes_supported
111115
return " ".join(protected_resource_metadata.scopes_supported)
116+
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
117+
return " ".join(authorization_server_metadata.scopes_supported)
112118
else:
113119
# Priority 3: Omit scope parameter
114120
return None
@@ -143,7 +149,7 @@ def get_discovery_urls(auth_server_url: str) -> list[str]:
143149

144150
async def handle_protected_resource_response(
145151
response: Response,
146-
) -> tuple[bool, ProtectedResourceMetadata | None, str | None]:
152+
) -> ProtectedResourceMetadata | None:
147153
"""
148154
Handle protected resource metadata discovery response.
149155
@@ -156,22 +162,96 @@ async def handle_protected_resource_response(
156162
try:
157163
content = await response.aread()
158164
metadata = ProtectedResourceMetadata.model_validate_json(content)
159-
auth_server_url: str | None = None
160-
if metadata.authorization_servers:
161-
auth_server_url = str(metadata.authorization_servers[0])
162-
return True, metadata, auth_server_url
165+
return metadata
163166

164167
except ValidationError:
165168
# Invalid metadata - try next URL
166-
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
167-
return False, None, None
168-
elif response.status_code == 404:
169+
return None
170+
else:
169171
# Not found - try next URL in fallback chain
170-
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
171-
return False, None, None
172+
return None
173+
174+
175+
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
176+
if response.status_code == 200:
177+
try:
178+
content = await response.aread()
179+
asm = OAuthMetadata.model_validate_json(content)
180+
return True, asm
181+
except ValidationError:
182+
return True, None
183+
elif response.status_code < 400 or response.status_code >= 500:
184+
return False, None # Non-4XX error, stop trying
185+
return True, None
186+
187+
188+
def create_oauth_metadata_request(url: str) -> Request:
189+
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
190+
191+
192+
def create_client_registration_request(
193+
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
194+
) -> Request:
195+
"""Build registration request or skip if already registered."""
196+
197+
if auth_server_metadata and auth_server_metadata.registration_endpoint:
198+
registration_url = str(auth_server_metadata.registration_endpoint)
172199
else:
173-
# Other error - fail immediately
174-
raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}")
200+
registration_url = urljoin(auth_base_url, "/register")
201+
202+
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
203+
204+
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
205+
206+
207+
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
208+
"""Handle registration response."""
209+
if response.status_code not in (200, 201):
210+
await response.aread()
211+
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
212+
213+
try:
214+
content = await response.aread()
215+
client_info = OAuthClientInformationFull.model_validate_json(content)
216+
return client_info
217+
# self.context.client_info = client_info
218+
# await self.context.storage.set_client_info(client_info)
219+
except ValidationError as e:
220+
raise OAuthRegistrationError(f"Invalid registration response: {e}")
221+
222+
223+
# async def prm_discovery(
224+
# server_url: str,
225+
# www_auth_resource_metadata_url: str | None,
226+
# ) -> ProtectedResourceMetadata | None:
227+
# # Step 1: Discover protected resource metadata (SEP-985 with fallback support)
228+
# prm_discovery_urls = build_protected_resource_discovery_urls(www_auth_resource_metadata_url, server_url)
229+
# prm_discovery_success = False
230+
# for url in prm_discovery_urls:
231+
# discovery_request = create_oauth_metadata_request(url)
232+
#
233+
# discovery_response = yield discovery_request # sending request
234+
#
235+
# prm = await handle_protected_resource_response(discovery_response)
236+
# if prm:
237+
# prm_discovery_success = True
238+
#
239+
# # saving the response metadata
240+
# self.context.protected_resource_metadata = prm
241+
# if prm.authorization_servers:
242+
# self.context.auth_server_url = str(prm.authorization_servers[0])
243+
#
244+
# break
245+
# else:
246+
# logger.debug(f"Protected resource metadata discovery failed: {url}")
247+
# if not prm_discovery_success:
248+
# raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found")
249+
250+
251+
# class OAuthHandler:
252+
# async def make_request(self, request: Request) -> Response: ...
253+
#
254+
# async def store_prm(self, prm: ProtectedResourceMetadata) -> None: ...
175255

176256

177257
# async def discovery_process(discovery_urls: list[str]) -> AsyncGenerator[Request, Response]:

0 commit comments

Comments
 (0)