Skip to content
4 changes: 4 additions & 0 deletions src/strands/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Agent: The main interface for interacting with AI models and tools
- ConversationManager: Classes for managing conversation history and context windows
- Retry Strategies: Configurable retry behavior for model calls
"""

from .agent import Agent
Expand All @@ -14,6 +15,7 @@
SlidingWindowConversationManager,
SummarizingConversationManager,
)
from .retry import ModelRetryStrategy, NoopRetryStrategy

__all__ = [
"Agent",
Expand All @@ -22,4 +24,6 @@
"NullConversationManager",
"SlidingWindowConversationManager",
"SummarizingConversationManager",
"ModelRetryStrategy",
"NoopRetryStrategy",
]
21 changes: 21 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
hooks: Optional[list[HookProvider]] = None,
session_manager: Optional[SessionManager] = None,
tool_executor: Optional[ToolExecutor] = None,
retry_strategy: Optional[HookProvider] = None,
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -173,6 +174,9 @@ def __init__(
session_manager: Manager for handling agent sessions including conversation history and state.
If provided, enables session-based persistence and state management.
tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.).
retry_strategy: Strategy for retrying model calls on throttling or other transient errors.
Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s.
Pass NoopRetryStrategy to disable retries, or implement a custom HookProvider for custom retry logic.

Raises:
ValueError: If agent id contains path separators.
Expand Down Expand Up @@ -245,6 +249,11 @@ def __init__(

self._interrupt_state = _InterruptState()

# Initialize retry strategy
from .retry import ModelRetryStrategy

self._retry_strategy = retry_strategy if retry_strategy is not None else ModelRetryStrategy()

# Initialize session management functionality
self._session_manager = session_manager
if self._session_manager:
Expand All @@ -253,6 +262,9 @@ def __init__(
# Allow conversation_managers to subscribe to hooks
self.hooks.add_hook(self.conversation_manager)

# Register retry strategy as a hook
self.hooks.add_hook(self._retry_strategy)

self.tool_executor = tool_executor or ConcurrentToolExecutor()

if hooks:
Expand Down Expand Up @@ -289,6 +301,15 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None:
"""
self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value)

@property
def retry_strategy(self) -> HookProvider:
"""Get the retry strategy for this agent.

Returns:
The retry strategy hook provider.
"""
return self._retry_strategy

@property
def tool(self) -> _ToolCaller:
"""Call tool as a function.
Expand Down
212 changes: 212 additions & 0 deletions src/strands/agent/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""Retry strategy implementations for handling model throttling and other retry scenarios.

This module provides hook-based retry strategies that can be configured on the Agent
to control retry behavior for model invocations. Retry strategies implement the
HookProvider protocol and register callbacks for AfterModelCallEvent to determine
when and how to retry failed model calls.
"""

import asyncio
import logging
from typing import Any

from ..types.exceptions import ModelThrottledException
from ..hooks.events import AfterInvocationEvent, AfterModelCallEvent
from ..hooks.registry import HookProvider, HookRegistry

logger = logging.getLogger(__name__)


class ModelRetryStrategy(HookProvider):
"""Default retry strategy for model throttling with exponential backoff.

This strategy implements automatic retry logic for model throttling exceptions,
using exponential backoff to handle rate limiting gracefully. It retries
model calls when ModelThrottledException is raised, up to a configurable
maximum number of attempts.

The delay between retries starts at initial_delay and doubles after each
retry, up to a maximum of max_delay. The strategy automatically resets
its state after a successful model call.

Example:
```python
from strands import Agent
from strands.hooks import ModelRetryStrategy

# Use custom retry parameters
retry_strategy = ModelRetryStrategy(
max_attempts=3,
initial_delay=2,
max_delay=60
)
agent = Agent(retry_strategy=retry_strategy)
```

Attributes:
max_attempts: Maximum number of retry attempts before giving up.
initial_delay: Initial delay in seconds before the first retry.
max_delay: Maximum delay in seconds between retries.
current_attempt: Current retry attempt counter (resets on success).
current_delay: Current delay value for exponential backoff.
"""

def __init__(
self,
max_attempts: int = 6,
initial_delay: int = 4,
max_delay: int = 240,
):
"""Initialize the retry strategy with the specified parameters.

Args:
max_attempts: Maximum number of retry attempts. Defaults to 6.
initial_delay: Initial delay in seconds before retrying. Defaults to 4.
max_delay: Maximum delay in seconds between retries. Defaults to 240 (4 minutes).
"""
self._max_attempts = max_attempts
self._initial_delay = initial_delay
self._max_delay = max_delay
self._current_attempt = 0
self._did_trigger_retry = False

def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
"""Register callbacks for AfterModelCallEvent and AfterInvocationEvent.

Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.
"""
registry.add_callback(AfterModelCallEvent, self._handle_after_model_call)
registry.add_callback(AfterInvocationEvent, self._handle_after_invocation)

def _calculate_delay(self) -> float:
"""Calculate the current retry delay based on attempt number.

Uses exponential backoff: initial_delay * (2 ** attempt), capped at max_delay.

Returns:
The delay in seconds for the current attempt.
"""
if self._current_attempt == 0:
return self._initial_delay
delay = self._initial_delay * (2 ** (self._current_attempt - 1))
return min(delay, self._max_delay)

@property
def _current_delay(self) -> float:
"""Get the current retry delay (for backwards compatibility with EventLoopThrottleEvent).

This property is private and only exists for backwards compatibility with EventLoopThrottleEvent.
External code should not access this property.
"""
return self._calculate_delay()

def _reset_retry_state(self) -> None:
"""Reset retry state to initial values."""
self._current_attempt = 0
self._did_trigger_retry = False

async def _handle_after_invocation(self, event: AfterInvocationEvent) -> None:
"""Reset retry state after invocation completes.

Args:
event: The AfterInvocationEvent signaling invocation completion.
"""
self._reset_retry_state()

async def _handle_after_model_call(self, event: AfterModelCallEvent) -> None:
"""Handle model call completion and determine if retry is needed.

This callback is invoked after each model call. If the call failed with
a ModelThrottledException and we haven't exceeded max_attempts, it sets
event.retry to True and sleeps for the current delay before returning.

On successful calls, it resets the retry state to prepare for future calls.

Args:
event: The AfterModelCallEvent containing call results or exception.
"""
# If already retrying, skip processing (another hook may have triggered retry)
if event.retry:
return

# If model call succeeded, reset retry state
if event.stop_response is not None:
logger.debug(
"stop_reason=<%s> | model call succeeded, resetting retry state",
event.stop_response.stop_reason,
)
self._reset_retry_state()
return

# Check if we have an exception and reset state if no exception
if event.exception is None:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we should do the same "resetting" as _handle_after_invocation; in fact we shoul dhave a common method for that

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted _reset_retry_state() method for common reset logic.

self._reset_retry_state()
return

# Only retry on ModelThrottledException
if not isinstance(event.exception, ModelThrottledException):
return

# Increment attempt counter first
self._current_attempt += 1

# Check if we've exceeded max attempts
if self._current_attempt >= self._max_attempts:
logger.debug(
"current_attempt=<%d>, max_attempts=<%d> | max retry attempts reached, not retrying",
self._current_attempt,
self._max_attempts,
)
self._did_trigger_retry = False
return

# Calculate delay for this attempt
delay = self._calculate_delay()

# Retry the model call
logger.debug(
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
"| throttling exception encountered | delaying before next retry",
delay,
self._max_attempts,
self._current_attempt,
)

# Sleep for current delay
await asyncio.sleep(delay)

# Set retry flag and track that this strategy triggered it
event.retry = True
self._did_trigger_retry = True


class NoopRetryStrategy(HookProvider):
"""No-op retry strategy that disables automatic retries.

This strategy can be used when you want to explicitly disable retry behavior
and handle errors directly in your application code. It implements the
HookProvider protocol but does not register any callbacks.

Example:
```python
from strands import Agent
from strands.hooks import NoopRetryStrategy

# Disable automatic retries
agent = Agent(retry_strategy=NoopRetryStrategy())
```
"""

def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
"""Register hooks (no-op implementation).

This method intentionally does nothing, as this strategy disables retries.

Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.
"""
# Intentionally empty - no callbacks to register
pass
47 changes: 17 additions & 30 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@

logger = logging.getLogger(__name__)

MAX_ATTEMPTS = 6
INITIAL_DELAY = 4
MAX_DELAY = 240 # 4 minutes


def _has_tool_use_in_latest_message(messages: "Messages") -> bool:
"""Check if the latest message contains any ToolUse content blocks.
Expand Down Expand Up @@ -315,9 +311,9 @@ async def _handle_model_execution(
stream_trace = Trace("stream_messages", parent_id=cycle_trace.id)
cycle_trace.add_child(stream_trace)

# Retry loop for handling throttling exceptions
current_delay = INITIAL_DELAY
for attempt in range(MAX_ATTEMPTS):
# Retry loop - actual retry logic is handled by retry_strategy hook
# Hooks control when to stop retrying via the event.retry flag
while True:
model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None
model_invoke_span = tracer.start_model_invoke_span(
messages=agent.messages,
Expand Down Expand Up @@ -364,10 +360,14 @@ async def _handle_model_execution(
# Check if hooks want to retry the model call
if after_model_call_event.retry:
logger.debug(
"stop_reason=<%s>, retry_requested=<True>, attempt=<%d> | hook requested model retry",
"stop_reason=<%s>, retry_requested=<True> | hook requested model retry",
stop_reason,
attempt + 1,
)
# Emit EventLoopThrottleEvent for backwards compatibility if ModelRetryStrategy triggered retry
from ..agent.retry import ModelRetryStrategy

if isinstance(agent.retry_strategy, ModelRetryStrategy) and agent.retry_strategy._did_trigger_retry:
yield EventLoopThrottleEvent(delay=agent.retry_strategy._current_delay)
continue # Retry the model call

if stop_reason == "max_tokens":
Expand All @@ -390,31 +390,18 @@ async def _handle_model_execution(
# Check if hooks want to retry the model call
if after_model_call_event.retry:
logger.debug(
"exception=<%s>, retry_requested=<True>, attempt=<%d> | hook requested model retry",
"exception=<%s>, retry_requested=<True> | hook requested model retry",
type(e).__name__,
attempt + 1,
)
continue # Retry the model call

if isinstance(e, ModelThrottledException):
if attempt + 1 == MAX_ATTEMPTS:
yield ForceStopEvent(reason=e)
raise e
# Emit EventLoopThrottleEvent for backwards compatibility if ModelRetryStrategy triggered retry
from ..agent.retry import ModelRetryStrategy

logger.debug(
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
"| throttling exception encountered "
"| delaying before next retry",
current_delay,
MAX_ATTEMPTS,
attempt + 1,
)
await asyncio.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)
if isinstance(agent.retry_strategy, ModelRetryStrategy) and agent.retry_strategy._did_trigger_retry:
yield EventLoopThrottleEvent(delay=agent.retry_strategy._current_delay)
continue # Retry the model call

yield EventLoopThrottleEvent(delay=current_delay)
else:
raise e
# No retry requested, raise the exception
raise e

try:
# Add message in trace and mark the end of the stream messages trace
Expand Down
22 changes: 22 additions & 0 deletions tests/strands/agent/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Fixtures for agent tests."""

import asyncio
from unittest.mock import AsyncMock

import pytest


@pytest.fixture
def mock_sleep(monkeypatch):
"""Mock asyncio.sleep to avoid delays in tests and track sleep calls."""
sleep_calls = []

async def _mock_sleep(delay):
sleep_calls.append(delay)

mock = AsyncMock(side_effect=_mock_sleep)
monkeypatch.setattr(asyncio, "sleep", mock)

# Return both the mock and the sleep_calls list for verification
mock.sleep_calls = sleep_calls
return mock
Loading