Skip to content

Commit b7e88d8

Browse files
committed
Added unit-tests
1 parent dd0902e commit b7e88d8

File tree

2 files changed

+169
-31
lines changed

2 files changed

+169
-31
lines changed

src/mcp/client/auth.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,31 +34,6 @@
3434
logger = logging.getLogger(__name__)
3535

3636

37-
def _extract_resource_metadata_from_www_auth(response: httpx.Response) -> str | None:
38-
"""
39-
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
40-
41-
Returns:
42-
Resource metadata URL if found in WWW-Authenticate header, None otherwise
43-
"""
44-
if not response or response.status_code != 401:
45-
return None
46-
47-
www_auth_header = response.headers.get("WWW-Authenticate")
48-
if not www_auth_header:
49-
return None
50-
51-
# Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted)
52-
pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
53-
match = re.search(pattern, www_auth_header)
54-
55-
if match:
56-
# Return quoted value if present, otherwise unquoted value
57-
return match.group(1) or match.group(2)
58-
59-
return None
60-
61-
6237
class OAuthFlowError(Exception):
6338
"""Base exception for OAuth flow errors."""
6439

@@ -229,9 +204,33 @@ def __init__(
229204
)
230205
self._initialized = False
231206

232-
async def _discover_protected_resource(self, response: httpx.Response | None = None) -> httpx.Request:
233-
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header
234-
url = _extract_resource_metadata_from_www_auth(response) if response else None
207+
def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None:
208+
"""
209+
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
210+
211+
Returns:
212+
Resource metadata URL if found in WWW-Authenticate header, None otherwise
213+
"""
214+
if not init_response or init_response.status_code != 401:
215+
return None
216+
217+
www_auth_header = init_response.headers.get("WWW-Authenticate")
218+
if not www_auth_header:
219+
return None
220+
221+
# Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted)
222+
pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
223+
match = re.search(pattern, www_auth_header)
224+
225+
if match:
226+
# Return quoted value if present, otherwise unquoted value
227+
return match.group(1) or match.group(2)
228+
229+
return None
230+
231+
async def _discover_protected_resource(self, init_response: httpx.Response | None = None) -> httpx.Request:
232+
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
233+
url = self._extract_resource_metadata_from_www_auth(init_response) if init_response else None
235234

236235
if not url:
237236
# Fallback to well-known discovery

tests/client/test_auth.py

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,47 @@ class TestOAuthFlow:
196196
"""Test OAuth flow methods."""
197197

198198
@pytest.mark.anyio
199-
async def test_discover_protected_resource_request(self, oauth_provider):
200-
"""Test protected resource discovery request building."""
201-
request = await oauth_provider._discover_protected_resource()
199+
async def test_discover_protected_resource_request(self, client_metadata, mock_storage):
200+
"""Test protected resource discovery request building maintains backward compatibility."""
201+
async def redirect_handler(url: str) -> None:
202+
pass
203+
204+
async def callback_handler() -> tuple[str, str | None]:
205+
return "test_auth_code", "test_state"
202206

207+
provider = OAuthClientProvider(
208+
server_url="https://api.example.com",
209+
client_metadata=client_metadata,
210+
storage=mock_storage,
211+
redirect_handler=redirect_handler,
212+
callback_handler=callback_handler,
213+
)
214+
215+
# Test without response (backward compatibility)
216+
request = await provider._discover_protected_resource()
203217
assert request.method == "GET"
204218
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
205219
assert "mcp-protocol-version" in request.headers
220+
221+
# Test with response but no WWW-Authenticate (fallback)
222+
init_response = httpx.Response(
223+
status_code=401,
224+
headers={},
225+
request=httpx.Request("GET", "https://request-api.example.com")
226+
)
227+
228+
request = await provider._discover_protected_resource(init_response)
229+
assert request.method == "GET"
230+
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
231+
assert "mcp-protocol-version" in request.headers
232+
233+
# Test with WWW-Authenticate header
234+
init_response.headers["WWW-Authenticate"] = 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"'
235+
236+
request = await provider._discover_protected_resource(init_response)
237+
assert request.method == "GET"
238+
assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path"
239+
assert "mcp-protocol-version" in request.headers
206240

207241
@pytest.mark.anyio
208242
async def test_discover_oauth_metadata_request(self, oauth_provider):
@@ -544,3 +578,108 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v
544578
await auth_flow.asend(response)
545579
except StopAsyncIteration:
546580
pass # Expected
581+
582+
583+
class TestRFC9728WWWAuthenticate:
584+
"""Test RFC9728 WWW-Authenticate header parsing functionality."""
585+
586+
@pytest.mark.parametrize("www_auth_header,expected_url", [
587+
# Quoted URL
588+
('Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
589+
"https://api.example.com/.well-known/oauth-protected-resource"),
590+
# Unquoted URL
591+
("Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource",
592+
"https://api.example.com/.well-known/oauth-protected-resource"),
593+
# Complex header with multiple parameters
594+
('Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", error="insufficient_scope"',
595+
"https://api.example.com/.well-known/oauth-protected-resource"),
596+
# Different URL format
597+
('Bearer resource_metadata="https://custom.domain.com/metadata"',
598+
"https://custom.domain.com/metadata"),
599+
# With path and query params
600+
('Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"',
601+
"https://api.example.com/auth/metadata?version=1"),
602+
])
603+
def test_extract_resource_metadata_from_www_auth_valid_cases(self, client_metadata, mock_storage, www_auth_header, expected_url):
604+
"""Test extraction of resource_metadata URL from various valid WWW-Authenticate headers."""
605+
async def redirect_handler(url: str) -> None:
606+
pass
607+
608+
async def callback_handler() -> tuple[str, str | None]:
609+
return "test_auth_code", "test_state"
610+
611+
provider = OAuthClientProvider(
612+
server_url="https://api.example.com/v1/mcp",
613+
client_metadata=client_metadata,
614+
storage=mock_storage,
615+
redirect_handler=redirect_handler,
616+
callback_handler=callback_handler,
617+
)
618+
619+
init_response = httpx.Response(
620+
status_code=401,
621+
headers={"WWW-Authenticate": www_auth_header},
622+
request=httpx.Request("GET", "https://api.example.com/test")
623+
)
624+
625+
result = provider._extract_resource_metadata_from_www_auth(init_response)
626+
assert result == expected_url
627+
628+
@pytest.mark.parametrize("status_code,www_auth_header,description", [
629+
# No header
630+
(401, None, "no WWW-Authenticate header"),
631+
# Empty header
632+
(401, "", "empty WWW-Authenticate header"),
633+
# Header without resource_metadata
634+
(401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"),
635+
# Malformed header
636+
(401, "Bearer resource_metadata=", "malformed resource_metadata parameter"),
637+
# Non-401 status code
638+
(200, 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', "200 OK response"),
639+
(500, 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', "500 error response"),
640+
])
641+
def test_extract_resource_metadata_from_www_auth_invalid_cases(self, client_metadata, mock_storage, status_code, www_auth_header, description):
642+
"""Test extraction returns None for invalid cases."""
643+
async def redirect_handler(url: str) -> None:
644+
pass
645+
646+
async def callback_handler() -> tuple[str, str | None]:
647+
return "test_auth_code", "test_state"
648+
649+
provider = OAuthClientProvider(
650+
server_url="https://api.example.com/v1/mcp",
651+
client_metadata=client_metadata,
652+
storage=mock_storage,
653+
redirect_handler=redirect_handler,
654+
callback_handler=callback_handler,
655+
)
656+
657+
headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {}
658+
init_response = httpx.Response(
659+
status_code=status_code,
660+
headers=headers,
661+
request=httpx.Request("GET", "https://api.example.com/test")
662+
)
663+
664+
result = provider._extract_resource_metadata_from_www_auth(init_response)
665+
assert result is None, f"Should return None for {description}"
666+
667+
def test_extract_resource_metadata_from_www_auth_none_response(self, client_metadata, mock_storage):
668+
"""Test extraction with None response returns None."""
669+
async def redirect_handler(url: str) -> None:
670+
pass
671+
672+
async def callback_handler() -> tuple[str, str | None]:
673+
return "test_auth_code", "test_state"
674+
675+
provider = OAuthClientProvider(
676+
server_url="https://api.example.com/v1/mcp",
677+
client_metadata=client_metadata,
678+
storage=mock_storage,
679+
redirect_handler=redirect_handler,
680+
callback_handler=callback_handler,
681+
)
682+
683+
result = provider._extract_resource_metadata_from_www_auth(None)
684+
assert result is None
685+

0 commit comments

Comments
 (0)