Skip to content

Commit 2fe9134

Browse files
committed
Add tests
Signed-off-by: Sid Murching <sid.murching@databricks.com>
1 parent a39c24d commit 2fe9134

File tree

1 file changed

+175
-4
lines changed

1 file changed

+175
-4
lines changed

tests/client/test_auth.py

Lines changed: 175 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,16 @@ class TestOAuthFlow:
198198
"""Test OAuth flow methods."""
199199

200200
@pytest.mark.anyio
201-
async def test_protected_resource_discovery_urls(self, client_metadata, mock_storage):
202-
"""Test protected resource discovery URL generation with fallback."""
201+
async def test_protected_resource_discovery_urls_generation(self, client_metadata, mock_storage):
202+
"""Test that discovery URL generation works correctly for different server URLs."""
203203

204204
async def redirect_handler(url: str) -> None:
205205
pass
206206

207207
async def callback_handler() -> tuple[str, str | None]:
208208
return "test_auth_code", "test_state"
209209

210-
# Test with path component
210+
# Test with path component - should have both path-specific and base endpoints
211211
provider = OAuthClientProvider(
212212
server_url="https://api.example.com/api/2.0/mcp",
213213
client_metadata=client_metadata,
@@ -222,7 +222,7 @@ async def callback_handler() -> tuple[str, str | None]:
222222
"https://api.example.com/.well-known/oauth-protected-resource",
223223
]
224224

225-
# Test without path component
225+
# Test without path component - should only have base endpoint
226226
provider = OAuthClientProvider(
227227
server_url="https://api.example.com",
228228
client_metadata=client_metadata,
@@ -594,6 +594,177 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage):
594594
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
595595
assert oauth_provider.context.token_expiry_time is not None
596596

597+
@pytest.mark.anyio
598+
async def test_auth_flow_protected_resource_fallback(self, client_metadata, mock_storage):
599+
"""Test that the OAuth flow correctly implements fallback from path-specific to base endpoint."""
600+
601+
async def redirect_handler(url: str) -> None:
602+
pass
603+
604+
async def callback_handler() -> tuple[str, str | None]:
605+
return "test_auth_code", "test_state"
606+
607+
provider = OAuthClientProvider(
608+
server_url="https://api.example.com/api/2.0/mcp",
609+
client_metadata=client_metadata,
610+
storage=mock_storage,
611+
redirect_handler=redirect_handler,
612+
callback_handler=callback_handler,
613+
)
614+
615+
provider.context.current_tokens = None
616+
provider.context.token_expiry_time = None
617+
provider._initialized = True
618+
619+
test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp")
620+
auth_flow = provider.async_auth_flow(test_request)
621+
622+
# Step 1: Original request without auth
623+
request = await auth_flow.__anext__()
624+
assert "Authorization" not in request.headers
625+
626+
# Step 2: 401 triggers protected resource discovery - should try path-specific first
627+
response = httpx.Response(401, request=test_request)
628+
path_discovery_request = await auth_flow.asend(response)
629+
assert (
630+
str(path_discovery_request.url)
631+
== "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp"
632+
)
633+
634+
# Step 3: Path-specific fails with 404 - should trigger fallback
635+
path_404_response = httpx.Response(404, request=path_discovery_request)
636+
base_discovery_request = await auth_flow.asend(path_404_response)
637+
assert str(base_discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
638+
639+
# Step 4: Base endpoint succeeds - should store metadata and continue to OAuth discovery
640+
successful_response = httpx.Response(
641+
200,
642+
content=b'{"resource": "https://api.example.com", "authorization_servers": ["https://api.example.com"]}',
643+
request=base_discovery_request,
644+
)
645+
646+
# Verify the fallback worked and metadata was stored
647+
await auth_flow.asend(successful_response)
648+
assert provider.context.protected_resource_metadata is not None
649+
assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/"
650+
651+
# Clean up the generator
652+
try:
653+
await auth_flow.aclose()
654+
except Exception:
655+
pass
656+
657+
@pytest.mark.anyio
658+
async def test_auth_flow_www_authenticate_no_fallback(self, client_metadata, mock_storage):
659+
"""Test that WWW-Authenticate header skips fallback logic entirely."""
660+
661+
async def redirect_handler(url: str) -> None:
662+
pass
663+
664+
async def callback_handler() -> tuple[str, str | None]:
665+
return "test_auth_code", "test_state"
666+
667+
provider = OAuthClientProvider(
668+
server_url="https://api.example.com/api/2.0/mcp",
669+
client_metadata=client_metadata,
670+
storage=mock_storage,
671+
redirect_handler=redirect_handler,
672+
callback_handler=callback_handler,
673+
)
674+
675+
provider.context.current_tokens = None
676+
provider.context.token_expiry_time = None
677+
provider._initialized = True
678+
679+
test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp")
680+
auth_flow = provider.async_auth_flow(test_request)
681+
682+
# Step 1: Original request without auth
683+
request = await auth_flow.__anext__()
684+
assert "Authorization" not in request.headers
685+
686+
# Step 2: 401 with WWW-Authenticate should use that URL directly
687+
response = httpx.Response(
688+
401,
689+
headers={
690+
"WWW-Authenticate": 'Bearer resource_metadata="https://custom.example.com/.well-known/oauth-protected-resource"'
691+
},
692+
request=test_request,
693+
)
694+
695+
www_auth_request = await auth_flow.asend(response)
696+
assert str(www_auth_request.url) == "https://custom.example.com/.well-known/oauth-protected-resource"
697+
698+
# Step 3: Should proceed directly to OAuth metadata discovery (no fallback attempted)
699+
successful_response = httpx.Response(
700+
200,
701+
content=b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}',
702+
request=www_auth_request,
703+
)
704+
705+
await auth_flow.asend(successful_response)
706+
assert provider.context.protected_resource_metadata is not None
707+
708+
# Clean up the generator
709+
try:
710+
await auth_flow.aclose()
711+
except Exception:
712+
pass
713+
714+
@pytest.mark.anyio
715+
async def test_auth_flow_no_fallback_on_success(self, client_metadata, mock_storage):
716+
"""Test that first successful discovery response stops the fallback process."""
717+
718+
async def redirect_handler(url: str) -> None:
719+
pass
720+
721+
async def callback_handler() -> tuple[str, str | None]:
722+
return "test_auth_code", "test_state"
723+
724+
provider = OAuthClientProvider(
725+
server_url="https://api.example.com/api/2.0/mcp",
726+
client_metadata=client_metadata,
727+
storage=mock_storage,
728+
redirect_handler=redirect_handler,
729+
callback_handler=callback_handler,
730+
)
731+
732+
provider.context.current_tokens = None
733+
provider.context.token_expiry_time = None
734+
provider._initialized = True
735+
736+
test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp")
737+
auth_flow = provider.async_auth_flow(test_request)
738+
739+
# Step 1: Original request without auth
740+
request = await auth_flow.__anext__()
741+
assert "Authorization" not in request.headers
742+
743+
# Step 2: 401 triggers path-specific discovery
744+
response = httpx.Response(401, request=test_request)
745+
path_discovery_request = await auth_flow.asend(response)
746+
assert (
747+
str(path_discovery_request.url)
748+
== "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp"
749+
)
750+
751+
# Step 3: Path-specific succeeds - should skip fallback and go to OAuth discovery
752+
successful_response = httpx.Response(
753+
200,
754+
content=b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}',
755+
request=path_discovery_request,
756+
)
757+
758+
await auth_flow.asend(successful_response)
759+
assert provider.context.protected_resource_metadata is not None
760+
assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/api/2.0/mcp"
761+
762+
# Clean up the generator
763+
try:
764+
await auth_flow.aclose()
765+
except Exception:
766+
pass
767+
597768

598769
@pytest.mark.parametrize(
599770
(

0 commit comments

Comments
 (0)