diff --git a/src/strands/__init__.py b/src/strands/__init__.py index bc17497a0..6026d4240 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -3,6 +3,7 @@ from . import agent, models, telemetry, types from .agent.agent import Agent from .agent.base import AgentBase +from .event_loop._retry import ModelRetryStrategy from .tools.decorator import tool from .types.tools import ToolContext @@ -11,6 +12,7 @@ "AgentBase", "agent", "models", + "ModelRetryStrategy", "tool", "ToolContext", "types", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index c00623dc2..2e40866a9 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -4,8 +4,10 @@ - Agent: The main interface for interacting with AI models and tools - ConversationManager: Classes for managing conversation history and context windows +- Retry Strategies: Configurable retry behavior for model calls """ +from ..event_loop._retry import ModelRetryStrategy from .agent import Agent from .agent_result import AgentResult from .base import AgentBase @@ -24,4 +26,5 @@ "NullConversationManager", "SlidingWindowConversationManager", "SummarizingConversationManager", + "ModelRetryStrategy", ] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index b58b55f24..dbab30b64 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -26,7 +26,8 @@ from .. import _identifier from .._async import run_async -from ..event_loop.event_loop import event_loop_cycle +from ..event_loop._retry import ModelRetryStrategy +from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle from ..tools._tool_helpers import generate_missing_tool_result_content if TYPE_CHECKING: @@ -118,6 +119,7 @@ def __init__( hooks: list[HookProvider] | None = None, session_manager: SessionManager | None = None, tool_executor: ToolExecutor | None = None, + retry_strategy: ModelRetryStrategy | None = None, ): """Initialize the Agent with the specified configuration. @@ -167,6 +169,9 @@ def __init__( session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + retry_strategy: Strategy for retrying model calls on throttling or other transient errors. + Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. + Implement a custom HookProvider for custom retry logic, or pass None to disable retries. Raises: ValueError: If agent id contains path separators. @@ -244,6 +249,17 @@ def __init__( # separate event loops in different threads, so asyncio.Lock wouldn't work self._invocation_lock = threading.Lock() + # In the future, we'll have a RetryStrategy base class but until + # that API is determined we only allow ModelRetryStrategy + if retry_strategy and type(retry_strategy) is not ModelRetryStrategy: + raise ValueError("retry_strategy must be an instance of ModelRetryStrategy") + + self._retry_strategy = ( + retry_strategy + if retry_strategy is not None + else ModelRetryStrategy(max_attempts=MAX_ATTEMPTS, max_delay=MAX_DELAY, initial_delay=INITIAL_DELAY) + ) + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -252,6 +268,9 @@ def __init__( # Allow conversation_managers to subscribe to hooks self.hooks.add_hook(self.conversation_manager) + # Register retry strategy as a hook + self.hooks.add_hook(self._retry_strategy) + self.tool_executor = tool_executor or ConcurrentToolExecutor() if hooks: @@ -288,6 +307,15 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: """ self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) + @property + def retry_strategy(self) -> HookProvider: + """Get the retry strategy for this agent. + + Returns: + The retry strategy hook provider. + """ + return self._retry_strategy + @property def tool(self) -> _ToolCaller: """Call tool as a function. diff --git a/src/strands/event_loop/_retry.py b/src/strands/event_loop/_retry.py new file mode 100644 index 000000000..04a6101b8 --- /dev/null +++ b/src/strands/event_loop/_retry.py @@ -0,0 +1,157 @@ +"""Retry strategy implementations for handling model throttling and other retry scenarios. + +This module provides hook-based retry strategies that can be configured on the Agent +to control retry behavior for model invocations. Retry strategies implement the +HookProvider protocol and register callbacks for AfterModelCallEvent to determine +when and how to retry failed model calls. +""" + +import asyncio +import logging +from typing import Any + +from ..hooks.events import AfterInvocationEvent, AfterModelCallEvent +from ..hooks.registry import HookProvider, HookRegistry +from ..types._events import EventLoopThrottleEvent, TypedEvent +from ..types.exceptions import ModelThrottledException + +logger = logging.getLogger(__name__) + + +class ModelRetryStrategy(HookProvider): + """Default retry strategy for model throttling with exponential backoff. + + Retries model calls on ModelThrottledException using exponential backoff. + Delay doubles after each attempt: initial_delay, initial_delay*2, initial_delay*4, + etc., capped at max_delay. State resets after successful calls. + + With defaults (initial_delay=4, max_delay=240, max_attempts=6), delays are: + 4s → 8s → 16s → 32s → 64s (5 retries before giving up on the 6th attempt). + + Args: + max_attempts: Total model attempts before re-raising the exception. + initial_delay: Base delay in seconds; used for first two retries, then doubles. + max_delay: Upper bound in seconds for the exponential backoff. + """ + + def __init__( + self, + *, + max_attempts: int = 6, + initial_delay: int = 4, + max_delay: int = 240, + ): + """Initialize the retry strategy. + + Args: + max_attempts: Total model attempts before re-raising the exception. Defaults to 6. + initial_delay: Base delay in seconds; used for first two retries, then doubles. + Defaults to 4. + max_delay: Upper bound in seconds for the exponential backoff. Defaults to 240. + """ + self._max_attempts = max_attempts + self._initial_delay = initial_delay + self._max_delay = max_delay + self._current_attempt = 0 + self._backwards_compatible_event_to_yield: TypedEvent | None = None + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register callbacks for AfterModelCallEvent and AfterInvocationEvent. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + registry.add_callback(AfterModelCallEvent, self._handle_after_model_call) + registry.add_callback(AfterInvocationEvent, self._handle_after_invocation) + + def _calculate_delay(self, attempt: int) -> int: + """Calculate retry delay using exponential backoff. + + Args: + attempt: The attempt number (0-indexed) to calculate delay for. + + Returns: + Delay in seconds for the given attempt. + """ + delay: int = self._initial_delay * (2**attempt) + return min(delay, self._max_delay) + + def _reset_retry_state(self) -> None: + """Reset retry state to initial values.""" + self._current_attempt = 0 + + async def _handle_after_invocation(self, event: AfterInvocationEvent) -> None: + """Reset retry state after invocation completes. + + Args: + event: The AfterInvocationEvent signaling invocation completion. + """ + self._reset_retry_state() + + async def _handle_after_model_call(self, event: AfterModelCallEvent) -> None: + """Handle model call completion and determine if retry is needed. + + This callback is invoked after each model call. If the call failed with + a ModelThrottledException and we haven't exceeded max_attempts, it sets + event.retry to True and sleeps for the current delay before returning. + + On successful calls, it resets the retry state to prepare for future calls. + + Args: + event: The AfterModelCallEvent containing call results or exception. + """ + delay = self._calculate_delay(self._current_attempt) + + self._backwards_compatible_event_to_yield = None + + # If already retrying, skip processing (another hook may have triggered retry) + if event.retry: + return + + # If model call succeeded, reset retry state + if event.stop_response is not None: + logger.debug( + "stop_reason=<%s> | model call succeeded, resetting retry state", + event.stop_response.stop_reason, + ) + self._reset_retry_state() + return + + # Check if we have an exception and reset state if no exception + if event.exception is None: + self._reset_retry_state() + return + + # Only retry on ModelThrottledException + if not isinstance(event.exception, ModelThrottledException): + return + + # Increment attempt counter first + self._current_attempt += 1 + + # Check if we've exceeded max attempts + if self._current_attempt >= self._max_attempts: + logger.debug( + "current_attempt=<%d>, max_attempts=<%d> | max retry attempts reached, not retrying", + self._current_attempt, + self._max_attempts, + ) + return + + self._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=delay) + + # Retry the model call + logger.debug( + "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " + "| throttling exception encountered | delaying before next retry", + delay, + self._max_attempts, + self._current_attempt, + ) + + # Sleep for current delay + await asyncio.sleep(delay) + + # Set retry flag and track that this strategy triggered it + event.retry = True diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 99c8f5179..5f1ad5b60 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -8,7 +8,6 @@ 4. Manage recursive execution cycles """ -import asyncio import logging import uuid from collections.abc import AsyncGenerator @@ -23,7 +22,6 @@ from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..types._events import ( EventLoopStopEvent, - EventLoopThrottleEvent, ForceStopEvent, ModelMessageEvent, ModelStopReason, @@ -39,12 +37,12 @@ ContextWindowOverflowException, EventLoopException, MaxTokensReachedException, - ModelThrottledException, StructuredOutputException, ) from ..types.streaming import StopReason from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached +from ._retry import ModelRetryStrategy from .streaming import stream_messages if TYPE_CHECKING: @@ -316,9 +314,9 @@ async def _handle_model_execution( stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) cycle_trace.add_child(stream_trace) - # Retry loop for handling throttling exceptions - current_delay = INITIAL_DELAY - for attempt in range(MAX_ATTEMPTS): + # Retry loop - actual retry logic is handled by retry_strategy hook + # Hooks control when to stop retrying via the event.retry flag + while True: model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None model_invoke_span = tracer.start_model_invoke_span( messages=agent.messages, @@ -366,9 +364,8 @@ async def _handle_model_execution( # Check if hooks want to retry the model call if after_model_call_event.retry: logger.debug( - "stop_reason=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + "stop_reason=<%s>, retry_requested= | hook requested model retry", stop_reason, - attempt + 1, ) continue # Retry the model call @@ -389,34 +386,27 @@ async def _handle_model_execution( ) await agent.hooks.invoke_callbacks_async(after_model_call_event) + # Emit backwards-compatible events if retry strategy supports it + # (prior to making the retry strategy configurable, this is what we emitted) + + if ( + isinstance(agent.retry_strategy, ModelRetryStrategy) + and agent.retry_strategy._backwards_compatible_event_to_yield + ): + yield agent.retry_strategy._backwards_compatible_event_to_yield + # Check if hooks want to retry the model call if after_model_call_event.retry: logger.debug( - "exception=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + "exception=<%s>, retry_requested= | hook requested model retry", type(e).__name__, - attempt + 1, ) - continue # Retry the model call - if isinstance(e, ModelThrottledException): - if attempt + 1 == MAX_ATTEMPTS: - yield ForceStopEvent(reason=e) - raise e - - logger.debug( - "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " - "| throttling exception encountered " - "| delaying before next retry", - current_delay, - MAX_ATTEMPTS, - attempt + 1, - ) - await asyncio.sleep(current_delay) - current_delay = min(current_delay * 2, MAX_DELAY) + continue # Retry the model call - yield EventLoopThrottleEvent(delay=current_delay) - else: - raise e + # No retry requested, raise the exception + yield ForceStopEvent(reason=e) + raise e try: # Add message in trace and mark the end of the stream messages trace diff --git a/tests/strands/agent/conftest.py b/tests/strands/agent/conftest.py new file mode 100644 index 000000000..d3af90dc8 --- /dev/null +++ b/tests/strands/agent/conftest.py @@ -0,0 +1,22 @@ +"""Fixtures for agent tests.""" + +import asyncio +from unittest.mock import AsyncMock + +import pytest + + +@pytest.fixture +def mock_sleep(monkeypatch): + """Mock asyncio.sleep to avoid delays in tests and track sleep calls.""" + sleep_calls = [] + + async def _mock_sleep(delay): + sleep_calls.append(delay) + + mock = AsyncMock(side_effect=_mock_sleep) + monkeypatch.setattr(asyncio, "sleep", mock) + + # Return both the mock and the sleep_calls list for verification + mock.sleep_calls = sleep_calls + return mock diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 7b189a5c6..f511c7019 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -1,6 +1,6 @@ import asyncio import unittest.mock -from unittest.mock import ANY, MagicMock, call +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest from pydantic import BaseModel @@ -34,9 +34,7 @@ async def streaming_tool(): @pytest.fixture def mock_sleep(): - with unittest.mock.patch.object( - strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock - ) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock @@ -359,8 +357,8 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): {"arg1": 1013, "init_event_loop": True}, {"start": True}, {"start_event_loop": True}, + {"event_loop_throttled_delay": 4, **throttle_props}, {"event_loop_throttled_delay": 8, **throttle_props}, - {"event_loop_throttled_delay": 16, **throttle_props}, {"event": {"messageStart": {"role": "assistant"}}}, {"event": {"redactContent": {"redactUserContentMessage": "BLOCKED!"}}}, {"event": {"contentBlockStart": {"start": {}}}}, @@ -508,11 +506,11 @@ async def test_event_loop_cycle_text_response_throttling_early_end( {"init_event_loop": True, "arg1": 1013}, {"start": True}, {"start_event_loop": True}, + {"event_loop_throttled_delay": 4, **common_props}, {"event_loop_throttled_delay": 8, **common_props}, {"event_loop_throttled_delay": 16, **common_props}, {"event_loop_throttled_delay": 32, **common_props}, {"event_loop_throttled_delay": 64, **common_props}, - {"event_loop_throttled_delay": 128, **common_props}, {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"}, ] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 81ce65989..df6f1f10d 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -8,7 +8,8 @@ import time import unittest.mock import warnings -from typing import Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from uuid import uuid4 import pytest diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 00b9d368a..3946f5cd3 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -104,7 +104,7 @@ class User(BaseModel): @pytest.fixture def mock_sleep(): - with patch.object(strands.event_loop.event_loop.asyncio, "sleep", new_callable=AsyncMock) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock diff --git a/tests/strands/agent/test_agent_retry.py b/tests/strands/agent/test_agent_retry.py new file mode 100644 index 000000000..17f5a976b --- /dev/null +++ b/tests/strands/agent/test_agent_retry.py @@ -0,0 +1,162 @@ +"""Integration tests for Agent retry_strategy parameter.""" + +from unittest.mock import Mock + +import pytest + +from strands import Agent, ModelRetryStrategy +from strands.event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY +from strands.hooks import AfterModelCallEvent +from strands.types.exceptions import ModelThrottledException +from tests.fixtures.mocked_model_provider import MockedModelProvider + +# Agent Retry Strategy Initialization Tests + + +def test_agent_with_default_retry_strategy(): + """Test that Agent uses ModelRetryStrategy by default when retry_strategy=None.""" + agent = Agent() + + # Should have a retry_strategy + assert hasattr(agent, "retry_strategy") + assert agent.retry_strategy is not None + + # Should be ModelRetryStrategy with default parameters + assert isinstance(agent.retry_strategy, ModelRetryStrategy) + assert agent.retry_strategy._max_attempts == 6 + assert agent.retry_strategy._initial_delay == 4 + assert agent.retry_strategy._max_delay == 240 + + +def test_agent_with_custom_model_retry_strategy(): + """Test Agent initialization with custom ModelRetryStrategy parameters.""" + custom_strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + agent = Agent(retry_strategy=custom_strategy) + + assert agent.retry_strategy is custom_strategy + assert agent.retry_strategy._max_attempts == 3 + assert agent.retry_strategy._initial_delay == 2 + assert agent.retry_strategy._max_delay == 60 + + +def test_agent_rejects_invalid_retry_strategy_type(): + """Test that Agent raises ValueError for non-ModelRetryStrategy retry_strategy.""" + + class FakeRetryStrategy: + pass + + with pytest.raises(ValueError, match="retry_strategy must be an instance of ModelRetryStrategy"): + Agent(retry_strategy=FakeRetryStrategy()) + + +def test_agent_rejects_subclass_of_model_retry_strategy(): + """Test that Agent rejects subclasses of ModelRetryStrategy (strict type check).""" + + class CustomRetryStrategy(ModelRetryStrategy): + pass + + with pytest.raises(ValueError, match="retry_strategy must be an instance of ModelRetryStrategy"): + Agent(retry_strategy=CustomRetryStrategy()) + + +def test_agent_default_retry_strategy_uses_event_loop_constants(): + """Test that default retry strategy uses constants from event_loop module.""" + agent = Agent() + + assert agent.retry_strategy._max_attempts == MAX_ATTEMPTS + assert agent.retry_strategy._initial_delay == INITIAL_DELAY + assert agent.retry_strategy._max_delay == MAX_DELAY + + +def test_retry_strategy_registered_as_hook(): + """Test that retry_strategy is registered with the hook system.""" + custom_strategy = ModelRetryStrategy(max_attempts=3) + agent = Agent(retry_strategy=custom_strategy) + + # Verify retry strategy callback is registered + callbacks = list(agent.hooks.get_callbacks_for(AfterModelCallEvent(agent=agent, exception=None))) + + # Should have at least one callback (from retry strategy) + assert len(callbacks) > 0 + + # Verify one of the callbacks is from the retry strategy + assert any( + callback.__self__ is custom_strategy if hasattr(callback, "__self__") else False for callback in callbacks + ) + + +# Agent Retry Behavior Tests + + +@pytest.mark.asyncio +async def test_agent_retries_with_default_strategy(mock_sleep): + """Test that Agent retries on throttling with default ModelRetryStrategy.""" + # Create a model that fails twice with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success after retries"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have succeeded after retries - just check we got events + assert len(events) > 0 + + # Should have slept twice (for two retries) + assert len(mock_sleep.sleep_calls) == 2 + # First retry: 4 seconds + assert mock_sleep.sleep_calls[0] == 4 + # Second retry: 8 seconds (exponential backoff) + assert mock_sleep.sleep_calls[1] == 8 + + +@pytest.mark.asyncio +async def test_agent_respects_max_attempts(mock_sleep): + """Test that Agent respects max_attempts in retry strategy.""" + # Create a model that always fails + model = Mock() + model.stream.side_effect = ModelThrottledException("ThrottlingException") + + # Use custom strategy with max 2 attempts + custom_strategy = ModelRetryStrategy(max_attempts=2, initial_delay=1, max_delay=60) + agent = Agent(model=model, retry_strategy=custom_strategy) + + with pytest.raises(ModelThrottledException): + result = agent.stream_async("test prompt") + _ = [event async for event in result] + + # Should have attempted max_attempts times, which means (max_attempts - 1) sleeps + # Attempt 0: fail, sleep + # Attempt 1: fail, no more attempts + assert len(mock_sleep.sleep_calls) == 1 + + +# Backwards Compatibility Tests + + +@pytest.mark.asyncio +async def test_event_loop_throttle_event_emitted(mock_sleep): + """Test that EventLoopThrottleEvent is still emitted for backwards compatibility.""" + # Create a model that fails once with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have EventLoopThrottleEvent in the stream + throttle_events = [e for e in events if "event_loop_throttled_delay" in e] + assert len(throttle_events) > 0 + + # Should have the correct delay value + assert throttle_events[0]["event_loop_throttled_delay"] > 0 diff --git a/tests/strands/agent/test_retry.py b/tests/strands/agent/test_retry.py new file mode 100644 index 000000000..830c1b5b8 --- /dev/null +++ b/tests/strands/agent/test_retry.py @@ -0,0 +1,328 @@ +"""Unit tests for retry strategy implementations.""" + +from unittest.mock import Mock + +import pytest + +from strands import ModelRetryStrategy +from strands.hooks import AfterInvocationEvent, AfterModelCallEvent, HookRegistry +from strands.types._events import EventLoopThrottleEvent +from strands.types.exceptions import ModelThrottledException + +# ModelRetryStrategy Tests + + +def test_model_retry_strategy_init_with_defaults(): + """Test ModelRetryStrategy initialization with default parameters.""" + strategy = ModelRetryStrategy() + assert strategy._max_attempts == 6 + assert strategy._initial_delay == 4 + assert strategy._max_delay == 240 + assert strategy._current_attempt == 0 + + +def test_model_retry_strategy_init_with_custom_parameters(): + """Test ModelRetryStrategy initialization with custom parameters.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + assert strategy._max_attempts == 3 + assert strategy._initial_delay == 2 + assert strategy._max_delay == 60 + assert strategy._current_attempt == 0 + + +def test_model_retry_strategy_calculate_delay_with_different_attempts(): + """Test _calculate_delay returns correct exponential backoff for different attempt numbers.""" + strategy = ModelRetryStrategy(initial_delay=2, max_delay=32) + + # Test exponential backoff: 2 * (2^attempt) + assert strategy._calculate_delay(0) == 2 # 2 * 2^0 = 2 + assert strategy._calculate_delay(1) == 4 # 2 * 2^1 = 4 + assert strategy._calculate_delay(2) == 8 # 2 * 2^2 = 8 + assert strategy._calculate_delay(3) == 16 # 2 * 2^3 = 16 + assert strategy._calculate_delay(4) == 32 # 2 * 2^4 = 32 (at max) + assert strategy._calculate_delay(5) == 32 # 2 * 2^5 = 64, capped at 32 + assert strategy._calculate_delay(10) == 32 # Large attempt, still capped + + +def test_model_retry_strategy_calculate_delay_respects_max_delay(): + """Test _calculate_delay respects max_delay cap.""" + strategy = ModelRetryStrategy(initial_delay=10, max_delay=50) + + assert strategy._calculate_delay(0) == 10 # 10 * 2^0 = 10 + assert strategy._calculate_delay(1) == 20 # 10 * 2^1 = 20 + assert strategy._calculate_delay(2) == 40 # 10 * 2^2 = 40 + assert strategy._calculate_delay(3) == 50 # 10 * 2^3 = 80, capped at 50 + assert strategy._calculate_delay(4) == 50 # 10 * 2^4 = 160, capped at 50 + + +def test_model_retry_strategy_register_hooks(): + """Test that ModelRetryStrategy registers AfterModelCallEvent and AfterInvocationEvent callbacks.""" + strategy = ModelRetryStrategy() + registry = HookRegistry() + + strategy.register_hooks(registry) + + # Verify AfterModelCallEvent callback was registered + assert AfterModelCallEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + # Verify AfterInvocationEvent callback was registered + assert AfterInvocationEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterInvocationEvent]) == 1 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_retry_on_throttle_exception_first_attempt(mock_sleep): + """Test retry behavior on first ModelThrottledException.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should set retry to True + assert event.retry is True + # Should sleep for initial_delay (attempt 0: 2 * 2^0 = 2) + assert mock_sleep.sleep_calls == [2] + assert mock_sleep.sleep_calls[0] == strategy._calculate_delay(0) + # Should increment attempt + assert strategy._current_attempt == 1 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_exponential_backoff(mock_sleep): + """Test exponential backoff calculation.""" + strategy = ModelRetryStrategy(max_attempts=5, initial_delay=2, max_delay=16) + mock_agent = Mock() + + # Simulate multiple retries + for _ in range(4): + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event) + assert event.retry is True + + # Verify exponential backoff with max_delay cap + # attempt 0: 2*2^0=2, attempt 1: 2*2^1=4, attempt 2: 2*2^2=8, attempt 3: 2*2^3=16 (capped) + assert mock_sleep.sleep_calls == [2, 4, 8, 16] + for i, sleep_delay in enumerate(mock_sleep.sleep_calls): + assert sleep_delay == strategy._calculate_delay(i) + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_after_max_attempts(mock_sleep): + """Test that retry is not set after reaching max_attempts.""" + strategy = ModelRetryStrategy(max_attempts=2, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First attempt + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + + # Second attempt (at max_attempts) + event2 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event2) + # Should NOT retry after reaching max_attempts + assert event2.retry is False + assert strategy._current_attempt == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_non_throttle_exception(): + """Test that retry is not set for non-throttling exceptions.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ValueError("Some other error"), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on non-throttling exceptions + assert event.retry is False + assert strategy._current_attempt == 0 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_success(): + """Test that retry is not set when model call succeeds.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on success + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_retry_strategy_reset_on_success(mock_sleep): + """Test that strategy resets attempt counter on successful call.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First failure + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + # Should sleep for initial_delay (attempt 0: 2 * 2^0 = 2) + assert mock_sleep.sleep_calls == [2] + assert mock_sleep.sleep_calls[0] == strategy._calculate_delay(0) + + # Success - should reset + event2 = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + await strategy._handle_after_model_call(event2) + assert event2.retry is False + # Should reset to initial state + assert strategy._current_attempt == 0 + assert strategy._calculate_delay(0) == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_skips_if_already_retrying(): + """Test that strategy skips processing if event.retry is already True.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + # Simulate another hook already set retry to True + event.retry = True + + await strategy._handle_after_model_call(event) + + # Should not modify state since another hook already triggered retry + assert strategy._current_attempt == 0 + assert event.retry is True + + +@pytest.mark.asyncio +async def test_model_retry_strategy_reset_on_after_invocation(): + """Test that strategy resets state on AfterInvocationEvent.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # Simulate some retry attempts + strategy._current_attempt = 3 + + event = AfterInvocationEvent(agent=mock_agent, result=Mock()) + await strategy._handle_after_invocation(event) + + # Should reset to initial state + assert strategy._current_attempt == 0 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_set_on_retry(mock_sleep): + """Test that _backwards_compatible_event_to_yield is set when retrying.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should have set the backwards compatible event + assert strategy._backwards_compatible_event_to_yield is not None + assert isinstance(strategy._backwards_compatible_event_to_yield, EventLoopThrottleEvent) + assert strategy._backwards_compatible_event_to_yield["event_loop_throttled_delay"] == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_cleared_on_success(): + """Test that _backwards_compatible_event_to_yield is cleared on success.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # Set a previous backwards compatible event + strategy._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=2) + + event = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + + await strategy._handle_after_model_call(event) + + # Should have cleared the backwards compatible event + assert strategy._backwards_compatible_event_to_yield is None + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_not_set_on_max_attempts(mock_sleep): + """Test that _backwards_compatible_event_to_yield is not set when max attempts reached.""" + strategy = ModelRetryStrategy(max_attempts=1, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should not have set the backwards compatible event since max attempts reached + assert strategy._backwards_compatible_event_to_yield is None + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_when_no_exception_and_no_stop_response(): + """Test that retry is not set when there's no exception and no stop_response.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + # Event with neither exception nor stop_response + event = AfterModelCallEvent( + agent=mock_agent, + exception=None, + stop_response=None, + ) + + await strategy._handle_after_model_call(event) + + # Should not retry and should reset state + assert event.retry is False + assert strategy._current_attempt == 0 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 639e60ea0..d4afd579b 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import unittest.mock from unittest.mock import ANY, AsyncMock, MagicMock, call, patch @@ -7,6 +8,7 @@ import strands import strands.telemetry from strands import Agent +from strands.event_loop._retry import ModelRetryStrategy from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -31,9 +33,7 @@ @pytest.fixture def mock_sleep(): - with unittest.mock.patch.object( - strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock - ) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock @@ -116,7 +116,11 @@ def tool_stream(tool): @pytest.fixture def hook_registry(): - return HookRegistry() + registry = HookRegistry() + # Register default retry strategy + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(registry) + return registry @pytest.fixture @@ -147,6 +151,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.tool_executor = tool_executor mock._interrupt_state = _InterruptState() mock.trace_attributes = {} + mock.retry_strategy = ModelRetryStrategy() return mock @@ -693,7 +698,7 @@ async def test_event_loop_tracing_with_throttling_exception( ] # Mock the time.sleep function to speed up the test - with patch("strands.event_loop.event_loop.asyncio.sleep", new_callable=unittest.mock.AsyncMock): + with patch.object(asyncio, "sleep", new_callable=unittest.mock.AsyncMock): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -856,15 +861,21 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, # 1st call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 2nd call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 3rd call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 4th call - successful assert next(events) == BeforeModelCallEvent(agent=agent)