@@ -1081,6 +1081,116 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth(
10811081 # Verify exactly one request was yielded (no double-sending)
10821082 assert request_yields == 1 , f"Expected 1 request yield, got { request_yields } "
10831083
1084+ @pytest .mark .anyio
1085+ async def test_token_exchange_accepts_201_status (
1086+ self , oauth_provider : OAuthClientProvider , mock_storage : MockTokenStorage
1087+ ):
1088+ """Test that token exchange accepts both 200 and 201 status codes."""
1089+ # Ensure no tokens are stored
1090+ oauth_provider .context .current_tokens = None
1091+ oauth_provider .context .token_expiry_time = None
1092+ oauth_provider ._initialized = True
1093+
1094+ # Create a test request
1095+ test_request = httpx .Request ("GET" , "https://api.example.com/mcp" )
1096+
1097+ # Mock the auth flow
1098+ auth_flow = oauth_provider .async_auth_flow (test_request )
1099+
1100+ # First request should be the original request without auth header
1101+ request = await auth_flow .__anext__ ()
1102+ assert "Authorization" not in request .headers
1103+
1104+ # Send a 401 response to trigger the OAuth flow
1105+ response = httpx .Response (
1106+ 401 ,
1107+ headers = {
1108+ "WWW-Authenticate" : 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
1109+ },
1110+ request = test_request ,
1111+ )
1112+
1113+ # Next request should be to discover protected resource metadata
1114+ discovery_request = await auth_flow .asend (response )
1115+ assert discovery_request .method == "GET"
1116+ assert str (discovery_request .url ) == "https://api.example.com/.well-known/oauth-protected-resource"
1117+
1118+ # Send a successful discovery response with minimal protected resource metadata
1119+ discovery_response = httpx .Response (
1120+ 200 ,
1121+ content = b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}' ,
1122+ request = discovery_request ,
1123+ )
1124+
1125+ # Next request should be to discover OAuth metadata
1126+ oauth_metadata_request = await auth_flow .asend (discovery_response )
1127+ assert oauth_metadata_request .method == "GET"
1128+ assert str (oauth_metadata_request .url ).startswith ("https://auth.example.com/" )
1129+ assert "mcp-protocol-version" in oauth_metadata_request .headers
1130+
1131+ # Send a successful OAuth metadata response
1132+ oauth_metadata_response = httpx .Response (
1133+ 200 ,
1134+ content = (
1135+ b'{"issuer": "https://auth.example.com", '
1136+ b'"authorization_endpoint": "https://auth.example.com/authorize", '
1137+ b'"token_endpoint": "https://auth.example.com/token", '
1138+ b'"registration_endpoint": "https://auth.example.com/register"}'
1139+ ),
1140+ request = oauth_metadata_request ,
1141+ )
1142+
1143+ # Next request should be to register client
1144+ registration_request = await auth_flow .asend (oauth_metadata_response )
1145+ assert registration_request .method == "POST"
1146+ assert str (registration_request .url ) == "https://auth.example.com/register"
1147+
1148+ # Send a successful registration response with 201 status
1149+ registration_response = httpx .Response (
1150+ 201 ,
1151+ content = b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}' ,
1152+ request = registration_request ,
1153+ )
1154+
1155+ # Mock the authorization process
1156+ oauth_provider ._perform_authorization_code_grant = mock .AsyncMock (
1157+ return_value = ("test_auth_code" , "test_code_verifier" )
1158+ )
1159+
1160+ # Next request should be to exchange token
1161+ token_request = await auth_flow .asend (registration_response )
1162+ assert token_request .method == "POST"
1163+ assert str (token_request .url ) == "https://auth.example.com/token"
1164+ assert "code=test_auth_code" in token_request .content .decode ()
1165+
1166+ # Send a successful token response with 201 status code (test both 200 and 201 are accepted)
1167+ token_response = httpx .Response (
1168+ 201 ,
1169+ content = (
1170+ b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
1171+ b'"refresh_token": "new_refresh_token"}'
1172+ ),
1173+ request = token_request ,
1174+ )
1175+
1176+ # Final request should be the original request with auth header
1177+ final_request = await auth_flow .asend (token_response )
1178+ assert final_request .headers ["Authorization" ] == "Bearer new_access_token"
1179+ assert final_request .method == "GET"
1180+ assert str (final_request .url ) == "https://api.example.com/mcp"
1181+
1182+ # Send final success response to properly close the generator
1183+ final_response = httpx .Response (200 , request = final_request )
1184+ try :
1185+ await auth_flow .asend (final_response )
1186+ except StopAsyncIteration :
1187+ pass # Expected - generator should complete
1188+
1189+ # Verify tokens were stored
1190+ assert oauth_provider .context .current_tokens is not None
1191+ assert oauth_provider .context .current_tokens .access_token == "new_access_token"
1192+ assert oauth_provider .context .token_expiry_time is not None
1193+
10841194 @pytest .mark .anyio
10851195 async def test_403_insufficient_scope_updates_scope_from_header (
10861196 self ,
0 commit comments