@@ -110,6 +110,21 @@ def _generate_code_challenge(self, code_verifier: str) -> str:
110110 digest = hashlib .sha256 (code_verifier .encode ()).digest ()
111111 return base64 .urlsafe_b64encode (digest ).decode ().rstrip ("=" )
112112
113+ def _get_authorization_base_url (self , server_url : str ) -> str :
114+ """
115+ Determine the authorization base URL by discarding any path component.
116+
117+ Per MCP spec Section 2.3.2: "The authorization base URL MUST be determined
118+ from the MCP server URL by discarding any existing path component."
119+
120+ Example: https://api.example.com/v1/mcp -> https://api.example.com
121+ """
122+ from urllib .parse import urlparse , urlunparse
123+
124+ parsed = urlparse (server_url )
125+ # Discard path component by setting it to empty
126+ return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
127+
113128 async def _discover_oauth_metadata (self , server_url : str ) -> OAuthMetadata | None :
114129 """
115130 Discovers OAuth metadata from the server's well-known endpoint.
@@ -120,7 +135,9 @@ async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | Non
120135 Returns:
121136 OAuthMetadata if found, None otherwise
122137 """
123- url = urljoin (server_url , "/.well-known/oauth-authorization-server" )
138+ # Get authorization base URL per MCP spec Section 2.3.2
139+ auth_base_url = self ._get_authorization_base_url (server_url )
140+ url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
124141 headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
125142
126143 async with httpx .AsyncClient () as client :
@@ -171,24 +188,15 @@ async def _register_oauth_client(
171188 if metadata and metadata .registration_endpoint :
172189 registration_url = str (metadata .registration_endpoint )
173190 else :
174- registration_url = urljoin (server_url , "/register" )
191+ # Use authorization base URL for fallback registration endpoint
192+ auth_base_url = self ._get_authorization_base_url (server_url )
193+ registration_url = urljoin (auth_base_url , "/register" )
175194
176- # Prepare registration data and adjust scope based on server metadata
195+ # Prepare registration data
177196 registration_data = client_metadata .model_dump (
178197 by_alias = True , mode = "json" , exclude_none = True
179198 )
180199
181- # If the server has supported scopes, use them instead of the requested scope
182- if metadata and metadata .scopes_supported :
183- # Use the first supported scope or "user" if available
184- if "user" in metadata .scopes_supported :
185- registration_data ["scope" ] = "user"
186- else :
187- registration_data ["scope" ] = metadata .scopes_supported [0 ]
188- logger .debug (
189- f"Adjusted scope to server-supported: { registration_data ['scope' ]} "
190- )
191-
192200 async with httpx .AsyncClient () as client :
193201 try :
194202 response = await client .post (
@@ -252,6 +260,55 @@ def _has_valid_token(self) -> bool:
252260
253261 return True
254262
263+ async def _validate_token_scopes (self , token_response : OAuthToken ) -> None :
264+ """
265+ Validate that returned scopes are a subset of requested scopes.
266+
267+ Per OAuth 2.1 Section 3.3, the authorization server may issue a narrower
268+ set of scopes than requested, but must not grant additional scopes.
269+ """
270+ if not token_response .scope :
271+ # If no scope is returned, validation passes (server didn't grant anything extra)
272+ return
273+
274+ # Get the originally requested scopes
275+ requested_scopes : set [str ] = set ()
276+
277+ # Check for explicitly requested scopes from client metadata
278+ if self .client_metadata .scope :
279+ requested_scopes .update (self .client_metadata .scope .split ())
280+
281+ # If we have registered client info with specific scopes, use those
282+ # (This handles cases where scopes were negotiated during registration)
283+ if (
284+ self ._client_info
285+ and hasattr (self ._client_info , "scope" )
286+ and self ._client_info .scope
287+ ):
288+ # Only override if the client metadata didn't have explicit scopes
289+ # This represents what was actually registered/negotiated with the server
290+ if not requested_scopes :
291+ requested_scopes .update (self ._client_info .scope .split ())
292+
293+ # Parse returned scopes
294+ returned_scopes : set [str ] = set (token_response .scope .split ())
295+
296+ # Validate that returned scopes are a subset of requested scopes
297+ # Only enforce strict validation if we actually have requested scopes
298+ if requested_scopes :
299+ unauthorized_scopes : set [str ] = returned_scopes - requested_scopes
300+ if unauthorized_scopes :
301+ raise Exception (
302+ f"Server granted unauthorized scopes: { unauthorized_scopes } . "
303+ f"Requested: { requested_scopes } , Returned: { returned_scopes } "
304+ )
305+ else :
306+ # If no scopes were originally requested (fell back to server defaults),
307+ # accept whatever the server returned
308+ logger .debug (
309+ f"No specific scopes were requested, accepting server-granted scopes: { returned_scopes } "
310+ )
311+
255312 async def initialize (self ) -> None :
256313 """Initialize the auth handler by loading stored tokens and client info."""
257314 self ._current_tokens = await self .storage .get_tokens ()
@@ -307,7 +364,9 @@ async def _perform_oauth_flow(self) -> None:
307364 if self ._metadata and self ._metadata .authorization_endpoint :
308365 auth_url_base = str (self ._metadata .authorization_endpoint )
309366 else :
310- auth_url_base = urljoin (self .server_url , "/authorize" )
367+ # Use authorization base URL for fallback authorization endpoint
368+ auth_base_url = self ._get_authorization_base_url (self .server_url )
369+ auth_url_base = urljoin (auth_base_url , "/authorize" )
311370
312371 # Build authorization URL
313372 auth_params = {
@@ -319,16 +378,16 @@ async def _perform_oauth_flow(self) -> None:
319378 "code_challenge_method" : "S256" ,
320379 }
321380
322- if hasattr (client_info , "scope" ) and client_info .scope :
323- auth_params ["scope" ] = client_info .scope
324- elif self ._metadata and self ._metadata .scopes_supported :
325- # Use "user" if available, otherwise the first supported scope
326- if "user" in self ._metadata .scopes_supported :
327- auth_params ["scope" ] = "user"
328- else :
329- auth_params ["scope" ] = self ._metadata .scopes_supported [0 ]
330- elif self .client_metadata .scope :
381+ # Set scope parameter following OAuth 2.1 principles:
382+ # 1. Use client's explicit request first (what developer wants)
383+ # 2. Use registered client scope as fallback (what was negotiated)
384+ # 3. No scope = let server decide (omit scope parameter)
385+ if self .client_metadata .scope :
331386 auth_params ["scope" ] = self .client_metadata .scope
387+ elif hasattr (client_info , "scope" ) and client_info .scope :
388+ auth_params ["scope" ] = client_info .scope
389+ # If no scope specified anywhere, don't include scope parameter
390+ # This lets the server grant default scopes per OAuth 2.1
332391
333392 auth_url = f"{ auth_url_base } ?{ urlencode (auth_params )} "
334393
@@ -339,7 +398,7 @@ async def _perform_oauth_flow(self) -> None:
339398
340399 # Validate state parameter
341400 if returned_state != auth_params ["state" ]:
342- raise Exception ("State parameter mismatch - possible CSRF attack " )
401+ raise Exception ("State parameter mismatch" )
343402
344403 if not auth_code :
345404 raise Exception ("No authorization code received" )
@@ -355,7 +414,9 @@ async def _exchange_code_for_token(
355414 if self ._metadata and self ._metadata .token_endpoint :
356415 token_url = str (self ._metadata .token_endpoint )
357416 else :
358- token_url = urljoin (self .server_url , "/token" )
417+ # Use authorization base URL for fallback token endpoint
418+ auth_base_url = self ._get_authorization_base_url (self .server_url )
419+ token_url = urljoin (auth_base_url , "/token" )
359420
360421 token_data = {
361422 "grant_type" : "authorization_code" ,
@@ -384,6 +445,9 @@ async def _exchange_code_for_token(
384445 # Parse and store tokens
385446 token_response = OAuthToken .model_validate (response .json ())
386447
448+ # Validate returned scopes against requested scopes (OAuth 2.1 Section 3.3)
449+ await self ._validate_token_scopes (token_response )
450+
387451 # Calculate expiry time if available
388452 if token_response .expires_in :
389453 self ._token_expiry_time = time .time () + token_response .expires_in
@@ -406,7 +470,9 @@ async def _refresh_access_token(self) -> bool:
406470 if self ._metadata and self ._metadata .token_endpoint :
407471 token_url = str (self ._metadata .token_endpoint )
408472 else :
409- token_url = urljoin (self .server_url , "/token" )
473+ # Use authorization base URL for fallback token endpoint
474+ auth_base_url = self ._get_authorization_base_url (self .server_url )
475+ token_url = urljoin (auth_base_url , "/token" )
410476
411477 refresh_data = {
412478 "grant_type" : "refresh_token" ,
@@ -433,6 +499,9 @@ async def _refresh_access_token(self) -> bool:
433499 # Parse and store new tokens
434500 token_response = OAuthToken .model_validate (response .json ())
435501
502+ # Validate returned scopes against requested scopes (OAuth 2.1 Section 3.3)
503+ await self ._validate_token_scopes (token_response )
504+
436505 # Calculate expiry time if available
437506 if token_response .expires_in :
438507 self ._token_expiry_time = time .time () + token_response .expires_in
0 commit comments