@@ -251,14 +251,20 @@ def _create_registration_request(self, metadata: OAuthMetadata | None = None) ->
251251 headers = {"Content-Type" : "application/json" },
252252 )
253253
254- async def _handle_registration_response (self , response : httpx .Response ) -> None :
254+ async def _handle_registration_response (
255+ self , response : httpx .Response
256+ ) -> OAuthClientInformationFull :
255257 if response .status_code not in (200 , 201 ):
256258 await response .aread ()
257259 raise OAuthRegistrationError (f"Registration failed: { response .status_code } { response .text } " )
258260 content = await response .aread ()
259261 client_info = OAuthClientInformationFull .model_validate_json (content )
260262 self ._client_info = client_info
261263 await self .storage .set_client_info (client_info )
264+ context = getattr (self , "context" , None )
265+ if context is not None :
266+ context .client_info = client_info
267+ return client_info
262268
263269 def _apply_client_auth (
264270 self ,
@@ -315,6 +321,18 @@ def __init__(
315321 )
316322 self ._initialized = False
317323
324+ def _build_protected_resource_discovery_urls (self , resource_metadata_url : str | None ) -> list [str ]:
325+ """Build the list of PRM discovery URLs with legacy fallbacks."""
326+ return build_protected_resource_metadata_discovery_urls (
327+ resource_metadata_url , self .context .server_url
328+ )
329+
330+ def _get_discovery_urls (self , server_url : str | None = None ) -> list [str ]:
331+ """Build OAuth authorization server discovery URLs with legacy fallbacks."""
332+ return build_oauth_authorization_server_metadata_discovery_urls (
333+ server_url , self .context .server_url
334+ )
335+
318336 async def _handle_protected_resource_response (self , response : httpx .Response ) -> bool :
319337 """
320338 Handle protected resource metadata discovery response.
@@ -324,28 +342,30 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
324342 Returns:
325343 True if metadata was successfully discovered, False if we should try next URL
326344 """
327- if response .status_code == 200 :
328- try :
329- content = await response .aread ()
330- metadata = ProtectedResourceMetadata .model_validate_json (content )
331- self .context .protected_resource_metadata = metadata
332- if metadata .authorization_servers : # pragma: no branch
333- self .context .auth_server_url = str (metadata .authorization_servers [0 ])
334- return True
335-
336- except ValidationError : # pragma: no cover
337- # Invalid metadata - try next URL
338- logger .warning (f"Invalid protected resource metadata at { response .request .url } " )
339- return False
340- elif response .status_code == 404 : # pragma: no cover
341- # Not found - try next URL in fallback chain
342- logger .debug (f"Protected resource metadata not found at { response .request .url } , trying next URL" )
343- return False
344- else :
345- # Other error - fail immediately
346- raise OAuthFlowError (
347- f"Protected Resource Metadata request failed: { response .status_code } "
348- ) # pragma: no cover
345+ metadata = await handle_protected_resource_response (response )
346+ if metadata :
347+ self .context .protected_resource_metadata = metadata
348+ if metadata .authorization_servers : # pragma: no branch
349+ self .context .auth_server_url = str (metadata .authorization_servers [0 ])
350+ return True
351+
352+ logger .debug (
353+ "Protected resource metadata discovery failed with status %s at %s" ,
354+ response .status_code ,
355+ response .request .url ,
356+ )
357+ return False
358+
359+ async def _handle_oauth_metadata_response (
360+ self , response : httpx .Response
361+ ) -> tuple [bool , OAuthMetadata | None ]:
362+ ok , asm = await handle_auth_metadata_response (response )
363+ if asm :
364+ self .context .oauth_metadata = asm
365+ self ._metadata = asm
366+ if self .context .client_metadata .scope is None and asm .scopes_supported is not None :
367+ self .context .client_metadata .scope = " " .join (asm .scopes_supported )
368+ return ok , asm
349369
350370 async def _perform_authorization (self ) -> httpx .Request :
351371 """Perform the authorization flow."""
@@ -560,34 +580,33 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
560580 self ._metadata = None
561581
562582 # Step 1: Discover protected resource metadata (SEP-985 with fallback support)
563- prm_discovery_urls = build_protected_resource_metadata_discovery_urls (
564- www_auth_resource_metadata_url , self . context . server_url
583+ prm_discovery_urls = self . _build_protected_resource_discovery_urls (
584+ www_auth_resource_metadata_url
565585 )
566586
567587 for url in prm_discovery_urls : # pragma: no branch
568- discovery_request = create_oauth_metadata_request (url )
588+ discovery_request = self . _create_oauth_metadata_request (url )
569589 discovery_response = yield discovery_request
570590
571- prm = await handle_protected_resource_response (discovery_response )
572- if prm :
573- self .context .protected_resource_metadata = prm
574- if prm .authorization_servers : # pragma: no branch
575- self .context .auth_server_url = str (prm .authorization_servers [0 ])
591+ handled = await self ._handle_protected_resource_response (discovery_response )
592+ if handled :
576593 break
577594
578- logger .debug (f"Protected resource metadata discovery failed: { url } " )
579-
580595 # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
581- asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls (
582- self .context .auth_server_url , self .context .server_url
583- )
596+ asm_discovery_urls = self ._get_discovery_urls (self .context .auth_server_url )
584597
585598 authorization_metadata : OAuthMetadata | None = None
586599 for url in asm_discovery_urls : # pragma: no branch
587- oauth_metadata_request = create_oauth_metadata_request (url )
600+ oauth_metadata_request = self . _create_oauth_metadata_request (url )
588601 oauth_metadata_response = yield oauth_metadata_request
589602
590- ok , asm = await handle_auth_metadata_response (oauth_metadata_response )
603+ result = await self ._handle_oauth_metadata_response (oauth_metadata_response )
604+ if isinstance (result , tuple ):
605+ ok , asm = result
606+ else :
607+ ok = bool (result ) if result is not None else True
608+ asm = self .context .oauth_metadata or self ._metadata
609+
591610 if not ok :
592611 break
593612 if asm :
@@ -615,9 +634,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
615634 self .context .get_authorization_base_url (self .context .server_url ),
616635 )
617636 registration_response = yield registration_request
618- client_information = await handle_registration_response (registration_response )
619- self .context .client_info = client_information
620- await self .context .storage .set_client_info (client_information )
637+ await self ._handle_registration_response (registration_response )
621638
622639 # Step 5: Perform authorization and complete token exchange
623640 token_response = yield await self ._perform_authorization ()
0 commit comments