|
19 | 19 | import httpx |
20 | 20 | from pydantic import BaseModel, Field, ValidationError |
21 | 21 |
|
| 22 | +from mcp.client.auth import OAuthFlowError, OAuthTokenError |
22 | 23 | from mcp.client.auth.utils import ( |
23 | 24 | build_protected_resource_discovery_urls, |
| 25 | + create_client_registration_request, |
| 26 | + create_oauth_metadata_request, |
24 | 27 | extract_field_from_www_auth, |
25 | 28 | extract_resource_metadata_from_www_auth, |
26 | 29 | extract_scope_from_www_auth, |
27 | 30 | get_client_metadata_scopes, |
28 | 31 | get_discovery_urls, |
| 32 | + handle_auth_metadata_response, |
29 | 33 | handle_protected_resource_response, |
| 34 | + handle_registration_response, |
30 | 35 | ) |
31 | 36 | from mcp.client.streamable_http import MCP_PROTOCOL_VERSION |
32 | 37 | from mcp.shared.auth import ( |
|
37 | 42 | ProtectedResourceMetadata, |
38 | 43 | ) |
39 | 44 | from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url |
40 | | -from mcp.types import LATEST_PROTOCOL_VERSION |
41 | 45 |
|
42 | 46 | logger = logging.getLogger(__name__) |
43 | 47 |
|
44 | 48 |
|
45 | | -class OAuthFlowError(Exception): |
46 | | - """Base exception for OAuth flow errors.""" |
47 | | - |
48 | | - |
49 | | -class OAuthTokenError(OAuthFlowError): |
50 | | - """Raised when token operations fail.""" |
51 | | - |
52 | | - |
53 | | -class OAuthRegistrationError(OAuthFlowError): |
54 | | - """Raised when client registration fails.""" |
55 | | - |
56 | | - |
57 | 49 | class PKCEParameters(BaseModel): |
58 | 50 | """PKCE (Proof Key for Code Exchange) parameters.""" |
59 | 51 |
|
@@ -255,19 +247,19 @@ async def _register_client(self) -> httpx.Request | None: |
255 | 247 | "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} |
256 | 248 | ) |
257 | 249 |
|
258 | | - async def _handle_registration_response(self, response: httpx.Response) -> None: |
259 | | - """Handle registration response.""" |
260 | | - if response.status_code not in (200, 201): |
261 | | - await response.aread() |
262 | | - raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") |
263 | | - |
264 | | - try: |
265 | | - content = await response.aread() |
266 | | - client_info = OAuthClientInformationFull.model_validate_json(content) |
267 | | - self.context.client_info = client_info |
268 | | - await self.context.storage.set_client_info(client_info) |
269 | | - except ValidationError as e: |
270 | | - raise OAuthRegistrationError(f"Invalid registration response: {e}") |
| 250 | + # async def _handle_registration_response(self, response: httpx.Response) -> None: |
| 251 | + # """Handle registration response.""" |
| 252 | + # if response.status_code not in (200, 201): |
| 253 | + # await response.aread() |
| 254 | + # raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") |
| 255 | + # |
| 256 | + # try: |
| 257 | + # content = await response.aread() |
| 258 | + # client_info = OAuthClientInformationFull.model_validate_json(content) |
| 259 | + # self.context.client_info = client_info |
| 260 | + # await self.context.storage.set_client_info(client_info) |
| 261 | + # except ValidationError as e: |
| 262 | + # raise OAuthRegistrationError(f"Invalid registration response: {e}") |
271 | 263 |
|
272 | 264 | async def _perform_authorization(self) -> httpx.Request: |
273 | 265 | """Perform the authorization flow.""" |
@@ -456,9 +448,6 @@ def _add_auth_header(self, request: httpx.Request) -> None: |
456 | 448 | if self.context.current_tokens and self.context.current_tokens.access_token: |
457 | 449 | request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" |
458 | 450 |
|
459 | | - def _create_oauth_metadata_request(self, url: str) -> httpx.Request: |
460 | | - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) |
461 | | - |
462 | 451 | async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: |
463 | 452 | content = await response.aread() |
464 | 453 | metadata = OAuthMetadata.model_validate_json(content) |
@@ -494,54 +483,63 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. |
494 | 483 | www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) |
495 | 484 |
|
496 | 485 | # Step 1: Discover protected resource metadata (SEP-985 with fallback support) |
497 | | - discovery_urls = build_protected_resource_discovery_urls( |
| 486 | + prm_discovery_urls = build_protected_resource_discovery_urls( |
498 | 487 | www_auth_resource_metadata_url, self.context.server_url |
499 | 488 | ) |
500 | | - discovery_success = False |
501 | | - for url in discovery_urls: |
502 | | - discovery_request = httpx.Request( |
503 | | - "GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} |
504 | | - ) |
505 | | - discovery_response = yield discovery_request |
506 | | - discovery_success, prm, auth_server_url = await handle_protected_resource_response( |
507 | | - discovery_response |
508 | | - ) |
509 | | - if discovery_success: |
510 | | - assert prm is not None |
511 | | - assert auth_server_url is not None |
| 489 | + prm_discovery_success = False |
| 490 | + for url in prm_discovery_urls: |
| 491 | + discovery_request = create_oauth_metadata_request(url) |
| 492 | + |
| 493 | + discovery_response = yield discovery_request # sending request |
| 494 | + |
| 495 | + prm = await handle_protected_resource_response(discovery_response) |
| 496 | + if prm: |
| 497 | + prm_discovery_success = True |
| 498 | + |
| 499 | + # saving the response metadata |
512 | 500 | self.context.protected_resource_metadata = prm |
513 | 501 | if prm.authorization_servers: |
514 | | - self.context.auth_server_url = auth_server_url |
515 | | - break |
| 502 | + self.context.auth_server_url = str(prm.authorization_servers[0]) |
516 | 503 |
|
517 | | - if not discovery_success: |
| 504 | + break |
| 505 | + else: |
| 506 | + logger.debug(f"Protected resource metadata discovery failed: {url}") |
| 507 | + if not prm_discovery_success: |
518 | 508 | raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found") |
519 | 509 |
|
520 | | - # Step 2: Apply scope selection strategy |
521 | | - self.context.client_metadata.scope = get_client_metadata_scopes( |
522 | | - www_auth_resource_metadata_url, self.context.protected_resource_metadata |
523 | | - ) |
524 | | - |
525 | | - # Step 3: Discover OAuth metadata (with fallback for legacy servers) |
526 | | - discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url) |
527 | | - for url in discovery_urls: |
528 | | - oauth_metadata_request = self._create_oauth_metadata_request(url) |
| 510 | + # Step 2: Discover OAuth metadata (with fallback for legacy servers) |
| 511 | + asm_discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url) |
| 512 | + for url in asm_discovery_urls: |
| 513 | + oauth_metadata_request = create_oauth_metadata_request(url) |
529 | 514 | oauth_metadata_response = yield oauth_metadata_request |
530 | 515 |
|
531 | | - if oauth_metadata_response.status_code == 200: |
532 | | - try: |
533 | | - await self._handle_oauth_metadata_response(oauth_metadata_response) |
534 | | - break |
535 | | - except ValidationError: |
536 | | - continue |
537 | | - elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: |
538 | | - break # Non-4XX error, stop trying |
| 516 | + ok, asm = await handle_auth_metadata_response(oauth_metadata_response) |
| 517 | + if not ok: |
| 518 | + break |
| 519 | + if ok and asm: |
| 520 | + self.context.oauth_metadata = asm |
| 521 | + break |
| 522 | + else: |
| 523 | + logger.debug(f"OAuth metadata discovery failed: {url}") |
| 524 | + |
| 525 | + # Step 3: Apply scope selection strategy |
| 526 | + self.context.client_metadata.scope = get_client_metadata_scopes( |
| 527 | + www_auth_resource_metadata_url, |
| 528 | + self.context.protected_resource_metadata, |
| 529 | + self.context.oauth_metadata, |
| 530 | + ) |
539 | 531 |
|
540 | 532 | # Step 4: Register client if needed |
541 | | - registration_request = await self._register_client() |
542 | | - if registration_request: |
| 533 | + registration_request = create_client_registration_request( |
| 534 | + self.context.oauth_metadata, |
| 535 | + self.context.client_metadata, |
| 536 | + self.context.get_authorization_base_url(self.context.server_url), |
| 537 | + ) |
| 538 | + if not self.context.client_info: |
543 | 539 | registration_response = yield registration_request |
544 | | - await self._handle_registration_response(registration_response) |
| 540 | + client_information = await handle_registration_response(registration_response) |
| 541 | + self.context.client_info = client_information |
| 542 | + await self.context.storage.set_client_info(client_information) |
545 | 543 |
|
546 | 544 | # Step 5: Perform authorization and complete token exchange |
547 | 545 | token_response = yield await self._perform_authorization() |
|
0 commit comments