@@ -2482,56 +2482,153 @@ The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provide
248224823 . ** Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant
248324834 . ** Use Access Token** to call protected MCP server tools
24842484
2485+ ** Using the Access Token with MCP Server:**
2486+
2487+ 1 . Once you have obtained the access token, you can use it to authenticate requests to the MCP server
2488+ 2 . The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions.
2489+
2490+ ** Handling Token Expiration and Refresh:**
2491+
2492+ Access tokens have a limited lifetime and will expire. When tokens expire:
2493+
2494+ - ** Check Token Expiration** : Use the ` expires_in ` field to determine when the token expires
2495+ - ** Refresh Flow** : When expired, repeat the token exchange flow with a fresh ID token from your IdP
2496+ - ** Automatic Refresh** : Implement automatic token refresh before expiration (recommended for production)
2497+ - ** Error Handling** : Catch authentication errors and retry with refreshed tokens
2498+
2499+ ** Important Notes:**
2500+
2501+ - ** ID Token Expiration** : If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange
2502+ - ** Token Storage** : Store tokens securely and implement the ` TokenStorage ` interface to persist tokens between application restarts
2503+ - ** Scope Changes** : If you need different scopes, you must obtain a new ID token from the IdP with the required scopes
2504+ - ** Security** : Never log or expose access tokens or ID tokens in production environments
2505+
24852506** Example Usage:**
24862507
2508+ <!-- snippet-source examples/snippets/clients/enterprise_managed_auth_client.py -->
24872509``` python
24882510import asyncio
2511+ from datetime import datetime, timedelta, timezone
2512+ from typing import Any
2513+
24892514import httpx
24902515from pydantic import AnyUrl
24912516
2517+ from mcp import ClientSession
2518+ from mcp.client.auth import OAuthTokenError, TokenStorage
24922519from mcp.client.auth.extensions import (
24932520 EnterpriseAuthOAuthClientProvider,
24942521 TokenExchangeParameters,
24952522)
2496- from mcp.shared.auth import OAuthClientMetadata
2497- from mcp.client.auth import TokenStorage
2523+ from mcp.client.sse import sse_client
2524+ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
2525+ from mcp.types import CallToolResult
2526+
2527+
2528+ # Placeholder function for IdP authentication
2529+ async def get_id_token_from_idp () -> str :
2530+ """
2531+ Placeholder function to get ID token from your IdP.
2532+ In production, implement actual IdP authentication flow.
2533+ """
2534+ raise NotImplementedError (" Implement your IdP authentication flow here" )
2535+
24982536
24992537# Define token storage implementation
25002538class SimpleTokenStorage (TokenStorage ):
2501- def __init__ (self ):
2502- self ._tokens = None
2503- self ._client_info = None
2504-
2505- async def get_tokens (self ):
2539+ def __init__ (self ) -> None :
2540+ self ._tokens: OAuthToken | None = None
2541+ self ._client_info: OAuthClientInformationFull | None = None
2542+
2543+ async def get_tokens (self ) -> OAuthToken | None :
25062544 return self ._tokens
2507-
2508- async def set_tokens (self , tokens ) :
2545+
2546+ async def set_tokens (self , tokens : OAuthToken) -> None :
25092547 self ._tokens = tokens
2510-
2511- async def get_client_info (self ):
2548+
2549+ async def get_client_info (self ) -> OAuthClientInformationFull | None :
25122550 return self ._client_info
2513-
2514- async def set_client_info (self , client_info ) :
2551+
2552+ async def set_client_info (self , client_info : OAuthClientInformationFull) -> None :
25152553 self ._client_info = client_info
25162554
2517- async def main ():
2555+
2556+ def is_token_expired (access_token : OAuthToken) -> bool :
2557+ """ Check if the access token has expired."""
2558+ if access_token.expires_in:
2559+ # Calculate expiration time
2560+ issued_at = datetime.now(timezone.utc)
2561+ expiration_time = issued_at + timedelta(seconds = access_token.expires_in)
2562+ return datetime.now(timezone.utc) >= expiration_time
2563+ return False
2564+
2565+
2566+ async def refresh_access_token (
2567+ enterprise_auth : EnterpriseAuthOAuthClientProvider,
2568+ client : httpx.AsyncClient,
2569+ id_token : str ,
2570+ ) -> OAuthToken:
2571+ """ Refresh the access token when it expires."""
2572+ try :
2573+ # Update token exchange parameters with fresh ID token
2574+ enterprise_auth.token_exchange_params.subject_token = id_token
2575+
2576+ # Re-exchange for new ID-JAG
2577+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
2578+
2579+ # Get new access token
2580+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
2581+ return access_token
2582+ except Exception as e:
2583+ print (f " Token refresh failed: { e} " )
2584+ # Re-authenticate with IdP if ID token is also expired
2585+ id_token = await get_id_token_from_idp()
2586+ return await refresh_access_token(enterprise_auth, client, id_token)
2587+
2588+
2589+ async def call_tool_with_retry (
2590+ session : ClientSession,
2591+ tool_name : str ,
2592+ arguments : dict[str , Any],
2593+ enterprise_auth : EnterpriseAuthOAuthClientProvider,
2594+ client : httpx.AsyncClient,
2595+ id_token : str ,
2596+ ) -> CallToolResult | None :
2597+ """ Call a tool with automatic retry on token expiration."""
2598+ max_retries = 1
2599+
2600+ for attempt in range (max_retries + 1 ):
2601+ try :
2602+ result = await session.call_tool(tool_name, arguments)
2603+ return result
2604+ except OAuthTokenError:
2605+ if attempt < max_retries:
2606+ print (" Token expired, refreshing..." )
2607+ # Refresh token and reconnect
2608+ _access_token = await refresh_access_token(enterprise_auth, client, id_token)
2609+ # Note: In production, you'd need to reconnect the session here
2610+ else :
2611+ raise
2612+ return None
2613+
2614+
2615+ async def main () -> None :
25182616 # Step 1: Get ID token from your IdP (example with Okta)
25192617 id_token = await get_id_token_from_idp() # Your IdP authentication
2520-
2618+
25212619 # Step 2: Configure token exchange parameters
25222620 token_exchange_params = TokenExchangeParameters.from_id_token(
25232621 id_token = id_token,
25242622 mcp_server_auth_issuer = " https://your-idp.com" , # IdP issuer URL
25252623 mcp_server_resource_id = " https://mcp-server.example.com" , # MCP server resource ID
25262624 scope = " mcp:tools mcp:resources" , # Optional scopes
25272625 )
2528-
2626+
25292627 # Step 3: Create enterprise auth provider
25302628 enterprise_auth = EnterpriseAuthOAuthClientProvider(
25312629 server_url = " https://mcp-server.example.com" ,
25322630 client_metadata = OAuthClientMetadata(
25332631 client_name = " Enterprise MCP Client" ,
2534- client_id = " your-client-id" ,
25352632 redirect_uris = [AnyUrl(" http://localhost:3000/callback" )],
25362633 grant_types = [" urn:ietf:params:oauth:grant-type:jwt-bearer" ],
25372634 response_types = [" token" ],
@@ -2540,23 +2637,85 @@ async def main():
25402637 idp_token_endpoint = " https://your-idp.com/oauth2/v1/token" ,
25412638 token_exchange_params = token_exchange_params,
25422639 )
2543-
2544- # Step 4: Perform token exchange and get access token
2640+
25452641 async with httpx.AsyncClient() as client:
2546- # Exchange ID token for ID-JAG
2642+ # Step 4: Exchange ID token for ID-JAG
25472643 id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
25482644 print (f " Obtained ID-JAG: { id_jag[:50 ]} ... " )
2549-
2550- # Exchange ID-JAG for access token
2551- access_token = await enterprise_auth.exchange_id_jag_for_access_token(
2552- client, id_jag
2553- )
2645+
2646+ # Step 5: Exchange ID-JAG for access token
2647+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
25542648 print (f " Access token obtained, expires in: { access_token.expires_in} s " )
25552649
2650+ # Step 6: Check if token is expired (for demonstration)
2651+ if is_token_expired(access_token):
2652+ print (" Token is expired, refreshing..." )
2653+ access_token = await refresh_access_token(enterprise_auth, client, id_token)
2654+
2655+ # Step 7: Use the access token to connect to MCP server
2656+ headers = {" Authorization" : f " Bearer { access_token.access_token} " }
2657+
2658+ async with sse_client(url = " https://mcp-server.example.com" , headers = headers) as (read, write):
2659+ async with ClientSession(read, write) as session:
2660+ await session.initialize()
2661+
2662+ # Call tools with automatic retry on token expiration
2663+ result = await call_tool_with_retry(
2664+ session, " enterprise_tool" , {" param" : " value" }, enterprise_auth, client, id_token
2665+ )
2666+ if result:
2667+ print (f " Tool result: { result.content} " )
2668+
2669+ # List available resources
2670+ resources = await session.list_resources()
2671+ for resource in resources.resources:
2672+ print (f " Resource: { resource.uri} " )
2673+
2674+
2675+ async def maintain_active_session (
2676+ enterprise_auth : EnterpriseAuthOAuthClientProvider,
2677+ mcp_server_url : str ,
2678+ ) -> None :
2679+ """ Maintain an active session with automatic token refresh."""
2680+ id_token_var = await get_id_token_from_idp()
2681+
2682+ async with httpx.AsyncClient() as client:
2683+ while True :
2684+ try :
2685+ # Update token exchange params with current ID token
2686+ enterprise_auth.token_exchange_params.subject_token = id_token_var
2687+
2688+ # Get access token
2689+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
2690+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
2691+
2692+ # Calculate refresh time (refresh before expiration)
2693+ refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300
2694+
2695+ # Use the token for MCP operations
2696+ headers = {" Authorization" : f " Bearer { access_token.access_token} " }
2697+ async with sse_client(mcp_server_url, headers = headers) as (read, write):
2698+ async with ClientSession(read, write) as session:
2699+ await session.initialize()
2700+
2701+ # Perform operations...
2702+ # Schedule refresh before token expires
2703+ await asyncio.sleep(refresh_in)
2704+
2705+ except Exception as e:
2706+ print (f " Session error: { e} " )
2707+ # Re-authenticate with IdP
2708+ id_token_var = await get_id_token_from_idp()
2709+ await asyncio.sleep(5 ) # Wait before retry
2710+
2711+
25562712if __name__ == " __main__" :
25572713 asyncio.run(main())
25582714```
25592715
2716+ _ Full example: [ examples/snippets/clients/enterprise_managed_auth_client.py] ( https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py ) _
2717+ <!-- /snippet-source -->
2718+
25602719** Working with SAML Assertions:**
25612720
25622721If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions:
0 commit comments