Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
import os
from pathlib import Path
import tempfile
import textwrap
from typing import Optional
from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from ..apps.app import App

import textwrap
import click
from click.core import ParameterSource
from fastapi import FastAPI
Expand Down Expand Up @@ -515,6 +518,35 @@ def cli_run(
)
)

def _load_app_from_module(module_path: str) -> Optional['App']:
"""Try to load an App instance from the agent module.

Args:
module_path: Python module path (e.g., 'my_package.my_agent')

Returns:
App instance if found, None otherwise
"""
import importlib
from ..apps.app import App

try:
module = importlib.import_module(module_path)

# Find the first attribute that is an instance of App
for name, candidate in inspect.getmembers(module):
if isinstance(candidate, App):
logger.info(f"Loaded App instance '{name}' from {module_path}")
return candidate

logger.debug(f"No App instance found in {module_path}")

except (ImportError, AttributeError) as e:
logger.debug(f"Could not load App from module {module_path}: {e}")

return None



def eval_options():
"""Decorator to add common eval options to click commands."""
Expand Down Expand Up @@ -733,10 +765,19 @@ def cli_eval(
)

try:
# Try to load App if available (for plugin support like ReflectAndRetryToolPlugin)
app = _load_app_from_module(agent_module_file_path)

if app:
logger.info("Using App instance for evaluation (plugins will be applied)")
else:
logger.info("No App found, using root_agent directly")

eval_service = LocalEvalService(
root_agent=root_agent,
eval_sets_manager=eval_sets_manager,
eval_set_results_manager=eval_set_results_manager,
app=app, # NEW: Pass app if available
user_simulator_provider=user_simulator_provider,
)

Expand Down
222 changes: 159 additions & 63 deletions src/google/adk/evaluation/evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

import copy
import importlib
from typing import Any
from typing import AsyncGenerator
from typing import Optional
from typing import Any, AsyncGenerator, Optional, TYPE_CHECKING

if TYPE_CHECKING:
from ..apps.app import App

import uuid

from google.genai.types import Content
Expand All @@ -39,6 +41,7 @@
from .app_details import AgentDetails
from .app_details import AppDetails
from .eval_case import EvalCase
from .eval_case import IntermediateData
from .eval_case import Invocation
from .eval_case import InvocationEvent
from .eval_case import InvocationEvents
Expand Down Expand Up @@ -155,6 +158,54 @@ async def _process_query(
reset_func=reset_func,
initial_session=initial_session,
)

@staticmethod
async def _run_user_simulation_loop(
runner: Runner,
user_id: str,
session_id: str,
user_simulator: UserSimulator,
request_intercepter_plugin: _RequestIntercepterPlugin,
) -> list[Invocation]:
"""Run the user simulation loop and return invocations.

Args:
runner: Configured Runner instance
user_id: User identifier
session_id: Session identifier
user_simulator: User simulator to generate messages
request_intercepter_plugin: Plugin to intercept requests for app_details

Returns:
List of Invocation objects from the simulation
"""
events = []

# Loop through user simulator messages (handles both static and dynamic)
while True:
next_user_message = await user_simulator.get_next_user_message(
copy.deepcopy(events)
)
if next_user_message.status == UserSimulatorStatus.SUCCESS:
async for event in EvaluationGenerator._generate_inferences_for_single_user_invocation(
runner, user_id, session_id, next_user_message.user_message
):
events.append(event)
else: # no more messages
break

# Extract app details from intercepted requests
app_details_by_invocation_id = (
EvaluationGenerator._get_app_details_by_invocation_id(
events, request_intercepter_plugin
)
)

# Convert events to invocations
return EvaluationGenerator.convert_events_to_eval_invocations(
events, app_details_by_invocation_id
)


@staticmethod
async def _generate_inferences_for_single_user_invocation(
Expand Down Expand Up @@ -195,74 +246,59 @@ async def _generate_inferences_from_root_agent(
artifact_service: Optional[BaseArtifactService] = None,
memory_service: Optional[BaseMemoryService] = None,
) -> list[Invocation]:
"""Scrapes the root agent in coordination with the user simulator."""

if not session_service:
session_service = InMemorySessionService()
"""Scrapes the root agent in coordination with the user simulator."""

if not memory_service:
memory_service = InMemoryMemoryService()
if not session_service:
session_service = InMemorySessionService()

app_name = (
initial_session.app_name if initial_session else "EvaluationGenerator"
)
user_id = initial_session.user_id if initial_session else "test_user_id"
session_id = session_id if session_id else str(uuid.uuid4())

_ = await session_service.create_session(
app_name=app_name,
user_id=user_id,
state=initial_session.state if initial_session else {},
session_id=session_id,
)
if not memory_service:
memory_service = InMemoryMemoryService()

if not artifact_service:
artifact_service = InMemoryArtifactService()
app_name = (
initial_session.app_name if initial_session else "EvaluationGenerator"
)
user_id = initial_session.user_id if initial_session else "test_user_id"
session_id = session_id if session_id else str(uuid.uuid4())

_ = await session_service.create_session(
app_name=app_name,
user_id=user_id,
state=initial_session.state if initial_session else {},
session_id=session_id,
)

# Reset agent state for each query
if callable(reset_func):
reset_func()
if not artifact_service:
artifact_service = InMemoryArtifactService()

request_intercepter_plugin = _RequestIntercepterPlugin(
name="request_intercepter_plugin"
)
# We ensure that there is some kind of retries on the llm_requests that are
# generated from the Agent. This is done to make inferencing step of evals
# more resilient to temporary model failures.
ensure_retry_options_plugin = EnsureRetryOptionsPlugin(
name="ensure_retry_options"
)
async with Runner(
app_name=app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
plugins=[request_intercepter_plugin, ensure_retry_options_plugin],
) as runner:
events = []
while True:
next_user_message = await user_simulator.get_next_user_message(
copy.deepcopy(events)
)
if next_user_message.status == UserSimulatorStatus.SUCCESS:
async for (
event
) in EvaluationGenerator._generate_inferences_for_single_user_invocation(
runner, user_id, session_id, next_user_message.user_message
):
events.append(event)
else: # no message generated
break
# Reset agent state for each query
if callable(reset_func):
reset_func()

app_details_by_invocation_id = (
EvaluationGenerator._get_app_details_by_invocation_id(
events, request_intercepter_plugin
)
request_intercepter_plugin = _RequestIntercepterPlugin(
name="request_intercepter_plugin"
)
return EvaluationGenerator.convert_events_to_eval_invocations(
events, app_details_by_invocation_id
# We ensure that there is some kind of retries on the llm_requests that are
# generated from the Agent. This is done to make inferencing step of evals
# more resilient to temporary model failures.
ensure_retry_options_plugin = EnsureRetryOptionsPlugin(
name="ensure_retry_options"
)
async with Runner(
app_name=app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
plugins=[request_intercepter_plugin, ensure_retry_options_plugin],
) as runner:
return await EvaluationGenerator._run_user_simulation_loop(
runner=runner,
user_id=user_id,
session_id=session_id,
user_simulator=user_simulator,
request_intercepter_plugin=request_intercepter_plugin,
)


@staticmethod
def convert_events_to_eval_invocations(
Expand Down Expand Up @@ -325,6 +361,65 @@ def convert_events_to_eval_invocations(
)

return invocations

@staticmethod
async def _generate_inferences_from_app(
app: 'App',
user_simulator: 'UserSimulator',
initial_session: Optional['SessionInput'],
session_id: str,
session_service: 'BaseSessionService',
artifact_service: 'BaseArtifactService',
memory_service: 'BaseMemoryService',
) -> list['Invocation']:
"""Generate inferences by invoking through App (preserving plugins)."""

# Determine user_id consistently
user_id = initial_session.user_id if initial_session else 'test_user_id'

# Initialize session
app_name = initial_session.app_name if initial_session else app.name
await session_service.create_session(
app_name=app_name,
user_id=user_id,
session_id=session_id,
state=initial_session.state if initial_session else {},
)

# Create plugins to track requests (needed for app_details)
request_intercepter_plugin = _RequestIntercepterPlugin(
name="request_intercepter_plugin"
)
ensure_retry_options_plugin = EnsureRetryOptionsPlugin(
name="ensure_retry_options"
)

# Create a copy of the app to avoid mutating the original object and add eval-specific plugins.
app_for_runner = app.model_copy(deep=True)
# Add eval-specific plugins, ensuring no duplicates.
existing_plugin_names = {p.name for p in app_for_runner.plugins}
if request_intercepter_plugin.name not in existing_plugin_names:
app_for_runner.plugins.append(request_intercepter_plugin)
if ensure_retry_options_plugin.name not in existing_plugin_names:
app_for_runner.plugins.append(ensure_retry_options_plugin)

# Create Runner with the modified App to preserve plugins
async with Runner(
app=app_for_runner,
session_service=session_service,
artifact_service=artifact_service,
memory_service=memory_service,
) as runner:
return await EvaluationGenerator._run_user_simulation_loop(
runner=runner,
user_id=user_id,
session_id=session_id,
user_simulator=user_simulator,
request_intercepter_plugin=request_intercepter_plugin,
)




@staticmethod
def _get_app_details_by_invocation_id(
Expand Down Expand Up @@ -413,3 +508,4 @@ def _process_query_with_session(session_data, data):
responses[index]["actual_tool_use"] = actual_tool_uses
responses[index]["response"] = response
return responses

Loading