|
| 1 | +import base64 |
1 | 2 | import urllib.parse |
2 | 3 |
|
3 | 4 | import jwt |
@@ -161,3 +162,292 @@ async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523O |
161 | 162 | assert claims["name"] == "John Doe" |
162 | 163 | assert claims["admin"] |
163 | 164 | assert claims["iat"] == 1516239022 |
| 165 | + |
| 166 | + @pytest.mark.anyio |
| 167 | + async def test_exchange_token_client_credentials_with_private_key_jwt( |
| 168 | + self, rfc7523_oauth_provider: RFC7523OAuthClientProvider |
| 169 | + ): |
| 170 | + """Test client_credentials token exchange with private_key_jwt authentication.""" |
| 171 | + # Set up required context for client_credentials with private_key_jwt |
| 172 | + rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( |
| 173 | + client_id="test-client", |
| 174 | + grant_types=["client_credentials"], |
| 175 | + token_endpoint_auth_method="private_key_jwt", |
| 176 | + redirect_uris=[AnyUrl("http://localhost:0/unused")], |
| 177 | + scope="read write", |
| 178 | + ) |
| 179 | + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( |
| 180 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 181 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 182 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 183 | + ) |
| 184 | + rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info |
| 185 | + rfc7523_oauth_provider.context.protocol_version = "2025-06-18" |
| 186 | + rfc7523_oauth_provider.jwt_parameters = JWTParameters( |
| 187 | + issuer="test-client", |
| 188 | + subject="test-client", |
| 189 | + jwt_signing_algorithm="HS256", |
| 190 | + jwt_signing_key="a-string-secret-at-least-256-bits-long", |
| 191 | + jwt_lifetime_seconds=300, |
| 192 | + ) |
| 193 | + |
| 194 | + request = await rfc7523_oauth_provider._exchange_token_client_credentials() |
| 195 | + |
| 196 | + assert request.method == "POST" |
| 197 | + assert str(request.url) == "https://auth.example.com/token" |
| 198 | + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" |
| 199 | + |
| 200 | + # Check form data |
| 201 | + content = urllib.parse.unquote_plus(request.content.decode()) |
| 202 | + assert "grant_type=client_credentials" in content |
| 203 | + assert "scope=read write" in content |
| 204 | + assert "resource=https://api.example.com/v1/mcp" in content |
| 205 | + assert "client_assertion=" in content |
| 206 | + assert "client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer" in content |
| 207 | + |
| 208 | + @pytest.mark.anyio |
| 209 | + async def test_exchange_token_client_credentials_with_client_secret_basic( |
| 210 | + self, rfc7523_oauth_provider: RFC7523OAuthClientProvider |
| 211 | + ): |
| 212 | + """Test client_credentials token exchange with client_secret_basic authentication.""" |
| 213 | + # Set up required context for client_credentials with client_secret_basic |
| 214 | + rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( |
| 215 | + client_id="test-client", |
| 216 | + client_secret="test-secret", |
| 217 | + grant_types=["client_credentials"], |
| 218 | + token_endpoint_auth_method="client_secret_basic", |
| 219 | + redirect_uris=[AnyUrl("http://localhost:0/unused")], |
| 220 | + scope="read write", |
| 221 | + ) |
| 222 | + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( |
| 223 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 224 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 225 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 226 | + ) |
| 227 | + rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info |
| 228 | + rfc7523_oauth_provider.context.protocol_version = "2025-06-18" |
| 229 | + # No JWT parameters needed for client_secret_basic |
| 230 | + |
| 231 | + request = await rfc7523_oauth_provider._exchange_token_client_credentials() |
| 232 | + |
| 233 | + assert request.method == "POST" |
| 234 | + assert str(request.url) == "https://auth.example.com/token" |
| 235 | + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" |
| 236 | + |
| 237 | + # Check Authorization header (Basic auth) |
| 238 | + assert "Authorization" in request.headers |
| 239 | + auth_header = request.headers["Authorization"] |
| 240 | + assert auth_header.startswith("Basic ") |
| 241 | + |
| 242 | + # Decode and verify credentials |
| 243 | + encoded_creds = auth_header[6:] # Remove "Basic " prefix |
| 244 | + decoded = base64.b64decode(encoded_creds).decode() |
| 245 | + assert decoded == "test-client:test-secret" |
| 246 | + |
| 247 | + # Check form data |
| 248 | + content = urllib.parse.unquote_plus(request.content.decode()) |
| 249 | + assert "grant_type=client_credentials" in content |
| 250 | + assert "scope=read write" in content |
| 251 | + assert "resource=https://api.example.com/v1/mcp" in content |
| 252 | + # client_secret should NOT be in body for client_secret_basic |
| 253 | + assert "client_secret=" not in content |
| 254 | + |
| 255 | + @pytest.mark.anyio |
| 256 | + async def test_perform_authorization_routes_to_client_credentials( |
| 257 | + self, rfc7523_oauth_provider: RFC7523OAuthClientProvider |
| 258 | + ): |
| 259 | + """Test that _perform_authorization routes to client_credentials when configured.""" |
| 260 | + # Set up required context for client_credentials flow |
| 261 | + rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( |
| 262 | + client_id="test-client", |
| 263 | + client_secret="test-secret", |
| 264 | + grant_types=["client_credentials"], |
| 265 | + token_endpoint_auth_method="client_secret_basic", |
| 266 | + redirect_uris=[AnyUrl("http://localhost:0/unused")], |
| 267 | + ) |
| 268 | + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( |
| 269 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 270 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 271 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 272 | + ) |
| 273 | + rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info |
| 274 | + rfc7523_oauth_provider.context.protocol_version = "2025-06-18" |
| 275 | + |
| 276 | + request = await rfc7523_oauth_provider._perform_authorization() |
| 277 | + |
| 278 | + # Should route to client_credentials flow |
| 279 | + content = urllib.parse.unquote_plus(request.content.decode()) |
| 280 | + assert "grant_type=client_credentials" in content |
| 281 | + |
| 282 | + @pytest.mark.anyio |
| 283 | + async def test_perform_authorization_routes_to_jwt_bearer( |
| 284 | + self, rfc7523_oauth_provider: RFC7523OAuthClientProvider |
| 285 | + ): |
| 286 | + """Test that _perform_authorization routes to jwt-bearer when configured.""" |
| 287 | + # Set up required context for jwt-bearer flow |
| 288 | + rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( |
| 289 | + client_id="test-client", |
| 290 | + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], |
| 291 | + token_endpoint_auth_method="private_key_jwt", |
| 292 | + redirect_uris=[AnyUrl("http://localhost:0/unused")], |
| 293 | + ) |
| 294 | + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( |
| 295 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 296 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 297 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 298 | + ) |
| 299 | + rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info |
| 300 | + rfc7523_oauth_provider.context.protocol_version = "2025-06-18" |
| 301 | + rfc7523_oauth_provider.jwt_parameters = JWTParameters( |
| 302 | + issuer="test-client", |
| 303 | + subject="test-client", |
| 304 | + jwt_signing_algorithm="HS256", |
| 305 | + jwt_signing_key="a-string-secret-at-least-256-bits-long", |
| 306 | + ) |
| 307 | + |
| 308 | + request = await rfc7523_oauth_provider._perform_authorization() |
| 309 | + |
| 310 | + # Should route to jwt-bearer flow |
| 311 | + content = urllib.parse.unquote_plus(request.content.decode()) |
| 312 | + assert "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer" in content |
| 313 | + |
| 314 | + @pytest.mark.anyio |
| 315 | + async def test_add_client_authentication_jwt( |
| 316 | + self, rfc7523_oauth_provider: RFC7523OAuthClientProvider |
| 317 | + ): |
| 318 | + """Test _add_client_authentication_jwt adds correct JWT assertion parameters.""" |
| 319 | + # Set up required context |
| 320 | + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( |
| 321 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 322 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 323 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 324 | + ) |
| 325 | + rfc7523_oauth_provider.jwt_parameters = JWTParameters( |
| 326 | + issuer="test-client", |
| 327 | + subject="test-client", |
| 328 | + jwt_signing_algorithm="HS256", |
| 329 | + jwt_signing_key="a-string-secret-at-least-256-bits-long", |
| 330 | + ) |
| 331 | + |
| 332 | + token_data: dict = {} |
| 333 | + rfc7523_oauth_provider._add_client_authentication_jwt(token_data=token_data) |
| 334 | + |
| 335 | + # Check that JWT assertion parameters were added |
| 336 | + assert "client_assertion" in token_data |
| 337 | + assert token_data["client_assertion_type"] == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" |
| 338 | + assert token_data["audience"] == "https://api.example.com/v1/mcp" |
| 339 | + |
| 340 | + # Verify the JWT assertion is valid and has correct audience (issuer identifier) |
| 341 | + claims = jwt.decode( |
| 342 | + token_data["client_assertion"], |
| 343 | + key="a-string-secret-at-least-256-bits-long", |
| 344 | + algorithms=["HS256"], |
| 345 | + audience="https://auth.example.com/", # Should be issuer, not token endpoint |
| 346 | + verify=True, |
| 347 | + ) |
| 348 | + assert claims["iss"] == "test-client" |
| 349 | + assert claims["sub"] == "test-client" |
| 350 | + |
| 351 | + @pytest.mark.anyio |
| 352 | + async def test_exchange_token_authorization_code_with_private_key_jwt( |
| 353 | + self, rfc7523_oauth_provider: RFC7523OAuthClientProvider |
| 354 | + ): |
| 355 | + """Test authorization_code token exchange adds JWT when using private_key_jwt.""" |
| 356 | + # Set up required context |
| 357 | + rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( |
| 358 | + client_id="test-client", |
| 359 | + grant_types=["authorization_code"], |
| 360 | + token_endpoint_auth_method="private_key_jwt", |
| 361 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 362 | + ) |
| 363 | + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( |
| 364 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 365 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 366 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 367 | + ) |
| 368 | + rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info |
| 369 | + rfc7523_oauth_provider.context.protocol_version = "2025-06-18" |
| 370 | + rfc7523_oauth_provider.jwt_parameters = JWTParameters( |
| 371 | + issuer="test-client", |
| 372 | + subject="test-client", |
| 373 | + jwt_signing_algorithm="HS256", |
| 374 | + jwt_signing_key="a-string-secret-at-least-256-bits-long", |
| 375 | + ) |
| 376 | + |
| 377 | + request = await rfc7523_oauth_provider._exchange_token_authorization_code( |
| 378 | + "test-auth-code", "test-verifier" |
| 379 | + ) |
| 380 | + |
| 381 | + # Check form data contains JWT assertion |
| 382 | + content = urllib.parse.unquote_plus(request.content.decode()) |
| 383 | + assert "grant_type=authorization_code" in content |
| 384 | + assert "code=test-auth-code" in content |
| 385 | + assert "client_assertion=" in content |
| 386 | + assert "client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer" in content |
| 387 | + |
| 388 | + @pytest.mark.anyio |
| 389 | + async def test_exchange_token_authorization_code_without_private_key_jwt( |
| 390 | + self, rfc7523_oauth_provider: RFC7523OAuthClientProvider |
| 391 | + ): |
| 392 | + """Test authorization_code token exchange without private_key_jwt uses standard auth.""" |
| 393 | + # Set up required context with client_secret_post (not private_key_jwt) |
| 394 | + rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( |
| 395 | + client_id="test-client", |
| 396 | + client_secret="test-secret", |
| 397 | + grant_types=["authorization_code"], |
| 398 | + token_endpoint_auth_method="client_secret_post", |
| 399 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 400 | + ) |
| 401 | + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( |
| 402 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 403 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 404 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 405 | + ) |
| 406 | + rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info |
| 407 | + rfc7523_oauth_provider.context.protocol_version = "2025-06-18" |
| 408 | + |
| 409 | + request = await rfc7523_oauth_provider._exchange_token_authorization_code( |
| 410 | + "test-auth-code", "test-verifier" |
| 411 | + ) |
| 412 | + |
| 413 | + # Check form data does NOT contain JWT assertion |
| 414 | + content = urllib.parse.unquote_plus(request.content.decode()) |
| 415 | + assert "grant_type=authorization_code" in content |
| 416 | + assert "code=test-auth-code" in content |
| 417 | + assert "client_assertion=" not in content |
| 418 | + # Should have client_secret in body for client_secret_post |
| 419 | + assert "client_secret=test-secret" in content |
| 420 | + |
| 421 | + @pytest.mark.anyio |
| 422 | + async def test_perform_authorization_falls_back_to_parent( |
| 423 | + self, rfc7523_oauth_provider: RFC7523OAuthClientProvider |
| 424 | + ): |
| 425 | + """Test that _perform_authorization falls back to parent when not client_credentials or jwt-bearer.""" |
| 426 | + # Set up required context with authorization_code grant (not client_credentials or jwt-bearer) |
| 427 | + rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( |
| 428 | + client_id="test-client", |
| 429 | + client_secret="test-secret", |
| 430 | + grant_types=["authorization_code", "refresh_token"], |
| 431 | + token_endpoint_auth_method="client_secret_post", |
| 432 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 433 | + ) |
| 434 | + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( |
| 435 | + issuer=AnyHttpUrl("https://auth.example.com"), |
| 436 | + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), |
| 437 | + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), |
| 438 | + ) |
| 439 | + rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info |
| 440 | + rfc7523_oauth_provider.context.protocol_version = "2025-06-18" |
| 441 | + |
| 442 | + # Mock the parent class's _perform_authorization since it would try to do real OAuth |
| 443 | + from unittest.mock import AsyncMock, patch |
| 444 | + |
| 445 | + mock_request = AsyncMock() |
| 446 | + with patch.object( |
| 447 | + rfc7523_oauth_provider.__class__.__bases__[0], |
| 448 | + "_perform_authorization", |
| 449 | + new=AsyncMock(return_value=mock_request), |
| 450 | + ) as mock_parent: |
| 451 | + result = await rfc7523_oauth_provider._perform_authorization() |
| 452 | + mock_parent.assert_called_once() |
| 453 | + assert result == mock_request |
0 commit comments