From 20f337abcf8f2b2d23517f1e50627f2586a9f9e7 Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Fri, 13 Feb 2026 21:31:25 +0000 Subject: [PATCH 1/4] feat(agent): add add_hook convenience method for hook registration Add a public add_hook method to the Agent class that provides a cleaner API for registering hook callbacks. This is a convenience method that delegates to self.hooks.add_callback but provides a more discoverable interface for users. Changes: - Add Agent.add_hook(callback, event_type=None) method - Support inferring event_type from callback's type hint - Update HookRegistry.add_callback to support type inference - Raise clear ValueError when type cannot be inferred The method: - Accepts both sync and async callbacks - Uses proper type hints with TEvent and HookCallback - Includes comprehensive Google-style docstring with examples Resolves #1686 --- src/strands/agent/agent.py | 40 ++++++++++ src/strands/hooks/registry.py | 109 +++++++++++++++++++++++++-- tests/strands/agent/test_agent.py | 75 ++++++++++++++++++ tests/strands/hooks/test_registry.py | 44 +++++++++++ 4 files changed, 262 insertions(+), 6 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 567a92b4a..9ff54995c 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,10 +37,12 @@ AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, + HookCallback, HookProvider, HookRegistry, MessageAddedEvent, ) +from ..hooks.registry import TEvent from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel from ..models.model import Model @@ -567,6 +569,44 @@ def cleanup(self) -> None: """ self.tool_registry.cleanup() + def add_hook( + self, + callback: HookCallback[TEvent], + event_type: type[TEvent] | None = None, + ) -> None: + """Register a callback function for a specific event type. + + Callbacks can be either synchronous or asynchronous functions. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type of events this callback should handle. + If not provided, the event type will be inferred from the callback's + first parameter type hint. + + Raises: + ValueError: If event_type is not provided and cannot be inferred from + the callback's type hints. + + Example: + ```python + def log_model_call(event: BeforeModelCallEvent) -> None: + print(f"Calling model for agent: {event.agent.name}") + + agent = Agent() + + # With event type inferred from type hint + agent.add_hook(log_model_call) + + # With explicit event type + agent.add_hook(log_model_call, BeforeModelCallEvent) + ``` + """ + if event_type is not None: + self.hooks.add_callback(event_type, callback) + else: + self.hooks.add_callback(callback) + def __del__(self) -> None: """Clean up resources when agent is garbage collected.""" # __del__ is called even when an exception is thrown in the constructor, diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 309e3ba76..ca3016969 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -11,7 +11,15 @@ import logging from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Protocol, + TypeVar, + get_type_hints, + runtime_checkable, +) from ..interrupt import Interrupt, InterruptException @@ -157,27 +165,116 @@ def __init__(self) -> None: """Initialize an empty hook registry.""" self._registered_callbacks: dict[type, list[HookCallback]] = {} - def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: + def add_callback( + self, + event_type: type[TEvent] | HookCallback[TEvent] | None = None, + callback: HookCallback[TEvent] | None = None, + ) -> None: """Register a callback function for a specific event type. Args: - event_type: The class type of events this callback should handle. + event_type: The class type of events this callback should handle, or the + callback function itself. If a callback is passed as this argument, + the event type will be inferred from its first parameter's type hint. + Note: In a future v2 release, the argument order will change to + (callback, event_type) for a cleaner API. callback: The callback function to invoke when events of this type occur. + Can be passed as the first positional argument if event_type is omitted. + + Raises: + ValueError: If event_type is not provided and cannot be inferred from + the callback's type hints, or if AgentInitializedEvent is registered + with an async callback. Example: ```python def my_handler(event: StartRequestEvent): print("Request started") + # With explicit event type registry.add_callback(StartRequestEvent, my_handler) + + # With event type inferred from type hint + registry.add_callback(my_handler) ``` """ + resolved_callback: HookCallback[TEvent] + resolved_event_type: type[TEvent] + + # Support both add_callback(callback) and add_callback(event_type, callback) + if callback is None: + if event_type is None: + raise ValueError("callback is required") + # First argument is actually the callback, infer event_type + if callable(event_type) and not isinstance(event_type, type): + resolved_callback = event_type + resolved_event_type = self._infer_event_type(resolved_callback) + else: + raise ValueError("callback is required when event_type is a type") + elif event_type is None: + # callback provided but event_type is None - infer it + resolved_callback = callback + resolved_event_type = self._infer_event_type(callback) + else: + # Both provided - event_type should be a type + if isinstance(event_type, type): + resolved_callback = callback + resolved_event_type = event_type + else: + raise ValueError("event_type must be a type when callback is provided") + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 - if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(resolved_callback): raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") - callbacks = self._registered_callbacks.setdefault(event_type, []) - callbacks.append(callback) + callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) + callbacks.append(resolved_callback) + + def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: + """Infer the event type from a callback's type hints. + + Args: + callback: The callback function to inspect. + + Returns: + The event type inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints. + """ + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError( + "callback has no parameters | cannot infer event type, please provide event_type explicitly" + ) + + first_param = params[0] + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return type_hint # type: ignore[return-value] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index eb039185c..e74da55d3 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2475,3 +2475,78 @@ def agent_tool(tool_context: ToolContext) -> str: ], "role": "user", } + + +def test_agent_add_hook_registers_callback(): + """Test that add_hook registers a callback with the hooks registry.""" + from strands.hooks import BeforeModelCallEvent + + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback = unittest.mock.Mock() + + agent.add_hook(callback, BeforeModelCallEvent) + + # Verify callback was registered by checking it gets invoked + agent("test prompt") + callback.assert_called_once() + # Verify it was called with the correct event type + call_args = callback.call_args[0] + assert isinstance(call_args[0], BeforeModelCallEvent) + + +def test_agent_add_hook_delegates_to_hooks_add_callback(): + """Test that add_hook delegates to self.hooks.add_callback.""" + from strands.hooks import BeforeInvocationEvent + + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback = unittest.mock.Mock() + + # Spy on the hooks.add_callback method + with unittest.mock.patch.object(agent.hooks, "add_callback") as mock_add_callback: + agent.add_hook(callback, BeforeInvocationEvent) + mock_add_callback.assert_called_once_with(BeforeInvocationEvent, callback) + + +@pytest.mark.asyncio +async def test_agent_add_hook_works_with_async_callback(): + """Test that add_hook works with async callbacks.""" + from strands.hooks import BeforeModelCallEvent + + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + async_callback = unittest.mock.AsyncMock() + + agent.add_hook(async_callback, BeforeModelCallEvent) + + # Use stream_async to invoke the agent with async support + _ = [event async for event in agent.stream_async("test prompt")] + async_callback.assert_called_once() + # Verify it was called with the correct event type + call_args = async_callback.call_args[0] + assert isinstance(call_args[0], BeforeModelCallEvent) + + +def test_agent_add_hook_infers_event_type_from_callback(): + """Test that add_hook infers event type from callback type hint.""" + from strands.hooks import BeforeModelCallEvent + + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + + def typed_callback(event: BeforeModelCallEvent) -> None: + pass + + # Register without explicit event_type - should infer from type hint + agent.add_hook(typed_callback) + + # Verify callback was registered by checking it gets invoked + agent("test prompt") + + +def test_agent_add_hook_raises_error_when_no_type_hint(): + """Test that add_hook raises error when event type cannot be inferred.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + + def untyped_callback(event): + pass + + with pytest.raises(ValueError, match="cannot infer event type"): + agent.add_hook(untyped_callback) diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 3daf41734..5793e6d13 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -87,3 +87,47 @@ def test_hook_registry_invoke_callbacks_coroutine(registry, agent): with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"): registry.invoke_callbacks(BeforeInvocationEvent(agent=agent)) + + +def test_hook_registry_add_callback_infers_event_type(registry): + """Test that add_callback infers event type from callback type hint.""" + + def typed_callback(event: BeforeInvocationEvent) -> None: + pass + + # Register without explicit event_type - should infer from type hint + registry.add_callback(typed_callback) + + # Verify callback was registered + assert BeforeInvocationEvent in registry._registered_callbacks + assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_raises_error_no_type_hint(registry): + """Test that add_callback raises error when type hint is missing.""" + + def untyped_callback(event): + pass + + with pytest.raises(ValueError, match="cannot infer event type"): + registry.add_callback(untyped_callback) + + +def test_hook_registry_add_callback_raises_error_invalid_type_hint(registry): + """Test that add_callback raises error when type hint is not a BaseHookEvent subclass.""" + + def invalid_callback(event: str) -> None: + pass + + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + registry.add_callback(invalid_callback) + + +def test_hook_registry_add_callback_raises_error_no_parameters(registry): + """Test that add_callback raises error when callback has no parameters.""" + + def no_param_callback() -> None: + pass + + with pytest.raises(ValueError, match="callback has no parameters"): + registry.add_callback(no_param_callback) From 5009a02a9e5da8f7aacb14b40f6c4c48694344de Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Mon, 16 Feb 2026 17:09:16 +0000 Subject: [PATCH 2/4] refactor(hooks): add @overload decorators for clearer type hints - Add @overload decorators to HookRegistry.add_callback() for clearer type hints - Add @overload decorators to Agent.add_hook() for clearer type hints - Improves IDE support and documentation of supported call patterns - Addresses review feedback from PR review --- src/strands/agent/agent.py | 11 +++++++++++ src/strands/hooks/registry.py | 19 +++++++++++++------ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9ff54995c..3a6aec098 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -19,6 +19,7 @@ TypeVar, Union, cast, + overload, ) from opentelemetry import trace as trace_api @@ -569,6 +570,12 @@ def cleanup(self) -> None: """ self.tool_registry.cleanup() + @overload + def add_hook(self, callback: HookCallback[TEvent]) -> None: ... + + @overload + def add_hook(self, callback: HookCallback[TEvent], event_type: type[TEvent]) -> None: ... + def add_hook( self, callback: HookCallback[TEvent], @@ -576,6 +583,10 @@ def add_hook( ) -> None: """Register a callback function for a specific event type. + This method supports two call patterns: + 1. ``add_hook(callback)`` - Event type inferred from callback's type hint + 2. ``add_hook(callback, event_type)`` - Event type specified explicitly + Callbacks can be either synchronous or asynchronous functions. Args: diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index ca3016969..7886fc9c0 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -18,6 +18,7 @@ Protocol, TypeVar, get_type_hints, + overload, runtime_checkable, ) @@ -165,6 +166,12 @@ def __init__(self) -> None: """Initialize an empty hook registry.""" self._registered_callbacks: dict[type, list[HookCallback]] = {} + @overload + def add_callback(self, callback: HookCallback[TEvent]) -> None: ... + + @overload + def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: ... + def add_callback( self, event_type: type[TEvent] | HookCallback[TEvent] | None = None, @@ -172,14 +179,14 @@ def add_callback( ) -> None: """Register a callback function for a specific event type. + This method supports two call patterns: + 1. ``add_callback(callback)`` - Event type inferred from callback's type hint + 2. ``add_callback(event_type, callback)`` - Event type specified explicitly + Args: - event_type: The class type of events this callback should handle, or the - callback function itself. If a callback is passed as this argument, - the event type will be inferred from its first parameter's type hint. - Note: In a future v2 release, the argument order will change to - (callback, event_type) for a cleaner API. + event_type: The class type of events this callback should handle. + When using the single-argument form, pass the callback here instead. callback: The callback function to invoke when events of this type occur. - Can be passed as the first positional argument if event_type is omitted. Raises: ValueError: If event_type is not provided and cannot be inferred from From 9f0b8b53509fd5295bc7b9bd68918c654364a955 Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Mon, 16 Feb 2026 15:22:59 -0500 Subject: [PATCH 3/4] Update method overloads --- src/strands/agent/agent.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3a6aec098..3ddfbfeb6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -19,7 +19,6 @@ TypeVar, Union, cast, - overload, ) from opentelemetry import trace as trace_api @@ -570,12 +569,6 @@ def cleanup(self) -> None: """ self.tool_registry.cleanup() - @overload - def add_hook(self, callback: HookCallback[TEvent]) -> None: ... - - @overload - def add_hook(self, callback: HookCallback[TEvent], event_type: type[TEvent]) -> None: ... - def add_hook( self, callback: HookCallback[TEvent], From 084247933b1f91344529584c27edf26e12a66386 Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Tue, 17 Feb 2026 15:21:05 -0500 Subject: [PATCH 4/4] Address pr feedback --- src/strands/agent/agent.py | 2 + src/strands/hooks/registry.py | 7 ---- tests/strands/agent/test_agent.py | 18 +++----- tests/strands/hooks/test_registry.py | 62 ++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 19 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3ddfbfeb6..0766a7983 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -605,6 +605,8 @@ def log_model_call(event: BeforeModelCallEvent) -> None: # With explicit event type agent.add_hook(log_model_call, BeforeModelCallEvent) ``` + Docs: + https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ """ if event_type is not None: self.hooks.add_callback(event_type, callback) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 7886fc9c0..3cedb23ad 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -18,7 +18,6 @@ Protocol, TypeVar, get_type_hints, - overload, runtime_checkable, ) @@ -166,12 +165,6 @@ def __init__(self) -> None: """Initialize an empty hook registry.""" self._registered_callbacks: dict[type, list[HookCallback]] = {} - @overload - def add_callback(self, callback: HookCallback[TEvent]) -> None: ... - - @overload - def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: ... - def add_callback( self, event_type: type[TEvent] | HookCallback[TEvent] | None = None, diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index e74da55d3..503b86c6e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -20,7 +20,7 @@ from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler -from strands.hooks import BeforeToolCallEvent +from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, BeforeToolCallEvent from strands.interrupt import Interrupt from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager @@ -2479,8 +2479,6 @@ def agent_tool(tool_context: ToolContext) -> str: def test_agent_add_hook_registers_callback(): """Test that add_hook registers a callback with the hooks registry.""" - from strands.hooks import BeforeModelCallEvent - agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) callback = unittest.mock.Mock() @@ -2496,8 +2494,6 @@ def test_agent_add_hook_registers_callback(): def test_agent_add_hook_delegates_to_hooks_add_callback(): """Test that add_hook delegates to self.hooks.add_callback.""" - from strands.hooks import BeforeInvocationEvent - agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) callback = unittest.mock.Mock() @@ -2510,7 +2506,6 @@ def test_agent_add_hook_delegates_to_hooks_add_callback(): @pytest.mark.asyncio async def test_agent_add_hook_works_with_async_callback(): """Test that add_hook works with async callbacks.""" - from strands.hooks import BeforeModelCallEvent agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) async_callback = unittest.mock.AsyncMock() @@ -2527,19 +2522,18 @@ async def test_agent_add_hook_works_with_async_callback(): def test_agent_add_hook_infers_event_type_from_callback(): """Test that add_hook infers event type from callback type hint.""" - from strands.hooks import BeforeModelCallEvent - agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback_invoked = [] def typed_callback(event: BeforeModelCallEvent) -> None: - pass + callback_invoked.append(event) - # Register without explicit event_type - should infer from type hint agent.add_hook(typed_callback) - - # Verify callback was registered by checking it gets invoked agent("test prompt") + assert len(callback_invoked) == 1 + assert isinstance(callback_invoked[0], BeforeModelCallEvent) + def test_agent_add_hook_raises_error_when_no_type_hint(): """Test that add_hook raises error when event type cannot be inferred.""" diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 5793e6d13..e4c5c8874 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -131,3 +131,65 @@ def no_param_callback() -> None: with pytest.raises(ValueError, match="callback has no parameters"): registry.add_callback(no_param_callback) + + +def test_hook_registry_add_callback_raises_error_when_callback_is_none(registry): + """Test that add_callback raises error when callback is None and event_type is None.""" + with pytest.raises(ValueError, match="callback is required"): + registry.add_callback(None, None) + + +def test_hook_registry_add_callback_raises_error_when_event_type_is_type_without_callback(registry): + """Test that add_callback raises error when event_type is a type but callback is None.""" + with pytest.raises(ValueError, match="callback is required when event_type is a type"): + registry.add_callback(BeforeInvocationEvent, None) + + +def test_hook_registry_add_callback_raises_error_when_event_type_is_callable_with_callback(registry): + """Test that add_callback raises error when event_type is callable and callback is provided.""" + + def callback1(event: BeforeInvocationEvent) -> None: + pass + + def callback2(event: BeforeInvocationEvent) -> None: + pass + + with pytest.raises(ValueError, match="event_type must be a type when callback is provided"): + registry.add_callback(callback1, callback2) + + +def test_hook_registry_add_callback_infers_event_type_when_callback_provided_without_event_type(registry): + """Test that add_callback infers event type when callback is provided but event_type is None.""" + + def typed_callback(event: BeforeInvocationEvent) -> None: + pass + + registry.add_callback(None, typed_callback) + + assert BeforeInvocationEvent in registry._registered_callbacks + assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_with_explicit_event_type_and_callback(registry): + """Test that add_callback works with explicit event_type and callback.""" + + def callback(event: BeforeInvocationEvent) -> None: + pass + + registry.add_callback(BeforeInvocationEvent, callback) + + assert BeforeInvocationEvent in registry._registered_callbacks + assert callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_raises_error_on_type_hints_failure(registry): + """Test that add_callback raises error when get_type_hints fails.""" + + class BadCallback: + def __call__(self, event: "NonExistentType") -> None: # noqa: F821 + pass + + callback = BadCallback() + + with pytest.raises(ValueError, match="failed to get type hints for callback"): + registry.add_callback(callback)