Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@
AfterInvocationEvent,
AgentInitializedEvent,
BeforeInvocationEvent,
HookCallback,
HookProvider,
HookRegistry,
MessageAddedEvent,
)
from ..hooks.registry import TEvent
from ..interrupt import _InterruptState
from ..models.bedrock import BedrockModel
from ..models.model import Model
Expand Down Expand Up @@ -567,6 +569,50 @@ def cleanup(self) -> None:
"""
self.tool_registry.cleanup()

def add_hook(
self,
callback: HookCallback[TEvent],
event_type: type[TEvent] | None = None,
) -> None:
"""Register a callback function for a specific event type.

This method supports two call patterns:
1. ``add_hook(callback)`` - Event type inferred from callback's type hint
2. ``add_hook(callback, event_type)`` - Event type specified explicitly

Callbacks can be either synchronous or asynchronous functions.

Args:
callback: The callback function to invoke when events of this type occur.
event_type: The class type of events this callback should handle.
If not provided, the event type will be inferred from the callback's
first parameter type hint.

Raises:
ValueError: If event_type is not provided and cannot be inferred from
the callback's type hints.

Example:
```python
def log_model_call(event: BeforeModelCallEvent) -> None:
print(f"Calling model for agent: {event.agent.name}")

agent = Agent()

# With event type inferred from type hint
agent.add_hook(log_model_call)

# With explicit event type
agent.add_hook(log_model_call, BeforeModelCallEvent)
```
Docs:
https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/
"""
if event_type is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need an if/else condition here?

self.hooks.add_callback(event_type, callback)
else:
self.hooks.add_callback(callback)

def __del__(self) -> None:
"""Clean up resources when agent is garbage collected."""
# __del__ is called even when an exception is thrown in the constructor,
Expand Down
107 changes: 102 additions & 5 deletions src/strands/hooks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import logging
from collections.abc import Awaitable, Generator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable
from typing import (
TYPE_CHECKING,
Any,
Generic,
Protocol,
TypeVar,
get_type_hints,
runtime_checkable,
)

from ..interrupt import Interrupt, InterruptException

Expand Down Expand Up @@ -157,27 +165,116 @@ def __init__(self) -> None:
"""Initialize an empty hook registry."""
self._registered_callbacks: dict[type, list[HookCallback]] = {}

def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None:
def add_callback(
self,
event_type: type[TEvent] | HookCallback[TEvent] | None = None,
callback: HookCallback[TEvent] | None = None,
) -> None:
"""Register a callback function for a specific event type.

This method supports two call patterns:
1. ``add_callback(callback)`` - Event type inferred from callback's type hint
2. ``add_callback(event_type, callback)`` - Event type specified explicitly

Args:
event_type: The class type of events this callback should handle.
When using the single-argument form, pass the callback here instead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we solve this another way? This seems like a bad practice to overload methods 😅

I would rather have an explicit enum for event type to auto-discovery, instead of doing this.

callback: The callback function to invoke when events of this type occur.

Raises:
ValueError: If event_type is not provided and cannot be inferred from
the callback's type hints, or if AgentInitializedEvent is registered
with an async callback.

Example:
```python
def my_handler(event: StartRequestEvent):
print("Request started")

# With explicit event type
registry.add_callback(StartRequestEvent, my_handler)

# With event type inferred from type hint
registry.add_callback(my_handler)
```
"""
resolved_callback: HookCallback[TEvent]
resolved_event_type: type[TEvent]

# Support both add_callback(callback) and add_callback(event_type, callback)
if callback is None:
if event_type is None:
raise ValueError("callback is required")
# First argument is actually the callback, infer event_type
if callable(event_type) and not isinstance(event_type, type):
resolved_callback = event_type
resolved_event_type = self._infer_event_type(resolved_callback)
else:
raise ValueError("callback is required when event_type is a type")
elif event_type is None:
# callback provided but event_type is None - infer it
resolved_callback = callback
resolved_event_type = self._infer_event_type(callback)
else:
# Both provided - event_type should be a type
if isinstance(event_type, type):
resolved_callback = callback
resolved_event_type = event_type
else:
raise ValueError("event_type must be a type when callback is provided")

# Related issue: https://github.com/strands-agents/sdk-python/issues/330
if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(resolved_callback):
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")

callbacks = self._registered_callbacks.setdefault(event_type, [])
callbacks.append(callback)
callbacks = self._registered_callbacks.setdefault(resolved_event_type, [])
callbacks.append(resolved_callback)

def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]:
"""Infer the event type from a callback's type hints.

Args:
callback: The callback function to inspect.

Returns:
The event type inferred from the callback's first parameter type hint.

Raises:
ValueError: If the event type cannot be inferred from the callback's type hints.
"""
try:
hints = get_type_hints(callback)
except Exception as e:
logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e)
raise ValueError(
"failed to get type hints for callback | cannot infer event type, please provide event_type explicitly"
) from e

# Get the first parameter's type hint
sig = inspect.signature(callback)
params = list(sig.parameters.values())

if not params:
raise ValueError(
"callback has no parameters | cannot infer event type, please provide event_type explicitly"
)

first_param = params[0]
type_hint = hints.get(first_param.name)

if type_hint is None:
raise ValueError(
f"parameter=<{first_param.name}> has no type hint | "
"cannot infer event type, please provide event_type explicitly"
)

# Handle single type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No union support?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will address this as a follow-up once this is merged

if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent):
return type_hint # type: ignore[return-value]

raise ValueError(
f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent"
)

def add_hook(self, hook: HookProvider) -> None:
"""Register all callbacks from a hook provider.
Expand Down
71 changes: 70 additions & 1 deletion tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
from strands.agent.state import AgentState
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from strands.hooks import BeforeToolCallEvent
from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, BeforeToolCallEvent
from strands.interrupt import Interrupt
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel
from strands.session.repository_session_manager import RepositorySessionManager
Expand Down Expand Up @@ -2475,3 +2475,72 @@ def agent_tool(tool_context: ToolContext) -> str:
],
"role": "user",
}


def test_agent_add_hook_registers_callback():
"""Test that add_hook registers a callback with the hooks registry."""
agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))
callback = unittest.mock.Mock()

agent.add_hook(callback, BeforeModelCallEvent)

# Verify callback was registered by checking it gets invoked
agent("test prompt")
callback.assert_called_once()
# Verify it was called with the correct event type
call_args = callback.call_args[0]
assert isinstance(call_args[0], BeforeModelCallEvent)


def test_agent_add_hook_delegates_to_hooks_add_callback():
"""Test that add_hook delegates to self.hooks.add_callback."""
agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))
callback = unittest.mock.Mock()

# Spy on the hooks.add_callback method
with unittest.mock.patch.object(agent.hooks, "add_callback") as mock_add_callback:
agent.add_hook(callback, BeforeInvocationEvent)
mock_add_callback.assert_called_once_with(BeforeInvocationEvent, callback)


@pytest.mark.asyncio
async def test_agent_add_hook_works_with_async_callback():
"""Test that add_hook works with async callbacks."""

agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))
async_callback = unittest.mock.AsyncMock()

agent.add_hook(async_callback, BeforeModelCallEvent)

# Use stream_async to invoke the agent with async support
_ = [event async for event in agent.stream_async("test prompt")]
async_callback.assert_called_once()
# Verify it was called with the correct event type
call_args = async_callback.call_args[0]
assert isinstance(call_args[0], BeforeModelCallEvent)


def test_agent_add_hook_infers_event_type_from_callback():
"""Test that add_hook infers event type from callback type hint."""
agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))
callback_invoked = []

def typed_callback(event: BeforeModelCallEvent) -> None:
callback_invoked.append(event)

agent.add_hook(typed_callback)
agent("test prompt")

assert len(callback_invoked) == 1
assert isinstance(callback_invoked[0], BeforeModelCallEvent)


def test_agent_add_hook_raises_error_when_no_type_hint():
"""Test that add_hook raises error when event type cannot be inferred."""
agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]))

def untyped_callback(event):
pass

with pytest.raises(ValueError, match="cannot infer event type"):
agent.add_hook(untyped_callback)
106 changes: 106 additions & 0 deletions tests/strands/hooks/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,109 @@ def test_hook_registry_invoke_callbacks_coroutine(registry, agent):

with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"):
registry.invoke_callbacks(BeforeInvocationEvent(agent=agent))


def test_hook_registry_add_callback_infers_event_type(registry):
"""Test that add_callback infers event type from callback type hint."""

def typed_callback(event: BeforeInvocationEvent) -> None:
pass

# Register without explicit event_type - should infer from type hint
registry.add_callback(typed_callback)

# Verify callback was registered
assert BeforeInvocationEvent in registry._registered_callbacks
assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent]


def test_hook_registry_add_callback_raises_error_no_type_hint(registry):
"""Test that add_callback raises error when type hint is missing."""

def untyped_callback(event):
pass

with pytest.raises(ValueError, match="cannot infer event type"):
registry.add_callback(untyped_callback)


def test_hook_registry_add_callback_raises_error_invalid_type_hint(registry):
"""Test that add_callback raises error when type hint is not a BaseHookEvent subclass."""

def invalid_callback(event: str) -> None:
pass

with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"):
registry.add_callback(invalid_callback)


def test_hook_registry_add_callback_raises_error_no_parameters(registry):
"""Test that add_callback raises error when callback has no parameters."""

def no_param_callback() -> None:
pass

with pytest.raises(ValueError, match="callback has no parameters"):
registry.add_callback(no_param_callback)


def test_hook_registry_add_callback_raises_error_when_callback_is_none(registry):
"""Test that add_callback raises error when callback is None and event_type is None."""
with pytest.raises(ValueError, match="callback is required"):
registry.add_callback(None, None)


def test_hook_registry_add_callback_raises_error_when_event_type_is_type_without_callback(registry):
"""Test that add_callback raises error when event_type is a type but callback is None."""
with pytest.raises(ValueError, match="callback is required when event_type is a type"):
registry.add_callback(BeforeInvocationEvent, None)


def test_hook_registry_add_callback_raises_error_when_event_type_is_callable_with_callback(registry):
"""Test that add_callback raises error when event_type is callable and callback is provided."""

def callback1(event: BeforeInvocationEvent) -> None:
pass

def callback2(event: BeforeInvocationEvent) -> None:
pass

with pytest.raises(ValueError, match="event_type must be a type when callback is provided"):
registry.add_callback(callback1, callback2)


def test_hook_registry_add_callback_infers_event_type_when_callback_provided_without_event_type(registry):
"""Test that add_callback infers event type when callback is provided but event_type is None."""

def typed_callback(event: BeforeInvocationEvent) -> None:
pass

registry.add_callback(None, typed_callback)

assert BeforeInvocationEvent in registry._registered_callbacks
assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent]


def test_hook_registry_add_callback_with_explicit_event_type_and_callback(registry):
"""Test that add_callback works with explicit event_type and callback."""

def callback(event: BeforeInvocationEvent) -> None:
pass

registry.add_callback(BeforeInvocationEvent, callback)

assert BeforeInvocationEvent in registry._registered_callbacks
assert callback in registry._registered_callbacks[BeforeInvocationEvent]


def test_hook_registry_add_callback_raises_error_on_type_hints_failure(registry):
"""Test that add_callback raises error when get_type_hints fails."""

class BadCallback:
def __call__(self, event: "NonExistentType") -> None: # noqa: F821
pass

callback = BadCallback()

with pytest.raises(ValueError, match="failed to get type hints for callback"):
registry.add_callback(callback)
Loading