66"""
77
88import logging
9- from typing import Any
9+ from collections . abc import Awaitable , Callable
1010
1111import httpx
1212from pydantic import BaseModel , Field
@@ -166,8 +166,8 @@ def __init__(
166166 storage : TokenStorage ,
167167 idp_token_endpoint : str ,
168168 token_exchange_params : TokenExchangeParameters ,
169- redirect_handler : Any = None ,
170- callback_handler : Any = None ,
169+ redirect_handler : Callable [[ str ], Awaitable [ None ]] | None = None ,
170+ callback_handler : Callable [[], Awaitable [ tuple [ str , str | None ]]] | None = None ,
171171 timeout : float = 300.0 ,
172172 ) -> None :
173173 """
@@ -228,7 +228,8 @@ async def exchange_token_for_id_jag(
228228
229229 # Add client authentication if needed
230230 if self .context .client_info :
231- token_data ["client_id" ] = self .context .client_info .client_id
231+ if self .context .client_info .client_id is not None :
232+ token_data ["client_id" ] = self .context .client_info .client_id
232233 if self .context .client_info .client_secret is not None :
233234 token_data ["client_secret" ] = self .context .client_info .client_secret
234235
@@ -240,11 +241,11 @@ async def exchange_token_for_id_jag(
240241 )
241242
242243 if response .status_code != 200 :
243- error_data : dict [str , str ] = (
244+ error_data : dict [str , object ] = (
244245 response .json () if response .headers .get ("content-type" , "" ).startswith ("application/json" ) else {}
245246 )
246- error : str = error_data .get ("error" , "unknown_error" )
247- error_description : str = error_data .get ("error_description" , "Token exchange failed" )
247+ error = str ( error_data .get ("error" , "unknown_error" ) )
248+ error_description = str ( error_data .get ("error_description" , "Token exchange failed" ) )
248249 raise OAuthTokenError (f"Token exchange failed: { error } - { error_description } " )
249250
250251 # Parse response
@@ -298,7 +299,8 @@ async def exchange_id_jag_for_access_token(
298299
299300 # Add client authentication
300301 if self .context .client_info :
301- token_data ["client_id" ] = self .context .client_info .client_id
302+ if self .context .client_info .client_id is not None :
303+ token_data ["client_id" ] = self .context .client_info .client_id
302304 if self .context .client_info .client_secret is not None :
303305 token_data ["client_secret" ] = self .context .client_info .client_secret
304306
@@ -310,11 +312,11 @@ async def exchange_id_jag_for_access_token(
310312 )
311313
312314 if response .status_code != 200 :
313- error_data : dict [str , str ] = (
315+ error_data : dict [str , object ] = (
314316 response .json () if response .headers .get ("content-type" , "" ).startswith ("application/json" ) else {}
315317 )
316- error : str = error_data .get ("error" , "unknown_error" )
317- error_description : str = error_data .get ("error_description" , "JWT bearer grant failed" )
318+ error = str ( error_data .get ("error" , "unknown_error" ) )
319+ error_description = str ( error_data .get ("error_description" , "JWT bearer grant failed" ) )
318320 raise OAuthTokenError (f"JWT bearer grant failed: { error } - { error_description } " )
319321
320322 # Parse OAuth token response
0 commit comments