From b94ec19634147e8d17eea030e1a22206237edec7 Mon Sep 17 00:00:00 2001 From: jayy-77 <1427jay@gmail.com> Date: Thu, 22 Jan 2026 18:30:16 +0530 Subject: [PATCH] additional plugins for evals --- .../adk/evaluation/base_eval_service.py | 27 +++++++++++ .../adk/evaluation/evaluation_generator.py | 45 ++++++++++++++++-- .../adk/evaluation/local_eval_service.py | 4 ++ .../evaluation/test_evaluation_generator.py | 44 +++++++++++++++++ .../evaluation/test_local_eval_service.py | 47 +++++++++++++++++++ 5 files changed, 164 insertions(+), 3 deletions(-) diff --git a/src/google/adk/evaluation/base_eval_service.py b/src/google/adk/evaluation/base_eval_service.py index bb1c3b23a4..9232c8ce62 100644 --- a/src/google/adk/evaluation/base_eval_service.py +++ b/src/google/adk/evaluation/base_eval_service.py @@ -17,8 +17,10 @@ from abc import ABC from abc import abstractmethod from enum import Enum +from typing import Any from typing import AsyncGenerator from typing import Optional +from typing import TYPE_CHECKING from pydantic import alias_generators from pydantic import BaseModel @@ -29,6 +31,9 @@ from .eval_metrics import EvalMetric from .eval_result import EvalCaseResult +if TYPE_CHECKING: + from ..plugins.base_plugin import BasePlugin + class EvaluateConfig(BaseModel): """Contains configurations needed to run evaluations.""" @@ -81,6 +86,28 @@ class InferenceConfig(BaseModel): could also overwhelm those tools.""", ) + plugins: list[Any] = Field( + default_factory=list, + description="""Additional plugins to use during evaluation inference. + +These plugins are added to the built-in evaluation plugins +(_RequestIntercepterPlugin and EnsureRetryOptionsPlugin). + +Common use cases: +- ReflectAndRetryToolPlugin: Automatically retry failed tool calls with + reflection +- Custom logging or monitoring plugins +- State management plugins + +Example: + from google.adk.plugins import ReflectAndRetryToolPlugin + + config = InferenceConfig( + plugins=[ReflectAndRetryToolPlugin(max_retries=3)] + ) +""", + ) + class InferenceRequest(BaseModel): """Represent a request to perform inferences for the eval cases in an eval set.""" diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index 5d8b48c150..c98a9181b9 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -19,6 +19,7 @@ from typing import Any from typing import AsyncGenerator from typing import Optional +from typing import TYPE_CHECKING import uuid from google.genai.types import Content @@ -49,6 +50,9 @@ from .simulation.user_simulator import UserSimulator from .simulation.user_simulator_provider import UserSimulatorProvider +if TYPE_CHECKING: + from ..plugins.base_plugin import BasePlugin + _USER_AUTHOR = "user" _DEFAULT_AUTHOR = "agent" @@ -72,6 +76,7 @@ async def generate_responses( agent_module_path: str, repeat_num: int = 3, agent_name: str = None, + plugins: Optional[list[BasePlugin]] = None, ) -> list[EvalCaseResponses]: """Returns evaluation responses for the given dataset and agent. @@ -82,6 +87,8 @@ async def generate_responses( usually done to remove uncertainty that a single run may bring. agent_name: The name of the agent that should be evaluated. This is usually the sub-agent. + plugins: Optional list of additional plugins to use during evaluation. + These will be added to the built-in evaluation plugins. """ results = [] @@ -95,6 +102,7 @@ async def generate_responses( user_simulator, agent_name, eval_case.session_input, + plugins, ) responses.append(response_invocations) @@ -136,8 +144,17 @@ async def _process_query( user_simulator: UserSimulator, agent_name: Optional[str] = None, initial_session: Optional[SessionInput] = None, + plugins: Optional[list[BasePlugin]] = None, ) -> list[Invocation]: - """Process a query using the agent and evaluation dataset.""" + """Process a query using the agent and evaluation dataset. + + Args: + module_name: Path to the module containing the agent. + user_simulator: User simulator for the evaluation. + agent_name: Optional name of specific sub-agent to evaluate. + initial_session: Optional initial session state. + plugins: Optional list of additional plugins to use during evaluation. + """ module_path = f"{module_name}" agent_module = importlib.import_module(module_path) root_agent = agent_module.agent.root_agent @@ -154,6 +171,7 @@ async def _process_query( user_simulator=user_simulator, reset_func=reset_func, initial_session=initial_session, + plugins=plugins, ) @staticmethod @@ -194,8 +212,22 @@ async def _generate_inferences_from_root_agent( session_service: Optional[BaseSessionService] = None, artifact_service: Optional[BaseArtifactService] = None, memory_service: Optional[BaseMemoryService] = None, + plugins: Optional[list[BasePlugin]] = None, ) -> list[Invocation]: - """Scrapes the root agent in coordination with the user simulator.""" + """Scrapes the root agent in coordination with the user simulator. + + Args: + root_agent: The agent to evaluate. + user_simulator: User simulator for the evaluation. + reset_func: Optional reset function to call before evaluation. + initial_session: Optional initial session state. + session_id: Optional session ID to use. + session_service: Optional session service to use. + artifact_service: Optional artifact service to use. + memory_service: Optional memory service to use. + plugins: Optional list of additional plugins to use during evaluation. + These will be added to the built-in evaluation plugins. + """ if not session_service: session_service = InMemorySessionService() @@ -223,6 +255,7 @@ async def _generate_inferences_from_root_agent( if callable(reset_func): reset_func() + # Build plugin list: start with built-in eval plugins request_intercepter_plugin = _RequestIntercepterPlugin( name="request_intercepter_plugin" ) @@ -232,13 +265,19 @@ async def _generate_inferences_from_root_agent( ensure_retry_options_plugin = EnsureRetryOptionsPlugin( name="ensure_retry_options" ) + + # Merge built-in plugins with user-provided plugins + all_plugins = [request_intercepter_plugin, ensure_retry_options_plugin] + if plugins: + all_plugins.extend(plugins) + async with Runner( app_name=app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, memory_service=memory_service, - plugins=[request_intercepter_plugin, ensure_retry_options_plugin], + plugins=all_plugins, ) as runner: events = [] while True: diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 7031266e27..cd82bf6feb 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -482,6 +482,9 @@ async def _perform_inference_single_eval_item( try: with client_label_context(EVAL_CLIENT_LABEL): + # Extract plugins from inference config + plugins = inference_request.inference_config.plugins or [] + inferences = ( await EvaluationGenerator._generate_inferences_from_root_agent( root_agent=root_agent, @@ -491,6 +494,7 @@ async def _perform_inference_single_eval_item( session_service=self._session_service, artifact_service=self._artifact_service, memory_service=self._memory_service, + plugins=plugins, ) ) diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index 873239e7f4..f2dd289ee7 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -455,3 +455,47 @@ async def mock_generate_inferences_side_effect( mock_generate_inferences.assert_called_once() called_with_content = mock_generate_inferences.call_args.args[3] assert called_with_content.parts[0].text == "message 1" + + @pytest.mark.asyncio + async def test_generates_inferences_with_custom_plugins( + self, mocker, mock_runner, mock_session_service + ): + """Tests that custom plugins are merged with built-in plugins.""" + from google.adk.plugins.base_plugin import BasePlugin + + mock_agent = mocker.MagicMock() + mock_user_sim = mocker.MagicMock(spec=UserSimulator) + + # Mock user simulator will stop immediately + async def get_next_user_message_side_effect(*args, **kwargs): + return NextUserMessage(status=UserSimulatorStatus.STOP_SIGNAL_DETECTED) + + mock_user_sim.get_next_user_message = mocker.AsyncMock( + side_effect=get_next_user_message_side_effect + ) + + # Create a custom plugin + custom_plugin = mocker.MagicMock(spec=BasePlugin) + custom_plugin.name = "custom_test_plugin" + + await EvaluationGenerator._generate_inferences_from_root_agent( + root_agent=mock_agent, + user_simulator=mock_user_sim, + plugins=[custom_plugin], + ) + + # Verify Runner was created with merged plugins + # Built-in plugins: _RequestIntercepterPlugin, EnsureRetryOptionsPlugin + # Custom plugins: custom_test_plugin + runner_call_args = mock_runner.call_args + assert runner_call_args is not None + plugins_passed = runner_call_args[1]["plugins"] + assert len(plugins_passed) == 3, ( + "Expected 3 plugins: 2 built-in + 1 custom" + ) + + # Verify built-in plugins are present + plugin_names = [p.name for p in plugins_passed] + assert "request_intercepter_plugin" in plugin_names + assert "ensure_retry_options" in plugin_names + assert "custom_test_plugin" in plugin_names diff --git a/tests/unittests/evaluation/test_local_eval_service.py b/tests/unittests/evaluation/test_local_eval_service.py index 08ef2aa8b0..d1c94097b3 100644 --- a/tests/unittests/evaluation/test_local_eval_service.py +++ b/tests/unittests/evaluation/test_local_eval_service.py @@ -256,6 +256,53 @@ async def test_perform_inference_with_case_ids( ) +@pytest.mark.asyncio +async def test_perform_inference_with_custom_plugins( + eval_service, + dummy_agent, + mock_eval_sets_manager, + mocker, +): + """Tests that custom plugins are passed through to EvaluationGenerator.""" + from google.adk.plugins.base_plugin import BasePlugin + + eval_set = EvalSet( + eval_set_id="test_eval_set", + eval_cases=[ + EvalCase(eval_id="case1", conversation=[], session_input=None), + ], + ) + mock_eval_sets_manager.get_eval_set.return_value = eval_set + + # Create a custom plugin + custom_plugin = mocker.MagicMock(spec=BasePlugin) + custom_plugin.name = "custom_test_plugin" + + # Mock the EvaluationGenerator call to verify plugins are passed + mock_generate_inferences = mocker.patch( + "google.adk.evaluation.local_eval_service.EvaluationGenerator._generate_inferences_from_root_agent", + return_value=[], + ) + + inference_request = InferenceRequest( + app_name="test_app", + eval_set_id="test_eval_set", + inference_config=InferenceConfig( + parallelism=1, plugins=[custom_plugin] + ), + ) + + results = [] + async for result in eval_service.perform_inference(inference_request): + results.append(result) + + # Verify that plugins were passed to EvaluationGenerator + mock_generate_inferences.assert_called_once() + call_kwargs = mock_generate_inferences.call_args.kwargs + assert "plugins" in call_kwargs + assert call_kwargs["plugins"] == [custom_plugin] + + @pytest.mark.asyncio async def test_perform_inference_eval_set_not_found( eval_service,