-
Notifications
You must be signed in to change notification settings - Fork 653
feat(agent): add add_hook convenience method for hook callback registration #1706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
20f337a
5009a02
9f0b8b5
0842479
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,10 +37,12 @@ | |
| AfterInvocationEvent, | ||
| AgentInitializedEvent, | ||
| BeforeInvocationEvent, | ||
| HookCallback, | ||
| HookProvider, | ||
| HookRegistry, | ||
| MessageAddedEvent, | ||
| ) | ||
| from ..hooks.registry import TEvent | ||
| from ..interrupt import _InterruptState | ||
| from ..models.bedrock import BedrockModel | ||
| from ..models.model import Model | ||
|
|
@@ -567,6 +569,50 @@ def cleanup(self) -> None: | |
| """ | ||
| self.tool_registry.cleanup() | ||
|
|
||
| def add_hook( | ||
| self, | ||
| callback: HookCallback[TEvent], | ||
| event_type: type[TEvent] | None = None, | ||
Unshure marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) -> None: | ||
| """Register a callback function for a specific event type. | ||
|
|
||
| This method supports two call patterns: | ||
| 1. ``add_hook(callback)`` - Event type inferred from callback's type hint | ||
| 2. ``add_hook(callback, event_type)`` - Event type specified explicitly | ||
|
|
||
| Callbacks can be either synchronous or asynchronous functions. | ||
|
|
||
| Args: | ||
| callback: The callback function to invoke when events of this type occur. | ||
| event_type: The class type of events this callback should handle. | ||
| If not provided, the event type will be inferred from the callback's | ||
| first parameter type hint. | ||
|
|
||
| Raises: | ||
| ValueError: If event_type is not provided and cannot be inferred from | ||
| the callback's type hints. | ||
|
|
||
| Example: | ||
| ```python | ||
| def log_model_call(event: BeforeModelCallEvent) -> None: | ||
| print(f"Calling model for agent: {event.agent.name}") | ||
|
|
||
| agent = Agent() | ||
|
|
||
| # With event type inferred from type hint | ||
| agent.add_hook(log_model_call) | ||
|
|
||
| # With explicit event type | ||
| agent.add_hook(log_model_call, BeforeModelCallEvent) | ||
| ``` | ||
| Docs: | ||
| https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ | ||
| """ | ||
| if event_type is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need an if/else condition here? |
||
| self.hooks.add_callback(event_type, callback) | ||
| else: | ||
| self.hooks.add_callback(callback) | ||
|
|
||
| def __del__(self) -> None: | ||
| """Clean up resources when agent is garbage collected.""" | ||
| # __del__ is called even when an exception is thrown in the constructor, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,15 @@ | |
| import logging | ||
| from collections.abc import Awaitable, Generator | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| Generic, | ||
| Protocol, | ||
| TypeVar, | ||
| get_type_hints, | ||
| runtime_checkable, | ||
| ) | ||
|
|
||
| from ..interrupt import Interrupt, InterruptException | ||
|
|
||
|
|
@@ -157,27 +165,116 @@ def __init__(self) -> None: | |
| """Initialize an empty hook registry.""" | ||
| self._registered_callbacks: dict[type, list[HookCallback]] = {} | ||
|
|
||
| def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: | ||
| def add_callback( | ||
Unshure marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| event_type: type[TEvent] | HookCallback[TEvent] | None = None, | ||
Unshure marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| callback: HookCallback[TEvent] | None = None, | ||
| ) -> None: | ||
| """Register a callback function for a specific event type. | ||
|
|
||
| This method supports two call patterns: | ||
| 1. ``add_callback(callback)`` - Event type inferred from callback's type hint | ||
| 2. ``add_callback(event_type, callback)`` - Event type specified explicitly | ||
|
|
||
| Args: | ||
| event_type: The class type of events this callback should handle. | ||
| When using the single-argument form, pass the callback here instead. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we solve this another way? This seems like a bad practice to overload methods 😅 I would rather have an explicit enum for event type to auto-discovery, instead of doing this. |
||
| callback: The callback function to invoke when events of this type occur. | ||
|
|
||
| Raises: | ||
| ValueError: If event_type is not provided and cannot be inferred from | ||
| the callback's type hints, or if AgentInitializedEvent is registered | ||
| with an async callback. | ||
|
|
||
| Example: | ||
| ```python | ||
| def my_handler(event: StartRequestEvent): | ||
| print("Request started") | ||
|
|
||
| # With explicit event type | ||
| registry.add_callback(StartRequestEvent, my_handler) | ||
|
|
||
| # With event type inferred from type hint | ||
| registry.add_callback(my_handler) | ||
| ``` | ||
| """ | ||
| resolved_callback: HookCallback[TEvent] | ||
| resolved_event_type: type[TEvent] | ||
|
|
||
| # Support both add_callback(callback) and add_callback(event_type, callback) | ||
| if callback is None: | ||
| if event_type is None: | ||
| raise ValueError("callback is required") | ||
| # First argument is actually the callback, infer event_type | ||
| if callable(event_type) and not isinstance(event_type, type): | ||
| resolved_callback = event_type | ||
| resolved_event_type = self._infer_event_type(resolved_callback) | ||
| else: | ||
| raise ValueError("callback is required when event_type is a type") | ||
| elif event_type is None: | ||
| # callback provided but event_type is None - infer it | ||
| resolved_callback = callback | ||
| resolved_event_type = self._infer_event_type(callback) | ||
| else: | ||
| # Both provided - event_type should be a type | ||
| if isinstance(event_type, type): | ||
| resolved_callback = callback | ||
| resolved_event_type = event_type | ||
| else: | ||
| raise ValueError("event_type must be a type when callback is provided") | ||
|
|
||
| # Related issue: https://github.com/strands-agents/sdk-python/issues/330 | ||
| if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): | ||
| if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(resolved_callback): | ||
| raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") | ||
|
|
||
| callbacks = self._registered_callbacks.setdefault(event_type, []) | ||
| callbacks.append(callback) | ||
| callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) | ||
| callbacks.append(resolved_callback) | ||
|
|
||
| def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: | ||
| """Infer the event type from a callback's type hints. | ||
|
|
||
| Args: | ||
| callback: The callback function to inspect. | ||
|
|
||
| Returns: | ||
| The event type inferred from the callback's first parameter type hint. | ||
|
|
||
| Raises: | ||
| ValueError: If the event type cannot be inferred from the callback's type hints. | ||
| """ | ||
| try: | ||
| hints = get_type_hints(callback) | ||
| except Exception as e: | ||
| logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) | ||
| raise ValueError( | ||
| "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" | ||
| ) from e | ||
|
|
||
| # Get the first parameter's type hint | ||
| sig = inspect.signature(callback) | ||
| params = list(sig.parameters.values()) | ||
|
|
||
| if not params: | ||
| raise ValueError( | ||
| "callback has no parameters | cannot infer event type, please provide event_type explicitly" | ||
| ) | ||
|
|
||
| first_param = params[0] | ||
| type_hint = hints.get(first_param.name) | ||
|
|
||
| if type_hint is None: | ||
| raise ValueError( | ||
| f"parameter=<{first_param.name}> has no type hint | " | ||
| "cannot infer event type, please provide event_type explicitly" | ||
| ) | ||
|
|
||
| # Handle single type | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No union support?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will address this as a follow-up once this is merged |
||
| if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): | ||
| return type_hint # type: ignore[return-value] | ||
|
|
||
| raise ValueError( | ||
| f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" | ||
| ) | ||
|
|
||
| def add_hook(self, hook: HookProvider) -> None: | ||
| """Register all callbacks from a hook provider. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.