Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import agent, models, telemetry, types
from .agent.agent import Agent
from .agent.base import AgentBase
from .event_loop._retry import ModelRetryStrategy
from .tools.decorator import tool
from .types.tools import ToolContext

Expand All @@ -11,6 +12,7 @@
"AgentBase",
"agent",
"models",
"ModelRetryStrategy",
"tool",
"ToolContext",
"types",
Expand Down
3 changes: 3 additions & 0 deletions src/strands/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

- Agent: The main interface for interacting with AI models and tools
- ConversationManager: Classes for managing conversation history and context windows
- Retry Strategies: Configurable retry behavior for model calls
"""

from ..event_loop._retry import ModelRetryStrategy
from .agent import Agent
from .agent_result import AgentResult
from .base import AgentBase
Expand All @@ -24,4 +26,5 @@
"NullConversationManager",
"SlidingWindowConversationManager",
"SummarizingConversationManager",
"ModelRetryStrategy",
]
30 changes: 29 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

from .. import _identifier
from .._async import run_async
from ..event_loop.event_loop import event_loop_cycle
from ..event_loop._retry import ModelRetryStrategy
from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle
from ..tools._tool_helpers import generate_missing_tool_result_content

if TYPE_CHECKING:
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
hooks: list[HookProvider] | None = None,
session_manager: SessionManager | None = None,
tool_executor: ToolExecutor | None = None,
retry_strategy: ModelRetryStrategy | None = None,
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -167,6 +169,9 @@ def __init__(
session_manager: Manager for handling agent sessions including conversation history and state.
If provided, enables session-based persistence and state management.
tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.).
retry_strategy: Strategy for retrying model calls on throttling or other transient errors.
Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s.
Implement a custom HookProvider for custom retry logic, or pass None to disable retries.

Raises:
ValueError: If agent id contains path separators.
Expand Down Expand Up @@ -244,6 +249,17 @@ def __init__(
# separate event loops in different threads, so asyncio.Lock wouldn't work
self._invocation_lock = threading.Lock()

# In the future, we'll have a RetryStrategy base class but until
# that API is determined we only allow ModelRetryStrategy
if retry_strategy and type(retry_strategy) is not ModelRetryStrategy:
raise ValueError("retry_strategy must be an instance of ModelRetryStrategy")

self._retry_strategy = (
retry_strategy
if retry_strategy is not None
else ModelRetryStrategy(max_attempts=MAX_ATTEMPTS, max_delay=MAX_DELAY, initial_delay=INITIAL_DELAY)
)

# Initialize session management functionality
self._session_manager = session_manager
if self._session_manager:
Expand All @@ -252,6 +268,9 @@ def __init__(
# Allow conversation_managers to subscribe to hooks
self.hooks.add_hook(self.conversation_manager)

# Register retry strategy as a hook
self.hooks.add_hook(self._retry_strategy)

self.tool_executor = tool_executor or ConcurrentToolExecutor()

if hooks:
Expand Down Expand Up @@ -288,6 +307,15 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None:
"""
self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value)

@property
def retry_strategy(self) -> HookProvider:
"""Get the retry strategy for this agent.

Returns:
The retry strategy hook provider.
"""
return self._retry_strategy

@property
def tool(self) -> _ToolCaller:
"""Call tool as a function.
Expand Down
157 changes: 157 additions & 0 deletions src/strands/event_loop/_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Retry strategy implementations for handling model throttling and other retry scenarios.

This module provides hook-based retry strategies that can be configured on the Agent
to control retry behavior for model invocations. Retry strategies implement the
HookProvider protocol and register callbacks for AfterModelCallEvent to determine
when and how to retry failed model calls.
"""

import asyncio
import logging
from typing import Any

from ..hooks.events import AfterInvocationEvent, AfterModelCallEvent
from ..hooks.registry import HookProvider, HookRegistry
from ..types._events import EventLoopThrottleEvent, TypedEvent
from ..types.exceptions import ModelThrottledException

logger = logging.getLogger(__name__)


class ModelRetryStrategy(HookProvider):
"""Default retry strategy for model throttling with exponential backoff.

Retries model calls on ModelThrottledException using exponential backoff.
Delay doubles after each attempt: initial_delay, initial_delay*2, initial_delay*4,
etc., capped at max_delay. State resets after successful calls.

With defaults (initial_delay=4, max_delay=240, max_attempts=6), delays are:
4s → 8s → 16s → 32s → 64s (5 retries before giving up on the 6th attempt).

Args:
max_attempts: Total model attempts before re-raising the exception.
initial_delay: Base delay in seconds; used for first two retries, then doubles.
max_delay: Upper bound in seconds for the exponential backoff.
"""

def __init__(
self,
*,
max_attempts: int = 6,
initial_delay: int = 4,
max_delay: int = 240,
):
"""Initialize the retry strategy.

Args:
max_attempts: Total model attempts before re-raising the exception. Defaults to 6.
initial_delay: Base delay in seconds; used for first two retries, then doubles.
Defaults to 4.
max_delay: Upper bound in seconds for the exponential backoff. Defaults to 240.
"""
self._max_attempts = max_attempts
self._initial_delay = initial_delay
self._max_delay = max_delay
self._current_attempt = 0
self._backwards_compatible_event_to_yield: TypedEvent | None = None

def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
"""Register callbacks for AfterModelCallEvent and AfterInvocationEvent.

Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.
"""
registry.add_callback(AfterModelCallEvent, self._handle_after_model_call)
registry.add_callback(AfterInvocationEvent, self._handle_after_invocation)

def _calculate_delay(self, attempt: int) -> int:
"""Calculate retry delay using exponential backoff.

Args:
attempt: The attempt number (0-indexed) to calculate delay for.

Returns:
Delay in seconds for the given attempt.
"""
delay: int = self._initial_delay * (2**attempt)
return min(delay, self._max_delay)

def _reset_retry_state(self) -> None:
"""Reset retry state to initial values."""
self._current_attempt = 0

async def _handle_after_invocation(self, event: AfterInvocationEvent) -> None:
"""Reset retry state after invocation completes.

Args:
event: The AfterInvocationEvent signaling invocation completion.
"""
self._reset_retry_state()

async def _handle_after_model_call(self, event: AfterModelCallEvent) -> None:
"""Handle model call completion and determine if retry is needed.

This callback is invoked after each model call. If the call failed with
a ModelThrottledException and we haven't exceeded max_attempts, it sets
event.retry to True and sleeps for the current delay before returning.

On successful calls, it resets the retry state to prepare for future calls.

Args:
event: The AfterModelCallEvent containing call results or exception.
"""
delay = self._calculate_delay(self._current_attempt)

self._backwards_compatible_event_to_yield = None

# If already retrying, skip processing (another hook may have triggered retry)
if event.retry:
return

# If model call succeeded, reset retry state
if event.stop_response is not None:
logger.debug(
"stop_reason=<%s> | model call succeeded, resetting retry state",
event.stop_response.stop_reason,
)
self._reset_retry_state()
return

# Check if we have an exception and reset state if no exception
if event.exception is None:
self._reset_retry_state()
return

# Only retry on ModelThrottledException
if not isinstance(event.exception, ModelThrottledException):
return

# Increment attempt counter first
self._current_attempt += 1

# Check if we've exceeded max attempts
if self._current_attempt >= self._max_attempts:
logger.debug(
"current_attempt=<%d>, max_attempts=<%d> | max retry attempts reached, not retrying",
self._current_attempt,
self._max_attempts,
)
return

self._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=delay)

# Retry the model call
logger.debug(
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
"| throttling exception encountered | delaying before next retry",
delay,
self._max_attempts,
self._current_attempt,
)

# Sleep for current delay
await asyncio.sleep(delay)

# Set retry flag and track that this strategy triggered it
event.retry = True
48 changes: 19 additions & 29 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
4. Manage recursive execution cycles
"""

import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
Expand All @@ -23,7 +22,6 @@
from ..tools.structured_output._structured_output_context import StructuredOutputContext
from ..types._events import (
EventLoopStopEvent,
EventLoopThrottleEvent,
ForceStopEvent,
ModelMessageEvent,
ModelStopReason,
Expand All @@ -39,12 +37,12 @@
ContextWindowOverflowException,
EventLoopException,
MaxTokensReachedException,
ModelThrottledException,
StructuredOutputException,
)
from ..types.streaming import StopReason
from ..types.tools import ToolResult, ToolUse
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
from ._retry import ModelRetryStrategy
from .streaming import stream_messages

if TYPE_CHECKING:
Expand Down Expand Up @@ -316,9 +314,9 @@ async def _handle_model_execution(
stream_trace = Trace("stream_messages", parent_id=cycle_trace.id)
cycle_trace.add_child(stream_trace)

# Retry loop for handling throttling exceptions
current_delay = INITIAL_DELAY
for attempt in range(MAX_ATTEMPTS):
# Retry loop - actual retry logic is handled by retry_strategy hook
# Hooks control when to stop retrying via the event.retry flag
while True:
model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None
model_invoke_span = tracer.start_model_invoke_span(
messages=agent.messages,
Expand Down Expand Up @@ -366,9 +364,8 @@ async def _handle_model_execution(
# Check if hooks want to retry the model call
if after_model_call_event.retry:
logger.debug(
"stop_reason=<%s>, retry_requested=<True>, attempt=<%d> | hook requested model retry",
"stop_reason=<%s>, retry_requested=<True> | hook requested model retry",
stop_reason,
attempt + 1,
)
continue # Retry the model call

Expand All @@ -389,34 +386,27 @@ async def _handle_model_execution(
)
await agent.hooks.invoke_callbacks_async(after_model_call_event)

# Emit backwards-compatible events if retry strategy supports it
# (prior to making the retry strategy configurable, this is what we emitted)

if (
isinstance(agent.retry_strategy, ModelRetryStrategy)
and agent.retry_strategy._backwards_compatible_event_to_yield
):
yield agent.retry_strategy._backwards_compatible_event_to_yield

# Check if hooks want to retry the model call
if after_model_call_event.retry:
logger.debug(
"exception=<%s>, retry_requested=<True>, attempt=<%d> | hook requested model retry",
"exception=<%s>, retry_requested=<True> | hook requested model retry",
type(e).__name__,
attempt + 1,
)
continue # Retry the model call

if isinstance(e, ModelThrottledException):
if attempt + 1 == MAX_ATTEMPTS:
yield ForceStopEvent(reason=e)
raise e

logger.debug(
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
"| throttling exception encountered "
"| delaying before next retry",
current_delay,
MAX_ATTEMPTS,
attempt + 1,
)
await asyncio.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)
continue # Retry the model call

yield EventLoopThrottleEvent(delay=current_delay)
else:
raise e
# No retry requested, raise the exception
yield ForceStopEvent(reason=e)
raise e

try:
# Add message in trace and mark the end of the stream messages trace
Expand Down
22 changes: 22 additions & 0 deletions tests/strands/agent/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Fixtures for agent tests."""

import asyncio
from unittest.mock import AsyncMock

import pytest


@pytest.fixture
def mock_sleep(monkeypatch):
"""Mock asyncio.sleep to avoid delays in tests and track sleep calls."""
sleep_calls = []

async def _mock_sleep(delay):
sleep_calls.append(delay)

mock = AsyncMock(side_effect=_mock_sleep)
monkeypatch.setattr(asyncio, "sleep", mock)

# Return both the mock and the sleep_calls list for verification
mock.sleep_calls = sleep_calls
return mock
Loading
Loading