-
Notifications
You must be signed in to change notification settings - Fork 0
feat(agent): add configurable retry_strategy for model calls #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
69a9162
cfec198
cc20be0
b5a2d3d
0c4b4f4
587449a
d777c79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| """Retry strategy implementations for handling model throttling and other retry scenarios. | ||
|
|
||
| This module provides hook-based retry strategies that can be configured on the Agent | ||
| to control retry behavior for model invocations. Retry strategies implement the | ||
| HookProvider protocol and register callbacks for AfterModelCallEvent to determine | ||
| when and how to retry failed model calls. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import logging | ||
| from typing import Any | ||
|
|
||
| from ..types.exceptions import ModelThrottledException | ||
| from ..hooks.events import AfterInvocationEvent, AfterModelCallEvent | ||
| from ..hooks.registry import HookProvider, HookRegistry | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class ModelRetryStrategy(HookProvider): | ||
| """Default retry strategy for model throttling with exponential backoff. | ||
|
|
||
| This strategy implements automatic retry logic for model throttling exceptions, | ||
| using exponential backoff to handle rate limiting gracefully. It retries | ||
| model calls when ModelThrottledException is raised, up to a configurable | ||
| maximum number of attempts. | ||
|
|
||
| The delay between retries starts at initial_delay and doubles after each | ||
| retry, up to a maximum of max_delay. The strategy automatically resets | ||
| its state after a successful model call. | ||
|
|
||
| Example: | ||
| ```python | ||
| from strands import Agent | ||
| from strands.hooks import ModelRetryStrategy | ||
|
|
||
| # Use custom retry parameters | ||
| retry_strategy = ModelRetryStrategy( | ||
| max_attempts=3, | ||
| initial_delay=2, | ||
| max_delay=60 | ||
| ) | ||
| agent = Agent(retry_strategy=retry_strategy) | ||
| ``` | ||
|
|
||
| Attributes: | ||
| max_attempts: Maximum number of retry attempts before giving up. | ||
| initial_delay: Initial delay in seconds before the first retry. | ||
| max_delay: Maximum delay in seconds between retries. | ||
| current_attempt: Current retry attempt counter (resets on success). | ||
| current_delay: Current delay value for exponential backoff. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| max_attempts: int = 6, | ||
| initial_delay: int = 4, | ||
| max_delay: int = 240, | ||
| ): | ||
| """Initialize the retry strategy with the specified parameters. | ||
|
|
||
| Args: | ||
| max_attempts: Maximum number of retry attempts. Defaults to 6. | ||
| initial_delay: Initial delay in seconds before retrying. Defaults to 4. | ||
| max_delay: Maximum delay in seconds between retries. Defaults to 240 (4 minutes). | ||
| """ | ||
| self._max_attempts = max_attempts | ||
| self._initial_delay = initial_delay | ||
| self._max_delay = max_delay | ||
| self._current_attempt = 0 | ||
| self._did_trigger_retry = False | ||
|
|
||
| def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: | ||
| """Register callbacks for AfterModelCallEvent and AfterInvocationEvent. | ||
|
|
||
| Args: | ||
| registry: The hook registry to register callbacks with. | ||
| **kwargs: Additional keyword arguments for future extensibility. | ||
| """ | ||
| registry.add_callback(AfterModelCallEvent, self._handle_after_model_call) | ||
| registry.add_callback(AfterInvocationEvent, self._handle_after_invocation) | ||
|
|
||
| def _calculate_delay(self) -> float: | ||
| """Calculate the current retry delay based on attempt number. | ||
|
|
||
| Uses exponential backoff: initial_delay * (2 ** attempt), capped at max_delay. | ||
|
|
||
| Returns: | ||
| The delay in seconds for the current attempt. | ||
| """ | ||
| if self._current_attempt == 0: | ||
| return self._initial_delay | ||
| delay = self._initial_delay * (2 ** (self._current_attempt - 1)) | ||
| return min(delay, self._max_delay) | ||
|
|
||
| @property | ||
| def _current_delay(self) -> float: | ||
| """Get the current retry delay (for backwards compatibility with EventLoopThrottleEvent). | ||
|
|
||
| This property is private and only exists for backwards compatibility with EventLoopThrottleEvent. | ||
| External code should not access this property. | ||
| """ | ||
| return self._calculate_delay() | ||
|
|
||
| def _reset_retry_state(self) -> None: | ||
| """Reset retry state to initial values.""" | ||
| self._current_attempt = 0 | ||
| self._did_trigger_retry = False | ||
|
|
||
| async def _handle_after_invocation(self, event: AfterInvocationEvent) -> None: | ||
| """Reset retry state after invocation completes. | ||
|
|
||
| Args: | ||
| event: The AfterInvocationEvent signaling invocation completion. | ||
| """ | ||
| self._reset_retry_state() | ||
|
|
||
| async def _handle_after_model_call(self, event: AfterModelCallEvent) -> None: | ||
| """Handle model call completion and determine if retry is needed. | ||
|
|
||
| This callback is invoked after each model call. If the call failed with | ||
| a ModelThrottledException and we haven't exceeded max_attempts, it sets | ||
| event.retry to True and sleeps for the current delay before returning. | ||
|
|
||
| On successful calls, it resets the retry state to prepare for future calls. | ||
|
|
||
| Args: | ||
| event: The AfterModelCallEvent containing call results or exception. | ||
| """ | ||
| # If already retrying, skip processing (another hook may have triggered retry) | ||
| if event.retry: | ||
| return | ||
|
|
||
| # If model call succeeded, reset retry state | ||
| if event.stop_response is not None: | ||
| logger.debug( | ||
| "stop_reason=<%s> | model call succeeded, resetting retry state", | ||
| event.stop_response.stop_reason, | ||
| ) | ||
| self._reset_retry_state() | ||
| return | ||
|
|
||
| # Check if we have an exception and reset state if no exception | ||
| if event.exception is None: | ||
zastrowm marked this conversation as resolved.
Show resolved
Hide resolved
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, we should do the same "resetting" as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extracted _reset_retry_state() method for common reset logic. |
||
| self._reset_retry_state() | ||
| return | ||
|
|
||
| # Only retry on ModelThrottledException | ||
| if not isinstance(event.exception, ModelThrottledException): | ||
| return | ||
|
|
||
| # Increment attempt counter first | ||
| self._current_attempt += 1 | ||
|
|
||
| # Check if we've exceeded max attempts | ||
| if self._current_attempt >= self._max_attempts: | ||
| logger.debug( | ||
| "current_attempt=<%d>, max_attempts=<%d> | max retry attempts reached, not retrying", | ||
| self._current_attempt, | ||
| self._max_attempts, | ||
| ) | ||
| self._did_trigger_retry = False | ||
| return | ||
|
|
||
| # Calculate delay for this attempt | ||
| delay = self._calculate_delay() | ||
|
|
||
| # Retry the model call | ||
| logger.debug( | ||
| "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " | ||
| "| throttling exception encountered | delaying before next retry", | ||
| delay, | ||
| self._max_attempts, | ||
| self._current_attempt, | ||
| ) | ||
|
|
||
| # Sleep for current delay | ||
| await asyncio.sleep(delay) | ||
|
|
||
| # Set retry flag and track that this strategy triggered it | ||
| event.retry = True | ||
| self._did_trigger_retry = True | ||
|
|
||
|
|
||
| class NoopRetryStrategy(HookProvider): | ||
| """No-op retry strategy that disables automatic retries. | ||
|
|
||
| This strategy can be used when you want to explicitly disable retry behavior | ||
| and handle errors directly in your application code. It implements the | ||
| HookProvider protocol but does not register any callbacks. | ||
|
|
||
| Example: | ||
| ```python | ||
| from strands import Agent | ||
| from strands.hooks import NoopRetryStrategy | ||
|
|
||
| # Disable automatic retries | ||
| agent = Agent(retry_strategy=NoopRetryStrategy()) | ||
| ``` | ||
| """ | ||
|
|
||
| def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: | ||
| """Register hooks (no-op implementation). | ||
|
|
||
| This method intentionally does nothing, as this strategy disables retries. | ||
|
|
||
| Args: | ||
| registry: The hook registry to register callbacks with. | ||
| **kwargs: Additional keyword arguments for future extensibility. | ||
| """ | ||
| # Intentionally empty - no callbacks to register | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| """Fixtures for agent tests.""" | ||
|
|
||
| import asyncio | ||
| from unittest.mock import AsyncMock | ||
|
|
||
| import pytest | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_sleep(monkeypatch): | ||
| """Mock asyncio.sleep to avoid delays in tests and track sleep calls.""" | ||
| sleep_calls = [] | ||
|
|
||
| async def _mock_sleep(delay): | ||
| sleep_calls.append(delay) | ||
|
|
||
| mock = AsyncMock(side_effect=_mock_sleep) | ||
| monkeypatch.setattr(asyncio, "sleep", mock) | ||
|
|
||
| # Return both the mock and the sleep_calls list for verification | ||
| mock.sleep_calls = sleep_calls | ||
| return mock |
Uh oh!
There was an error while loading. Please reload this page.