Skip to content

Commit c8c6cd7

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Introduce ExtendedOAuth2 scheme that auto-populates auth/token URLs
Use auto-discovered auth_endpoint and token_endpoint in CredentialManager. PiperOrigin-RevId: 811183929
1 parent f159bd9 commit c8c6cd7

File tree

3 files changed

+199
-1
lines changed

3 files changed

+199
-1
lines changed

src/google/adk/auth/auth_schemes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
from enum import Enum
1618
from typing import List
1719
from typing import Optional
1820
from typing import Union
1921

22+
from fastapi.openapi.models import OAuth2
2023
from fastapi.openapi.models import OAuthFlows
2124
from fastapi.openapi.models import SecurityBase
2225
from fastapi.openapi.models import SecurityScheme
2326
from fastapi.openapi.models import SecuritySchemeType
2427
from pydantic import Field
2528

29+
from ..utils.feature_decorator import experimental
30+
2631

2732
class OpenIdConnectWithConfig(SecurityBase):
2833
type_: SecuritySchemeType = Field(
@@ -65,3 +70,10 @@ def from_flow(flow: OAuthFlows) -> "OAuthGrantType":
6570

6671
# AuthSchemeType re-exports SecuritySchemeType from OpenAPI 3.0.
6772
AuthSchemeType = SecuritySchemeType
73+
74+
75+
@experimental
76+
class ExtendedOAuth2(OAuth2):
77+
"""OAuth2 scheme that incorporates auto-discovery for endpoints."""
78+
79+
issuer_url: Optional[str] = None # Used for endpoint-discovery

src/google/adk/auth/credential_manager.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,26 @@
1414

1515
from __future__ import annotations
1616

17+
import logging
1718
from typing import Optional
1819

20+
from fastapi.openapi.models import OAuth2
21+
1922
from ..agents.callback_context import CallbackContext
2023
from ..utils.feature_decorator import experimental
2124
from .auth_credential import AuthCredential
2225
from .auth_credential import AuthCredentialTypes
2326
from .auth_schemes import AuthSchemeType
27+
from .auth_schemes import ExtendedOAuth2
2428
from .auth_tool import AuthConfig
2529
from .exchanger.base_credential_exchanger import BaseCredentialExchanger
2630
from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry
31+
from .oauth2_discovery import OAuth2DiscoveryManager
2732
from .refresher.base_credential_refresher import BaseCredentialRefresher
2833
from .refresher.credential_refresher_registry import CredentialRefresherRegistry
2934

35+
logger = logging.getLogger("google_adk." + __name__)
36+
3037

3138
@experimental
3239
class CredentialManager:
@@ -74,6 +81,7 @@ def __init__(
7481
self._auth_config = auth_config
7582
self._exchanger_registry = CredentialExchangerRegistry()
7683
self._refresher_registry = CredentialRefresherRegistry()
84+
self._discovery_manager = OAuth2DiscoveryManager()
7785

7886
# Register default exchangers and refreshers
7987
# TODO: support service account credential exchanger
@@ -247,7 +255,14 @@ async def _validate_credential(self) -> None:
247255
"auth_config.raw_credential.oauth2 required for credential type "
248256
f"{raw_credential.auth_type}"
249257
)
250-
# Additional validation can be added here
258+
259+
if self._missing_oauth_info() and not await self._populate_auth_scheme():
260+
raise ValueError(
261+
"OAuth scheme info is missing, and auto-discovery has failed to fill"
262+
" them in."
263+
)
264+
265+
# Additional validation can be added here
251266

252267
async def _save_credential(
253268
self, callback_context: CallbackContext, credential: AuthCredential
@@ -259,3 +274,57 @@ async def _save_credential(
259274
credential_service = callback_context._invocation_context.credential_service
260275
if credential_service:
261276
await callback_context.save_credential(self._auth_config)
277+
278+
async def _populate_auth_scheme(self) -> bool:
279+
"""Auto-discover server metadata and populate missing auth scheme info.
280+
281+
Returns:
282+
True if auto-discovery was successful, False otherwise.
283+
"""
284+
auth_scheme = self._auth_config.auth_scheme
285+
if (
286+
not isinstance(auth_scheme, ExtendedOAuth2)
287+
or not auth_scheme.issuer_url
288+
):
289+
logger.warning("No issuer_url was provided for auto-discovery.")
290+
return False
291+
292+
metadata = await self._discovery_manager.discover_auth_server_metadata(
293+
auth_scheme.issuer_url
294+
)
295+
if not metadata:
296+
logger.warning("Auto-discovery has failed to populate OAuth scheme info.")
297+
return False
298+
299+
flows = auth_scheme.flows
300+
301+
if flows.implicit and not flows.implicit.authorizationUrl:
302+
flows.implicit.authorizationUrl = metadata.authorization_endpoint
303+
if flows.password and not flows.password.tokenUrl:
304+
flows.password.tokenUrl = metadata.token_endpoint
305+
if flows.clientCredentials and not flows.clientCredentials.tokenUrl:
306+
flows.clientCredentials.tokenUrl = metadata.token_endpoint
307+
if flows.authorizationCode and not flows.authorizationCode.authorizationUrl:
308+
flows.authorizationCode.authorizationUrl = metadata.authorization_endpoint
309+
if flows.authorizationCode and not flows.authorizationCode.tokenUrl:
310+
flows.authorizationCode.tokenUrl = metadata.token_endpoint
311+
return True
312+
313+
def _missing_oauth_info(self) -> bool:
314+
"""Checks if we are missing auth/token URLs needed for OAuth."""
315+
auth_scheme = self._auth_config.auth_scheme
316+
if isinstance(auth_scheme, OAuth2):
317+
flows = auth_scheme.flows
318+
return (
319+
flows.implicit
320+
and not flows.implicit.authorizationUrl
321+
or flows.password
322+
and not flows.password.tokenUrl
323+
or flows.clientCredentials
324+
and not flows.clientCredentials.tokenUrl
325+
or flows.authorizationCode
326+
and not flows.authorizationCode.authorizationUrl
327+
or flows.authorizationCode
328+
and not flows.authorizationCode.tokenUrl
329+
)
330+
return False

tests/unittests/auth/test_credential_manager.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,21 @@
1616
from unittest.mock import Mock
1717
from unittest.mock import patch
1818

19+
from fastapi.openapi.models import OAuth2
20+
from fastapi.openapi.models import OAuthFlowAuthorizationCode
21+
from fastapi.openapi.models import OAuthFlowImplicit
22+
from fastapi.openapi.models import OAuthFlows
1923
from google.adk.auth.auth_credential import AuthCredential
2024
from google.adk.auth.auth_credential import AuthCredentialTypes
2125
from google.adk.auth.auth_credential import OAuth2Auth
2226
from google.adk.auth.auth_credential import ServiceAccount
2327
from google.adk.auth.auth_credential import ServiceAccountCredential
2428
from google.adk.auth.auth_schemes import AuthScheme
2529
from google.adk.auth.auth_schemes import AuthSchemeType
30+
from google.adk.auth.auth_schemes import ExtendedOAuth2
2631
from google.adk.auth.auth_tool import AuthConfig
2732
from google.adk.auth.credential_manager import CredentialManager
33+
from google.adk.auth.oauth2_discovery import AuthorizationServerMetadata
2834
import pytest
2935

3036

@@ -390,6 +396,28 @@ async def test_validate_credential_oauth2_missing_oauth2_field(self):
390396
with pytest.raises(ValueError, match="oauth2 required for credential type"):
391397
await manager._validate_credential()
392398

399+
@pytest.mark.asyncio
400+
async def test_validate_credential_oauth2_missing_scheme_info(
401+
self, extended_oauth2_scheme
402+
):
403+
"""Test _validate_credential with OAuth2 missing scheme info."""
404+
mock_raw_credential = Mock(spec=AuthCredential)
405+
mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2
406+
mock_raw_credential.oauth2 = Mock(spec=OAuth2Auth)
407+
408+
auth_config = Mock(spec=AuthConfig)
409+
auth_config.raw_auth_credential = mock_raw_credential
410+
auth_config.auth_scheme = extended_oauth2_scheme
411+
412+
manager = CredentialManager(auth_config)
413+
414+
with patch.object(
415+
manager,
416+
"_populate_auth_scheme",
417+
return_value=False,
418+
) and pytest.raises(ValueError, match="OAuth scheme info is missing"):
419+
await manager._validate_credential()
420+
393421
@pytest.mark.asyncio
394422
async def test_exchange_credentials_service_account(self):
395423
"""Test _exchange_credential with service account credential."""
@@ -445,6 +473,95 @@ async def test_exchange_credential_no_exchanger(self):
445473
assert result == mock_credential
446474
assert was_exchanged is False
447475

476+
@pytest.fixture
477+
def auth_server_metadata(self):
478+
"""Create AuthorizationServerMetadata object."""
479+
return AuthorizationServerMetadata(
480+
issuer="https://auth.example.com",
481+
authorization_endpoint="https://auth.example.com/authorize",
482+
token_endpoint="https://auth.example.com/token",
483+
scopes_supported=["read", "write"],
484+
)
485+
486+
@pytest.fixture
487+
def extended_oauth2_scheme(self):
488+
"""Create ExtendedOAuth2 object with empty endpoints."""
489+
return ExtendedOAuth2(
490+
issuer_url="https://auth.example.com",
491+
flows=OAuthFlows(
492+
authorizationCode=OAuthFlowAuthorizationCode(
493+
authorizationUrl="",
494+
tokenUrl="",
495+
)
496+
),
497+
)
498+
499+
@pytest.fixture
500+
def implicit_oauth2_scheme(self):
501+
"""Create OAuth2 object with implicit flow."""
502+
return OAuth2(
503+
flows=OAuthFlows(
504+
implicit=OAuthFlowImplicit(
505+
authorizationUrl="https://auth.example.com/authorize"
506+
)
507+
)
508+
)
509+
510+
@pytest.mark.asyncio
511+
async def test_populate_auth_scheme_success(
512+
self, auth_server_metadata, extended_oauth2_scheme
513+
):
514+
"""Test _populate_auth_scheme successfully populates missing info."""
515+
auth_config = Mock(spec=AuthConfig)
516+
auth_config.auth_scheme = extended_oauth2_scheme
517+
518+
manager = CredentialManager(auth_config)
519+
with patch.object(
520+
manager._discovery_manager,
521+
"discover_auth_server_metadata",
522+
return_value=auth_server_metadata,
523+
):
524+
assert await manager._populate_auth_scheme()
525+
526+
assert (
527+
manager._auth_config.auth_scheme.flows.authorizationCode.authorizationUrl
528+
== "https://auth.example.com/authorize"
529+
)
530+
assert (
531+
manager._auth_config.auth_scheme.flows.authorizationCode.tokenUrl
532+
== "https://auth.example.com/token"
533+
)
534+
535+
@pytest.mark.asyncio
536+
async def test_populate_auth_scheme_fail(self, extended_oauth2_scheme):
537+
"""Test _populate_auth_scheme when auto-discovery fails."""
538+
auth_config = Mock(spec=AuthConfig)
539+
auth_config.auth_scheme = extended_oauth2_scheme
540+
541+
manager = CredentialManager(auth_config)
542+
with patch.object(
543+
manager._discovery_manager,
544+
"discover_auth_server_metadata",
545+
return_value=None,
546+
):
547+
assert not await manager._populate_auth_scheme()
548+
549+
assert (
550+
not manager._auth_config.auth_scheme.flows.authorizationCode.authorizationUrl
551+
)
552+
assert not manager._auth_config.auth_scheme.flows.authorizationCode.tokenUrl
553+
554+
@pytest.mark.asyncio
555+
async def test_populate_auth_scheme_noop(self, implicit_oauth2_scheme):
556+
"""Test _populate_auth_scheme when auth scheme info not missing."""
557+
auth_config = Mock(spec=AuthConfig)
558+
auth_config.auth_scheme = implicit_oauth2_scheme
559+
560+
manager = CredentialManager(auth_config)
561+
assert not await manager._populate_auth_scheme() # no-op
562+
563+
assert manager._auth_config.auth_scheme == implicit_oauth2_scheme
564+
448565

449566
@pytest.fixture
450567
def oauth2_auth_scheme():

0 commit comments

Comments
 (0)