Skip to content

Commit a76304a

Browse files
committed
Fix: Use App (with plugins) for eval when available
- Extend LocalEvalService to accept optional App parameter - Route evaluation through App so plugins are applied - Add _generate_inferences_from_app() to EvaluationGenerator - Update CLI eval command to load and pass App - Add helper to load App from agent module Fixes #3833
1 parent 6ab87da commit a76304a

File tree

3 files changed

+144
-14
lines changed

3 files changed

+144
-14
lines changed

src/google/adk/cli/cli_tools_click.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
import logging
2323
import os
2424
import tempfile
25-
from typing import Optional
25+
from typing import Optional, TYPE_CHECKING
26+
27+
if TYPE_CHECKING:
28+
from ..apps.app import App
29+
2630

2731
import click
2832
from click.core import ParameterSource
@@ -279,6 +283,34 @@ def cli_run(
279283
)
280284
)
281285

286+
def _load_app_from_module(module_path: str) -> Optional['App']:
287+
"""Try to load an App instance from the agent module.
288+
289+
Args:
290+
module_path: Python module path (e.g., 'my_package.my_agent')
291+
292+
Returns:
293+
App instance if found, None otherwise
294+
"""
295+
try:
296+
import importlib
297+
module = importlib.import_module(module_path)
298+
299+
# Check for 'app' attribute (most common convention)
300+
if hasattr(module, 'app'):
301+
from ..apps.app import App
302+
candidate = getattr(module, 'app')
303+
if isinstance(candidate, App):
304+
logger.info(f"Loaded App instance from {module_path}")
305+
return candidate
306+
307+
logger.debug(f"No App instance found in {module_path}")
308+
309+
except (ImportError, AttributeError) as e:
310+
logger.debug(f"Could not load App from module {module_path}: {e}")
311+
312+
return None
313+
282314

283315
@main.command("eval", cls=HelpfulCommand)
284316
@click.argument(
@@ -471,10 +503,19 @@ def cli_eval(
471503
)
472504

473505
try:
506+
# Try to load App if available (for plugin support like ReflectAndRetryToolPlugin)
507+
app = _load_app_from_module(agent_module_file_path)
508+
509+
if app:
510+
logger.info("Using App instance for evaluation (plugins will be applied)")
511+
else:
512+
logger.info("No App found, using root_agent directly")
513+
474514
eval_service = LocalEvalService(
475515
root_agent=root_agent,
476516
eval_sets_manager=eval_sets_manager,
477517
eval_set_results_manager=eval_set_results_manager,
518+
app=app, # NEW: Pass app if available
478519
)
479520

480521
inference_results = asyncio.run(

src/google/adk/evaluation/evaluation_generator.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
from __future__ import annotations
1616

1717
import importlib
18-
from typing import Any
19-
from typing import Optional
18+
from typing import Any, Optional, TYPE_CHECKING
19+
20+
if TYPE_CHECKING:
21+
from ..apps.app import App
22+
2023
import uuid
2124

2225
from pydantic import BaseModel
@@ -220,6 +223,71 @@ async def _generate_inferences_from_root_agent(
220223
)
221224

222225
return response_invocations
226+
227+
@staticmethod
228+
async def _generate_inferences_from_app(
229+
invocations: list['Invocation'],
230+
app: 'App',
231+
initial_session: Optional['SessionInput'],
232+
session_id: str,
233+
session_service: 'BaseSessionService',
234+
artifact_service: 'BaseArtifactService',
235+
) -> list['Invocation']:
236+
"""Generate inferences by invoking through App (preserving plugins)."""
237+
238+
actual_invocations = []
239+
240+
# Determine user_id consistently
241+
user_id = 'test_user_id'
242+
if initial_session and initial_session.user_id is not None:
243+
user_id = initial_session.user_id
244+
245+
# Initialize session if provided
246+
if initial_session:
247+
app_name = initial_session.app_name if initial_session.app_name else app.name
248+
await session_service.create_session(
249+
app_name=app_name,
250+
user_id=user_id,
251+
session_id=session_id,
252+
state=initial_session.state if initial_session.state else {},
253+
)
254+
255+
# Run each invocation through the app
256+
for expected_invocation in invocations:
257+
user_content = expected_invocation.user_content
258+
259+
# Invoke through App (this applies all plugins)
260+
response = await app.run(
261+
user_id=user_id,
262+
session_id=session_id,
263+
new_message=user_content,
264+
)
265+
266+
# Extract response similar to existing implementation
267+
final_response = None
268+
tool_uses = []
269+
invocation_id = ""
270+
271+
async for event in response:
272+
invocation_id = invocation_id or event.invocation_id
273+
274+
if event.is_final_response() and event.content and event.content.parts:
275+
final_response = event.content
276+
elif event.get_function_calls():
277+
for call in event.get_function_calls():
278+
tool_uses.append(call)
279+
280+
actual_invocations.append(
281+
Invocation(
282+
invocation_id=invocation_id,
283+
user_content=user_content,
284+
final_response=final_response,
285+
intermediate_data=IntermediateData(tool_uses=tool_uses),
286+
)
287+
)
288+
289+
return actual_invocations
290+
223291

224292
@staticmethod
225293
def _process_query_with_session(session_data, data):

src/google/adk/evaluation/local_eval_service.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
from typing import AsyncGenerator
2121
from typing import Callable
2222
from typing import Optional
23+
from typing import TYPE_CHECKING
24+
25+
if TYPE_CHECKING:
26+
from ..apps.app import App
2327
import uuid
2428

2529
from typing_extensions import override
@@ -38,6 +42,7 @@
3842
from .base_eval_service import InferenceResult
3943
from .base_eval_service import InferenceStatus
4044
from .eval_case import Invocation
45+
from .eval_case import SessionInput
4146
from .eval_metrics import EvalMetric
4247
from .eval_metrics import EvalMetricResult
4348
from .eval_metrics import EvalMetricResultPerInvocation
@@ -73,9 +78,11 @@ def __init__(
7378
artifact_service: Optional[BaseArtifactService] = None,
7479
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
7580
session_id_supplier: Callable[[], str] = _get_session_id,
81+
app: Optional['App'] = None,
7682
):
7783
self._root_agent = root_agent
7884
self._eval_sets_manager = eval_sets_manager
85+
self._app = app
7986
metric_evaluator_registry = (
8087
metric_evaluator_registry or DEFAULT_METRIC_EVALUATOR_REGISTRY
8188
)
@@ -364,23 +371,37 @@ async def _perform_inference_sigle_eval_item(
364371
)
365372

366373
try:
367-
inferences = (
368-
await EvaluationGenerator._generate_inferences_from_root_agent(
369-
invocations=eval_case.conversation,
370-
root_agent=root_agent,
371-
initial_session=initial_session,
372-
session_id=session_id,
373-
session_service=self._session_service,
374-
artifact_service=self._artifact_service,
374+
# Use App if available (so plugins like ReflectAndRetryToolPlugin run)
375+
if self._app is not None:
376+
inferences = (
377+
await EvaluationGenerator._generate_inferences_from_app(
378+
invocations=eval_case.conversation,
379+
app=self._app,
380+
initial_session=initial_session,
381+
session_id=session_id,
382+
session_service=self._session_service,
383+
artifact_service=self._artifact_service,
384+
)
385+
)
386+
else:
387+
# Fallback to direct root_agent usage (existing behavior)
388+
inferences = (
389+
await EvaluationGenerator._generate_inferences_from_root_agent(
390+
invocations=eval_case.conversation,
391+
root_agent=root_agent,
392+
initial_session=initial_session,
393+
session_id=session_id,
394+
session_service=self._session_service,
395+
artifact_service=self._artifact_service,
396+
)
375397
)
376-
)
377398

378399
inference_result.inferences = inferences
379400
inference_result.status = InferenceStatus.SUCCESS
380401

381402
return inference_result
382403
except Exception as e:
383-
# We intentionally catch the Exception as we don't failures to affect
404+
# We intentionally catch the Exception as we don't want failures to affect
384405
# other inferences.
385406
logger.error(
386407
'Inference failed for eval case `%s` with error %s',
@@ -389,4 +410,4 @@ async def _perform_inference_sigle_eval_item(
389410
)
390411
inference_result.status = InferenceStatus.FAILURE
391412
inference_result.error_message = str(e)
392-
return inference_result
413+
return inference_result

0 commit comments

Comments
 (0)