Skip to content

Commit 1960179

Browse files
committed
Add comprehensive tests for OAuth discovery fallback behavior
This commit adds and updates tests to properly cover the new OAuth discovery logic for legacy server compatibility and path-aware discovery. New tests: - test_oauth_discovery_legacy_fallback_when_no_prm: Verifies that when PRM discovery fails, only root OAuth URL is tried (March 2025 spec) - test_oauth_discovery_path_aware_when_auth_server_has_path: Ensures path-based URLs are tried when auth server URL has a path - test_oauth_discovery_root_when_auth_server_has_no_path: Ensures root URLs are tried when auth server URL has no path - test_oauth_discovery_root_when_auth_server_has_only_slash: Handles trailing slash edge case - test_legacy_server_no_prm_falls_back_to_root_oauth_discovery: End-to-end test simulating Linear-style legacy servers - test_legacy_server_with_different_prm_and_root_urls: Tests fallback with custom WWW-Authenticate PRM URLs Updated tests: - test_oauth_discovery_fallback_conditions: Updated to reflect new path-aware discovery behavior (no root URLs when auth server has path) - test_oauth_discovery_fallback_order: Simplified to focus on path-aware case All tests pass with proper linting and type checking. Github-Issue: #1495 Github-Issue: #1623
1 parent de38133 commit 1960179

File tree

1 file changed

+263
-6
lines changed

1 file changed

+263
-6
lines changed

tests/client/test_auth.py

Lines changed: 263 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,57 @@ def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider
307307
class TestOAuthFallback:
308308
"""Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers."""
309309

310+
@pytest.mark.anyio
311+
async def test_oauth_discovery_legacy_fallback_when_no_prm(self):
312+
"""Test that when PRM discovery fails, only root OAuth URL is tried (March 2025 spec)."""
313+
# When auth_server_url is None (PRM failed), we use server_url and only try root
314+
discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://mcp.linear.app/sse")
315+
316+
# Should only try the root URL (legacy behavior)
317+
assert discovery_urls == [
318+
"https://mcp.linear.app/.well-known/oauth-authorization-server",
319+
]
320+
321+
@pytest.mark.anyio
322+
async def test_oauth_discovery_path_aware_when_auth_server_has_path(self):
323+
"""Test that when auth server URL has a path, only path-based URLs are tried."""
324+
discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
325+
"https://auth.example.com/tenant1", "https://api.example.com/mcp"
326+
)
327+
328+
# Should try path-based URLs only (no root URLs)
329+
assert discovery_urls == [
330+
"https://auth.example.com/.well-known/oauth-authorization-server/tenant1",
331+
"https://auth.example.com/.well-known/openid-configuration/tenant1",
332+
"https://auth.example.com/tenant1/.well-known/openid-configuration",
333+
]
334+
335+
@pytest.mark.anyio
336+
async def test_oauth_discovery_root_when_auth_server_has_no_path(self):
337+
"""Test that when auth server URL has no path, only root URLs are tried."""
338+
discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
339+
"https://auth.example.com", "https://api.example.com/mcp"
340+
)
341+
342+
# Should try root URLs only
343+
assert discovery_urls == [
344+
"https://auth.example.com/.well-known/oauth-authorization-server",
345+
"https://auth.example.com/.well-known/openid-configuration",
346+
]
347+
348+
@pytest.mark.anyio
349+
async def test_oauth_discovery_root_when_auth_server_has_only_slash(self):
350+
"""Test that when auth server URL has only trailing slash, treated as root."""
351+
discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
352+
"https://auth.example.com/", "https://api.example.com/mcp"
353+
)
354+
355+
# Should try root URLs only
356+
assert discovery_urls == [
357+
"https://auth.example.com/.well-known/oauth-authorization-server",
358+
"https://auth.example.com/.well-known/openid-configuration",
359+
]
360+
310361
@pytest.mark.anyio
311362
async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider):
312363
"""Test fallback URL construction order when auth server URL has a path."""
@@ -362,13 +413,14 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
362413
assert discovery_request.method == "GET"
363414

364415
# Send a successful discovery response with minimal protected resource metadata
416+
# Note: auth server URL has a path (/v1/mcp), so only path-based URLs will be tried
365417
discovery_response = httpx.Response(
366418
200,
367419
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com/v1/mcp"]}',
368420
request=discovery_request,
369421
)
370422

371-
# Next request should be to discover OAuth metadata
423+
# Next request should be to discover OAuth metadata at path-aware OAuth URL
372424
oauth_metadata_request_1 = await auth_flow.asend(discovery_response)
373425
assert (
374426
str(oauth_metadata_request_1.url)
@@ -383,9 +435,9 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
383435
request=oauth_metadata_request_1,
384436
)
385437

386-
# Next request should be to discover OAuth metadata at the next endpoint
438+
# Next request should be path-aware OIDC URL (not root URL since auth server has path)
387439
oauth_metadata_request_2 = await auth_flow.asend(oauth_metadata_response_1)
388-
assert str(oauth_metadata_request_2.url) == "https://auth.example.com/.well-known/oauth-authorization-server"
440+
assert str(oauth_metadata_request_2.url) == "https://auth.example.com/.well-known/openid-configuration/v1/mcp"
389441
assert oauth_metadata_request_2.method == "GET"
390442

391443
# Send a 400 response
@@ -395,9 +447,9 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
395447
request=oauth_metadata_request_2,
396448
)
397449

398-
# Next request should be to discover OAuth metadata at the next endpoint
450+
# Next request should be OIDC path-appended URL
399451
oauth_metadata_request_3 = await auth_flow.asend(oauth_metadata_response_2)
400-
assert str(oauth_metadata_request_3.url) == "https://auth.example.com/.well-known/openid-configuration/v1/mcp"
452+
assert str(oauth_metadata_request_3.url) == "https://auth.example.com/v1/mcp/.well-known/openid-configuration"
401453
assert oauth_metadata_request_3.method == "GET"
402454

403455
# Send a 500 response
@@ -412,7 +464,8 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
412464
return_value=("test_auth_code", "test_code_verifier")
413465
)
414466

415-
# Next request should fall back to legacy behavior and auth with the RS (mocked /authorize, next is /token)
467+
# All path-based URLs failed, flow continues with default endpoints
468+
# Next request should be token exchange using MCP server base URL (fallback when OAuth metadata not found)
416469
token_request = await auth_flow.asend(oauth_metadata_response_3)
417470
assert str(token_request.url) == "https://api.example.com/token"
418471
assert token_request.method == "POST"
@@ -1059,6 +1112,210 @@ def test_build_metadata(
10591112
)
10601113

10611114

1115+
class TestLegacyServerFallback:
1116+
"""Test backward compatibility with legacy servers that don't support PRM (issue #1495)."""
1117+
1118+
@pytest.mark.anyio
1119+
async def test_legacy_server_no_prm_falls_back_to_root_oauth_discovery(
1120+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1121+
):
1122+
"""Test that when PRM discovery fails completely, we fall back to root OAuth discovery (March 2025 spec)."""
1123+
1124+
async def redirect_handler(url: str) -> None:
1125+
pass # pragma: no cover
1126+
1127+
async def callback_handler() -> tuple[str, str | None]:
1128+
return "test_auth_code", "test_state" # pragma: no cover
1129+
1130+
# Simulate a legacy server like Linear
1131+
provider = OAuthClientProvider(
1132+
server_url="https://mcp.linear.app/sse",
1133+
client_metadata=client_metadata,
1134+
storage=mock_storage,
1135+
redirect_handler=redirect_handler,
1136+
callback_handler=callback_handler,
1137+
)
1138+
1139+
provider.context.current_tokens = None
1140+
provider.context.token_expiry_time = None
1141+
provider._initialized = True
1142+
1143+
# Mock client info to skip DCR
1144+
provider.context.client_info = OAuthClientInformationFull(
1145+
client_id="existing_client",
1146+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
1147+
)
1148+
1149+
test_request = httpx.Request("GET", "https://mcp.linear.app/sse")
1150+
auth_flow = provider.async_auth_flow(test_request)
1151+
1152+
# First request
1153+
request = await auth_flow.__anext__()
1154+
assert "Authorization" not in request.headers
1155+
1156+
# Send 401 without WWW-Authenticate header (typical legacy server)
1157+
response = httpx.Response(401, headers={}, request=test_request)
1158+
1159+
# Should try path-based PRM first
1160+
prm_request_1 = await auth_flow.asend(response)
1161+
assert str(prm_request_1.url) == "https://mcp.linear.app/.well-known/oauth-protected-resource/sse"
1162+
1163+
# PRM returns 404
1164+
prm_response_1 = httpx.Response(404, request=prm_request_1)
1165+
1166+
# Should try root-based PRM
1167+
prm_request_2 = await auth_flow.asend(prm_response_1)
1168+
assert str(prm_request_2.url) == "https://mcp.linear.app/.well-known/oauth-protected-resource"
1169+
1170+
# PRM returns 404 again - all PRM URLs failed
1171+
prm_response_2 = httpx.Response(404, request=prm_request_2)
1172+
1173+
# Should fall back to root OAuth discovery (March 2025 spec behavior)
1174+
oauth_metadata_request = await auth_flow.asend(prm_response_2)
1175+
assert str(oauth_metadata_request.url) == "https://mcp.linear.app/.well-known/oauth-authorization-server"
1176+
assert oauth_metadata_request.method == "GET"
1177+
1178+
# Send successful OAuth metadata response
1179+
oauth_metadata_response = httpx.Response(
1180+
200,
1181+
content=(
1182+
b'{"issuer": "https://mcp.linear.app", '
1183+
b'"authorization_endpoint": "https://mcp.linear.app/authorize", '
1184+
b'"token_endpoint": "https://mcp.linear.app/token"}'
1185+
),
1186+
request=oauth_metadata_request,
1187+
)
1188+
1189+
# Mock authorization
1190+
provider._perform_authorization_code_grant = mock.AsyncMock(
1191+
return_value=("test_auth_code", "test_code_verifier")
1192+
)
1193+
1194+
# Next should be token exchange
1195+
token_request = await auth_flow.asend(oauth_metadata_response)
1196+
assert str(token_request.url) == "https://mcp.linear.app/token"
1197+
1198+
# Send successful token response
1199+
token_response = httpx.Response(
1200+
200,
1201+
content=b'{"access_token": "linear_token", "token_type": "Bearer", "expires_in": 3600}',
1202+
request=token_request,
1203+
)
1204+
1205+
# Final request with auth header
1206+
final_request = await auth_flow.asend(token_response)
1207+
assert final_request.headers["Authorization"] == "Bearer linear_token"
1208+
assert str(final_request.url) == "https://mcp.linear.app/sse"
1209+
1210+
# Complete flow
1211+
final_response = httpx.Response(200, request=final_request)
1212+
try:
1213+
await auth_flow.asend(final_response)
1214+
except StopAsyncIteration:
1215+
pass
1216+
1217+
@pytest.mark.anyio
1218+
async def test_legacy_server_with_different_prm_and_root_urls(
1219+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1220+
):
1221+
"""Test PRM fallback with different WWW-Authenticate and root URLs."""
1222+
1223+
async def redirect_handler(url: str) -> None:
1224+
pass # pragma: no cover
1225+
1226+
async def callback_handler() -> tuple[str, str | None]:
1227+
return "test_auth_code", "test_state" # pragma: no cover
1228+
1229+
provider = OAuthClientProvider(
1230+
server_url="https://api.example.com/v1/mcp",
1231+
client_metadata=client_metadata,
1232+
storage=mock_storage,
1233+
redirect_handler=redirect_handler,
1234+
callback_handler=callback_handler,
1235+
)
1236+
1237+
provider.context.current_tokens = None
1238+
provider.context.token_expiry_time = None
1239+
provider._initialized = True
1240+
1241+
provider.context.client_info = OAuthClientInformationFull(
1242+
client_id="existing_client",
1243+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
1244+
)
1245+
1246+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
1247+
auth_flow = provider.async_auth_flow(test_request)
1248+
1249+
await auth_flow.__anext__()
1250+
1251+
# 401 with custom WWW-Authenticate PRM URL
1252+
response = httpx.Response(
1253+
401,
1254+
headers={
1255+
"WWW-Authenticate": 'Bearer resource_metadata="https://custom.prm.com/.well-known/oauth-protected-resource"'
1256+
},
1257+
request=test_request,
1258+
)
1259+
1260+
# Try custom PRM URL first
1261+
prm_request_1 = await auth_flow.asend(response)
1262+
assert str(prm_request_1.url) == "https://custom.prm.com/.well-known/oauth-protected-resource"
1263+
1264+
# Returns 500
1265+
prm_response_1 = httpx.Response(500, request=prm_request_1)
1266+
1267+
# Try path-based fallback
1268+
prm_request_2 = await auth_flow.asend(prm_response_1)
1269+
assert str(prm_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1270+
1271+
# Returns 404
1272+
prm_response_2 = httpx.Response(404, request=prm_request_2)
1273+
1274+
# Try root fallback
1275+
prm_request_3 = await auth_flow.asend(prm_response_2)
1276+
assert str(prm_request_3.url) == "https://api.example.com/.well-known/oauth-protected-resource"
1277+
1278+
# Also returns 404 - all PRM URLs failed
1279+
prm_response_3 = httpx.Response(404, request=prm_request_3)
1280+
1281+
# Should fall back to root OAuth discovery
1282+
oauth_metadata_request = await auth_flow.asend(prm_response_3)
1283+
assert str(oauth_metadata_request.url) == "https://api.example.com/.well-known/oauth-authorization-server"
1284+
1285+
# Complete the flow
1286+
oauth_metadata_response = httpx.Response(
1287+
200,
1288+
content=(
1289+
b'{"issuer": "https://api.example.com", '
1290+
b'"authorization_endpoint": "https://api.example.com/authorize", '
1291+
b'"token_endpoint": "https://api.example.com/token"}'
1292+
),
1293+
request=oauth_metadata_request,
1294+
)
1295+
1296+
provider._perform_authorization_code_grant = mock.AsyncMock(
1297+
return_value=("test_auth_code", "test_code_verifier")
1298+
)
1299+
1300+
token_request = await auth_flow.asend(oauth_metadata_response)
1301+
assert str(token_request.url) == "https://api.example.com/token"
1302+
1303+
token_response = httpx.Response(
1304+
200,
1305+
content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}',
1306+
request=token_request,
1307+
)
1308+
1309+
final_request = await auth_flow.asend(token_response)
1310+
assert final_request.headers["Authorization"] == "Bearer test_token"
1311+
1312+
final_response = httpx.Response(200, request=final_request)
1313+
try:
1314+
await auth_flow.asend(final_response)
1315+
except StopAsyncIteration:
1316+
pass
1317+
1318+
10621319
class TestSEP985Discovery:
10631320
"""Test SEP-985 protected resource metadata discovery with fallback."""
10641321

0 commit comments

Comments
 (0)