Skip to content

Commit a39c24d

Browse files
committed
Simplify
Signed-off-by: Sid Murching <sid.murching@databricks.com>
1 parent f2a9b82 commit a39c24d

File tree

2 files changed

+54
-40
lines changed

2 files changed

+54
-40
lines changed

src/mcp/client/auth.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,6 @@ class OAuthContext:
108108
# State
109109
lock: anyio.Lock = field(default_factory=anyio.Lock)
110110

111-
# Discovery state for fallback support
112-
discovery_base_url: str | None = None
113-
discovery_pathname: str | None = None
114-
115111
def get_authorization_base_url(self, server_url: str) -> str:
116112
"""Extract base URL by removing path component."""
117113
parsed = urlparse(server_url)
@@ -228,23 +224,23 @@ def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response
228224

229225
return None
230226

231-
async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
232-
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
233-
url = self._extract_resource_metadata_from_www_auth(init_response)
234-
235-
if not url:
236-
# Fallback to well-known discovery with path component included
237-
parsed = urlparse(self.context.server_url)
238-
auth_base_url = f"{parsed.scheme}://{parsed.netloc}"
227+
def _get_protected_resource_discovery_urls(self) -> list[str]:
228+
"""Generate ordered list of URLs for protected resource discovery attempts."""
229+
urls: list[str] = []
230+
parsed = urlparse(self.context.server_url)
231+
base_url = f"{parsed.scheme}://{parsed.netloc}"
239232

240-
if parsed.path and parsed.path != "/":
241-
# Include path component in the well-known URL
242-
path_component = parsed.path.rstrip("/")
243-
url = urljoin(auth_base_url, f"/.well-known/oauth-protected-resource{path_component}")
244-
else:
245-
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
233+
if parsed.path and parsed.path != "/":
234+
# Try path-specific endpoint first
235+
path_component = parsed.path.rstrip("/")
236+
urls.append(urljoin(base_url, f"/.well-known/oauth-protected-resource{path_component}"))
237+
# Then fallback to base endpoint
238+
urls.append(urljoin(base_url, "/.well-known/oauth-protected-resource"))
239+
else:
240+
# No path, just use base endpoint
241+
urls.append(urljoin(base_url, "/.well-known/oauth-protected-resource"))
246242

247-
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
243+
return urls
248244

249245
async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
250246
"""Handle discovery response."""
@@ -517,9 +513,28 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
517513
try:
518514
# OAuth flow must be inline due to generator constraints
519515
# Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
520-
discovery_request = await self._discover_protected_resource(response)
521-
discovery_response = yield discovery_request
522-
await self._handle_protected_resource_response(discovery_response)
516+
# Check if WWW-Authenticate provides resource_metadata URL first
517+
www_auth_url = self._extract_resource_metadata_from_www_auth(response)
518+
if www_auth_url:
519+
discovery_request = httpx.Request(
520+
"GET", www_auth_url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
521+
)
522+
discovery_response = yield discovery_request
523+
await self._handle_protected_resource_response(discovery_response)
524+
else:
525+
# Try well-known discovery URLs with fallback
526+
discovery_urls = self._get_protected_resource_discovery_urls()
527+
for url in discovery_urls:
528+
discovery_request = httpx.Request(
529+
"GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
530+
)
531+
discovery_response = yield discovery_request
532+
533+
if discovery_response.status_code == 200:
534+
await self._handle_protected_resource_response(discovery_response)
535+
break # Success, stop trying other URLs
536+
elif discovery_response.status_code != 404:
537+
break # Non-404 error, stop trying
523538

524539
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
525540
discovery_urls = self._get_discovery_urls()

tests/client/test_auth.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,16 @@ class TestOAuthFlow:
198198
"""Test OAuth flow methods."""
199199

200200
@pytest.mark.anyio
201-
async def test_discover_protected_resource_request(self, client_metadata, mock_storage):
202-
"""Test protected resource discovery request building maintains backward compatibility."""
201+
async def test_protected_resource_discovery_urls(self, client_metadata, mock_storage):
202+
"""Test protected resource discovery URL generation with fallback."""
203203

204204
async def redirect_handler(url: str) -> None:
205205
pass
206206

207207
async def callback_handler() -> tuple[str, str | None]:
208208
return "test_auth_code", "test_state"
209209

210+
# Test with path component
210211
provider = OAuthClientProvider(
211212
server_url="https://api.example.com/api/2.0/mcp",
212213
client_metadata=client_metadata,
@@ -215,25 +216,23 @@ async def callback_handler() -> tuple[str, str | None]:
215216
callback_handler=callback_handler,
216217
)
217218

218-
# Test without WWW-Authenticate (fallback)
219-
init_response = httpx.Response(
220-
status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com")
221-
)
222-
223-
request = await provider._discover_protected_resource(init_response)
224-
assert request.method == "GET"
225-
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp"
226-
assert "mcp-protocol-version" in request.headers
219+
urls = provider._get_protected_resource_discovery_urls()
220+
assert urls == [
221+
"https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp",
222+
"https://api.example.com/.well-known/oauth-protected-resource",
223+
]
227224

228-
# Test with WWW-Authenticate header
229-
init_response.headers["WWW-Authenticate"] = (
230-
'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"'
225+
# Test without path component
226+
provider = OAuthClientProvider(
227+
server_url="https://api.example.com",
228+
client_metadata=client_metadata,
229+
storage=mock_storage,
230+
redirect_handler=redirect_handler,
231+
callback_handler=callback_handler,
231232
)
232233

233-
request = await provider._discover_protected_resource(init_response)
234-
assert request.method == "GET"
235-
assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path"
236-
assert "mcp-protocol-version" in request.headers
234+
urls = provider._get_protected_resource_discovery_urls()
235+
assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"]
237236

238237
@pytest.mark.anyio
239238
def test_create_oauth_metadata_request(self, oauth_provider):

0 commit comments

Comments
 (0)