Skip to content

Commit d0ea308

Browse files
committed
feat(tools): support additional headers for google api toolset #non-breaking
1 parent 2424d6a commit d0ea308

File tree

6 files changed

+96
-4
lines changed

6 files changed

+96
-4
lines changed

src/google/adk/tools/google_api_tool/google_api_tool.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,18 @@ def __init__(
3939
client_id: Optional[str] = None,
4040
client_secret: Optional[str] = None,
4141
service_account: Optional[ServiceAccount] = None,
42+
additional_headers: Optional[Dict[str, str]] = None,
4243
):
4344
super().__init__(
4445
name=rest_api_tool.name,
4546
description=rest_api_tool.description,
4647
is_long_running=rest_api_tool.is_long_running,
4748
)
4849
self._rest_api_tool = rest_api_tool
50+
self._rest_api_tool.set_default_headers(additional_headers or {})
4951
if service_account is not None:
5052
self.configure_sa_auth(service_account)
51-
else:
53+
elif client_id is not None and client_secret is not None:
5254
self.configure_auth(client_id, client_secret)
5355

5456
@override
@@ -57,7 +59,7 @@ def _get_declaration(self) -> FunctionDeclaration:
5759

5860
@override
5961
async def run_async(
60-
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
62+
self, *, args: Dict[str, Any], tool_context: Optional[ToolContext]
6163
) -> Dict[str, Any]:
6264
return await self._rest_api_tool.run_async(
6365
args=args, tool_context=tool_context

src/google/adk/tools/google_api_tool/google_api_toolset.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Dict
1718
from typing import List
1819
from typing import Optional
1920
from typing import Union
@@ -45,6 +46,8 @@ class GoogleApiToolset(BaseToolset):
4546
tool_filter: Optional filter to include only specific tools or use a predicate function.
4647
service_account: Optional service account for authentication.
4748
tool_name_prefix: Optional prefix to add to all tool names in this toolset.
49+
additional_headers: Optional dict of HTTP headers to inject into every request
50+
executed by this toolset.
4851
"""
4952

5053
def __init__(
@@ -56,13 +59,15 @@ def __init__(
5659
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
5760
service_account: Optional[ServiceAccount] = None,
5861
tool_name_prefix: Optional[str] = None,
62+
additional_headers: Optional[Dict[str, str]] = None,
5963
):
6064
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
6165
self.api_name = api_name
6266
self.api_version = api_version
6367
self._client_id = client_id
6468
self._client_secret = client_secret
6569
self._service_account = service_account
70+
self._additional_headers = dict(additional_headers or {})
6671
self._openapi_toolset = self._load_toolset_with_oidc_auth()
6772

6873
@override
@@ -72,7 +77,11 @@ async def get_tools(
7277
"""Get all tools in the toolset."""
7378
return [
7479
GoogleApiTool(
75-
tool, self._client_id, self._client_secret, self._service_account
80+
tool,
81+
self._client_id,
82+
self._client_secret,
83+
self._service_account,
84+
self._additional_headers,
7685
)
7786
for tool in await self._openapi_toolset.get_tools(readonly_context)
7887
if self._is_tool_selected(tool, readonly_context)

src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(
128128
else operation
129129
)
130130
self.auth_credential, self.auth_scheme = None, None
131+
self._default_headers: Dict[str, str] = {}
131132

132133
self.configure_auth_credential(auth_credential)
133134
self.configure_auth_scheme(auth_scheme)
@@ -216,6 +217,10 @@ def configure_auth_credential(
216217
auth_credential = AuthCredential.model_validate_json(auth_credential)
217218
self.auth_credential = auth_credential
218219

220+
def set_default_headers(self, headers: Dict[str, str]):
221+
"""Sets default headers that are merged into every request."""
222+
self._default_headers = dict(headers)
223+
219224
def _prepare_auth_request_params(
220225
self,
221226
auth_scheme: AuthScheme,
@@ -335,6 +340,9 @@ def _prepare_request_params(
335340
k: v for k, v in query_params.items() if v is not None
336341
}
337342

343+
for key, value in self._default_headers.items():
344+
header_params.setdefault(key, value)
345+
338346
request_params: Dict[str, Any] = {
339347
"method": method,
340348
"url": url,

tests/unittests/tools/google_api_tool/test_google_api_tool.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def test_init(self, mock_rest_api_tool):
5656
assert tool.is_long_running is False
5757
assert tool._rest_api_tool == mock_rest_api_tool
5858

59+
def test_init_with_additional_headers(self, mock_rest_api_tool):
60+
"""Test GoogleApiTool initialization with additional headers."""
61+
headers = {"developer-token": "test-token"}
62+
63+
GoogleApiTool(mock_rest_api_tool, additional_headers=headers)
64+
65+
mock_rest_api_tool.set_default_headers.assert_called_once_with(headers)
66+
5967
def test_get_declaration(self, mock_rest_api_tool):
6068
"""Test _get_declaration method."""
6169
tool = GoogleApiTool(mock_rest_api_tool)

tests/unittests/tools/google_api_tool/test_google_api_toolset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,14 @@ def test_init(
126126

127127
client_id = "test_client_id"
128128
client_secret = "test_client_secret"
129+
additional_headers = {"developer-token": "abc123"}
129130

130131
tool_set = GoogleApiToolset(
131132
api_name=TEST_API_NAME,
132133
api_version=TEST_API_VERSION,
133134
client_id=client_id,
134135
client_secret=client_secret,
136+
additional_headers=additional_headers,
135137
)
136138

137139
assert tool_set.api_name == TEST_API_NAME
@@ -141,6 +143,7 @@ def test_init(
141143
assert tool_set._service_account is None
142144
assert tool_set.tool_filter is None
143145
assert tool_set._openapi_toolset == mock_openapi_toolset_instance
146+
assert tool_set._additional_headers == additional_headers
144147

145148
mock_converter_class.assert_called_once_with(
146149
TEST_API_NAME, TEST_API_VERSION
@@ -191,13 +194,15 @@ async def test_get_tools(
191194
client_id = "cid"
192195
client_secret = "csecret"
193196
sa_mock = mock.MagicMock(spec=ServiceAccount)
197+
additional_headers = {"developer-token": "token"}
194198

195199
tool_set = GoogleApiToolset(
196200
api_name=TEST_API_NAME,
197201
api_version=TEST_API_VERSION,
198202
client_id=client_id,
199203
client_secret=client_secret,
200204
service_account=sa_mock,
205+
additional_headers=additional_headers,
201206
)
202207

203208
tools = await tool_set.get_tools(mock_readonly_context)
@@ -209,7 +214,7 @@ async def test_get_tools(
209214

210215
for i, rest_tool in enumerate(mock_rest_api_tools):
211216
mock_google_api_tool_class.assert_any_call(
212-
rest_tool, client_id, client_secret, sa_mock
217+
rest_tool, client_id, client_secret, sa_mock, additional_headers
213218
)
214219
assert tools[i] is mock_google_api_tool_instances[i]
215220

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,66 @@ def test_prepare_request_params_unknown_parameter(
685685

686686
# Make sure unknown parameters are ignored and do not raise errors.
687687
assert "unknown_param" not in request_params["params"]
688+
def test_prepare_request_params_merges_default_headers(
689+
self,
690+
sample_endpoint,
691+
sample_auth_credential,
692+
sample_auth_scheme,
693+
sample_operation,
694+
):
695+
tool = RestApiTool(
696+
name="test_tool",
697+
description="Test Tool",
698+
endpoint=sample_endpoint,
699+
operation=sample_operation,
700+
auth_credential=sample_auth_credential,
701+
auth_scheme=sample_auth_scheme,
702+
)
703+
tool.set_default_headers({"developer-token": "token"})
704+
705+
request_params = tool._prepare_request_params([], {})
706+
707+
assert request_params["headers"]["developer-token"] == "token"
708+
709+
def test_prepare_request_params_preserves_existing_headers(
710+
self,
711+
sample_endpoint,
712+
sample_auth_credential,
713+
sample_auth_scheme,
714+
sample_operation,
715+
sample_api_parameters,
716+
):
717+
tool = RestApiTool(
718+
name="test_tool",
719+
description="Test Tool",
720+
endpoint=sample_endpoint,
721+
operation=sample_operation,
722+
auth_credential=sample_auth_credential,
723+
auth_scheme=sample_auth_scheme,
724+
)
725+
tool.set_default_headers(
726+
{
727+
"Content-Type": "text/plain",
728+
"developer-token": "token",
729+
"User-Agent": "custom-default",
730+
}
731+
)
732+
733+
header_param = ApiParameter(
734+
original_name="User-Agent",
735+
py_name="user_agent",
736+
param_location="header",
737+
param_schema=OpenAPISchema(type="string"),
738+
)
739+
740+
params = sample_api_parameters + [header_param]
741+
kwargs = {"test_body_param": "value", "user_agent": "api-client"}
742+
743+
request_params = tool._prepare_request_params(params, kwargs)
744+
745+
assert request_params["headers"]["Content-Type"] == "application/json"
746+
assert request_params["headers"]["developer-token"] == "token"
747+
assert request_params["headers"]["User-Agent"] == "api-client"
688748

689749
def test_prepare_request_params_base_url_handling(
690750
self, sample_auth_credential, sample_auth_scheme, sample_operation

0 commit comments

Comments
 (0)