diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 6026d4240..8a71ce4d8 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,6 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy +from .hooks.decorator import hook from .tools.decorator import tool from .types.tools import ToolContext @@ -11,6 +12,7 @@ "Agent", "AgentBase", "agent", + "hook", "models", "ModelRetryStrategy", "tool", diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 96c7f577b..599ff0fb6 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -5,7 +5,7 @@ built-in SDK components and user code to react to or modify agent behavior through strongly-typed event callbacks. -Example Usage: +Example Usage with Class-Based Hooks: ```python from strands.hooks import HookProvider, HookRegistry from strands.hooks.events import BeforeInvocationEvent, AfterInvocationEvent @@ -25,10 +25,24 @@ def log_end(self, event: AfterInvocationEvent) -> None: agent = Agent(hooks=[LoggingHooks()]) ``` -This replaces the older callback_handler approach with a more composable, -type-safe system that supports multiple subscribers per event type. +Example Usage with Decorator-Based Hooks: + ```python + from strands import Agent, hook + from strands.hooks import BeforeToolCallEvent + + @hook + def log_tool_calls(event: BeforeToolCallEvent) -> None: + '''Log all tool calls before execution.''' + print(f"Tool: {event.tool_use}") + + agent = Agent(hooks=[log_tool_calls]) + ``` + +This module supports both the class-based HookProvider approach and the newer +decorator-based @hook approach for maximum flexibility. """ +from .decorator import DecoratedFunctionHook, hook from .events import ( AfterInvocationEvent, AfterModelCallEvent, @@ -48,6 +62,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry __all__ = [ + # Events "AgentInitializedEvent", "BeforeInvocationEvent", "BeforeToolCallEvent", @@ -56,15 +71,19 @@ def log_end(self, event: AfterInvocationEvent) -> None: "AfterModelCallEvent", "AfterInvocationEvent", "MessageAddedEvent", - "HookEvent", - "HookProvider", - "HookCallback", - "HookRegistry", - "HookEvent", - "BaseHookEvent", + # Multiagent events "AfterMultiAgentInvocationEvent", "AfterNodeCallEvent", "BeforeMultiAgentInvocationEvent", "BeforeNodeCallEvent", "MultiAgentInitializedEvent", + # Registry + "HookEvent", + "HookProvider", + "HookCallback", + "HookRegistry", + "BaseHookEvent", + # Decorator + "hook", + "DecoratedFunctionHook", ] diff --git a/src/strands/hooks/decorator.py b/src/strands/hooks/decorator.py new file mode 100644 index 000000000..89679ff19 --- /dev/null +++ b/src/strands/hooks/decorator.py @@ -0,0 +1,289 @@ +"""Hook decorator for defining hooks as functions. + +This module provides the @hook decorator that transforms Python functions into +HookProvider implementations with automatic event type detection from type hints. + +Example: + ```python + from strands import Agent, hook + from strands.hooks import BeforeToolCallEvent + + @hook + def log_tool_calls(event: BeforeToolCallEvent) -> None: + '''Log all tool calls before execution.''' + print(f"Tool: {event.tool_use}") + + agent = Agent(hooks=[log_tool_calls]) + ``` +""" + +import functools +import inspect +import types +from collections.abc import Callable +from dataclasses import dataclass +from typing import ( + Any, + Generic, + TypeVar, + Union, + cast, + get_args, + get_origin, + get_type_hints, +) + +from .registry import BaseHookEvent, HookCallback, HookProvider, HookRegistry + +TEvent = TypeVar("TEvent", bound=BaseHookEvent) + + +@dataclass +class HookMetadata: + """Metadata extracted from a decorated hook function. + + Attributes: + name: The name of the hook function. + description: Description extracted from the function's docstring. + event_types: List of event types this hook handles. + is_async: Whether the hook function is async. + """ + + name: str + description: str + event_types: list[type[BaseHookEvent]] + is_async: bool + + +class FunctionHookMetadata: + """Helper class to extract and manage function metadata for hook decoration.""" + + def __init__( + self, + func: Callable[..., Any], + ) -> None: + """Initialize with the function to process. + + Args: + func: The function to extract metadata from. + """ + self.func = func + self.signature = inspect.signature(func) + + # Validate and extract event types + self._event_types = self._resolve_event_types() + self._validate_event_types() + + def _resolve_event_types(self) -> list[type[BaseHookEvent]]: + """Resolve event types from type hints. + + Returns: + List of event types this hook handles. + + Raises: + ValueError: If no event type can be determined. + """ + # Try to extract from type hints + try: + type_hints = get_type_hints(self.func) + except Exception: + # get_type_hints can fail for various reasons (forward refs, etc.) + type_hints = {} + + # Find the first parameter's type hint (should be the event) + # Skip 'self' and 'cls' for class methods + params = list(self.signature.parameters.values()) + event_params = [p for p in params if p.name not in ("self", "cls")] + + if not event_params: + raise ValueError( + f"Hook function '{self.func.__name__}' must have at least one parameter " + "for the event with a type hint." + ) + + first_param = event_params[0] + event_type = type_hints.get(first_param.name) + + if event_type is None: + # Check annotation directly (for cases where get_type_hints fails) + if first_param.annotation is not inspect.Parameter.empty: + event_type = first_param.annotation + else: + raise ValueError( + f"Hook function '{self.func.__name__}' must have a type hint for the event parameter." + ) + + # Handle Union types (e.g., BeforeToolCallEvent | AfterToolCallEvent) + return self._extract_event_types_from_annotation(event_type) + + def _is_union_type(self, annotation: Any) -> bool: + """Check if annotation is a Union type (typing.Union or types.UnionType).""" + origin = get_origin(annotation) + if origin is Union: + return True + + # Python 3.10+ uses types.UnionType for `A | B` syntax + if isinstance(annotation, types.UnionType): + return True + + return False + + def _extract_event_types_from_annotation(self, annotation: Any) -> list[type[BaseHookEvent]]: + """Extract event types from a type annotation.""" + # Handle Union types (Union[A, B] or A | B) + if self._is_union_type(annotation): + args = get_args(annotation) + event_types = [] + for arg in args: + # Skip NoneType in Optional[X] + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, BaseHookEvent): + event_types.append(arg) + else: + raise ValueError(f"All types in Union must be subclasses of BaseHookEvent, got {arg}") + return event_types + + # Single type + if isinstance(annotation, type) and issubclass(annotation, BaseHookEvent): + return [annotation] + + raise ValueError(f"Event type must be a subclass of BaseHookEvent, got {annotation}") + + def _validate_event_types(self) -> None: + """Validate that all event types are valid.""" + if not self._event_types: + raise ValueError(f"Hook function '{self.func.__name__}' must handle at least one event type.") + + for event_type in self._event_types: + if not isinstance(event_type, type) or not issubclass(event_type, BaseHookEvent): + raise ValueError(f"Event type must be a subclass of BaseHookEvent, got {event_type}") + + def extract_metadata(self) -> HookMetadata: + """Extract metadata from the function to create hook specification.""" + return HookMetadata( + name=self.func.__name__, + description=inspect.getdoc(self.func) or self.func.__name__, + event_types=self._event_types, + is_async=inspect.iscoroutinefunction(self.func), + ) + + @property + def event_types(self) -> list[type[BaseHookEvent]]: + """Get the event types this hook handles.""" + return self._event_types + + +class DecoratedFunctionHook(HookProvider, Generic[TEvent]): + """A HookProvider that wraps a function decorated with @hook.""" + + _func: Callable[[TEvent], Any] + _metadata: FunctionHookMetadata + _hook_metadata: HookMetadata + + def __init__( + self, + func: Callable[[TEvent], Any], + metadata: FunctionHookMetadata, + ): + """Initialize the decorated function hook. + + Args: + func: The original function being decorated. + metadata: The FunctionHookMetadata object with extracted function information. + """ + self._func = func + self._metadata = metadata + self._hook_metadata = metadata.extract_metadata() + + # Preserve function metadata + functools.update_wrapper(wrapper=self, wrapped=self._func) + + def __get__(self, instance: Any, obj_type: type[Any] | None = None) -> "DecoratedFunctionHook[TEvent]": + """Descriptor protocol implementation for proper method binding.""" + if instance is not None and not inspect.ismethod(self._func): + # Create a bound method + bound_func = self._func.__get__(instance, instance.__class__) + return DecoratedFunctionHook(bound_func, self._metadata) + + return self + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register callback functions for specific event types.""" + callback = cast(HookCallback[BaseHookEvent], self._func) + for event_type in self._metadata.event_types: + registry.add_callback(event_type, callback) + + def __call__(self, event: TEvent) -> Any: + """Allow direct invocation for testing.""" + return self._func(event) + + @property + def name(self) -> str: + """Get the name of the hook.""" + return self._hook_metadata.name + + @property + def description(self) -> str: + """Get the description of the hook.""" + return self._hook_metadata.description + + @property + def event_types(self) -> list[type[BaseHookEvent]]: + """Get the event types this hook handles.""" + return self._hook_metadata.event_types + + @property + def is_async(self) -> bool: + """Check if this hook is async.""" + return self._hook_metadata.is_async + + def __repr__(self) -> str: + """Return a string representation of the hook.""" + event_names = [e.__name__ for e in self._hook_metadata.event_types] + return f"DecoratedFunctionHook({self._hook_metadata.name}, events={event_names})" + + +# Type variable for the decorated function +F = TypeVar("F", bound=Callable[..., Any]) + + +def hook( + func: F | None = None, +) -> DecoratedFunctionHook[Any] | Callable[[F], DecoratedFunctionHook[Any]]: + """Decorator that transforms a function into a HookProvider. + + The decorated function can be passed directly to Agent(hooks=[...]). + Event types are automatically detected from the function's type hints. + + Args: + func: The function to decorate. + + Returns: + A DecoratedFunctionHook that implements HookProvider. + + Raises: + ValueError: If no event type can be determined from type hints. + ValueError: If event types are not subclasses of BaseHookEvent. + + Example: + ```python + from strands import Agent, hook + from strands.hooks import BeforeToolCallEvent + + @hook + def log_tool_calls(event: BeforeToolCallEvent) -> None: + print(f"Tool: {event.tool_use}") + + agent = Agent(hooks=[log_tool_calls]) + ``` + """ + + def decorator(f: F) -> DecoratedFunctionHook[Any]: + hook_meta = FunctionHookMetadata(f) + return DecoratedFunctionHook(f, hook_meta) + + if func is None: + return decorator + + return decorator(func) diff --git a/tests/strands/hooks/test_decorator.py b/tests/strands/hooks/test_decorator.py new file mode 100644 index 000000000..a14561716 --- /dev/null +++ b/tests/strands/hooks/test_decorator.py @@ -0,0 +1,521 @@ +"""Tests for the @hook decorator.""" + +import asyncio +import unittest.mock as mock +from unittest.mock import MagicMock + +import pytest + +from strands.hooks import ( + AfterToolCallEvent, + BeforeInvocationEvent, + BeforeNodeCallEvent, + BeforeToolCallEvent, + DecoratedFunctionHook, + HookProvider, + HookRegistry, + hook, +) +from strands.hooks.decorator import FunctionHookMetadata, HookMetadata + + +class TestHookDecorator: + """Tests for the @hook decorator function.""" + + def test_basic_decorator_with_type_hint(self): + """Test @hook with type hints extracts event type correctly.""" + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + pass + + assert isinstance(my_hook, DecoratedFunctionHook) + assert my_hook.name == "my_hook" + assert my_hook.event_types == [BeforeToolCallEvent] + assert not my_hook.is_async + + def test_decorator_with_union_type_hint(self): + """Test @hook with Union type hint extracts multiple event types.""" + + @hook + def my_hook(event: BeforeToolCallEvent | AfterToolCallEvent) -> None: + pass + + assert isinstance(my_hook, DecoratedFunctionHook) + assert set(my_hook.event_types) == {BeforeToolCallEvent, AfterToolCallEvent} + + def test_decorator_with_typing_union(self): + """Test @hook with typing.Union type hint.""" + + @hook + def my_hook(event: BeforeToolCallEvent | AfterToolCallEvent) -> None: + pass + + assert isinstance(my_hook, DecoratedFunctionHook) + assert set(my_hook.event_types) == {BeforeToolCallEvent, AfterToolCallEvent} + + def test_async_hook_detection(self): + """Test that async hooks are detected correctly.""" + + @hook + async def async_hook(event: BeforeToolCallEvent) -> None: + pass + + assert async_hook.is_async + + @hook + def sync_hook(event: BeforeToolCallEvent) -> None: + pass + + assert not sync_hook.is_async + + def test_docstring_extraction(self): + """Test that docstring is extracted as description.""" + + @hook + def documented_hook(event: BeforeToolCallEvent) -> None: + """This is a documented hook for testing.""" + pass + + assert documented_hook.description == "This is a documented hook for testing." + + def test_default_description(self): + """Test that function name is used when no docstring.""" + + @hook + def undocumented_hook(event: BeforeToolCallEvent) -> None: + pass + + assert undocumented_hook.description == "undocumented_hook" + + def test_direct_invocation(self): + """Test that decorated hooks can be called directly.""" + received_events = [] + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + received_events.append(event) + + mock_event = MagicMock(spec=BeforeToolCallEvent) + my_hook(mock_event) + + assert len(received_events) == 1 + assert received_events[0] is mock_event + + def test_hook_registration(self): + """Test that hooks register correctly with HookRegistry.""" + callback_called = [] + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + callback_called.append(event) + + registry = HookRegistry() + my_hook.register_hooks(registry) + + mock_agent = MagicMock() + event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use={"toolUseId": "test-123", "name": "test_tool", "input": {}}, + invocation_state={}, + ) + + registry.invoke_callbacks(event) + + assert len(callback_called) == 1 + assert callback_called[0] is event + + def test_multi_event_registration(self): + """Test that multi-event hooks register for all event types.""" + events_received = [] + + @hook + def multi_hook(event: BeforeToolCallEvent | AfterToolCallEvent) -> None: + events_received.append(type(event).__name__) + + registry = HookRegistry() + multi_hook.register_hooks(registry) + + mock_agent = MagicMock() + mock_tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {}} + mock_result = {"toolUseId": "test-123", "status": "success", "content": []} + + before_event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use=mock_tool_use, + invocation_state={}, + ) + after_event = AfterToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use=mock_tool_use, + invocation_state={}, + result=mock_result, + ) + + registry.invoke_callbacks(before_event) + registry.invoke_callbacks(after_event) + + assert "BeforeToolCallEvent" in events_received + assert "AfterToolCallEvent" in events_received + + def test_repr(self): + """Test string representation of decorated hook.""" + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + pass + + repr_str = repr(my_hook) + assert "DecoratedFunctionHook" in repr_str + assert "my_hook" in repr_str + assert "BeforeToolCallEvent" in repr_str + + def test_hook_parentheses_no_args(self): + """Test @hook() syntax with empty parentheses.""" + + @hook() + def my_hook(event: BeforeToolCallEvent) -> None: + pass + + assert isinstance(my_hook, DecoratedFunctionHook) + assert my_hook.event_types == [BeforeToolCallEvent] + + def test_optional_type_hint_extracts_event_type(self): + """Test that Optional[EventType] correctly extracts the event type.""" + + @hook + def optional_hook(event: BeforeToolCallEvent | None) -> None: + pass + + assert isinstance(optional_hook, DecoratedFunctionHook) + assert optional_hook.event_types == [BeforeToolCallEvent] + + +class TestHookDecoratorErrors: + """Tests for error handling in @hook decorator.""" + + def test_no_parameters_error(self): + """Test error when function has no parameters.""" + with pytest.raises(ValueError, match="must have at least one parameter"): + + @hook + def no_params() -> None: + pass + + def test_no_type_hint_error(self): + """Test error when no type hint and no explicit event type.""" + with pytest.raises(ValueError, match="must have a type hint"): + + @hook + def no_hint(event) -> None: + pass + + def test_invalid_union_type_error(self): + """Test error when Union contains non-event types.""" + with pytest.raises(ValueError, match="must be subclasses of BaseHookEvent"): + + @hook + def invalid_union(event: BeforeToolCallEvent | str) -> None: # type: ignore + pass + + def test_invalid_annotation_not_event_type(self): + """Test error when annotation is a non-event class type.""" + + class NotAnEvent: + pass + + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + + @hook + def invalid_hook(event: NotAnEvent) -> None: + pass + + +class TestFunctionHookMetadata: + """Tests for FunctionHookMetadata class.""" + + def test_metadata_extraction(self): + """Test metadata extraction from function.""" + + def my_func(event: BeforeToolCallEvent) -> None: + """A test hook function.""" + pass + + metadata = FunctionHookMetadata(my_func) + hook_meta = metadata.extract_metadata() + + assert isinstance(hook_meta, HookMetadata) + assert hook_meta.name == "my_func" + assert hook_meta.description == "A test hook function." + assert hook_meta.event_types == [BeforeToolCallEvent] + assert not hook_meta.is_async + + def test_event_types_property(self): + """Test FunctionHookMetadata.event_types property.""" + + def my_func(event: BeforeToolCallEvent) -> None: + pass + + metadata = FunctionHookMetadata(my_func) + assert metadata.event_types == [BeforeToolCallEvent] + + +class TestDecoratedFunctionHook: + """Tests for DecoratedFunctionHook class.""" + + def test_hook_provider_protocol(self): + """Test that DecoratedFunctionHook implements HookProvider.""" + + @hook + def my_hook(event: BeforeToolCallEvent) -> None: + pass + + assert hasattr(my_hook, "register_hooks") + assert callable(my_hook.register_hooks) + + def test_function_wrapper_preserves_metadata(self): + """Test that functools.wraps preserves function metadata.""" + + @hook + def original_function(event: BeforeToolCallEvent) -> None: + """Original docstring.""" + pass + + assert original_function.__name__ == "original_function" + assert original_function.__doc__ == "Original docstring." + + +class TestAsyncHooks: + """Tests for async hook support.""" + + def test_async_hook_direct_invocation(self): + """Test async hook direct invocation.""" + received_events = [] + + @hook + async def async_hook(event: BeforeToolCallEvent) -> None: + received_events.append(event) + + mock_event = MagicMock(spec=BeforeToolCallEvent) + asyncio.run(async_hook(mock_event)) + + assert len(received_events) == 1 + assert received_events[0] is mock_event + + def test_async_hook_via_registry(self): + """Test async hook when invoked via registry.""" + received_events = [] + + @hook + async def async_hook(event: BeforeToolCallEvent) -> None: + received_events.append(event) + + registry = HookRegistry() + async_hook.register_hooks(registry) + + mock_agent = MagicMock() + event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use={"toolUseId": "test-123", "name": "test_tool", "input": {}}, + invocation_state={}, + ) + + async def run_callbacks(): + for callback in registry.get_callbacks_for(event): + result = callback(event) + if asyncio.iscoroutine(result): + await result + + asyncio.run(run_callbacks()) + + assert len(received_events) == 1 + + +class TestMixedHooksUsage: + """Tests for using decorated hooks alongside class-based hooks.""" + + def test_mixed_hooks_in_registry(self): + """Test using both decorator and class-based hooks together.""" + decorator_called = [] + class_called = [] + + @hook + def decorator_hook(event: BeforeInvocationEvent) -> None: + decorator_called.append(event) + + class ClassHook(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(BeforeInvocationEvent, self.on_event) + + def on_event(self, event: BeforeInvocationEvent) -> None: + class_called.append(event) + + registry = HookRegistry() + registry.add_hook(decorator_hook) + registry.add_hook(ClassHook()) + + mock_agent = MagicMock() + event = BeforeInvocationEvent(agent=mock_agent) + + registry.invoke_callbacks(event) + + assert len(decorator_called) == 1 + assert len(class_called) == 1 + + +class TestDescriptorProtocol: + """Tests for the __get__ descriptor protocol implementation.""" + + def test_hook_as_class_method(self): + """Test that @hook works correctly on class methods.""" + results = [] + + class MyHooks: + def __init__(self, prefix: str): + self.prefix = prefix + + @hook + def my_hook(self, event: BeforeToolCallEvent) -> None: + results.append(f"{self.prefix}: {event}") + + hooks_instance = MyHooks("test") + bound_hook = hooks_instance.my_hook + + assert isinstance(bound_hook, DecoratedFunctionHook) + + mock_agent = MagicMock() + event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use={"toolUseId": "test-123", "name": "test_tool", "input": {}}, + invocation_state={}, + ) + + bound_hook(event) + + assert len(results) == 1 + assert results[0].startswith("test:") + + def test_hook_class_method_via_registry(self): + """Test that class method hooks work with HookRegistry.""" + results = [] + + class MyHooks: + def __init__(self, name: str): + self.name = name + + @hook + def on_tool_call(self, event: BeforeToolCallEvent) -> None: + results.append({"name": self.name, "event": event}) + + hooks_instance = MyHooks("registry_test") + + registry = HookRegistry() + registry.add_hook(hooks_instance.on_tool_call) + + mock_agent = MagicMock() + event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use={"toolUseId": "test-123", "name": "test_tool", "input": {}}, + invocation_state={}, + ) + + registry.invoke_callbacks(event) + + assert len(results) == 1 + assert results[0]["name"] == "registry_test" + assert results[0]["event"] is event + + def test_hook_accessed_via_class_returns_self(self): + """Test that accessing hook via class (not instance) returns the hook itself.""" + + class MyHooks: + @hook + def my_hook(self, event: BeforeToolCallEvent) -> None: + pass + + class_hook = MyHooks.my_hook + assert isinstance(class_hook, DecoratedFunctionHook) + + def test_hook_different_instances_are_independent(self): + """Test that hooks bound to different instances are independent.""" + results = [] + + class MyHooks: + def __init__(self, name: str): + self.name = name + + @hook + def my_hook(self, event: BeforeToolCallEvent) -> None: + results.append(self.name) + + hooks1 = MyHooks("first") + hooks2 = MyHooks("second") + + mock_agent = MagicMock() + event = BeforeToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use={"toolUseId": "test-123", "name": "test_tool", "input": {}}, + invocation_state={}, + ) + + hooks1.my_hook(event) + hooks2.my_hook(event) + + assert results == ["first", "second"] + + +class TestMultiagentEvents: + """Tests for multiagent event support.""" + + def test_multiagent_hook_works(self): + """Test that hooks for multiagent events work correctly.""" + events_received = [] + + @hook + def node_hook(event: BeforeNodeCallEvent) -> None: + events_received.append(event) + + assert node_hook.event_types == [BeforeNodeCallEvent] + + mock_source = MagicMock() + event = BeforeNodeCallEvent( + source=mock_source, + node_id="test-node", + invocation_state={}, + ) + + node_hook(event) + + assert len(events_received) == 1 + assert events_received[0] is event + + +class TestEdgeCases: + """Edge case tests for remaining coverage gaps.""" + + def test_get_type_hints_exception_fallback(self): + """Test fallback when get_type_hints raises an exception.""" + + def func_with_annotation(event: BeforeToolCallEvent) -> None: + pass + + with mock.patch("strands.hooks.decorator.get_type_hints", side_effect=Exception("Type hint error")): + metadata = FunctionHookMetadata(func_with_annotation) + assert metadata.event_types == [BeforeToolCallEvent] + + def test_annotation_fallback_when_type_hints_empty(self): + """Test annotation is used when get_type_hints returns empty dict.""" + + def func_with_annotation(event: BeforeToolCallEvent) -> None: + pass + + with mock.patch("strands.hooks.decorator.get_type_hints", return_value={}): + metadata = FunctionHookMetadata(func_with_annotation) + assert metadata.event_types == [BeforeToolCallEvent]