Skip to content

Commit d7f2cc0

Browse files
committed
- Added typing for request payload structures TokenExchangeRequestData and JWTBearerGrantRequestData.
- Added snippet file for adding code to the README.md file. - Added new section in README.md file to add information regarding: "how to use the access token once you get it" and "How does this work when the client ID is expired?".
1 parent a4f0c40 commit d7f2cc0

File tree

4 files changed

+417
-52
lines changed

4 files changed

+417
-52
lines changed

README.md

Lines changed: 184 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2482,56 +2482,153 @@ The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provide
24822482
3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant
24832483
4. **Use Access Token** to call protected MCP server tools
24842484

2485+
**Using the Access Token with MCP Server:**
2486+
2487+
1. Once you have obtained the access token, you can use it to authenticate requests to the MCP server
2488+
2. The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions.
2489+
2490+
**Handling Token Expiration and Refresh:**
2491+
2492+
Access tokens have a limited lifetime and will expire. When tokens expire:
2493+
2494+
- **Check Token Expiration**: Use the `expires_in` field to determine when the token expires
2495+
- **Refresh Flow**: When expired, repeat the token exchange flow with a fresh ID token from your IdP
2496+
- **Automatic Refresh**: Implement automatic token refresh before expiration (recommended for production)
2497+
- **Error Handling**: Catch authentication errors and retry with refreshed tokens
2498+
2499+
**Important Notes:**
2500+
2501+
- **ID Token Expiration**: If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange
2502+
- **Token Storage**: Store tokens securely and implement the `TokenStorage` interface to persist tokens between application restarts
2503+
- **Scope Changes**: If you need different scopes, you must obtain a new ID token from the IdP with the required scopes
2504+
- **Security**: Never log or expose access tokens or ID tokens in production environments
2505+
24852506
**Example Usage:**
24862507

2508+
<!-- snippet-source examples/snippets/clients/enterprise_managed_auth_client.py -->
24872509
```python
24882510
import asyncio
2511+
from datetime import datetime, timedelta, timezone
2512+
from typing import Any
2513+
24892514
import httpx
24902515
from pydantic import AnyUrl
24912516

2517+
from mcp import ClientSession
2518+
from mcp.client.auth import OAuthTokenError, TokenStorage
24922519
from mcp.client.auth.extensions import (
24932520
EnterpriseAuthOAuthClientProvider,
24942521
TokenExchangeParameters,
24952522
)
2496-
from mcp.shared.auth import OAuthClientMetadata
2497-
from mcp.client.auth import TokenStorage
2523+
from mcp.client.sse import sse_client
2524+
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
2525+
from mcp.types import CallToolResult
2526+
2527+
2528+
# Placeholder function for IdP authentication
2529+
async def get_id_token_from_idp() -> str:
2530+
"""
2531+
Placeholder function to get ID token from your IdP.
2532+
In production, implement actual IdP authentication flow.
2533+
"""
2534+
raise NotImplementedError("Implement your IdP authentication flow here")
2535+
24982536

24992537
# Define token storage implementation
25002538
class SimpleTokenStorage(TokenStorage):
2501-
def __init__(self):
2502-
self._tokens = None
2503-
self._client_info = None
2504-
2505-
async def get_tokens(self):
2539+
def __init__(self) -> None:
2540+
self._tokens: OAuthToken | None = None
2541+
self._client_info: OAuthClientInformationFull | None = None
2542+
2543+
async def get_tokens(self) -> OAuthToken | None:
25062544
return self._tokens
2507-
2508-
async def set_tokens(self, tokens):
2545+
2546+
async def set_tokens(self, tokens: OAuthToken) -> None:
25092547
self._tokens = tokens
2510-
2511-
async def get_client_info(self):
2548+
2549+
async def get_client_info(self) -> OAuthClientInformationFull | None:
25122550
return self._client_info
2513-
2514-
async def set_client_info(self, client_info):
2551+
2552+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
25152553
self._client_info = client_info
25162554

2517-
async def main():
2555+
2556+
def is_token_expired(access_token: OAuthToken) -> bool:
2557+
"""Check if the access token has expired."""
2558+
if access_token.expires_in:
2559+
# Calculate expiration time
2560+
issued_at = datetime.now(timezone.utc)
2561+
expiration_time = issued_at + timedelta(seconds=access_token.expires_in)
2562+
return datetime.now(timezone.utc) >= expiration_time
2563+
return False
2564+
2565+
2566+
async def refresh_access_token(
2567+
enterprise_auth: EnterpriseAuthOAuthClientProvider,
2568+
client: httpx.AsyncClient,
2569+
id_token: str,
2570+
) -> OAuthToken:
2571+
"""Refresh the access token when it expires."""
2572+
try:
2573+
# Update token exchange parameters with fresh ID token
2574+
enterprise_auth.token_exchange_params.subject_token = id_token
2575+
2576+
# Re-exchange for new ID-JAG
2577+
id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
2578+
2579+
# Get new access token
2580+
access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
2581+
return access_token
2582+
except Exception as e:
2583+
print(f"Token refresh failed: {e}")
2584+
# Re-authenticate with IdP if ID token is also expired
2585+
id_token = await get_id_token_from_idp()
2586+
return await refresh_access_token(enterprise_auth, client, id_token)
2587+
2588+
2589+
async def call_tool_with_retry(
2590+
session: ClientSession,
2591+
tool_name: str,
2592+
arguments: dict[str, Any],
2593+
enterprise_auth: EnterpriseAuthOAuthClientProvider,
2594+
client: httpx.AsyncClient,
2595+
id_token: str,
2596+
) -> CallToolResult | None:
2597+
"""Call a tool with automatic retry on token expiration."""
2598+
max_retries = 1
2599+
2600+
for attempt in range(max_retries + 1):
2601+
try:
2602+
result = await session.call_tool(tool_name, arguments)
2603+
return result
2604+
except OAuthTokenError:
2605+
if attempt < max_retries:
2606+
print("Token expired, refreshing...")
2607+
# Refresh token and reconnect
2608+
_access_token = await refresh_access_token(enterprise_auth, client, id_token)
2609+
# Note: In production, you'd need to reconnect the session here
2610+
else:
2611+
raise
2612+
return None
2613+
2614+
2615+
async def main() -> None:
25182616
# Step 1: Get ID token from your IdP (example with Okta)
25192617
id_token = await get_id_token_from_idp() # Your IdP authentication
2520-
2618+
25212619
# Step 2: Configure token exchange parameters
25222620
token_exchange_params = TokenExchangeParameters.from_id_token(
25232621
id_token=id_token,
25242622
mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL
25252623
mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID
25262624
scope="mcp:tools mcp:resources", # Optional scopes
25272625
)
2528-
2626+
25292627
# Step 3: Create enterprise auth provider
25302628
enterprise_auth = EnterpriseAuthOAuthClientProvider(
25312629
server_url="https://mcp-server.example.com",
25322630
client_metadata=OAuthClientMetadata(
25332631
client_name="Enterprise MCP Client",
2534-
client_id="your-client-id",
25352632
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
25362633
grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
25372634
response_types=["token"],
@@ -2540,23 +2637,85 @@ async def main():
25402637
idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
25412638
token_exchange_params=token_exchange_params,
25422639
)
2543-
2544-
# Step 4: Perform token exchange and get access token
2640+
25452641
async with httpx.AsyncClient() as client:
2546-
# Exchange ID token for ID-JAG
2642+
# Step 4: Exchange ID token for ID-JAG
25472643
id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
25482644
print(f"Obtained ID-JAG: {id_jag[:50]}...")
2549-
2550-
# Exchange ID-JAG for access token
2551-
access_token = await enterprise_auth.exchange_id_jag_for_access_token(
2552-
client, id_jag
2553-
)
2645+
2646+
# Step 5: Exchange ID-JAG for access token
2647+
access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
25542648
print(f"Access token obtained, expires in: {access_token.expires_in}s")
25552649

2650+
# Step 6: Check if token is expired (for demonstration)
2651+
if is_token_expired(access_token):
2652+
print("Token is expired, refreshing...")
2653+
access_token = await refresh_access_token(enterprise_auth, client, id_token)
2654+
2655+
# Step 7: Use the access token to connect to MCP server
2656+
headers = {"Authorization": f"Bearer {access_token.access_token}"}
2657+
2658+
async with sse_client(url="https://mcp-server.example.com", headers=headers) as (read, write):
2659+
async with ClientSession(read, write) as session:
2660+
await session.initialize()
2661+
2662+
# Call tools with automatic retry on token expiration
2663+
result = await call_tool_with_retry(
2664+
session, "enterprise_tool", {"param": "value"}, enterprise_auth, client, id_token
2665+
)
2666+
if result:
2667+
print(f"Tool result: {result.content}")
2668+
2669+
# List available resources
2670+
resources = await session.list_resources()
2671+
for resource in resources.resources:
2672+
print(f"Resource: {resource.uri}")
2673+
2674+
2675+
async def maintain_active_session(
2676+
enterprise_auth: EnterpriseAuthOAuthClientProvider,
2677+
mcp_server_url: str,
2678+
) -> None:
2679+
"""Maintain an active session with automatic token refresh."""
2680+
id_token_var = await get_id_token_from_idp()
2681+
2682+
async with httpx.AsyncClient() as client:
2683+
while True:
2684+
try:
2685+
# Update token exchange params with current ID token
2686+
enterprise_auth.token_exchange_params.subject_token = id_token_var
2687+
2688+
# Get access token
2689+
id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
2690+
access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
2691+
2692+
# Calculate refresh time (refresh before expiration)
2693+
refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300
2694+
2695+
# Use the token for MCP operations
2696+
headers = {"Authorization": f"Bearer {access_token.access_token}"}
2697+
async with sse_client(mcp_server_url, headers=headers) as (read, write):
2698+
async with ClientSession(read, write) as session:
2699+
await session.initialize()
2700+
2701+
# Perform operations...
2702+
# Schedule refresh before token expires
2703+
await asyncio.sleep(refresh_in)
2704+
2705+
except Exception as e:
2706+
print(f"Session error: {e}")
2707+
# Re-authenticate with IdP
2708+
id_token_var = await get_id_token_from_idp()
2709+
await asyncio.sleep(5) # Wait before retry
2710+
2711+
25562712
if __name__ == "__main__":
25572713
asyncio.run(main())
25582714
```
25592715

2716+
_Full example: [examples/snippets/clients/enterprise_managed_auth_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py)_
2717+
<!-- /snippet-source -->
2718+
25602719
**Working with SAML Assertions:**
25612720

25622721
If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions:

0 commit comments

Comments
 (0)