Skip to content

Commit e3d542a

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: Support authentication for MCP tool listing
Currently only tool calling supports MCP auth. This refactors the auth logic into a auth_utils file and uses it for tool listing as well. Fixes #2168. Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 859201722
1 parent d62f9c8 commit e3d542a

File tree

6 files changed

+386
-490
lines changed

6 files changed

+386
-490
lines changed

src/google/adk/tools/mcp_tool/mcp_auth_utils.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import base64
1718
import inspect
1819
import logging
1920
from typing import Any
@@ -23,6 +24,7 @@
2324
from typing import Union
2425
import warnings
2526

27+
from fastapi.openapi.models import APIKeyIn
2628
from google.genai.types import FunctionDeclaration
2729
from mcp.types import Tool as McpBaseTool
2830
from typing_extensions import override
@@ -37,7 +39,6 @@
3739
from ..base_authenticated_tool import BaseAuthenticatedTool
3840
# import
3941
from ..tool_context import ToolContext
40-
from .mcp_auth_utils import get_mcp_auth_headers
4142
from .mcp_session_manager import MCPSessionManager
4243
from .mcp_session_manager import retry_on_errors
4344

@@ -194,12 +195,7 @@ async def _run_async_impl(
194195
Any: The response from the tool.
195196
"""
196197
# Extract headers from credential for session pooling
197-
auth_scheme = (
198-
self._auth_config.auth_scheme
199-
if hasattr(self, "_auth_config") and self._auth_config
200-
else None
201-
)
202-
auth_headers = get_mcp_auth_headers(auth_scheme, credential)
198+
auth_headers = await self._get_headers(tool_context, credential)
203199
dynamic_headers = None
204200
if self._header_provider:
205201
dynamic_headers = self._header_provider(
@@ -221,6 +217,90 @@ async def _run_async_impl(
221217
response = await session.call_tool(self._mcp_tool.name, arguments=args)
222218
return response.model_dump(exclude_none=True, mode="json")
223219

220+
async def _get_headers(
221+
self, tool_context: ToolContext, credential: AuthCredential
222+
) -> Optional[dict[str, str]]:
223+
"""Extracts authentication headers from credentials.
224+
225+
Args:
226+
tool_context: The tool context of the current invocation.
227+
credential: The authentication credential to process.
228+
229+
Returns:
230+
Dictionary of headers to add to the request, or None if no auth.
231+
232+
Raises:
233+
ValueError: If API key authentication is configured for non-header location.
234+
"""
235+
headers: Optional[dict[str, str]] = None
236+
if credential:
237+
if credential.oauth2:
238+
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
239+
elif credential.http:
240+
# Handle HTTP authentication schemes
241+
if (
242+
credential.http.scheme.lower() == "bearer"
243+
and credential.http.credentials.token
244+
):
245+
headers = {
246+
"Authorization": f"Bearer {credential.http.credentials.token}"
247+
}
248+
elif credential.http.scheme.lower() == "basic":
249+
# Handle basic auth
250+
if (
251+
credential.http.credentials.username
252+
and credential.http.credentials.password
253+
):
254+
255+
credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
256+
encoded_credentials = base64.b64encode(
257+
credentials.encode()
258+
).decode()
259+
headers = {"Authorization": f"Basic {encoded_credentials}"}
260+
elif credential.http.credentials.token:
261+
# Handle other HTTP schemes with token
262+
headers = {
263+
"Authorization": (
264+
f"{credential.http.scheme} {credential.http.credentials.token}"
265+
)
266+
}
267+
elif credential.api_key:
268+
if (
269+
not self._credentials_manager
270+
or not self._credentials_manager._auth_config
271+
):
272+
error_msg = (
273+
"Cannot find corresponding auth scheme for API key credential"
274+
f" {credential}"
275+
)
276+
logger.error(error_msg)
277+
raise ValueError(error_msg)
278+
elif (
279+
self._credentials_manager._auth_config.auth_scheme.in_
280+
!= APIKeyIn.header
281+
):
282+
error_msg = (
283+
"McpTool only supports header-based API key authentication."
284+
" Configured location:"
285+
f" {self._credentials_manager._auth_config.auth_scheme.in_}"
286+
)
287+
logger.error(error_msg)
288+
raise ValueError(error_msg)
289+
else:
290+
headers = {
291+
self._credentials_manager._auth_config.auth_scheme.name: (
292+
credential.api_key
293+
)
294+
}
295+
elif credential.service_account:
296+
# Service accounts should be exchanged for access tokens before reaching this point
297+
logger.warning(
298+
"Service account credentials should be exchanged before MCP"
299+
" session creation"
300+
)
301+
302+
return headers
303+
224304

225305
class MCPTool(McpTool):
226306
"""Deprecated name, use `McpTool` instead."""

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,11 @@
3333
from ...agents.readonly_context import ReadonlyContext
3434
from ...auth.auth_credential import AuthCredential
3535
from ...auth.auth_schemes import AuthScheme
36-
from ...auth.auth_tool import AuthConfig
37-
from ...auth.credential_manager import CredentialManager
3836
from ..base_tool import BaseTool
3937
from ..base_toolset import BaseToolset
4038
from ..base_toolset import ToolPredicate
4139
from ..tool_configs import BaseToolConfig
4240
from ..tool_configs import ToolArgsConfig
43-
from .mcp_auth_utils import get_mcp_auth_headers
4441
from .mcp_session_manager import MCPSessionManager
4542
from .mcp_session_manager import retry_on_errors
4643
from .mcp_session_manager import SseConnectionParams
@@ -157,50 +154,13 @@ async def get_tools(
157154
Returns:
158155
List[BaseTool]: A list of tools available under the specified context.
159156
"""
160-
provided_headers = (
157+
headers = (
161158
self._header_provider(readonly_context)
162159
if self._header_provider and readonly_context
163-
else {}
160+
else None
164161
)
165-
166-
auth_headers = {}
167-
if self._auth_scheme:
168-
try:
169-
# Instantiate CredentialsManager to resolve credentials
170-
auth_config = AuthConfig(
171-
auth_scheme=self._auth_scheme,
172-
raw_auth_credential=self._auth_credential,
173-
)
174-
credentials_manager = CredentialManager(auth_config)
175-
176-
# Resolve the credential
177-
resolved_credential = await credentials_manager.get_auth_credential(
178-
readonly_context
179-
)
180-
181-
if resolved_credential:
182-
auth_headers = get_mcp_auth_headers(
183-
self._auth_scheme, resolved_credential
184-
)
185-
else:
186-
logger.warning(
187-
"Failed to resolve credential for tool listing, proceeding"
188-
" without auth headers."
189-
)
190-
except Exception as e:
191-
logger.warning(
192-
"Error generating auth headers for tool listing: %s, proceeding"
193-
" without auth headers.",
194-
e,
195-
exc_info=True,
196-
)
197-
198-
merged_headers = {**(provided_headers or {}), **(auth_headers or {})}
199-
200162
# Get session from session manager
201-
session = await self._mcp_session_manager.create_session(
202-
headers=merged_headers
203-
)
163+
session = await self._mcp_session_manager.create_session(headers=headers)
204164

205165
# Fetch available tools from the MCP server
206166
timeout_in_seconds = (

0 commit comments

Comments
 (0)