1414
1515from __future__ import annotations
1616
17- from typing import Any
1817from typing import Union
1918
20- from google .genai import types
21- from typing_extensions import override
22-
2319from ..agents .llm_agent import LlmAgent
24- from ..memory .in_memory_memory_service import InMemoryMemoryService
2520from ..models .base_llm import BaseLlm
26- from ..runners import Runner
27- from ..sessions .in_memory_session_service import InMemorySessionService
28- from ..utils .context_utils import Aclosing
29- from ._forwarding_artifact_service import ForwardingArtifactService
30- from .agent_tool import AgentTool
21+ from ._search_agent_tool import _SearchAgentTool
3122from .enterprise_search_tool import enterprise_web_search_tool
32- from .tool_context import ToolContext
3323
3424
3525def create_enterprise_search_agent (model : Union [str , BaseLlm ]) -> LlmAgent :
@@ -50,7 +40,7 @@ def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
5040 )
5141
5242
53- class EnterpriseSearchAgentTool (AgentTool ):
43+ class EnterpriseSearchAgentTool (_SearchAgentTool ):
5444 """A tool that wraps a sub-agent that only uses enterprise_web_search tool.
5545
5646 This is a workaround to support using enterprise_web_search tool with other tools.
@@ -63,79 +53,3 @@ class EnterpriseSearchAgentTool(AgentTool):
6353 def __init__ (self , agent : LlmAgent ):
6454 self .agent = agent
6555 super ().__init__ (agent = self .agent )
66-
67- @override
68- async def run_async (
69- self ,
70- * ,
71- args : dict [str , Any ],
72- tool_context : ToolContext ,
73- ) -> Any :
74- from ..agents .llm_agent import LlmAgent
75-
76- if isinstance (self .agent , LlmAgent ) and self .agent .input_schema :
77- input_value = self .agent .input_schema .model_validate (args )
78- content = types .Content (
79- role = 'user' ,
80- parts = [
81- types .Part .from_text (
82- text = input_value .model_dump_json (exclude_none = True )
83- )
84- ],
85- )
86- else :
87- content = types .Content (
88- role = 'user' ,
89- parts = [types .Part .from_text (text = args ['request' ])],
90- )
91- runner = Runner (
92- app_name = self .agent .name ,
93- agent = self .agent ,
94- artifact_service = ForwardingArtifactService (tool_context ),
95- session_service = InMemorySessionService (),
96- memory_service = InMemoryMemoryService (),
97- credential_service = tool_context ._invocation_context .credential_service ,
98- plugins = list (tool_context ._invocation_context .plugin_manager .plugins ),
99- )
100-
101- state_dict = {
102- k : v
103- for k , v in tool_context .state .to_dict ().items ()
104- if not k .startswith ('_adk' ) # Filter out adk internal states
105- }
106- session = await runner .session_service .create_session (
107- app_name = self .agent .name ,
108- user_id = tool_context ._invocation_context .user_id ,
109- state = state_dict ,
110- )
111-
112- last_content = None
113- last_grounding_metadata = None
114- async with Aclosing (
115- runner .run_async (
116- user_id = session .user_id , session_id = session .id , new_message = content
117- )
118- ) as agen :
119- async for event in agen :
120- # Forward state delta to parent session.
121- if event .actions .state_delta :
122- tool_context .state .update (event .actions .state_delta )
123- if event .content :
124- last_content = event .content
125- last_grounding_metadata = event .grounding_metadata
126-
127- if not last_content :
128- return ''
129- merged_text = '\n ' .join (p .text for p in last_content .parts if p .text )
130- if isinstance (self .agent , LlmAgent ) and self .agent .output_schema :
131- tool_result = self .agent .output_schema .model_validate_json (
132- merged_text
133- ).model_dump (exclude_none = True )
134- else :
135- tool_result = merged_text
136-
137- if last_grounding_metadata :
138- tool_context .state ['temp:_adk_grounding_metadata' ] = (
139- last_grounding_metadata
140- )
141- return tool_result
0 commit comments