Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/google/adk/evaluation/base_eval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from abc import ABC
from abc import abstractmethod
from enum import Enum
from typing import Any
from typing import AsyncGenerator
from typing import Optional
from typing import TYPE_CHECKING

from pydantic import alias_generators
from pydantic import BaseModel
Expand All @@ -29,6 +31,9 @@
from .eval_metrics import EvalMetric
from .eval_result import EvalCaseResult

if TYPE_CHECKING:
from ..plugins.base_plugin import BasePlugin


class EvaluateConfig(BaseModel):
"""Contains configurations needed to run evaluations."""
Expand Down Expand Up @@ -81,6 +86,28 @@ class InferenceConfig(BaseModel):
could also overwhelm those tools.""",
)

plugins: list[Any] = Field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The plugins field in InferenceConfig is currently typed as list[Any]. While BasePlugin is conditionally imported for TYPE_CHECKING, this creates a type inconsistency with downstream functions (e.g., in evaluation_generator.py) that expect list[BasePlugin]. For better type safety and consistency, consider importing BasePlugin directly (not under TYPE_CHECKING) and typing this field as list[BasePlugin]. This would allow Pydantic to perform runtime validation and ensure type alignment throughout the codebase.

Suggested change
plugins: list[Any] = Field(
plugins: list[BasePlugin] = Field(

default_factory=list,
description="""Additional plugins to use during evaluation inference.

These plugins are added to the built-in evaluation plugins
(_RequestIntercepterPlugin and EnsureRetryOptionsPlugin).

Common use cases:
- ReflectAndRetryToolPlugin: Automatically retry failed tool calls with
reflection
- Custom logging or monitoring plugins
- State management plugins

Example:
from google.adk.plugins import ReflectAndRetryToolPlugin

config = InferenceConfig(
plugins=[ReflectAndRetryToolPlugin(max_retries=3)]
)
""",
)


class InferenceRequest(BaseModel):
"""Represent a request to perform inferences for the eval cases in an eval set."""
Expand Down
45 changes: 42 additions & 3 deletions src/google/adk/evaluation/evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any
from typing import AsyncGenerator
from typing import Optional
from typing import TYPE_CHECKING
import uuid

from google.genai.types import Content
Expand Down Expand Up @@ -49,6 +50,9 @@
from .simulation.user_simulator import UserSimulator
from .simulation.user_simulator_provider import UserSimulatorProvider

if TYPE_CHECKING:
from ..plugins.base_plugin import BasePlugin

_USER_AUTHOR = "user"
_DEFAULT_AUTHOR = "agent"

Expand All @@ -72,6 +76,7 @@ async def generate_responses(
agent_module_path: str,
repeat_num: int = 3,
agent_name: str = None,
plugins: Optional[list[BasePlugin]] = None,
) -> list[EvalCaseResponses]:
"""Returns evaluation responses for the given dataset and agent.

Expand All @@ -82,6 +87,8 @@ async def generate_responses(
usually done to remove uncertainty that a single run may bring.
agent_name: The name of the agent that should be evaluated. This is
usually the sub-agent.
plugins: Optional list of additional plugins to use during evaluation.
These will be added to the built-in evaluation plugins.
"""
results = []

Expand All @@ -95,6 +102,7 @@ async def generate_responses(
user_simulator,
agent_name,
eval_case.session_input,
plugins,
)
responses.append(response_invocations)

Expand Down Expand Up @@ -136,8 +144,17 @@ async def _process_query(
user_simulator: UserSimulator,
agent_name: Optional[str] = None,
initial_session: Optional[SessionInput] = None,
plugins: Optional[list[BasePlugin]] = None,
) -> list[Invocation]:
"""Process a query using the agent and evaluation dataset."""
"""Process a query using the agent and evaluation dataset.

Args:
module_name: Path to the module containing the agent.
user_simulator: User simulator for the evaluation.
agent_name: Optional name of specific sub-agent to evaluate.
initial_session: Optional initial session state.
plugins: Optional list of additional plugins to use during evaluation.
"""
module_path = f"{module_name}"
agent_module = importlib.import_module(module_path)
root_agent = agent_module.agent.root_agent
Expand All @@ -154,6 +171,7 @@ async def _process_query(
user_simulator=user_simulator,
reset_func=reset_func,
initial_session=initial_session,
plugins=plugins,
)

@staticmethod
Expand Down Expand Up @@ -194,8 +212,22 @@ async def _generate_inferences_from_root_agent(
session_service: Optional[BaseSessionService] = None,
artifact_service: Optional[BaseArtifactService] = None,
memory_service: Optional[BaseMemoryService] = None,
plugins: Optional[list[BasePlugin]] = None,
) -> list[Invocation]:
"""Scrapes the root agent in coordination with the user simulator."""
"""Scrapes the root agent in coordination with the user simulator.

Args:
root_agent: The agent to evaluate.
user_simulator: User simulator for the evaluation.
reset_func: Optional reset function to call before evaluation.
initial_session: Optional initial session state.
session_id: Optional session ID to use.
session_service: Optional session service to use.
artifact_service: Optional artifact service to use.
memory_service: Optional memory service to use.
plugins: Optional list of additional plugins to use during evaluation.
These will be added to the built-in evaluation plugins.
"""

if not session_service:
session_service = InMemorySessionService()
Expand Down Expand Up @@ -223,6 +255,7 @@ async def _generate_inferences_from_root_agent(
if callable(reset_func):
reset_func()

# Build plugin list: start with built-in eval plugins
request_intercepter_plugin = _RequestIntercepterPlugin(
name="request_intercepter_plugin"
)
Expand All @@ -232,13 +265,19 @@ async def _generate_inferences_from_root_agent(
ensure_retry_options_plugin = EnsureRetryOptionsPlugin(
name="ensure_retry_options"
)

# Merge built-in plugins with user-provided plugins
all_plugins = [request_intercepter_plugin, ensure_retry_options_plugin]
if plugins:
all_plugins.extend(plugins)

async with Runner(
app_name=app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
plugins=[request_intercepter_plugin, ensure_retry_options_plugin],
plugins=all_plugins,
) as runner:
events = []
while True:
Expand Down
4 changes: 4 additions & 0 deletions src/google/adk/evaluation/local_eval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ async def _perform_inference_single_eval_item(

try:
with client_label_context(EVAL_CLIENT_LABEL):
# Extract plugins from inference config
plugins = inference_request.inference_config.plugins or []

inferences = (
await EvaluationGenerator._generate_inferences_from_root_agent(
root_agent=root_agent,
Expand All @@ -491,6 +494,7 @@ async def _perform_inference_single_eval_item(
session_service=self._session_service,
artifact_service=self._artifact_service,
memory_service=self._memory_service,
plugins=plugins,
)
)

Expand Down
44 changes: 44 additions & 0 deletions tests/unittests/evaluation/test_evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,47 @@ async def mock_generate_inferences_side_effect(
mock_generate_inferences.assert_called_once()
called_with_content = mock_generate_inferences.call_args.args[3]
assert called_with_content.parts[0].text == "message 1"

@pytest.mark.asyncio
async def test_generates_inferences_with_custom_plugins(
self, mocker, mock_runner, mock_session_service
):
"""Tests that custom plugins are merged with built-in plugins."""
from google.adk.plugins.base_plugin import BasePlugin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's generally good practice to place all import statements at the top of the file for better readability and consistency, unless there's a specific reason (like avoiding circular imports or very heavy imports only needed in rare paths). Moving this import to the top of the file would align with standard Python style guidelines.


mock_agent = mocker.MagicMock()
mock_user_sim = mocker.MagicMock(spec=UserSimulator)

# Mock user simulator will stop immediately
async def get_next_user_message_side_effect(*args, **kwargs):
return NextUserMessage(status=UserSimulatorStatus.STOP_SIGNAL_DETECTED)

mock_user_sim.get_next_user_message = mocker.AsyncMock(
side_effect=get_next_user_message_side_effect
)

# Create a custom plugin
custom_plugin = mocker.MagicMock(spec=BasePlugin)
custom_plugin.name = "custom_test_plugin"

await EvaluationGenerator._generate_inferences_from_root_agent(
root_agent=mock_agent,
user_simulator=mock_user_sim,
plugins=[custom_plugin],
)

# Verify Runner was created with merged plugins
# Built-in plugins: _RequestIntercepterPlugin, EnsureRetryOptionsPlugin
# Custom plugins: custom_test_plugin
runner_call_args = mock_runner.call_args
assert runner_call_args is not None
plugins_passed = runner_call_args[1]["plugins"]
assert len(plugins_passed) == 3, (
"Expected 3 plugins: 2 built-in + 1 custom"
)

# Verify built-in plugins are present
plugin_names = [p.name for p in plugins_passed]
assert "request_intercepter_plugin" in plugin_names
assert "ensure_retry_options" in plugin_names
assert "custom_test_plugin" in plugin_names
47 changes: 47 additions & 0 deletions tests/unittests/evaluation/test_local_eval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,53 @@ async def test_perform_inference_with_case_ids(
)


@pytest.mark.asyncio
async def test_perform_inference_with_custom_plugins(
eval_service,
dummy_agent,
mock_eval_sets_manager,
mocker,
):
"""Tests that custom plugins are passed through to EvaluationGenerator."""
from google.adk.plugins.base_plugin import BasePlugin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous test file, it's generally good practice to place all import statements at the top of the file for better readability and consistency. Moving this import to the top of the file would align with standard Python style guidelines.


eval_set = EvalSet(
eval_set_id="test_eval_set",
eval_cases=[
EvalCase(eval_id="case1", conversation=[], session_input=None),
],
)
mock_eval_sets_manager.get_eval_set.return_value = eval_set

# Create a custom plugin
custom_plugin = mocker.MagicMock(spec=BasePlugin)
custom_plugin.name = "custom_test_plugin"

# Mock the EvaluationGenerator call to verify plugins are passed
mock_generate_inferences = mocker.patch(
"google.adk.evaluation.local_eval_service.EvaluationGenerator._generate_inferences_from_root_agent",
return_value=[],
)

inference_request = InferenceRequest(
app_name="test_app",
eval_set_id="test_eval_set",
inference_config=InferenceConfig(
parallelism=1, plugins=[custom_plugin]
),
)

results = []
async for result in eval_service.perform_inference(inference_request):
results.append(result)

# Verify that plugins were passed to EvaluationGenerator
mock_generate_inferences.assert_called_once()
call_kwargs = mock_generate_inferences.call_args.kwargs
assert "plugins" in call_kwargs
assert call_kwargs["plugins"] == [custom_plugin]


@pytest.mark.asyncio
async def test_perform_inference_eval_set_not_found(
eval_service,
Expand Down