Skip to content

Commit 8387ff5

Browse files
committed
feat: refactor search agent tools to inherit from a common base class
1 parent 335edb5 commit 8387ff5

File tree

6 files changed

+126
-202
lines changed

6 files changed

+126
-202
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,7 @@ async def _maybe_add_grounding_metadata(
847847
invocation_context.canonical_tools_cache = tools
848848

849849
if not any(
850-
tool.name == 'google_search_agent'
851-
or tool.name == 'enterprise_search_agent'
850+
tool.name in {'google_search_agent', 'enterprise_search_agent'}
852851
for tool in tools
853852
):
854853
return response
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any
18+
19+
from google.genai import types
20+
from typing_extensions import override
21+
22+
from ..agents.llm_agent import LlmAgent
23+
from ..memory.in_memory_memory_service import InMemoryMemoryService
24+
from ..runners import Runner
25+
from ..sessions.in_memory_session_service import InMemorySessionService
26+
from ..utils.context_utils import Aclosing
27+
from ._forwarding_artifact_service import ForwardingArtifactService
28+
from .agent_tool import AgentTool
29+
from .tool_context import ToolContext
30+
31+
32+
class _SearchAgentTool(AgentTool):
33+
"""A base class for search agent tools."""
34+
35+
@override
36+
async def run_async(
37+
self,
38+
*,
39+
args: dict[str, Any],
40+
tool_context: ToolContext,
41+
) -> Any:
42+
from ..agents.llm_agent import LlmAgent
43+
44+
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
45+
input_value = self.agent.input_schema.model_validate(args)
46+
content = types.Content(
47+
role='user',
48+
parts=[
49+
types.Part.from_text(
50+
text=input_value.model_dump_json(exclude_none=True)
51+
)
52+
],
53+
)
54+
else:
55+
content = types.Content(
56+
role='user',
57+
parts=[types.Part.from_text(text=args['request'])],
58+
)
59+
runner = Runner(
60+
app_name=self.agent.name,
61+
agent=self.agent,
62+
artifact_service=ForwardingArtifactService(tool_context),
63+
session_service=InMemorySessionService(),
64+
memory_service=InMemoryMemoryService(),
65+
credential_service=tool_context._invocation_context.credential_service,
66+
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
67+
)
68+
try:
69+
state_dict = {
70+
k: v
71+
for k, v in tool_context.state.to_dict().items()
72+
if not k.startswith('_adk') # Filter out adk internal states
73+
}
74+
session = await runner.session_service.create_session(
75+
app_name=self.agent.name,
76+
user_id=tool_context._invocation_context.user_id,
77+
state=state_dict,
78+
)
79+
80+
last_content = None
81+
last_grounding_metadata = None
82+
async with Aclosing(
83+
runner.run_async(
84+
user_id=session.user_id,
85+
session_id=session.id,
86+
new_message=content,
87+
)
88+
) as agen:
89+
async for event in agen:
90+
# Forward state delta to parent session.
91+
if event.actions.state_delta:
92+
tool_context.state.update(event.actions.state_delta)
93+
if event.content:
94+
last_content = event.content
95+
last_grounding_metadata = event.grounding_metadata
96+
97+
if not last_content:
98+
return ''
99+
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
100+
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
101+
tool_result = self.agent.output_schema.model_validate_json(
102+
merged_text
103+
).model_dump(exclude_none=True)
104+
else:
105+
tool_result = merged_text
106+
107+
if last_grounding_metadata:
108+
tool_context.state['temp:_adk_grounding_metadata'] = (
109+
last_grounding_metadata
110+
)
111+
return tool_result
112+
finally:
113+
await runner.close()

src/google/adk/tools/enterprise_search_agent_tool.py

Lines changed: 2 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,12 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any
1817
from typing import Union
1918

20-
from google.genai import types
21-
from typing_extensions import override
22-
2319
from ..agents.llm_agent import LlmAgent
24-
from ..memory.in_memory_memory_service import InMemoryMemoryService
2520
from ..models.base_llm import BaseLlm
26-
from ..runners import Runner
27-
from ..sessions.in_memory_session_service import InMemorySessionService
28-
from ..utils.context_utils import Aclosing
29-
from ._forwarding_artifact_service import ForwardingArtifactService
30-
from .agent_tool import AgentTool
21+
from ._search_agent_tool import _SearchAgentTool
3122
from .enterprise_search_tool import enterprise_web_search_tool
32-
from .tool_context import ToolContext
3323

3424

3525
def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
@@ -50,7 +40,7 @@ def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
5040
)
5141

5242

53-
class EnterpriseSearchAgentTool(AgentTool):
43+
class EnterpriseSearchAgentTool(_SearchAgentTool):
5444
"""A tool that wraps a sub-agent that only uses enterprise_web_search tool.
5545
5646
This is a workaround to support using enterprise_web_search tool with other tools.
@@ -63,79 +53,3 @@ class EnterpriseSearchAgentTool(AgentTool):
6353
def __init__(self, agent: LlmAgent):
6454
self.agent = agent
6555
super().__init__(agent=self.agent)
66-
67-
@override
68-
async def run_async(
69-
self,
70-
*,
71-
args: dict[str, Any],
72-
tool_context: ToolContext,
73-
) -> Any:
74-
from ..agents.llm_agent import LlmAgent
75-
76-
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
77-
input_value = self.agent.input_schema.model_validate(args)
78-
content = types.Content(
79-
role='user',
80-
parts=[
81-
types.Part.from_text(
82-
text=input_value.model_dump_json(exclude_none=True)
83-
)
84-
],
85-
)
86-
else:
87-
content = types.Content(
88-
role='user',
89-
parts=[types.Part.from_text(text=args['request'])],
90-
)
91-
runner = Runner(
92-
app_name=self.agent.name,
93-
agent=self.agent,
94-
artifact_service=ForwardingArtifactService(tool_context),
95-
session_service=InMemorySessionService(),
96-
memory_service=InMemoryMemoryService(),
97-
credential_service=tool_context._invocation_context.credential_service,
98-
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
99-
)
100-
101-
state_dict = {
102-
k: v
103-
for k, v in tool_context.state.to_dict().items()
104-
if not k.startswith('_adk') # Filter out adk internal states
105-
}
106-
session = await runner.session_service.create_session(
107-
app_name=self.agent.name,
108-
user_id=tool_context._invocation_context.user_id,
109-
state=state_dict,
110-
)
111-
112-
last_content = None
113-
last_grounding_metadata = None
114-
async with Aclosing(
115-
runner.run_async(
116-
user_id=session.user_id, session_id=session.id, new_message=content
117-
)
118-
) as agen:
119-
async for event in agen:
120-
# Forward state delta to parent session.
121-
if event.actions.state_delta:
122-
tool_context.state.update(event.actions.state_delta)
123-
if event.content:
124-
last_content = event.content
125-
last_grounding_metadata = event.grounding_metadata
126-
127-
if not last_content:
128-
return ''
129-
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
130-
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
131-
tool_result = self.agent.output_schema.model_validate_json(
132-
merged_text
133-
).model_dump(exclude_none=True)
134-
else:
135-
tool_result = merged_text
136-
137-
if last_grounding_metadata:
138-
tool_context.state['temp:_adk_grounding_metadata'] = (
139-
last_grounding_metadata
140-
)
141-
return tool_result

src/google/adk/tools/enterprise_search_tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class EnterpriseWebSearchTool(BaseTool):
3636
"""
3737

3838
def __init__(self, *, bypass_multi_tools_limit: bool = False):
39-
"""Initializes the Google search tool.
39+
"""Initializes the Enterprise web search tool.
4040
4141
Args:
4242
bypass_multi_tools_limit: Whether to bypass the multi tools limitation,

src/google/adk/tools/google_search_agent_tool.py

Lines changed: 2 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,12 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any
1817
from typing import Union
1918

20-
from google.genai import types
21-
from typing_extensions import override
22-
2319
from ..agents.llm_agent import LlmAgent
24-
from ..memory.in_memory_memory_service import InMemoryMemoryService
2520
from ..models.base_llm import BaseLlm
26-
from ..utils.context_utils import Aclosing
27-
from ._forwarding_artifact_service import ForwardingArtifactService
28-
from .agent_tool import AgentTool
21+
from ._search_agent_tool import _SearchAgentTool
2922
from .google_search_tool import google_search
30-
from .tool_context import ToolContext
3123

3224

3325
def create_google_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
@@ -47,7 +39,7 @@ def create_google_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
4739
)
4840

4941

50-
class GoogleSearchAgentTool(AgentTool):
42+
class GoogleSearchAgentTool(_SearchAgentTool):
5143
"""A tool that wraps a sub-agent that only uses google_search tool.
5244
5345
This is a workaround to support using google_search tool with other tools.
@@ -60,81 +52,3 @@ class GoogleSearchAgentTool(AgentTool):
6052
def __init__(self, agent: LlmAgent):
6153
self.agent = agent
6254
super().__init__(agent=self.agent)
63-
64-
@override
65-
async def run_async(
66-
self,
67-
*,
68-
args: dict[str, Any],
69-
tool_context: ToolContext,
70-
) -> Any:
71-
from ..agents.llm_agent import LlmAgent
72-
from ..runners import Runner
73-
from ..sessions.in_memory_session_service import InMemorySessionService
74-
75-
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
76-
input_value = self.agent.input_schema.model_validate(args)
77-
content = types.Content(
78-
role='user',
79-
parts=[
80-
types.Part.from_text(
81-
text=input_value.model_dump_json(exclude_none=True)
82-
)
83-
],
84-
)
85-
else:
86-
content = types.Content(
87-
role='user',
88-
parts=[types.Part.from_text(text=args['request'])],
89-
)
90-
runner = Runner(
91-
app_name=self.agent.name,
92-
agent=self.agent,
93-
artifact_service=ForwardingArtifactService(tool_context),
94-
session_service=InMemorySessionService(),
95-
memory_service=InMemoryMemoryService(),
96-
credential_service=tool_context._invocation_context.credential_service,
97-
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
98-
)
99-
100-
state_dict = {
101-
k: v
102-
for k, v in tool_context.state.to_dict().items()
103-
if not k.startswith('_adk') # Filter out adk internal states
104-
}
105-
session = await runner.session_service.create_session(
106-
app_name=self.agent.name,
107-
user_id=tool_context._invocation_context.user_id,
108-
state=state_dict,
109-
)
110-
111-
last_content = None
112-
last_grounding_metadata = None
113-
async with Aclosing(
114-
runner.run_async(
115-
user_id=session.user_id, session_id=session.id, new_message=content
116-
)
117-
) as agen:
118-
async for event in agen:
119-
# Forward state delta to parent session.
120-
if event.actions.state_delta:
121-
tool_context.state.update(event.actions.state_delta)
122-
if event.content:
123-
last_content = event.content
124-
last_grounding_metadata = event.grounding_metadata
125-
126-
if not last_content:
127-
return ''
128-
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
129-
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
130-
tool_result = self.agent.output_schema.model_validate_json(
131-
merged_text
132-
).model_dump(exclude_none=True)
133-
else:
134-
tool_result = merged_text
135-
136-
if last_grounding_metadata:
137-
tool_context.state['temp:_adk_grounding_metadata'] = (
138-
last_grounding_metadata
139-
)
140-
return tool_result

0 commit comments

Comments
 (0)