Skip to content

Commit 2a2da0f

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Introduce OAuth2DiscoveryManager to fetch metadata needed for OAuth
This is the first step to bring ADK to compliance with MCP Authorization Spec. PiperOrigin-RevId: 811177152
1 parent 5a485b0 commit 2a2da0f

File tree

2 files changed

+433
-0
lines changed

2 files changed

+433
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import json
18+
import logging
19+
from typing import List
20+
from typing import Optional
21+
from urllib.parse import urlparse
22+
23+
import httpx
24+
from pydantic import BaseModel
25+
from pydantic import ValidationError
26+
27+
from ..utils.feature_decorator import experimental
28+
29+
logger = logging.getLogger("google_adk." + __name__)
30+
31+
32+
@experimental
33+
class AuthorizationServerMetadata(BaseModel):
34+
"""Represents the OAuth2 authorization server metadata per RFC8414."""
35+
36+
issuer: str
37+
authorization_endpoint: str
38+
token_endpoint: str
39+
scopes_supported: Optional[List[str]] = None
40+
registration_endpoint: Optional[str] = None
41+
42+
43+
@experimental
44+
class ProtectedResourceMetadata(BaseModel):
45+
"""Represents the OAuth2 protected resource metadata per RFC9728."""
46+
47+
resource: str
48+
authorization_servers: List[str] = []
49+
50+
51+
@experimental
52+
class OAuth2DiscoveryManager:
53+
"""Implements Metadata discovery for OAuth2 following RFC8414 and RFC9728."""
54+
55+
async def discover_auth_server_metadata(
56+
self, issuer_url: str
57+
) -> Optional[AuthorizationServerMetadata]:
58+
"""Discovers the OAuth2 authorization server metadata."""
59+
try:
60+
parsed_url = urlparse(issuer_url)
61+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
62+
path = parsed_url.path
63+
except ValueError as e:
64+
logger.warning("Failed to parse issuer_url %s: %s", issuer_url, e)
65+
return None
66+
67+
# Try the standard well-known endpoints in order.
68+
if path and path != "/":
69+
endpoints_to_try = [
70+
# 1. OAuth 2.0 Authorization Server Metadata with path insertion
71+
f"{base_url}/.well-known/oauth-authorization-server{path}",
72+
# 2. OpenID Connect Discovery 1.0 with path insertion
73+
f"{base_url}/.well-known/openid-configuration{path}",
74+
# 3. OpenID Connect Discovery 1.0 with path appending
75+
f"{base_url}{path}/.well-known/openid-configuration",
76+
]
77+
else:
78+
endpoints_to_try = [
79+
# 1. OAuth 2.0 Authorization Server Metadata
80+
f"{base_url}/.well-known/oauth-authorization-server",
81+
# 2. OpenID Connect Discovery 1.0
82+
f"{base_url}/.well-known/openid-configuration",
83+
]
84+
85+
async with httpx.AsyncClient() as client:
86+
for endpoint in endpoints_to_try:
87+
try:
88+
response = await client.get(endpoint, timeout=5)
89+
response.raise_for_status()
90+
metadata = AuthorizationServerMetadata.model_validate(response.json())
91+
# Validate issuer to defend against MIX-UP attacks
92+
if metadata.issuer == issuer_url.rstrip("/"):
93+
return metadata
94+
else:
95+
logger.warning(
96+
"Issuer in metadata %s does not match issuer_url %s",
97+
metadata.issuer,
98+
issuer_url,
99+
)
100+
except httpx.HTTPError as e:
101+
logger.debug("Failed to fetch metadata from %s: %s", endpoint, e)
102+
except (json.decoder.JSONDecodeError, ValidationError) as e:
103+
logger.debug("Failed to parse metadata from %s: %s", endpoint, e)
104+
return None
105+
106+
async def discover_resource_metadata(
107+
self, resource_url: str
108+
) -> Optional[ProtectedResourceMetadata]:
109+
"""Discovers the OAuth2 protected resource metadata."""
110+
try:
111+
parsed_url = urlparse(resource_url)
112+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
113+
path = parsed_url.path
114+
except ValueError as e:
115+
logger.warning("Failed to parse resource_url %s: %s", resource_url, e)
116+
return None
117+
118+
if path and path != "/":
119+
well_known_endpoint = (
120+
f"{base_url}/.well-known/oauth-protected-resource{path}"
121+
)
122+
else:
123+
well_known_endpoint = f"{base_url}/.well-known/oauth-protected-resource"
124+
125+
async with httpx.AsyncClient() as client:
126+
try:
127+
response = await client.get(well_known_endpoint, timeout=5)
128+
response.raise_for_status()
129+
metadata = ProtectedResourceMetadata.model_validate(response.json())
130+
# Validate resource to defend against MIX-UP attacks
131+
if metadata.resource == resource_url.rstrip("/"):
132+
return metadata
133+
else:
134+
logger.warning(
135+
"Resource in metadata %s does not match resource_url %s",
136+
metadata.resource,
137+
resource_url,
138+
)
139+
except httpx.HTTPError as e:
140+
logger.debug(
141+
"Failed to fetch metadata from %s: %s", well_known_endpoint, e
142+
)
143+
except (json.decoder.JSONDecodeError, ValidationError) as e:
144+
logger.debug(
145+
"Failed to parse metadata from %s: %s", well_known_endpoint, e
146+
)
147+
148+
return None

0 commit comments

Comments
 (0)