@@ -80,6 +80,40 @@ class JWTParameters(BaseModel):
8080 jwt_signing_key : str | None = Field (default = None , description = "Private key for JWT signing." )
8181 jwt_lifetime_seconds : int = Field (default = 300 , description = "Lifetime of generated JWT in seconds." )
8282
83+ def to_assertion (self , with_audience_fallback : str | None = None ) -> str :
84+ if self .assertion is not None :
85+ # Prebuilt JWT (e.g. acquired out-of-band)
86+ assertion = self .assertion
87+ else :
88+ if not self .jwt_signing_key :
89+ raise OAuthFlowError ("Missing signing key for JWT bearer grant" )
90+ if not self .issuer :
91+ raise OAuthFlowError ("Missing issuer for JWT bearer grant" )
92+ if not self .subject :
93+ raise OAuthFlowError ("Missing subject for JWT bearer grant" )
94+
95+ audience = self .audience if self .audience else with_audience_fallback
96+ if not audience :
97+ raise OAuthFlowError ("Missing audience for JWT bearer grant" )
98+
99+ now = int (time .time ())
100+ claims : dict [str , Any ] = {
101+ "iss" : self .issuer ,
102+ "sub" : self .subject ,
103+ "aud" : audience ,
104+ "exp" : now + self .jwt_lifetime_seconds ,
105+ "iat" : now ,
106+ "jti" : str (uuid4 ()),
107+ }
108+ claims .update (self .claims or {})
109+
110+ assertion = jwt .encode (
111+ claims ,
112+ self .jwt_signing_key ,
113+ algorithm = self .jwt_signing_algorithm or "RS256" ,
114+ )
115+ return assertion
116+
83117
84118class TokenStorage (Protocol ):
85119 """Protocol for token storage implementations."""
@@ -111,7 +145,6 @@ class OAuthContext:
111145 redirect_handler : Callable [[str ], Awaitable [None ]] | None
112146 callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None
113147 timeout : float = 300.0
114- jwt_parameters : JWTParameters | None = None
115148
116149 # Discovered metadata
117150 protected_resource_metadata : ProtectedResourceMetadata | None = None
@@ -213,7 +246,6 @@ def __init__(
213246 redirect_handler : Callable [[str ], Awaitable [None ]] | None = None ,
214247 callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None = None ,
215248 timeout : float = 300.0 ,
216- jwt_parameters : JWTParameters | None = None ,
217249 ):
218250 """Initialize OAuth2 authentication."""
219251 self .context = OAuthContext (
@@ -223,7 +255,6 @@ def __init__(
223255 redirect_handler = redirect_handler ,
224256 callback_handler = callback_handler ,
225257 timeout = timeout ,
226- jwt_parameters = jwt_parameters ,
227258 )
228259 self ._initialized = False
229260
@@ -334,16 +365,9 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
334365
335366 async def _perform_authorization (self ) -> httpx .Request :
336367 """Perform the authorization flow."""
337- if "client_credentials" in self .context .client_metadata .grant_types :
338- token_request = await self ._exchange_token_client_credentials ()
339- return token_request
340- elif "urn:ietf:params:oauth:grant-type:jwt-bearer" in self .context .client_metadata .grant_types :
341- token_request = await self ._exchange_token_jwt_bearer ()
342- return token_request
343- else :
344- auth_code , code_verifier = await self ._perform_authorization_code_grant ()
345- token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
346- return token_request
368+ auth_code , code_verifier = await self ._perform_authorization_code_grant ()
369+ token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
370+ return token_request
347371
348372 async def _perform_authorization_code_grant (self ) -> tuple [str , str ]:
349373 """Perform the authorization redirect and get auth code."""
@@ -406,21 +430,25 @@ def _get_token_endpoint(self) -> str:
406430 token_url = urljoin (auth_base_url , "/token" )
407431 return token_url
408432
409- async def _exchange_token_authorization_code (self , auth_code : str , code_verifier : str ) -> httpx .Request :
433+ async def _exchange_token_authorization_code (
434+ self , auth_code : str , code_verifier : str , * , token_data : dict [str , Any ] = {}
435+ ) -> httpx .Request :
410436 """Build token exchange request for authorization_code flow."""
411437 if self .context .client_metadata .redirect_uris is None :
412438 raise OAuthFlowError ("No redirect URIs provided for authorization code grant" )
413439 if not self .context .client_info :
414440 raise OAuthFlowError ("Missing client info" )
415441
416442 token_url = self ._get_token_endpoint ()
417- token_data = {
418- "grant_type" : "authorization_code" ,
419- "code" : auth_code ,
420- "redirect_uri" : str (self .context .client_metadata .redirect_uris [0 ]),
421- "client_id" : self .context .client_info .client_id ,
422- "code_verifier" : code_verifier ,
423- }
443+ token_data .update (
444+ {
445+ "grant_type" : "authorization_code" ,
446+ "code" : auth_code ,
447+ "redirect_uri" : str (self .context .client_metadata .redirect_uris [0 ]),
448+ "client_id" : self .context .client_info .client_id ,
449+ "code_verifier" : code_verifier ,
450+ }
451+ )
424452
425453 # Only include resource param if conditions are met
426454 if self .context .should_include_resource_param (self .context .protocol_version ):
@@ -433,131 +461,6 @@ async def _exchange_token_authorization_code(self, auth_code: str, code_verifier
433461 "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
434462 )
435463
436- async def _exchange_token_client_credentials (self ) -> httpx .Request :
437- """Build token exchange request for client_credentials flow."""
438- if not self .context .client_info :
439- raise OAuthFlowError ("Missing client info" )
440-
441- token_url = self ._get_token_endpoint ()
442- token_data = {
443- "grant_type" : "client_credentials" ,
444- }
445-
446- headers = {"Content-Type" : "application/x-www-form-urlencoded" }
447-
448- # Only include resource param if conditions are met
449- if self .context .should_include_resource_param (self .context .protocol_version ):
450- token_data ["resource" ] = self .context .get_resource_url () # RFC 8707
451-
452- if self .context .client_metadata .scope :
453- token_data ["scope" ] = self .context .client_metadata .scope
454-
455- if self .context .client_metadata .token_endpoint_auth_method == "client_secret_post" :
456- # Include in request body
457- if self .context .client_info .client_id :
458- token_data ["client_id" ] = self .context .client_info .client_id
459- if self .context .client_info .client_secret :
460- token_data ["client_secret" ] = self .context .client_info .client_secret
461- elif self .context .client_metadata .token_endpoint_auth_method == "client_secret_basic" :
462- # Include as Basic auth header
463- if not self .context .client_info .client_id :
464- raise OAuthTokenError ("Missing client_id in Basic auth flow" )
465- if not self .context .client_info .client_secret :
466- raise OAuthTokenError ("Missing client_secret in Basic auth flow" )
467- raw_auth = f"{ self .context .client_info .client_id } :{ self .context .client_info .client_secret } "
468- headers ["Authorization" ] = f"Basic { base64 .b64encode (raw_auth .encode ()).decode ()} "
469- elif self .context .client_metadata .token_endpoint_auth_method == "private_key_jwt" :
470- # Use JWT assertion for client authentication
471- if not self .context .jwt_parameters :
472- raise OAuthTokenError ("Missing JWT parameters for private_key_jwt flow" )
473-
474- if self .context .jwt_parameters .assertion is not None :
475- # Prebuilt JWT (e.g. acquired out-of-band)
476- assertion = self .context .jwt_parameters .assertion
477- else :
478- if not self .context .jwt_parameters .jwt_signing_key :
479- raise OAuthTokenError ("Missing JWT signing key for private_key_jwt flow" )
480- if not self .context .jwt_parameters .jwt_signing_algorithm :
481- raise OAuthTokenError ("Missing JWT signing algorithm for private_key_jwt flow" )
482-
483- now = int (time .time ())
484- claims = {
485- "iss" : self .context .jwt_parameters .issuer ,
486- "sub" : self .context .jwt_parameters .subject ,
487- "aud" : self .context .jwt_parameters .audience if self .context .jwt_parameters .audience else token_url ,
488- "exp" : now + self .context .jwt_parameters .jwt_lifetime_seconds ,
489- "iat" : now ,
490- "jti" : str (uuid4 ()),
491- }
492- claims .update (self .context .jwt_parameters .claims or {})
493-
494- assertion = jwt .encode (
495- claims ,
496- self .context .jwt_parameters .jwt_signing_key ,
497- algorithm = self .context .jwt_parameters .jwt_signing_algorithm or "RS256" ,
498- )
499-
500- # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
501- token_data ["client_assertion" ] = assertion
502- token_data ["client_assertion_type" ] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
503- # We need to set the audience to the token endpoint, the audience is difference from the one in claims
504- # it represents the resource server that will validate the token
505- token_data ["audience" ] = self .context .get_resource_url ()
506-
507- return httpx .Request ("POST" , token_url , data = token_data , headers = headers )
508-
509- async def _exchange_token_jwt_bearer (self ) -> httpx .Request :
510- """Build token exchange request for JWT bearer grant."""
511- if not self .context .client_info :
512- raise OAuthFlowError ("Missing client info" )
513- if not self .context .jwt_parameters :
514- raise OAuthFlowError ("Missing JWT parameters" )
515-
516- token_url = self ._get_token_endpoint ()
517-
518- if self .context .jwt_parameters .assertion is not None :
519- # Prebuilt JWT (e.g. acquired out-of-band)
520- assertion = self .context .jwt_parameters .assertion
521- else :
522- if not self .context .jwt_parameters .jwt_signing_key :
523- raise OAuthFlowError ("Missing signing key for JWT bearer grant" )
524- if not self .context .jwt_parameters .issuer :
525- raise OAuthFlowError ("Missing issuer for JWT bearer grant" )
526- if not self .context .jwt_parameters .subject :
527- raise OAuthFlowError ("Missing subject for JWT bearer grant" )
528-
529- now = int (time .time ())
530- claims = {
531- "iss" : self .context .jwt_parameters .issuer ,
532- "sub" : self .context .jwt_parameters .subject ,
533- "aud" : token_url ,
534- "exp" : now + self .context .jwt_parameters .jwt_lifetime_seconds ,
535- "iat" : now ,
536- "jti" : str (uuid4 ()),
537- }
538- claims .update (self .context .jwt_parameters .claims or {})
539-
540- assertion = jwt .encode (
541- claims ,
542- self .context .jwt_parameters .jwt_signing_key ,
543- algorithm = self .context .jwt_parameters .jwt_signing_algorithm or "RS256" ,
544- )
545-
546- token_data = {
547- "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
548- "assertion" : assertion ,
549- }
550-
551- if self .context .should_include_resource_param (self .context .protocol_version ):
552- token_data ["resource" ] = self .context .get_resource_url ()
553-
554- if self .context .client_metadata .scope :
555- token_data ["scope" ] = self .context .client_metadata .scope
556-
557- return httpx .Request (
558- "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
559- )
560-
561464 async def _handle_token_response (self , response : httpx .Response ) -> None :
562465 """Handle token exchange response."""
563466 if response .status_code != 200 :
@@ -720,3 +623,78 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
720623 # Retry with new tokens
721624 self ._add_auth_header (request )
722625 yield request
626+
627+
628+ class RFC7523OAuthClientProvider (OAuthClientProvider ):
629+ """OAuth client provider for RFC7532 clients."""
630+
631+ jwt_parameters : JWTParameters | None = None
632+
633+ def __init__ (
634+ self ,
635+ server_url : str ,
636+ client_metadata : OAuthClientMetadata ,
637+ storage : TokenStorage ,
638+ redirect_handler : Callable [[str ], Awaitable [None ]] | None = None ,
639+ callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None = None ,
640+ timeout : float = 300.0 ,
641+ jwt_parameters : JWTParameters | None = None ,
642+ ) -> None :
643+ super ().__init__ (server_url , client_metadata , storage , redirect_handler , callback_handler , timeout )
644+ self .jwt_parameters = jwt_parameters
645+
646+ async def _exchange_token_authorization_code (
647+ self , auth_code : str , code_verifier : str , * , token_data : dict [str , Any ] = {}
648+ ) -> httpx .Request :
649+ """Build token exchange request for authorization_code flow."""
650+ if self .context .client_metadata .token_endpoint_auth_method == "private_key_jwt" :
651+ self ._add_client_authentication_jwt (token_data = token_data )
652+ return await super ()._exchange_token_authorization_code (auth_code , code_verifier , token_data = token_data )
653+
654+ async def _perform_authorization (self ) -> httpx .Request :
655+ """Perform the authorization flow."""
656+ if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self .context .client_metadata .grant_types :
657+ token_request = await self ._exchange_token_jwt_bearer ()
658+ return token_request
659+ else :
660+ return await super ()._perform_authorization ()
661+
662+ def _add_client_authentication_jwt (self , * , token_data : dict [str , Any ]):
663+ """Add JWT assertion for client authentication to token endpoint parameters."""
664+ if not self .jwt_parameters :
665+ raise OAuthTokenError ("Missing JWT parameters for private_key_jwt flow" )
666+
667+ token_url = self ._get_token_endpoint ()
668+ assertion = self .jwt_parameters .to_assertion (with_audience_fallback = token_url )
669+
670+ # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
671+ token_data ["client_assertion" ] = assertion
672+ token_data ["client_assertion_type" ] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
673+ # We need to set the audience to the token endpoint, the audience is difference from the one in claims
674+ # it represents the resource server that will validate the token
675+ token_data ["audience" ] = self .context .get_resource_url ()
676+
677+ async def _exchange_token_jwt_bearer (self ) -> httpx .Request :
678+ """Build token exchange request for JWT bearer grant."""
679+ if not self .context .client_info :
680+ raise OAuthFlowError ("Missing client info" )
681+ if not self .jwt_parameters :
682+ raise OAuthFlowError ("Missing JWT parameters" )
683+
684+ token_url = self ._get_token_endpoint ()
685+ assertion = self .jwt_parameters .to_assertion (with_audience_fallback = token_url )
686+
687+ token_data = {
688+ "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
689+ "assertion" : assertion ,
690+ }
691+
692+ if self .context .should_include_resource_param (self .context .protocol_version ):
693+ token_data ["resource" ] = self .context .get_resource_url ()
694+
695+ if self .context .client_metadata .scope :
696+ token_data ["scope" ] = self .context .client_metadata .scope
697+
698+ return httpx .Request (
699+ "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
700+ )
0 commit comments