22"""
33Simple MCP client example with OAuth authentication support.
44
5- This client connects to an MCP server using streamable HTTP transport with OAuth authentication .
6- It provides an interactive command-line interface to list tools and execute them.
5+ This client connects to an MCP server using streamable HTTP transport with OAuth.
6+
77"""
88
99import asyncio
2121 OAuthClientProvider ,
2222 discover_oauth_metadata ,
2323)
24+ from mcp .client .oauth_auth import OAuthAuth
2425from mcp .client .session import ClientSession
2526from mcp .client .streamable_http import streamablehttp_client
2627from mcp .shared .auth import OAuthClientInformationFull , OAuthClientMetadata , OAuthToken
27- from pydantic import AnyHttpUrl
2828
2929
3030class CallbackHandler (BaseHTTPRequestHandler ):
3131 """Simple HTTP handler to capture OAuth callback."""
3232
3333 authorization_code = None
34+ state = None
3435 error = None
3536
3637 def do_GET (self ):
@@ -40,6 +41,7 @@ def do_GET(self):
4041
4142 if "code" in query_params :
4243 CallbackHandler .authorization_code = query_params ["code" ][0 ]
44+ CallbackHandler .state = query_params .get ("state" , [None ])[0 ]
4345 self .send_response (200 )
4446 self .send_header ("Content-type" , "text/html" )
4547 self .end_headers ()
@@ -116,8 +118,11 @@ class JsonSerializableOAuthClientMetadata(OAuthClientMetadata):
116118 """OAuth client metadata that handles JSON serialization properly."""
117119
118120 def model_dump (self , ** kwargs ) -> dict [str , Any ]:
119- """Override to ensure URLs are serialized as strings."""
121+ """Override to ensure URLs are serialized as strings and exclude null values."""
122+ # Exclude null values by default
123+ kwargs .setdefault ("exclude_none" , True )
120124 data = super ().model_dump (** kwargs )
125+
121126 # Convert AnyHttpUrl objects to strings
122127 if "redirect_uris" in data :
123128 data ["redirect_uris" ] = [str (url ) for url in data ["redirect_uris" ]]
@@ -193,9 +198,7 @@ async def tokens(self) -> OAuthToken | None:
193198
194199 async def save_tokens (self , tokens : OAuthToken ) -> None :
195200 self ._tokens = tokens
196- print (
197- f"Saved OAuth tokens, access token starts with: { tokens .access_token [:10 ]} ..."
198- )
201+ print (f"Saved OAuth tokens: { tokens .access_token [:10 ]} ..." )
199202
200203 async def redirect_to_authorization (self , authorization_url : str ) -> None :
201204 # Start callback server
@@ -252,66 +255,41 @@ async def connect(self):
252255 """Connect to the MCP server."""
253256 print (f"🔗 Attempting to connect to { self .server_url } ..." )
254257
255- # The streamable HTTP transport will handle the OAuth flow automatically
256- # We just need to wait for it to complete successfully
257258 try :
258- # Discover OAuth metadata first to set proper scopes
259- await self .auth_provider ._discover_and_update_metadata ()
260-
261- # Check if we already have tokens, if not do auth flow first
262- existing_tokens = await self .auth_provider .tokens ()
263- if not existing_tokens :
264- print ("🔐 No existing tokens found, initiating OAuth flow..." )
265- await self .auth_provider ._discover_and_update_metadata ()
266-
267- # Start the auth flow to get tokens
268- from mcp .client .auth import auth
269-
270- auth_result = await auth (
271- self .auth_provider , server_url = self .server_url .replace ("/mcp" , "" )
272- )
273-
274- if auth_result == "REDIRECT" :
275- print ("🔄 Waiting for OAuth completion..." )
276- # Wait for authorization code to be set by the redirect handler
277- timeout = 300 # 5 minutes
278- start_time = time .time ()
279- while (
280- not self .auth_provider ._authorization_code
281- and time .time () - start_time < timeout
282- ):
283- await asyncio .sleep (0.1 )
284-
285- if not self .auth_provider ._authorization_code :
286- raise Exception ("Timeout waiting for OAuth authorization" )
287-
288- # Now exchange the authorization code for tokens
289- auth_result = await auth (
290- self .auth_provider ,
291- server_url = self .server_url .replace ("/mcp" , "" ),
292- authorization_code = self .auth_provider ._authorization_code ,
293- )
294-
295- if auth_result != "AUTHORIZED" :
296- raise Exception ("Failed to authorize with server" )
297-
298- # Verify we have tokens now
299- tokens = await self .auth_provider .tokens ()
300- if not tokens :
301- raise Exception ("OAuth completed but no tokens were saved" )
259+ # Set up callback server
260+ callback_server = CallbackServer (port = 3000 )
261+ callback_server .start ()
262+
263+ async def callback_handler () -> tuple [str , str | None ]:
264+ """Wait for OAuth callback and return auth code and state."""
265+ print ("⏳ Waiting for authorization callback..." )
266+ try :
267+ auth_code = callback_server .wait_for_callback (timeout = 300 )
268+ return auth_code , CallbackHandler .state
269+ finally :
270+ callback_server .stop ()
271+
272+ # Create OAuth authentication handler using the new interface
273+ oauth_auth = OAuthAuth (
274+ server_url = self .server_url .replace ("/mcp" , "" ),
275+ client_metadata = self .auth_provider .client_metadata ,
276+ storage = None , # Use in-memory storage
277+ redirect_handler = None , # Use default (open browser)
278+ callback_handler = callback_handler ,
279+ )
302280
303- print (
304- f"✅ OAuth authorization successful! Access token: { tokens .access_token [:20 ]} ..."
305- )
281+ # Initialize the auth handler and ensure we have tokens
306282
307- # Create streamable HTTP transport with auth
283+ # Create streamable HTTP transport with auth handler
308284 stream_context = streamablehttp_client (
309285 url = self .server_url ,
310- auth_provider = self . auth_provider ,
311- timeout = timedelta (seconds = 60 ), # Longer timeout for OAuth flow
286+ auth = oauth_auth ,
287+ timeout = timedelta (seconds = 60 ),
312288 )
313289
314- print ("📡 Opening transport connection..." )
290+ print (
291+ "📡 Opening transport connection (HTTPX handles auth automatically)..."
292+ )
315293 async with stream_context as (read_stream , write_stream , get_session_id ):
316294 print ("🤝 Initializing MCP session..." )
317295 async with ClientSession (read_stream , write_stream ) as session :
@@ -365,7 +343,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = Non
365343 print (f"\n 🔧 Tool '{ tool_name } ' result:" )
366344 if hasattr (result , "content" ):
367345 for content in result .content :
368- if hasattr ( content , "text" ) :
346+ if content . type == "text" :
369347 print (content .text )
370348 else :
371349 print (content )
0 commit comments