diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 567a92b4a..0766a7983 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,10 +37,12 @@ AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, + HookCallback, HookProvider, HookRegistry, MessageAddedEvent, ) +from ..hooks.registry import TEvent from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel from ..models.model import Model @@ -567,6 +569,50 @@ def cleanup(self) -> None: """ self.tool_registry.cleanup() + def add_hook( + self, + callback: HookCallback[TEvent], + event_type: type[TEvent] | None = None, + ) -> None: + """Register a callback function for a specific event type. + + This method supports two call patterns: + 1. ``add_hook(callback)`` - Event type inferred from callback's type hint + 2. ``add_hook(callback, event_type)`` - Event type specified explicitly + + Callbacks can be either synchronous or asynchronous functions. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type of events this callback should handle. + If not provided, the event type will be inferred from the callback's + first parameter type hint. + + Raises: + ValueError: If event_type is not provided and cannot be inferred from + the callback's type hints. + + Example: + ```python + def log_model_call(event: BeforeModelCallEvent) -> None: + print(f"Calling model for agent: {event.agent.name}") + + agent = Agent() + + # With event type inferred from type hint + agent.add_hook(log_model_call) + + # With explicit event type + agent.add_hook(log_model_call, BeforeModelCallEvent) + ``` + Docs: + https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ + """ + if event_type is not None: + self.hooks.add_callback(event_type, callback) + else: + self.hooks.add_callback(callback) + def __del__(self) -> None: """Clean up resources when agent is garbage collected.""" # __del__ is called even when an exception is thrown in the constructor, diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 309e3ba76..3cedb23ad 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -11,7 +11,15 @@ import logging from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Protocol, + TypeVar, + get_type_hints, + runtime_checkable, +) from ..interrupt import Interrupt, InterruptException @@ -157,27 +165,116 @@ def __init__(self) -> None: """Initialize an empty hook registry.""" self._registered_callbacks: dict[type, list[HookCallback]] = {} - def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: + def add_callback( + self, + event_type: type[TEvent] | HookCallback[TEvent] | None = None, + callback: HookCallback[TEvent] | None = None, + ) -> None: """Register a callback function for a specific event type. + This method supports two call patterns: + 1. ``add_callback(callback)`` - Event type inferred from callback's type hint + 2. ``add_callback(event_type, callback)`` - Event type specified explicitly + Args: event_type: The class type of events this callback should handle. + When using the single-argument form, pass the callback here instead. callback: The callback function to invoke when events of this type occur. + Raises: + ValueError: If event_type is not provided and cannot be inferred from + the callback's type hints, or if AgentInitializedEvent is registered + with an async callback. + Example: ```python def my_handler(event: StartRequestEvent): print("Request started") + # With explicit event type registry.add_callback(StartRequestEvent, my_handler) + + # With event type inferred from type hint + registry.add_callback(my_handler) ``` """ + resolved_callback: HookCallback[TEvent] + resolved_event_type: type[TEvent] + + # Support both add_callback(callback) and add_callback(event_type, callback) + if callback is None: + if event_type is None: + raise ValueError("callback is required") + # First argument is actually the callback, infer event_type + if callable(event_type) and not isinstance(event_type, type): + resolved_callback = event_type + resolved_event_type = self._infer_event_type(resolved_callback) + else: + raise ValueError("callback is required when event_type is a type") + elif event_type is None: + # callback provided but event_type is None - infer it + resolved_callback = callback + resolved_event_type = self._infer_event_type(callback) + else: + # Both provided - event_type should be a type + if isinstance(event_type, type): + resolved_callback = callback + resolved_event_type = event_type + else: + raise ValueError("event_type must be a type when callback is provided") + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 - if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(resolved_callback): raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") - callbacks = self._registered_callbacks.setdefault(event_type, []) - callbacks.append(callback) + callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) + callbacks.append(resolved_callback) + + def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: + """Infer the event type from a callback's type hints. + + Args: + callback: The callback function to inspect. + + Returns: + The event type inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints. + """ + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError( + "callback has no parameters | cannot infer event type, please provide event_type explicitly" + ) + + first_param = params[0] + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return type_hint # type: ignore[return-value] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index eb039185c..503b86c6e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -20,7 +20,7 @@ from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler -from strands.hooks import BeforeToolCallEvent +from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, BeforeToolCallEvent from strands.interrupt import Interrupt from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager @@ -2475,3 +2475,72 @@ def agent_tool(tool_context: ToolContext) -> str: ], "role": "user", } + + +def test_agent_add_hook_registers_callback(): + """Test that add_hook registers a callback with the hooks registry.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback = unittest.mock.Mock() + + agent.add_hook(callback, BeforeModelCallEvent) + + # Verify callback was registered by checking it gets invoked + agent("test prompt") + callback.assert_called_once() + # Verify it was called with the correct event type + call_args = callback.call_args[0] + assert isinstance(call_args[0], BeforeModelCallEvent) + + +def test_agent_add_hook_delegates_to_hooks_add_callback(): + """Test that add_hook delegates to self.hooks.add_callback.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback = unittest.mock.Mock() + + # Spy on the hooks.add_callback method + with unittest.mock.patch.object(agent.hooks, "add_callback") as mock_add_callback: + agent.add_hook(callback, BeforeInvocationEvent) + mock_add_callback.assert_called_once_with(BeforeInvocationEvent, callback) + + +@pytest.mark.asyncio +async def test_agent_add_hook_works_with_async_callback(): + """Test that add_hook works with async callbacks.""" + + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + async_callback = unittest.mock.AsyncMock() + + agent.add_hook(async_callback, BeforeModelCallEvent) + + # Use stream_async to invoke the agent with async support + _ = [event async for event in agent.stream_async("test prompt")] + async_callback.assert_called_once() + # Verify it was called with the correct event type + call_args = async_callback.call_args[0] + assert isinstance(call_args[0], BeforeModelCallEvent) + + +def test_agent_add_hook_infers_event_type_from_callback(): + """Test that add_hook infers event type from callback type hint.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback_invoked = [] + + def typed_callback(event: BeforeModelCallEvent) -> None: + callback_invoked.append(event) + + agent.add_hook(typed_callback) + agent("test prompt") + + assert len(callback_invoked) == 1 + assert isinstance(callback_invoked[0], BeforeModelCallEvent) + + +def test_agent_add_hook_raises_error_when_no_type_hint(): + """Test that add_hook raises error when event type cannot be inferred.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + + def untyped_callback(event): + pass + + with pytest.raises(ValueError, match="cannot infer event type"): + agent.add_hook(untyped_callback) diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 3daf41734..e4c5c8874 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -87,3 +87,109 @@ def test_hook_registry_invoke_callbacks_coroutine(registry, agent): with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"): registry.invoke_callbacks(BeforeInvocationEvent(agent=agent)) + + +def test_hook_registry_add_callback_infers_event_type(registry): + """Test that add_callback infers event type from callback type hint.""" + + def typed_callback(event: BeforeInvocationEvent) -> None: + pass + + # Register without explicit event_type - should infer from type hint + registry.add_callback(typed_callback) + + # Verify callback was registered + assert BeforeInvocationEvent in registry._registered_callbacks + assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_raises_error_no_type_hint(registry): + """Test that add_callback raises error when type hint is missing.""" + + def untyped_callback(event): + pass + + with pytest.raises(ValueError, match="cannot infer event type"): + registry.add_callback(untyped_callback) + + +def test_hook_registry_add_callback_raises_error_invalid_type_hint(registry): + """Test that add_callback raises error when type hint is not a BaseHookEvent subclass.""" + + def invalid_callback(event: str) -> None: + pass + + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + registry.add_callback(invalid_callback) + + +def test_hook_registry_add_callback_raises_error_no_parameters(registry): + """Test that add_callback raises error when callback has no parameters.""" + + def no_param_callback() -> None: + pass + + with pytest.raises(ValueError, match="callback has no parameters"): + registry.add_callback(no_param_callback) + + +def test_hook_registry_add_callback_raises_error_when_callback_is_none(registry): + """Test that add_callback raises error when callback is None and event_type is None.""" + with pytest.raises(ValueError, match="callback is required"): + registry.add_callback(None, None) + + +def test_hook_registry_add_callback_raises_error_when_event_type_is_type_without_callback(registry): + """Test that add_callback raises error when event_type is a type but callback is None.""" + with pytest.raises(ValueError, match="callback is required when event_type is a type"): + registry.add_callback(BeforeInvocationEvent, None) + + +def test_hook_registry_add_callback_raises_error_when_event_type_is_callable_with_callback(registry): + """Test that add_callback raises error when event_type is callable and callback is provided.""" + + def callback1(event: BeforeInvocationEvent) -> None: + pass + + def callback2(event: BeforeInvocationEvent) -> None: + pass + + with pytest.raises(ValueError, match="event_type must be a type when callback is provided"): + registry.add_callback(callback1, callback2) + + +def test_hook_registry_add_callback_infers_event_type_when_callback_provided_without_event_type(registry): + """Test that add_callback infers event type when callback is provided but event_type is None.""" + + def typed_callback(event: BeforeInvocationEvent) -> None: + pass + + registry.add_callback(None, typed_callback) + + assert BeforeInvocationEvent in registry._registered_callbacks + assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_with_explicit_event_type_and_callback(registry): + """Test that add_callback works with explicit event_type and callback.""" + + def callback(event: BeforeInvocationEvent) -> None: + pass + + registry.add_callback(BeforeInvocationEvent, callback) + + assert BeforeInvocationEvent in registry._registered_callbacks + assert callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_raises_error_on_type_hints_failure(registry): + """Test that add_callback raises error when get_type_hints fails.""" + + class BadCallback: + def __call__(self, event: "NonExistentType") -> None: # noqa: F821 + pass + + callback = BadCallback() + + with pytest.raises(ValueError, match="failed to get type hints for callback"): + registry.add_callback(callback)