77"""
88
99import asyncio
10- import json
1110import os
1211import threading
1312import time
14- import webbrowser
1513from datetime import timedelta
1614from http .server import BaseHTTPRequestHandler , HTTPServer
1715from typing import Any
1816from urllib .parse import parse_qs , urlparse
1917
20- from mcp .client .auth import (
21- OAuthClientProvider ,
22- discover_oauth_metadata ,
23- )
2418from mcp .client .oauth_auth import OAuthAuth
2519from mcp .client .session import ClientSession
2620from mcp .client .streamable_http import streamablehttp_client
27- from mcp .shared .auth import OAuthClientInformationFull , OAuthClientMetadata , OAuthToken
21+ from mcp .shared .auth import OAuthClientMetadata
2822
2923
3024class CallbackHandler (BaseHTTPRequestHandler ):
@@ -114,141 +108,14 @@ def wait_for_callback(self, timeout=300):
114108 raise Exception ("Timeout waiting for OAuth callback" )
115109
116110
117- class JsonSerializableOAuthClientMetadata (OAuthClientMetadata ):
118- """OAuth client metadata that handles JSON serialization properly."""
119-
120- def model_dump (self , ** kwargs ) -> dict [str , Any ]:
121- """Override to ensure URLs are serialized as strings and exclude null values."""
122- # Exclude null values by default
123- kwargs .setdefault ("exclude_none" , True )
124- data = super ().model_dump (** kwargs )
125-
126- # Convert AnyHttpUrl objects to strings
127- if "redirect_uris" in data :
128- data ["redirect_uris" ] = [str (url ) for url in data ["redirect_uris" ]]
129-
130- # Debug: print what we're sending
131- print (f"🐛 Client metadata being sent: { json .dumps (data , indent = 2 )} " )
132- return data
133-
134-
135- class SimpleOAuthProvider (OAuthClientProvider ):
136- """Simple OAuth client provider for demonstration purposes."""
137-
138- def __init__ (self , server_url : str , callback_port : int = 3000 ):
139- self ._callback_port = callback_port
140- self ._redirect_uri = f"http://localhost:{ callback_port } /callback"
141- self ._server_url = server_url
142- self ._callback_server = None
143- print (f"🐛 OAuth provider initialized with redirect URI: { self ._redirect_uri } " )
144- # Store the raw data for easy serialization - scope will be updated dynamically
145- self ._client_metadata_dict = {
146- "client_name" : "Simple Auth Client" ,
147- "redirect_uris" : [self ._redirect_uri ],
148- "grant_types" : ["authorization_code" , "refresh_token" ],
149- "response_types" : ["code" ],
150- "token_endpoint_auth_method" : "client_secret_post" , # Use client secret
151- "scope" : "read" , # Default scope, will be updated
152- }
153- self ._client_info : OAuthClientInformationFull | None = None
154- self ._tokens : OAuthToken | None = None
155- self ._code_verifier : str | None = None
156- self ._authorization_code : str | None = None
157- self ._metadata_discovered = False
158-
159- @property
160- def redirect_url (self ) -> str :
161- return self ._redirect_uri
162-
163- async def _discover_and_update_metadata (self ):
164- """Discover server OAuth metadata and update client scope accordingly."""
165- if self ._metadata_discovered :
166- return
167-
168- try :
169- print ("🐛 Discovering OAuth metadata..." )
170- metadata = await discover_oauth_metadata (self ._server_url )
171- if metadata and metadata .scopes_supported :
172- scope = " " .join (metadata .scopes_supported )
173- self ._client_metadata_dict ["scope" ] = scope
174- print (f"🐛 Updated scope to: { scope } " )
175- self ._metadata_discovered = True
176- except Exception as e :
177- print (f"🐛 Failed to discover metadata: { e } , using default scope" )
178- self ._metadata_discovered = True
179-
180- @property
181- def client_metadata (self ) -> OAuthClientMetadata :
182- # Create a fresh instance each time using our custom serializable version
183- return JsonSerializableOAuthClientMetadata .model_validate (
184- self ._client_metadata_dict
185- )
186-
187- async def client_information (self ) -> OAuthClientInformationFull | None :
188- return self ._client_info
189-
190- async def save_client_information (
191- self , client_information : OAuthClientInformationFull
192- ) -> None :
193- self ._client_info = client_information
194- print (f"Saved client information: { client_information .client_id } " )
195-
196- async def tokens (self ) -> OAuthToken | None :
197- return self ._tokens
198-
199- async def save_tokens (self , tokens : OAuthToken ) -> None :
200- self ._tokens = tokens
201- print (f"Saved OAuth tokens: { tokens .access_token [:10 ]} ..." )
202-
203- async def redirect_to_authorization (self , authorization_url : str ) -> None :
204- # Start callback server
205- self ._callback_server = CallbackServer (self ._callback_port )
206- self ._callback_server .start ()
207-
208- print ("\n 🌐 Opening authorization URL in your default browser..." )
209- print (f"URL: { authorization_url } " )
210- webbrowser .open (authorization_url )
211-
212- print ("⏳ Waiting for authorization callback..." )
213- print ("(Complete the authorization in your browser)" )
214-
215- try :
216- # Wait for the callback with authorization code
217- authorization_code = self ._callback_server .wait_for_callback (timeout = 300 )
218- print (f"✅ Received authorization code: { authorization_code [:20 ]} ..." )
219-
220- # Store the authorization code so auth() can handle token exchange
221- self ._authorization_code = authorization_code
222- print ("🎉 OAuth callback received successfully!" )
223-
224- except Exception as e :
225- print (f"❌ OAuth flow failed: { e } " )
226- raise
227- finally :
228- # Always stop the callback server
229- if self ._callback_server :
230- self ._callback_server .stop ()
231- self ._callback_server = None
232-
233- async def save_code_verifier (self , code_verifier : str ) -> None :
234- self ._code_verifier = code_verifier
235-
236- async def code_verifier (self ) -> str :
237- if self ._code_verifier is None :
238- raise ValueError ("No code verifier available" )
239- return self ._code_verifier
240-
241-
242111class SimpleAuthClient :
243112 """Simple MCP client with auth support."""
244113
245114 def __init__ (self , server_url : str ):
246115 self .server_url = server_url
247116 # Extract base URL for auth server (remove /mcp endpoint for auth endpoints)
248- auth_server_url = server_url .replace ("/mcp" , "" )
249117 # Use default redirect URI - this is where the auth server will redirect the user
250118 # The user will need to copy the authorization code from this callback URL
251- self .auth_provider = SimpleOAuthProvider (auth_server_url )
252119 self .session : ClientSession | None = None
253120
254121 async def connect (self ):
@@ -269,10 +136,21 @@ async def callback_handler() -> tuple[str, str | None]:
269136 finally :
270137 callback_server .stop ()
271138
139+ client_metadata_dict = {
140+ "client_name" : "Simple Auth Client" ,
141+ "redirect_uris" : ["http://localhost:3000/callback" ],
142+ "grant_types" : ["authorization_code" , "refresh_token" ],
143+ "response_types" : ["code" ],
144+ "token_endpoint_auth_method" : "client_secret_post" , # Use client secret
145+ "scope" : "read" , # Default scope, will be updated
146+ }
147+
272148 # Create OAuth authentication handler using the new interface
273149 oauth_auth = OAuthAuth (
274150 server_url = self .server_url .replace ("/mcp" , "" ),
275- client_metadata = self .auth_provider .client_metadata ,
151+ client_metadata = OAuthClientMetadata .model_validate (
152+ client_metadata_dict
153+ ),
276154 storage = None , # Use in-memory storage
277155 redirect_handler = None , # Use default (open browser)
278156 callback_handler = callback_handler ,
0 commit comments