Skip to content

Commit 1bb4bc9

Browse files
committed
merge
1 parent d6acc58 commit 1bb4bc9

File tree

2 files changed

+9
-25
lines changed

2 files changed

+9
-25
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,20 @@
1919
import httpx
2020
from pydantic import BaseModel, Field, ValidationError
2121

22-
from mcp.client.auth import OAuthFlowError, OAuthTokenError
22+
from mcp.client.auth import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
2323
from mcp.client.auth.utils import (
2424
build_oauth_authorization_server_metadata_discovery_urls,
2525
build_protected_resource_metadata_discovery_urls,
2626
create_client_registration_request,
27-
create_oauth_metadata_request,
2827
extract_field_from_www_auth,
2928
extract_resource_metadata_from_www_auth,
3029
extract_scope_from_www_auth,
3130
get_client_metadata_scopes,
3231
handle_auth_metadata_response,
3332
handle_protected_resource_response,
34-
handle_registration_response,
3533
handle_token_response_scopes,
3634
)
3735
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
38-
from mcp.types import LATEST_PROTOCOL_VERSION
3936
from mcp.shared.auth import (
4037
OAuthClientInformationFull,
4138
OAuthClientMetadata,
@@ -48,6 +45,7 @@
4845
check_resource_allowed,
4946
resource_url_from_server_url,
5047
)
48+
from mcp.types import LATEST_PROTOCOL_VERSION
5149

5250
logger = logging.getLogger(__name__)
5351

@@ -251,9 +249,7 @@ def _create_registration_request(self, metadata: OAuthMetadata | None = None) ->
251249
headers={"Content-Type": "application/json"},
252250
)
253251

254-
async def _handle_registration_response(
255-
self, response: httpx.Response
256-
) -> OAuthClientInformationFull:
252+
async def _handle_registration_response(self, response: httpx.Response) -> OAuthClientInformationFull:
257253
if response.status_code not in (200, 201):
258254
await response.aread()
259255
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
@@ -323,15 +319,11 @@ def __init__(
323319

324320
def _build_protected_resource_discovery_urls(self, resource_metadata_url: str | None) -> list[str]:
325321
"""Build the list of PRM discovery URLs with legacy fallbacks."""
326-
return build_protected_resource_metadata_discovery_urls(
327-
resource_metadata_url, self.context.server_url
328-
)
322+
return build_protected_resource_metadata_discovery_urls(resource_metadata_url, self.context.server_url)
329323

330324
def _get_discovery_urls(self, server_url: str | None = None) -> list[str]:
331325
"""Build OAuth authorization server discovery URLs with legacy fallbacks."""
332-
return build_oauth_authorization_server_metadata_discovery_urls(
333-
server_url, self.context.server_url
334-
)
326+
return build_oauth_authorization_server_metadata_discovery_urls(server_url, self.context.server_url)
335327

336328
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
337329
"""
@@ -356,9 +348,7 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
356348
)
357349
return False
358350

359-
async def _handle_oauth_metadata_response(
360-
self, response: httpx.Response
361-
) -> tuple[bool, OAuthMetadata | None]:
351+
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> tuple[bool, OAuthMetadata | None]:
362352
ok, asm = await handle_auth_metadata_response(response)
363353
if asm:
364354
self.context.oauth_metadata = asm
@@ -580,9 +570,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
580570
self._metadata = None
581571

582572
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
583-
prm_discovery_urls = self._build_protected_resource_discovery_urls(
584-
www_auth_resource_metadata_url
585-
)
573+
prm_discovery_urls = self._build_protected_resource_discovery_urls(www_auth_resource_metadata_url)
586574

587575
for url in prm_discovery_urls: # pragma: no branch
588576
discovery_request = self._create_oauth_metadata_request(url)

tests/client/test_auth.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,9 +1364,7 @@ async def callback_handler() -> tuple[str, str | None]:
13641364
)
13651365

13661366
# Mock authorization
1367-
provider._perform_authorization_code_grant = AsyncMock(
1368-
return_value=("test_auth_code", "test_code_verifier")
1369-
)
1367+
provider._perform_authorization_code_grant = AsyncMock(return_value=("test_auth_code", "test_code_verifier"))
13701368

13711369
# Next should be token exchange
13721370
token_request = await auth_flow.asend(oauth_metadata_response)
@@ -1470,9 +1468,7 @@ async def callback_handler() -> tuple[str, str | None]:
14701468
request=oauth_metadata_request,
14711469
)
14721470

1473-
provider._perform_authorization_code_grant = AsyncMock(
1474-
return_value=("test_auth_code", "test_code_verifier")
1475-
)
1471+
provider._perform_authorization_code_grant = AsyncMock(return_value=("test_auth_code", "test_code_verifier"))
14761472

14771473
token_request = await auth_flow.asend(oauth_metadata_response)
14781474
assert str(token_request.url) == "https://api.example.com/token"

0 commit comments

Comments
 (0)