generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 601
feat(agent): add configurable retry_strategy for model calls #1424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zastrowm
wants to merge
15
commits into
strands-agents:main
Choose a base branch
from
zastrowm:configure_agent_retry
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
fdac712
feat(agent): add configurable retry_strategy for model calls
strands-agent ef83070
Yield backwards compatible events + update initial delay to be correct
zastrowm 9caab19
Always emit a ForceStopEvent when an exception bubbles
zastrowm fb78e58
Condense the doc strings down
zastrowm bfc71ab
Remove NoopRetryStrategy
zastrowm 4162947
Merges upstream main into configure_agent_retry branch
zastrowm d154c66
Adds strict type validation for Agent retry strategy
zastrowm 6ea2fe2
fix: move imports to top
zastrowm 1dc7cad
Merge remote-tracking branch 'upstream/main' into configure_agent_retry
zastrowm 14dc7e8
Tweaks after merge
zastrowm 21d688f
Move retry functionality into event_loop though still exposed at top …
zastrowm b122ec2
fix: linting error
zastrowm c71078c
Rename retry file to include underscore
zastrowm 083be7e
fix: local imports
zastrowm 71c5109
Remove one line
zastrowm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| """Retry strategy implementations for handling model throttling and other retry scenarios. | ||
|
|
||
| This module provides hook-based retry strategies that can be configured on the Agent | ||
| to control retry behavior for model invocations. Retry strategies implement the | ||
| HookProvider protocol and register callbacks for AfterModelCallEvent to determine | ||
| when and how to retry failed model calls. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import logging | ||
| from typing import Any | ||
|
|
||
| from ..hooks.events import AfterInvocationEvent, AfterModelCallEvent | ||
| from ..hooks.registry import HookProvider, HookRegistry | ||
| from ..types._events import EventLoopThrottleEvent, TypedEvent | ||
| from ..types.exceptions import ModelThrottledException | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class ModelRetryStrategy(HookProvider): | ||
| """Default retry strategy for model throttling with exponential backoff. | ||
|
|
||
| Retries model calls on ModelThrottledException using exponential backoff. | ||
| Delay doubles after each attempt: initial_delay, initial_delay*2, initial_delay*4, | ||
| etc., capped at max_delay. State resets after successful calls. | ||
|
|
||
| With defaults (initial_delay=4, max_delay=240, max_attempts=6), delays are: | ||
| 4s → 8s → 16s → 32s → 64s (5 retries before giving up on the 6th attempt). | ||
|
|
||
| Args: | ||
| max_attempts: Total model attempts before re-raising the exception. | ||
| initial_delay: Base delay in seconds; used for first two retries, then doubles. | ||
| max_delay: Upper bound in seconds for the exponential backoff. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| max_attempts: int = 6, | ||
| initial_delay: int = 4, | ||
| max_delay: int = 240, | ||
| ): | ||
| """Initialize the retry strategy. | ||
|
|
||
| Args: | ||
| max_attempts: Total model attempts before re-raising the exception. Defaults to 6. | ||
| initial_delay: Base delay in seconds; used for first two retries, then doubles. | ||
| Defaults to 4. | ||
| max_delay: Upper bound in seconds for the exponential backoff. Defaults to 240. | ||
| """ | ||
| self._max_attempts = max_attempts | ||
| self._initial_delay = initial_delay | ||
| self._max_delay = max_delay | ||
| self._current_attempt = 0 | ||
| self._backwards_compatible_event_to_yield: TypedEvent | None = None | ||
|
|
||
| def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: | ||
| """Register callbacks for AfterModelCallEvent and AfterInvocationEvent. | ||
|
|
||
| Args: | ||
| registry: The hook registry to register callbacks with. | ||
| **kwargs: Additional keyword arguments for future extensibility. | ||
| """ | ||
| registry.add_callback(AfterModelCallEvent, self._handle_after_model_call) | ||
| registry.add_callback(AfterInvocationEvent, self._handle_after_invocation) | ||
|
|
||
| def _calculate_delay(self, attempt: int) -> int: | ||
| """Calculate retry delay using exponential backoff. | ||
|
|
||
| Args: | ||
| attempt: The attempt number (0-indexed) to calculate delay for. | ||
|
|
||
| Returns: | ||
| Delay in seconds for the given attempt. | ||
| """ | ||
| delay: int = self._initial_delay * (2**attempt) | ||
| return min(delay, self._max_delay) | ||
|
|
||
| def _reset_retry_state(self) -> None: | ||
| """Reset retry state to initial values.""" | ||
| self._current_attempt = 0 | ||
|
|
||
| async def _handle_after_invocation(self, event: AfterInvocationEvent) -> None: | ||
| """Reset retry state after invocation completes. | ||
|
|
||
| Args: | ||
| event: The AfterInvocationEvent signaling invocation completion. | ||
| """ | ||
| self._reset_retry_state() | ||
|
|
||
| async def _handle_after_model_call(self, event: AfterModelCallEvent) -> None: | ||
| """Handle model call completion and determine if retry is needed. | ||
|
|
||
| This callback is invoked after each model call. If the call failed with | ||
| a ModelThrottledException and we haven't exceeded max_attempts, it sets | ||
| event.retry to True and sleeps for the current delay before returning. | ||
|
|
||
| On successful calls, it resets the retry state to prepare for future calls. | ||
|
|
||
| Args: | ||
| event: The AfterModelCallEvent containing call results or exception. | ||
| """ | ||
| delay = self._calculate_delay(self._current_attempt) | ||
|
|
||
| self._backwards_compatible_event_to_yield = None | ||
|
|
||
| # If already retrying, skip processing (another hook may have triggered retry) | ||
| if event.retry: | ||
| return | ||
|
|
||
| # If model call succeeded, reset retry state | ||
| if event.stop_response is not None: | ||
| logger.debug( | ||
| "stop_reason=<%s> | model call succeeded, resetting retry state", | ||
| event.stop_response.stop_reason, | ||
| ) | ||
| self._reset_retry_state() | ||
| return | ||
|
|
||
| # Check if we have an exception and reset state if no exception | ||
| if event.exception is None: | ||
| self._reset_retry_state() | ||
| return | ||
|
|
||
| # Only retry on ModelThrottledException | ||
| if not isinstance(event.exception, ModelThrottledException): | ||
| return | ||
|
|
||
| # Increment attempt counter first | ||
| self._current_attempt += 1 | ||
|
|
||
| # Check if we've exceeded max attempts | ||
| if self._current_attempt >= self._max_attempts: | ||
| logger.debug( | ||
| "current_attempt=<%d>, max_attempts=<%d> | max retry attempts reached, not retrying", | ||
| self._current_attempt, | ||
| self._max_attempts, | ||
| ) | ||
| return | ||
|
|
||
| self._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=delay) | ||
|
|
||
| # Retry the model call | ||
| logger.debug( | ||
| "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " | ||
| "| throttling exception encountered | delaying before next retry", | ||
| delay, | ||
| self._max_attempts, | ||
| self._current_attempt, | ||
| ) | ||
|
|
||
| # Sleep for current delay | ||
| await asyncio.sleep(delay) | ||
|
|
||
| # Set retry flag and track that this strategy triggered it | ||
| event.retry = True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| """Fixtures for agent tests.""" | ||
|
|
||
| import asyncio | ||
| from unittest.mock import AsyncMock | ||
|
|
||
| import pytest | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_sleep(monkeypatch): | ||
| """Mock asyncio.sleep to avoid delays in tests and track sleep calls.""" | ||
| sleep_calls = [] | ||
|
|
||
| async def _mock_sleep(delay): | ||
| sleep_calls.append(delay) | ||
|
|
||
| mock = AsyncMock(side_effect=_mock_sleep) | ||
| monkeypatch.setattr(asyncio, "sleep", mock) | ||
|
|
||
| # Return both the mock and the sleep_calls list for verification | ||
| mock.sleep_calls = sleep_calls | ||
| return mock |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.