From df827f327e9e22b22f5ca86699f0d37326d79fae Mon Sep 17 00:00:00 2001 From: "ivanmkc@google.com" Date: Thu, 7 Aug 2025 16:56:11 -0400 Subject: [PATCH 1/3] feat(plugins): add thread-safe LogCollectorPlugin for event logging This commit introduces the `LogCollectorPlugin`, a plugin that collects execution logs from ADK callbacks in asynchronous environments. Logs are organized by session ID for retrieval. The plugin uses `asyncio.Lock` to ensure thread-safe logging. Logs are stored in a `defaultdict(list)` indexed by session ID and can be retrieved with the `get_logs_by_session` method. The plugin implements the following callback hooks to log contextual data: - `on_user_message` - `before_run`, `after_run` - `before_agent`, `after_agent` - `before_model`, `after_model`, `on_model_error` - `before_tool`, `after_tool`, `on_tool_error` - `on_event` Each log entry includes the callback type and relevant data, such as invocation ID, agent name, tool name, function call ID, arguments, and results or errors. --- .../adk/plugins/log_collector_plugin.py | 250 ++++++++++++++ .../plugins/test_log_collector_plugin.py | 322 ++++++++++++++++++ 2 files changed, 572 insertions(+) create mode 100644 src/google/adk/plugins/log_collector_plugin.py create mode 100644 tests/unittests/plugins/test_log_collector_plugin.py diff --git a/src/google/adk/plugins/log_collector_plugin.py b/src/google/adk/plugins/log_collector_plugin.py new file mode 100644 index 0000000000..6b91c50bd8 --- /dev/null +++ b/src/google/adk/plugins/log_collector_plugin.py @@ -0,0 +1,250 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from collections import defaultdict +from typing import Any, Optional, Dict, List, TYPE_CHECKING + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.events.event import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types + +if TYPE_CHECKING: + from google.adk.agents.invocation_context import InvocationContext + + +class LogCollectorPlugin(BasePlugin): + """ + A plugin to programmatically and safely collect execution details from all + callbacks in async environments, organized by session ID. + """ + + def __init__(self, name: str = "log_collector"): + super().__init__(name) + self.logs: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + self._lock = asyncio.Lock() + + async def _log_entry(self, session_id: str, callback_type: str, data: Dict[str, Any]): + log_entry = {"callback_type": callback_type, **data} + async with self._lock: + self.logs[session_id].append(log_entry) + + async def on_user_message_callback( + self, *, invocation_context: "InvocationContext", user_message: types.Content + ) -> Optional[types.Content]: + session_id = invocation_context.session.id + await self._log_entry( + session_id, + "on_user_message", + { + "invocation_id": invocation_context.invocation_id, + "user_message": user_message.parts[0].text, + }, + ) + return None + + async def before_run_callback( + self, *, invocation_context: "InvocationContext" + ) -> Optional[types.Content]: + session_id = invocation_context.session.id + await self._log_entry( + session_id, + "before_run", + { + "invocation_id": invocation_context.invocation_id, + "agent_name": invocation_context.agent.name, + }, + ) + return None + + async def after_run_callback( + self, *, invocation_context: "InvocationContext" + ) -> None: + session_id = invocation_context.session.id + await self._log_entry( + session_id, + "after_run", + { + "invocation_id": invocation_context.invocation_id, + "agent_name": invocation_context.agent.name, + }, + ) + return None + + async def before_agent_callback( + self, *, agent: BaseAgent, callback_context: CallbackContext + ) -> Optional[types.Content]: + session_id = callback_context._invocation_context.session.id + await self._log_entry( + session_id, + "before_agent", + { + "agent_name": agent.name, + "invocation_id": callback_context.invocation_id, + }, + ) + return None + + async def after_agent_callback( + self, *, agent: BaseAgent, callback_context: CallbackContext + ) -> Optional[types.Content]: + session_id = callback_context._invocation_context.session.id + await self._log_entry( + session_id, + "after_agent", + { + "agent_name": agent.name, + "invocation_id": callback_context.invocation_id, + }, + ) + return None + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + session_id = callback_context._invocation_context.session.id + await self._log_entry( + session_id, + "before_model", + { + "agent_name": callback_context.agent_name, + "request": llm_request.model_dump(), + }, + ) + return None + + async def after_model_callback( + self, *, callback_context: CallbackContext, llm_response: LlmResponse + ) -> Optional[LlmResponse]: + session_id = callback_context._invocation_context.session.id + await self._log_entry( + session_id, + "after_model", + { + "agent_name": callback_context.agent_name, + "response": llm_response.model_dump(), + }, + ) + return None + + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + session_id = callback_context._invocation_context.session.id + await self._log_entry( + session_id, + "on_model_error", + { + "agent_name": callback_context.agent_name, + "request": llm_request.model_dump(), + "error": str(error), + }, + ) + return None + + async def on_event_callback( + self, *, invocation_context: "InvocationContext", event: Event + ) -> Optional[Event]: + session_id = invocation_context.session.id + await self._log_entry( + session_id, + "on_event", + { + "event_id": event.id, + "author": event.author, + "content": event.content.parts[0].text, + "is_final": event.is_final_response(), + }, + ) + return None + + async def before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: Dict[str, Any], + tool_context: ToolContext, + ) -> Optional[Dict]: + session_id = tool_context._invocation_context.session.id + await self._log_entry( + session_id, + "before_tool", + { + "tool_name": tool.name, + "agent_name": tool_context.agent_name, + "function_call_id": tool_context.function_call_id, + "args": tool_args, + }, + ) + return None + + async def after_tool_callback( + self, + *, + tool: BaseTool, + tool_args: Dict[str, Any], + tool_context: ToolContext, + result: Dict, + ) -> Optional[Dict]: + session_id = tool_context._invocation_context.session.id + await self._log_entry( + session_id, + "after_tool", + { + "tool_name": tool.name, + "agent_name": tool_context.agent_name, + "function_call_id": tool_context.function_call_id, + "args": tool_args, + "result": result, + }, + ) + return None + + async def on_tool_error_callback( + self, + *, + tool: BaseTool, + tool_args: Dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[Dict]: + session_id = tool_context._invocation_context.session.id + await self._log_entry( + session_id, + "on_tool_error", + { + "tool_name": tool.name, + "agent_name": tool_context.agent_name, + "function_call_id": tool_context.function_call_id, + "args": tool_args, + "error": str(error), + }, + ) + return None + + def get_logs_by_session(self, session_id: str) -> List[Dict[str, Any]]: + """Retrieve all logs for a specific session.""" + return self.logs.get(session_id, []) diff --git a/tests/unittests/plugins/test_log_collector_plugin.py b/tests/unittests/plugins/test_log_collector_plugin.py new file mode 100644 index 0000000000..6935e9cca1 --- /dev/null +++ b/tests/unittests/plugins/test_log_collector_plugin.py @@ -0,0 +1,322 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the LogCollectorPlugin.""" + +from __future__ import annotations + +from unittest.mock import Mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins import LogCollectorPlugin +from google.adk.sessions.session import Session +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + + +@pytest.fixture +def plugin() -> LogCollectorPlugin: + """Provides a clean LogCollectorPlugin instance for each test.""" + return LogCollectorPlugin() + + +def create_mock_invocation_context(session_id: str) -> Mock: + mock_context = Mock(spec=InvocationContext) + mock_context.session = Mock(spec=Session) + mock_context.session.id = session_id + return mock_context + + +def create_mock_callback_context(session_id: str) -> Mock: + mock_context = Mock(spec=CallbackContext) + mock_context._invocation_context = create_mock_invocation_context(session_id) + return mock_context + + +def create_mock_tool_context(session_id: str) -> Mock: + mock_context = Mock(spec=ToolContext) + mock_context._invocation_context = create_mock_invocation_context(session_id) + return mock_context + + +@pytest.mark.asyncio +async def test_on_user_message_callback(plugin: LogCollectorPlugin): + mock_context = create_mock_invocation_context("session1") + mock_context.invocation_id = "inv1" + user_message = types.Content(parts=[types.Part(text="Hello")]) + + await plugin.on_user_message_callback( + invocation_context=mock_context, user_message=user_message + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "on_user_message" + assert log["invocation_id"] == "inv1" + assert log["user_message"] == "Hello" + + +@pytest.mark.asyncio +async def test_before_agent_callback(plugin: LogCollectorPlugin): + mock_agent = Mock(spec=BaseAgent) + mock_agent.name = "test_agent" + mock_context = create_mock_callback_context("session1") + mock_context.invocation_id = "inv1" + + await plugin.before_agent_callback(agent=mock_agent, callback_context=mock_context) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "before_agent" + assert log["agent_name"] == "test_agent" + assert log["invocation_id"] == "inv1" + + +@pytest.mark.asyncio +async def test_after_agent_callback(plugin: LogCollectorPlugin): + mock_agent = Mock(spec=BaseAgent) + mock_agent.name = "test_agent" + mock_context = create_mock_callback_context("session1") + mock_context.invocation_id = "inv1" + + await plugin.after_agent_callback(agent=mock_agent, callback_context=mock_context) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "after_agent" + assert log["agent_name"] == "test_agent" + assert log["invocation_id"] == "inv1" + + +@pytest.mark.asyncio +async def test_before_model_callback(plugin: LogCollectorPlugin): + mock_context = create_mock_callback_context("session1") + mock_context.agent_name = "test_agent" + mock_request = Mock(spec=LlmRequest) + mock_request.model_dump.return_value = {"model": "gemini"} + + await plugin.before_model_callback( + callback_context=mock_context, llm_request=mock_request + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "before_model" + assert log["agent_name"] == "test_agent" + assert log["request"] == {"model": "gemini"} + + +@pytest.mark.asyncio +async def test_after_model_callback(plugin: LogCollectorPlugin): + mock_context = create_mock_callback_context("session1") + mock_context.agent_name = "test_agent" + mock_response = Mock(spec=LlmResponse) + mock_response.model_dump.return_value = {"text": "response"} + + await plugin.after_model_callback( + callback_context=mock_context, llm_response=mock_response + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "after_model" + assert log["agent_name"] == "test_agent" + assert log["response"] == {"text": "response"} + + +@pytest.mark.asyncio +async def test_on_event_callback(plugin: LogCollectorPlugin): + mock_context = create_mock_invocation_context("session1") + mock_event = Mock(spec=Event) + mock_event.id = "event1" + mock_event.author = "test_author" + mock_event.content = Mock(spec=types.Content) + mock_event.content.parts = [types.Part(text="event content")] + mock_event.is_final_response.return_value = True + + await plugin.on_event_callback(invocation_context=mock_context, event=mock_event) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "on_event" + assert log["event_id"] == "event1" + assert log["author"] == "test_author" + assert log["content"] == "event content" + assert log["is_final"] is True + + +@pytest.mark.asyncio +async def test_before_run_callback(plugin: LogCollectorPlugin): + mock_context = create_mock_invocation_context("session1") + mock_context.invocation_id = "inv1" + mock_context.agent = Mock(spec=BaseAgent) + mock_context.agent.name = "test_agent" + + await plugin.before_run_callback(invocation_context=mock_context) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "before_run" + assert log["invocation_id"] == "inv1" + assert log["agent_name"] == "test_agent" + + +@pytest.mark.asyncio +async def test_after_run_callback(plugin: LogCollectorPlugin): + mock_context = create_mock_invocation_context("session1") + mock_context.invocation_id = "inv1" + mock_context.agent = Mock(spec=BaseAgent) + mock_context.agent.name = "test_agent" + + await plugin.after_run_callback(invocation_context=mock_context) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "after_run" + assert log["invocation_id"] == "inv1" + assert log["agent_name"] == "test_agent" + + +@pytest.mark.asyncio +async def test_on_model_error_callback(plugin: LogCollectorPlugin): + mock_context = create_mock_callback_context("session1") + mock_context.agent_name = "test_agent" + mock_request = Mock(spec=LlmRequest) + mock_request.model_dump.return_value = {"model": "gemini"} + error = ValueError("test error") + + await plugin.on_model_error_callback( + callback_context=mock_context, llm_request=mock_request, error=error + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "on_model_error" + assert log["agent_name"] == "test_agent" + assert log["request"] == {"model": "gemini"} + assert log["error"] == "test error" + + +@pytest.mark.asyncio +async def test_before_tool_callback(plugin: LogCollectorPlugin): + mock_tool = Mock(spec=BaseTool) + mock_tool.name = "test_tool" + mock_context = create_mock_tool_context("session1") + mock_context.agent_name = "test_agent" + mock_context.function_call_id = "func1" + tool_args = {"arg1": "value1"} + + await plugin.before_tool_callback( + tool=mock_tool, tool_args=tool_args, tool_context=mock_context + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "before_tool" + assert log["tool_name"] == "test_tool" + assert log["agent_name"] == "test_agent" + assert log["function_call_id"] == "func1" + assert log["args"] == {"arg1": "value1"} + + +@pytest.mark.asyncio +async def test_after_tool_callback(plugin: LogCollectorPlugin): + mock_tool = Mock(spec=BaseTool) + mock_tool.name = "test_tool" + mock_context = create_mock_tool_context("session1") + mock_context.agent_name = "test_agent" + mock_context.function_call_id = "func1" + tool_args = {"arg1": "value1"} + result = {"result": "success"} + + await plugin.after_tool_callback( + tool=mock_tool, + tool_args=tool_args, + tool_context=mock_context, + result=result, + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "after_tool" + assert log["tool_name"] == "test_tool" + assert log["agent_name"] == "test_agent" + assert log["function_call_id"] == "func1" + assert log["args"] == {"arg1": "value1"} + assert log["result"] == {"result": "success"} + + +@pytest.mark.asyncio +async def test_on_tool_error_callback(plugin: LogCollectorPlugin): + mock_tool = Mock(spec=BaseTool) + mock_tool.name = "test_tool" + mock_context = create_mock_tool_context("session1") + mock_context.agent_name = "test_agent" + mock_context.function_call_id = "func1" + tool_args = {"arg1": "value1"} + error = ValueError("tool error") + + await plugin.on_tool_error_callback( + tool=mock_tool, + tool_args=tool_args, + tool_context=mock_context, + error=error, + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "on_tool_error" + assert log["tool_name"] == "test_tool" + assert log["agent_name"] == "test_agent" + assert log["function_call_id"] == "func1" + assert log["args"] == {"arg1": "value1"} + assert log["error"] == "tool error" + + +@pytest.mark.asyncio +async def test_multiple_sessions(plugin: LogCollectorPlugin): + mock_context1 = create_mock_invocation_context("session1") + mock_context1.invocation_id = "inv1" + user_message1 = types.Content(parts=[types.Part(text="Hello from session 1")]) + + mock_context2 = create_mock_invocation_context("session2") + mock_context2.invocation_id = "inv2" + user_message2 = types.Content(parts=[types.Part(text="Hello from session 2")]) + + await plugin.on_user_message_callback( + invocation_context=mock_context1, user_message=user_message1 + ) + await plugin.on_user_message_callback( + invocation_context=mock_context2, user_message=user_message2 + ) + + assert len(plugin.logs["session1"]) == 1 + assert len(plugin.logs["session2"]) == 1 + + log1 = plugin.logs["session1"][0] + assert log1["callback_type"] == "on_user_message" + assert log1["invocation_id"] == "inv1" + assert log1["user_message"] == "Hello from session 1" + + log2 = plugin.logs["session2"][0] + assert log2["callback_type"] == "on_user_message" + assert log2["invocation_id"] == "inv2" + assert log2["user_message"] == "Hello from session 2" From 4696b685022f8f5ed39d1a8b16d20fb0acc626ca Mon Sep 17 00:00:00 2001 From: "ivanmkc@google.com" Date: Fri, 8 Aug 2025 17:50:06 -0400 Subject: [PATCH 2/3] Add example usage --- .../adk/plugins/log_collector_plugin.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/google/adk/plugins/log_collector_plugin.py b/src/google/adk/plugins/log_collector_plugin.py index 6b91c50bd8..7c5b05e94c 100644 --- a/src/google/adk/plugins/log_collector_plugin.py +++ b/src/google/adk/plugins/log_collector_plugin.py @@ -36,6 +36,35 @@ class LogCollectorPlugin(BasePlugin): """ A plugin to programmatically and safely collect execution details from all callbacks in async environments, organized by session ID. + + The `session_id` is a user-defined string that you pass to the `Session` + object when you create it. This allows you to group all related logs for a + particular interaction or conversation. + + Example usage: + >>> import asyncio + >>> from google.adk.agents import Agent + >>> from google.adk.plugins import LogCollectorPlugin + >>> from google.adk.runners import InMemoryRunner + >>> + >>> async def main(): + ... log_plugin = LogCollectorPlugin() + ... agent = Agent( + ... # ... other agent parameters + ... ) + ... runner = InMemoryRunner(agent=agent, plugins=[log_plugin]) + ... session = await runner.session_service.create_session( + ... app_name=runner.app_name, user_id="test_user" + ... ) + ... # Run the agent with the session + ... # await runner.run_async(...) + ... # Retrieve logs for the specific session + ... session_logs = log_plugin.get_logs_by_session(session.id) + ... print(session_logs) + >>> + >>> if __name__ == "__main__": + ... asyncio.run(main()) + """ def __init__(self, name: str = "log_collector"): From 6d172d337d53e4fc41615160eff07254f729cef3 Mon Sep 17 00:00:00 2001 From: "ivanmkc@google.com" Date: Fri, 8 Aug 2025 18:49:12 -0400 Subject: [PATCH 3/3] Fixed import issue and formatting --- src/google/adk/plugins/__init__.py | 3 +- .../adk/plugins/log_collector_plugin.py | 17 +- .../plugins/test_log_collector_plugin.py | 420 +++++++++--------- 3 files changed, 228 insertions(+), 212 deletions(-) diff --git a/src/google/adk/plugins/__init__.py b/src/google/adk/plugins/__init__.py index b0c771ede5..dabe6e967c 100644 --- a/src/google/adk/plugins/__init__.py +++ b/src/google/adk/plugins/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from .base_plugin import BasePlugin +from .log_collector_plugin import LogCollectorPlugin -__all__ = ['BasePlugin'] +__all__ = ['BasePlugin', 'LogCollectorPlugin'] diff --git a/src/google/adk/plugins/log_collector_plugin.py b/src/google/adk/plugins/log_collector_plugin.py index 7c5b05e94c..c05121e0bb 100644 --- a/src/google/adk/plugins/log_collector_plugin.py +++ b/src/google/adk/plugins/log_collector_plugin.py @@ -16,7 +16,11 @@ import asyncio from collections import defaultdict -from typing import Any, Optional, Dict, List, TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import TYPE_CHECKING from google.adk.agents.base_agent import BaseAgent from google.adk.agents.callback_context import CallbackContext @@ -29,7 +33,7 @@ from google.genai import types if TYPE_CHECKING: - from google.adk.agents.invocation_context import InvocationContext + from google.adk.agents.invocation_context import InvocationContext class LogCollectorPlugin(BasePlugin): @@ -72,13 +76,18 @@ def __init__(self, name: str = "log_collector"): self.logs: Dict[str, List[Dict[str, Any]]] = defaultdict(list) self._lock = asyncio.Lock() - async def _log_entry(self, session_id: str, callback_type: str, data: Dict[str, Any]): + async def _log_entry( + self, session_id: str, callback_type: str, data: Dict[str, Any] + ): log_entry = {"callback_type": callback_type, **data} async with self._lock: self.logs[session_id].append(log_entry) async def on_user_message_callback( - self, *, invocation_context: "InvocationContext", user_message: types.Content + self, + *, + invocation_context: "InvocationContext", + user_message: types.Content, ) -> Optional[types.Content]: session_id = invocation_context.session.id await self._log_entry( diff --git a/tests/unittests/plugins/test_log_collector_plugin.py b/tests/unittests/plugins/test_log_collector_plugin.py index 6935e9cca1..904e129cda 100644 --- a/tests/unittests/plugins/test_log_collector_plugin.py +++ b/tests/unittests/plugins/test_log_collector_plugin.py @@ -39,284 +39,290 @@ def plugin() -> LogCollectorPlugin: def create_mock_invocation_context(session_id: str) -> Mock: - mock_context = Mock(spec=InvocationContext) - mock_context.session = Mock(spec=Session) - mock_context.session.id = session_id - return mock_context + mock_context = Mock(spec=InvocationContext) + mock_context.session = Mock(spec=Session) + mock_context.session.id = session_id + return mock_context def create_mock_callback_context(session_id: str) -> Mock: - mock_context = Mock(spec=CallbackContext) - mock_context._invocation_context = create_mock_invocation_context(session_id) - return mock_context + mock_context = Mock(spec=CallbackContext) + mock_context._invocation_context = create_mock_invocation_context(session_id) + return mock_context def create_mock_tool_context(session_id: str) -> Mock: - mock_context = Mock(spec=ToolContext) - mock_context._invocation_context = create_mock_invocation_context(session_id) - return mock_context + mock_context = Mock(spec=ToolContext) + mock_context._invocation_context = create_mock_invocation_context(session_id) + return mock_context @pytest.mark.asyncio async def test_on_user_message_callback(plugin: LogCollectorPlugin): - mock_context = create_mock_invocation_context("session1") - mock_context.invocation_id = "inv1" - user_message = types.Content(parts=[types.Part(text="Hello")]) + mock_context = create_mock_invocation_context("session1") + mock_context.invocation_id = "inv1" + user_message = types.Content(parts=[types.Part(text="Hello")]) - await plugin.on_user_message_callback( - invocation_context=mock_context, user_message=user_message - ) + await plugin.on_user_message_callback( + invocation_context=mock_context, user_message=user_message + ) - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "on_user_message" - assert log["invocation_id"] == "inv1" - assert log["user_message"] == "Hello" + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "on_user_message" + assert log["invocation_id"] == "inv1" + assert log["user_message"] == "Hello" @pytest.mark.asyncio async def test_before_agent_callback(plugin: LogCollectorPlugin): - mock_agent = Mock(spec=BaseAgent) - mock_agent.name = "test_agent" - mock_context = create_mock_callback_context("session1") - mock_context.invocation_id = "inv1" + mock_agent = Mock(spec=BaseAgent) + mock_agent.name = "test_agent" + mock_context = create_mock_callback_context("session1") + mock_context.invocation_id = "inv1" - await plugin.before_agent_callback(agent=mock_agent, callback_context=mock_context) + await plugin.before_agent_callback( + agent=mock_agent, callback_context=mock_context + ) - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "before_agent" - assert log["agent_name"] == "test_agent" - assert log["invocation_id"] == "inv1" + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "before_agent" + assert log["agent_name"] == "test_agent" + assert log["invocation_id"] == "inv1" @pytest.mark.asyncio async def test_after_agent_callback(plugin: LogCollectorPlugin): - mock_agent = Mock(spec=BaseAgent) - mock_agent.name = "test_agent" - mock_context = create_mock_callback_context("session1") - mock_context.invocation_id = "inv1" + mock_agent = Mock(spec=BaseAgent) + mock_agent.name = "test_agent" + mock_context = create_mock_callback_context("session1") + mock_context.invocation_id = "inv1" - await plugin.after_agent_callback(agent=mock_agent, callback_context=mock_context) + await plugin.after_agent_callback( + agent=mock_agent, callback_context=mock_context + ) - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "after_agent" - assert log["agent_name"] == "test_agent" - assert log["invocation_id"] == "inv1" + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "after_agent" + assert log["agent_name"] == "test_agent" + assert log["invocation_id"] == "inv1" @pytest.mark.asyncio async def test_before_model_callback(plugin: LogCollectorPlugin): - mock_context = create_mock_callback_context("session1") - mock_context.agent_name = "test_agent" - mock_request = Mock(spec=LlmRequest) - mock_request.model_dump.return_value = {"model": "gemini"} + mock_context = create_mock_callback_context("session1") + mock_context.agent_name = "test_agent" + mock_request = Mock(spec=LlmRequest) + mock_request.model_dump.return_value = {"model": "gemini"} - await plugin.before_model_callback( - callback_context=mock_context, llm_request=mock_request - ) + await plugin.before_model_callback( + callback_context=mock_context, llm_request=mock_request + ) - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "before_model" - assert log["agent_name"] == "test_agent" - assert log["request"] == {"model": "gemini"} + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "before_model" + assert log["agent_name"] == "test_agent" + assert log["request"] == {"model": "gemini"} @pytest.mark.asyncio async def test_after_model_callback(plugin: LogCollectorPlugin): - mock_context = create_mock_callback_context("session1") - mock_context.agent_name = "test_agent" - mock_response = Mock(spec=LlmResponse) - mock_response.model_dump.return_value = {"text": "response"} + mock_context = create_mock_callback_context("session1") + mock_context.agent_name = "test_agent" + mock_response = Mock(spec=LlmResponse) + mock_response.model_dump.return_value = {"text": "response"} - await plugin.after_model_callback( - callback_context=mock_context, llm_response=mock_response - ) + await plugin.after_model_callback( + callback_context=mock_context, llm_response=mock_response + ) - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "after_model" - assert log["agent_name"] == "test_agent" - assert log["response"] == {"text": "response"} + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "after_model" + assert log["agent_name"] == "test_agent" + assert log["response"] == {"text": "response"} @pytest.mark.asyncio async def test_on_event_callback(plugin: LogCollectorPlugin): - mock_context = create_mock_invocation_context("session1") - mock_event = Mock(spec=Event) - mock_event.id = "event1" - mock_event.author = "test_author" - mock_event.content = Mock(spec=types.Content) - mock_event.content.parts = [types.Part(text="event content")] - mock_event.is_final_response.return_value = True - - await plugin.on_event_callback(invocation_context=mock_context, event=mock_event) - - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "on_event" - assert log["event_id"] == "event1" - assert log["author"] == "test_author" - assert log["content"] == "event content" - assert log["is_final"] is True + mock_context = create_mock_invocation_context("session1") + mock_event = Mock(spec=Event) + mock_event.id = "event1" + mock_event.author = "test_author" + mock_event.content = Mock(spec=types.Content) + mock_event.content.parts = [types.Part(text="event content")] + mock_event.is_final_response.return_value = True + + await plugin.on_event_callback( + invocation_context=mock_context, event=mock_event + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "on_event" + assert log["event_id"] == "event1" + assert log["author"] == "test_author" + assert log["content"] == "event content" + assert log["is_final"] is True @pytest.mark.asyncio async def test_before_run_callback(plugin: LogCollectorPlugin): - mock_context = create_mock_invocation_context("session1") - mock_context.invocation_id = "inv1" - mock_context.agent = Mock(spec=BaseAgent) - mock_context.agent.name = "test_agent" + mock_context = create_mock_invocation_context("session1") + mock_context.invocation_id = "inv1" + mock_context.agent = Mock(spec=BaseAgent) + mock_context.agent.name = "test_agent" - await plugin.before_run_callback(invocation_context=mock_context) + await plugin.before_run_callback(invocation_context=mock_context) - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "before_run" - assert log["invocation_id"] == "inv1" - assert log["agent_name"] == "test_agent" + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "before_run" + assert log["invocation_id"] == "inv1" + assert log["agent_name"] == "test_agent" @pytest.mark.asyncio async def test_after_run_callback(plugin: LogCollectorPlugin): - mock_context = create_mock_invocation_context("session1") - mock_context.invocation_id = "inv1" - mock_context.agent = Mock(spec=BaseAgent) - mock_context.agent.name = "test_agent" + mock_context = create_mock_invocation_context("session1") + mock_context.invocation_id = "inv1" + mock_context.agent = Mock(spec=BaseAgent) + mock_context.agent.name = "test_agent" - await plugin.after_run_callback(invocation_context=mock_context) + await plugin.after_run_callback(invocation_context=mock_context) - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "after_run" - assert log["invocation_id"] == "inv1" - assert log["agent_name"] == "test_agent" + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "after_run" + assert log["invocation_id"] == "inv1" + assert log["agent_name"] == "test_agent" @pytest.mark.asyncio async def test_on_model_error_callback(plugin: LogCollectorPlugin): - mock_context = create_mock_callback_context("session1") - mock_context.agent_name = "test_agent" - mock_request = Mock(spec=LlmRequest) - mock_request.model_dump.return_value = {"model": "gemini"} - error = ValueError("test error") + mock_context = create_mock_callback_context("session1") + mock_context.agent_name = "test_agent" + mock_request = Mock(spec=LlmRequest) + mock_request.model_dump.return_value = {"model": "gemini"} + error = ValueError("test error") - await plugin.on_model_error_callback( - callback_context=mock_context, llm_request=mock_request, error=error - ) + await plugin.on_model_error_callback( + callback_context=mock_context, llm_request=mock_request, error=error + ) - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "on_model_error" - assert log["agent_name"] == "test_agent" - assert log["request"] == {"model": "gemini"} - assert log["error"] == "test error" + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "on_model_error" + assert log["agent_name"] == "test_agent" + assert log["request"] == {"model": "gemini"} + assert log["error"] == "test error" @pytest.mark.asyncio async def test_before_tool_callback(plugin: LogCollectorPlugin): - mock_tool = Mock(spec=BaseTool) - mock_tool.name = "test_tool" - mock_context = create_mock_tool_context("session1") - mock_context.agent_name = "test_agent" - mock_context.function_call_id = "func1" - tool_args = {"arg1": "value1"} - - await plugin.before_tool_callback( - tool=mock_tool, tool_args=tool_args, tool_context=mock_context - ) - - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "before_tool" - assert log["tool_name"] == "test_tool" - assert log["agent_name"] == "test_agent" - assert log["function_call_id"] == "func1" - assert log["args"] == {"arg1": "value1"} + mock_tool = Mock(spec=BaseTool) + mock_tool.name = "test_tool" + mock_context = create_mock_tool_context("session1") + mock_context.agent_name = "test_agent" + mock_context.function_call_id = "func1" + tool_args = {"arg1": "value1"} + + await plugin.before_tool_callback( + tool=mock_tool, tool_args=tool_args, tool_context=mock_context + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "before_tool" + assert log["tool_name"] == "test_tool" + assert log["agent_name"] == "test_agent" + assert log["function_call_id"] == "func1" + assert log["args"] == {"arg1": "value1"} @pytest.mark.asyncio async def test_after_tool_callback(plugin: LogCollectorPlugin): - mock_tool = Mock(spec=BaseTool) - mock_tool.name = "test_tool" - mock_context = create_mock_tool_context("session1") - mock_context.agent_name = "test_agent" - mock_context.function_call_id = "func1" - tool_args = {"arg1": "value1"} - result = {"result": "success"} - - await plugin.after_tool_callback( - tool=mock_tool, - tool_args=tool_args, - tool_context=mock_context, - result=result, - ) - - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "after_tool" - assert log["tool_name"] == "test_tool" - assert log["agent_name"] == "test_agent" - assert log["function_call_id"] == "func1" - assert log["args"] == {"arg1": "value1"} - assert log["result"] == {"result": "success"} + mock_tool = Mock(spec=BaseTool) + mock_tool.name = "test_tool" + mock_context = create_mock_tool_context("session1") + mock_context.agent_name = "test_agent" + mock_context.function_call_id = "func1" + tool_args = {"arg1": "value1"} + result = {"result": "success"} + + await plugin.after_tool_callback( + tool=mock_tool, + tool_args=tool_args, + tool_context=mock_context, + result=result, + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "after_tool" + assert log["tool_name"] == "test_tool" + assert log["agent_name"] == "test_agent" + assert log["function_call_id"] == "func1" + assert log["args"] == {"arg1": "value1"} + assert log["result"] == {"result": "success"} @pytest.mark.asyncio async def test_on_tool_error_callback(plugin: LogCollectorPlugin): - mock_tool = Mock(spec=BaseTool) - mock_tool.name = "test_tool" - mock_context = create_mock_tool_context("session1") - mock_context.agent_name = "test_agent" - mock_context.function_call_id = "func1" - tool_args = {"arg1": "value1"} - error = ValueError("tool error") - - await plugin.on_tool_error_callback( - tool=mock_tool, - tool_args=tool_args, - tool_context=mock_context, - error=error, - ) - - assert len(plugin.logs["session1"]) == 1 - log = plugin.logs["session1"][0] - assert log["callback_type"] == "on_tool_error" - assert log["tool_name"] == "test_tool" - assert log["agent_name"] == "test_agent" - assert log["function_call_id"] == "func1" - assert log["args"] == {"arg1": "value1"} - assert log["error"] == "tool error" + mock_tool = Mock(spec=BaseTool) + mock_tool.name = "test_tool" + mock_context = create_mock_tool_context("session1") + mock_context.agent_name = "test_agent" + mock_context.function_call_id = "func1" + tool_args = {"arg1": "value1"} + error = ValueError("tool error") + + await plugin.on_tool_error_callback( + tool=mock_tool, + tool_args=tool_args, + tool_context=mock_context, + error=error, + ) + + assert len(plugin.logs["session1"]) == 1 + log = plugin.logs["session1"][0] + assert log["callback_type"] == "on_tool_error" + assert log["tool_name"] == "test_tool" + assert log["agent_name"] == "test_agent" + assert log["function_call_id"] == "func1" + assert log["args"] == {"arg1": "value1"} + assert log["error"] == "tool error" @pytest.mark.asyncio async def test_multiple_sessions(plugin: LogCollectorPlugin): - mock_context1 = create_mock_invocation_context("session1") - mock_context1.invocation_id = "inv1" - user_message1 = types.Content(parts=[types.Part(text="Hello from session 1")]) - - mock_context2 = create_mock_invocation_context("session2") - mock_context2.invocation_id = "inv2" - user_message2 = types.Content(parts=[types.Part(text="Hello from session 2")]) - - await plugin.on_user_message_callback( - invocation_context=mock_context1, user_message=user_message1 - ) - await plugin.on_user_message_callback( - invocation_context=mock_context2, user_message=user_message2 - ) - - assert len(plugin.logs["session1"]) == 1 - assert len(plugin.logs["session2"]) == 1 - - log1 = plugin.logs["session1"][0] - assert log1["callback_type"] == "on_user_message" - assert log1["invocation_id"] == "inv1" - assert log1["user_message"] == "Hello from session 1" - - log2 = plugin.logs["session2"][0] - assert log2["callback_type"] == "on_user_message" - assert log2["invocation_id"] == "inv2" - assert log2["user_message"] == "Hello from session 2" + mock_context1 = create_mock_invocation_context("session1") + mock_context1.invocation_id = "inv1" + user_message1 = types.Content(parts=[types.Part(text="Hello from session 1")]) + + mock_context2 = create_mock_invocation_context("session2") + mock_context2.invocation_id = "inv2" + user_message2 = types.Content(parts=[types.Part(text="Hello from session 2")]) + + await plugin.on_user_message_callback( + invocation_context=mock_context1, user_message=user_message1 + ) + await plugin.on_user_message_callback( + invocation_context=mock_context2, user_message=user_message2 + ) + + assert len(plugin.logs["session1"]) == 1 + assert len(plugin.logs["session2"]) == 1 + + log1 = plugin.logs["session1"][0] + assert log1["callback_type"] == "on_user_message" + assert log1["invocation_id"] == "inv1" + assert log1["user_message"] == "Hello from session 1" + + log2 = plugin.logs["session2"][0] + assert log2["callback_type"] == "on_user_message" + assert log2["invocation_id"] == "inv2" + assert log2["user_message"] == "Hello from session 2"