|
16 | 16 | from unittest.mock import Mock |
17 | 17 | from unittest.mock import patch |
18 | 18 |
|
| 19 | +from fastapi.openapi.models import OAuth2 |
| 20 | +from fastapi.openapi.models import OAuthFlowAuthorizationCode |
| 21 | +from fastapi.openapi.models import OAuthFlowImplicit |
| 22 | +from fastapi.openapi.models import OAuthFlows |
19 | 23 | from google.adk.auth.auth_credential import AuthCredential |
20 | 24 | from google.adk.auth.auth_credential import AuthCredentialTypes |
21 | 25 | from google.adk.auth.auth_credential import OAuth2Auth |
22 | 26 | from google.adk.auth.auth_credential import ServiceAccount |
23 | 27 | from google.adk.auth.auth_credential import ServiceAccountCredential |
24 | 28 | from google.adk.auth.auth_schemes import AuthScheme |
25 | 29 | from google.adk.auth.auth_schemes import AuthSchemeType |
| 30 | +from google.adk.auth.auth_schemes import ExtendedOAuth2 |
26 | 31 | from google.adk.auth.auth_tool import AuthConfig |
27 | 32 | from google.adk.auth.credential_manager import CredentialManager |
| 33 | +from google.adk.auth.oauth2_discovery import AuthorizationServerMetadata |
28 | 34 | import pytest |
29 | 35 |
|
30 | 36 |
|
@@ -390,6 +396,28 @@ async def test_validate_credential_oauth2_missing_oauth2_field(self): |
390 | 396 | with pytest.raises(ValueError, match="oauth2 required for credential type"): |
391 | 397 | await manager._validate_credential() |
392 | 398 |
|
| 399 | + @pytest.mark.asyncio |
| 400 | + async def test_validate_credential_oauth2_missing_scheme_info( |
| 401 | + self, extended_oauth2_scheme |
| 402 | + ): |
| 403 | + """Test _validate_credential with OAuth2 missing scheme info.""" |
| 404 | + mock_raw_credential = Mock(spec=AuthCredential) |
| 405 | + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 |
| 406 | + mock_raw_credential.oauth2 = Mock(spec=OAuth2Auth) |
| 407 | + |
| 408 | + auth_config = Mock(spec=AuthConfig) |
| 409 | + auth_config.raw_auth_credential = mock_raw_credential |
| 410 | + auth_config.auth_scheme = extended_oauth2_scheme |
| 411 | + |
| 412 | + manager = CredentialManager(auth_config) |
| 413 | + |
| 414 | + with patch.object( |
| 415 | + manager, |
| 416 | + "_populate_auth_scheme", |
| 417 | + return_value=False, |
| 418 | + ) and pytest.raises(ValueError, match="OAuth scheme info is missing"): |
| 419 | + await manager._validate_credential() |
| 420 | + |
393 | 421 | @pytest.mark.asyncio |
394 | 422 | async def test_exchange_credentials_service_account(self): |
395 | 423 | """Test _exchange_credential with service account credential.""" |
@@ -445,6 +473,95 @@ async def test_exchange_credential_no_exchanger(self): |
445 | 473 | assert result == mock_credential |
446 | 474 | assert was_exchanged is False |
447 | 475 |
|
| 476 | + @pytest.fixture |
| 477 | + def auth_server_metadata(self): |
| 478 | + """Create AuthorizationServerMetadata object.""" |
| 479 | + return AuthorizationServerMetadata( |
| 480 | + issuer="https://auth.example.com", |
| 481 | + authorization_endpoint="https://auth.example.com/authorize", |
| 482 | + token_endpoint="https://auth.example.com/token", |
| 483 | + scopes_supported=["read", "write"], |
| 484 | + ) |
| 485 | + |
| 486 | + @pytest.fixture |
| 487 | + def extended_oauth2_scheme(self): |
| 488 | + """Create ExtendedOAuth2 object with empty endpoints.""" |
| 489 | + return ExtendedOAuth2( |
| 490 | + issuer_url="https://auth.example.com", |
| 491 | + flows=OAuthFlows( |
| 492 | + authorizationCode=OAuthFlowAuthorizationCode( |
| 493 | + authorizationUrl="", |
| 494 | + tokenUrl="", |
| 495 | + ) |
| 496 | + ), |
| 497 | + ) |
| 498 | + |
| 499 | + @pytest.fixture |
| 500 | + def implicit_oauth2_scheme(self): |
| 501 | + """Create OAuth2 object with implicit flow.""" |
| 502 | + return OAuth2( |
| 503 | + flows=OAuthFlows( |
| 504 | + implicit=OAuthFlowImplicit( |
| 505 | + authorizationUrl="https://auth.example.com/authorize" |
| 506 | + ) |
| 507 | + ) |
| 508 | + ) |
| 509 | + |
| 510 | + @pytest.mark.asyncio |
| 511 | + async def test_populate_auth_scheme_success( |
| 512 | + self, auth_server_metadata, extended_oauth2_scheme |
| 513 | + ): |
| 514 | + """Test _populate_auth_scheme successfully populates missing info.""" |
| 515 | + auth_config = Mock(spec=AuthConfig) |
| 516 | + auth_config.auth_scheme = extended_oauth2_scheme |
| 517 | + |
| 518 | + manager = CredentialManager(auth_config) |
| 519 | + with patch.object( |
| 520 | + manager._discovery_manager, |
| 521 | + "discover_auth_server_metadata", |
| 522 | + return_value=auth_server_metadata, |
| 523 | + ): |
| 524 | + assert await manager._populate_auth_scheme() |
| 525 | + |
| 526 | + assert ( |
| 527 | + manager._auth_config.auth_scheme.flows.authorizationCode.authorizationUrl |
| 528 | + == "https://auth.example.com/authorize" |
| 529 | + ) |
| 530 | + assert ( |
| 531 | + manager._auth_config.auth_scheme.flows.authorizationCode.tokenUrl |
| 532 | + == "https://auth.example.com/token" |
| 533 | + ) |
| 534 | + |
| 535 | + @pytest.mark.asyncio |
| 536 | + async def test_populate_auth_scheme_fail(self, extended_oauth2_scheme): |
| 537 | + """Test _populate_auth_scheme when auto-discovery fails.""" |
| 538 | + auth_config = Mock(spec=AuthConfig) |
| 539 | + auth_config.auth_scheme = extended_oauth2_scheme |
| 540 | + |
| 541 | + manager = CredentialManager(auth_config) |
| 542 | + with patch.object( |
| 543 | + manager._discovery_manager, |
| 544 | + "discover_auth_server_metadata", |
| 545 | + return_value=None, |
| 546 | + ): |
| 547 | + assert not await manager._populate_auth_scheme() |
| 548 | + |
| 549 | + assert ( |
| 550 | + not manager._auth_config.auth_scheme.flows.authorizationCode.authorizationUrl |
| 551 | + ) |
| 552 | + assert not manager._auth_config.auth_scheme.flows.authorizationCode.tokenUrl |
| 553 | + |
| 554 | + @pytest.mark.asyncio |
| 555 | + async def test_populate_auth_scheme_noop(self, implicit_oauth2_scheme): |
| 556 | + """Test _populate_auth_scheme when auth scheme info not missing.""" |
| 557 | + auth_config = Mock(spec=AuthConfig) |
| 558 | + auth_config.auth_scheme = implicit_oauth2_scheme |
| 559 | + |
| 560 | + manager = CredentialManager(auth_config) |
| 561 | + assert not await manager._populate_auth_scheme() # no-op |
| 562 | + |
| 563 | + assert manager._auth_config.auth_scheme == implicit_oauth2_scheme |
| 564 | + |
448 | 565 |
|
449 | 566 | @pytest.fixture |
450 | 567 | def oauth2_auth_scheme(): |
|
0 commit comments