diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 6618d3328..78a87e00d 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -4,6 +4,7 @@ - Agent: The main interface for interacting with AI models and tools - ConversationManager: Classes for managing conversation history and context windows +- Retry Strategies: Configurable retry behavior for model calls """ from .agent import Agent @@ -14,6 +15,7 @@ SlidingWindowConversationManager, SummarizingConversationManager, ) +from .retry import ModelRetryStrategy, NoopRetryStrategy __all__ = [ "Agent", @@ -22,4 +24,6 @@ "NullConversationManager", "SlidingWindowConversationManager", "SummarizingConversationManager", + "ModelRetryStrategy", + "NoopRetryStrategy", ] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9e726ca0b..cf56cd459 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -124,6 +124,7 @@ def __init__( hooks: Optional[list[HookProvider]] = None, session_manager: Optional[SessionManager] = None, tool_executor: Optional[ToolExecutor] = None, + retry_strategy: Optional[HookProvider] = None, ): """Initialize the Agent with the specified configuration. @@ -173,6 +174,9 @@ def __init__( session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + retry_strategy: Strategy for retrying model calls on throttling or other transient errors. + Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. + Pass NoopRetryStrategy to disable retries, or implement a custom HookProvider for custom retry logic. Raises: ValueError: If agent id contains path separators. @@ -245,6 +249,11 @@ def __init__( self._interrupt_state = _InterruptState() + # Initialize retry strategy + from .retry import ModelRetryStrategy + + self._retry_strategy = retry_strategy if retry_strategy is not None else ModelRetryStrategy() + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -253,6 +262,9 @@ def __init__( # Allow conversation_managers to subscribe to hooks self.hooks.add_hook(self.conversation_manager) + # Register retry strategy as a hook + self.hooks.add_hook(self._retry_strategy) + self.tool_executor = tool_executor or ConcurrentToolExecutor() if hooks: @@ -289,6 +301,15 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: """ self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) + @property + def retry_strategy(self) -> HookProvider: + """Get the retry strategy for this agent. + + Returns: + The retry strategy hook provider. + """ + return self._retry_strategy + @property def tool(self) -> _ToolCaller: """Call tool as a function. diff --git a/src/strands/agent/retry.py b/src/strands/agent/retry.py new file mode 100644 index 000000000..5d5467367 --- /dev/null +++ b/src/strands/agent/retry.py @@ -0,0 +1,212 @@ +"""Retry strategy implementations for handling model throttling and other retry scenarios. + +This module provides hook-based retry strategies that can be configured on the Agent +to control retry behavior for model invocations. Retry strategies implement the +HookProvider protocol and register callbacks for AfterModelCallEvent to determine +when and how to retry failed model calls. +""" + +import asyncio +import logging +from typing import Any + +from ..types.exceptions import ModelThrottledException +from ..hooks.events import AfterInvocationEvent, AfterModelCallEvent +from ..hooks.registry import HookProvider, HookRegistry + +logger = logging.getLogger(__name__) + + +class ModelRetryStrategy(HookProvider): + """Default retry strategy for model throttling with exponential backoff. + + This strategy implements automatic retry logic for model throttling exceptions, + using exponential backoff to handle rate limiting gracefully. It retries + model calls when ModelThrottledException is raised, up to a configurable + maximum number of attempts. + + The delay between retries starts at initial_delay and doubles after each + retry, up to a maximum of max_delay. The strategy automatically resets + its state after a successful model call. + + Example: + ```python + from strands import Agent + from strands.hooks import ModelRetryStrategy + + # Use custom retry parameters + retry_strategy = ModelRetryStrategy( + max_attempts=3, + initial_delay=2, + max_delay=60 + ) + agent = Agent(retry_strategy=retry_strategy) + ``` + + Attributes: + max_attempts: Maximum number of retry attempts before giving up. + initial_delay: Initial delay in seconds before the first retry. + max_delay: Maximum delay in seconds between retries. + current_attempt: Current retry attempt counter (resets on success). + current_delay: Current delay value for exponential backoff. + """ + + def __init__( + self, + max_attempts: int = 6, + initial_delay: int = 4, + max_delay: int = 240, + ): + """Initialize the retry strategy with the specified parameters. + + Args: + max_attempts: Maximum number of retry attempts. Defaults to 6. + initial_delay: Initial delay in seconds before retrying. Defaults to 4. + max_delay: Maximum delay in seconds between retries. Defaults to 240 (4 minutes). + """ + self._max_attempts = max_attempts + self._initial_delay = initial_delay + self._max_delay = max_delay + self._current_attempt = 0 + self._did_trigger_retry = False + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register callbacks for AfterModelCallEvent and AfterInvocationEvent. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + registry.add_callback(AfterModelCallEvent, self._handle_after_model_call) + registry.add_callback(AfterInvocationEvent, self._handle_after_invocation) + + def _calculate_delay(self) -> float: + """Calculate the current retry delay based on attempt number. + + Uses exponential backoff: initial_delay * (2 ** attempt), capped at max_delay. + + Returns: + The delay in seconds for the current attempt. + """ + if self._current_attempt == 0: + return self._initial_delay + delay = self._initial_delay * (2 ** (self._current_attempt - 1)) + return min(delay, self._max_delay) + + @property + def _current_delay(self) -> float: + """Get the current retry delay (for backwards compatibility with EventLoopThrottleEvent). + + This property is private and only exists for backwards compatibility with EventLoopThrottleEvent. + External code should not access this property. + """ + return self._calculate_delay() + + def _reset_retry_state(self) -> None: + """Reset retry state to initial values.""" + self._current_attempt = 0 + self._did_trigger_retry = False + + async def _handle_after_invocation(self, event: AfterInvocationEvent) -> None: + """Reset retry state after invocation completes. + + Args: + event: The AfterInvocationEvent signaling invocation completion. + """ + self._reset_retry_state() + + async def _handle_after_model_call(self, event: AfterModelCallEvent) -> None: + """Handle model call completion and determine if retry is needed. + + This callback is invoked after each model call. If the call failed with + a ModelThrottledException and we haven't exceeded max_attempts, it sets + event.retry to True and sleeps for the current delay before returning. + + On successful calls, it resets the retry state to prepare for future calls. + + Args: + event: The AfterModelCallEvent containing call results or exception. + """ + # If already retrying, skip processing (another hook may have triggered retry) + if event.retry: + return + + # If model call succeeded, reset retry state + if event.stop_response is not None: + logger.debug( + "stop_reason=<%s> | model call succeeded, resetting retry state", + event.stop_response.stop_reason, + ) + self._reset_retry_state() + return + + # Check if we have an exception and reset state if no exception + if event.exception is None: + self._reset_retry_state() + return + + # Only retry on ModelThrottledException + if not isinstance(event.exception, ModelThrottledException): + return + + # Increment attempt counter first + self._current_attempt += 1 + + # Check if we've exceeded max attempts + if self._current_attempt >= self._max_attempts: + logger.debug( + "current_attempt=<%d>, max_attempts=<%d> | max retry attempts reached, not retrying", + self._current_attempt, + self._max_attempts, + ) + self._did_trigger_retry = False + return + + # Calculate delay for this attempt + delay = self._calculate_delay() + + # Retry the model call + logger.debug( + "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " + "| throttling exception encountered | delaying before next retry", + delay, + self._max_attempts, + self._current_attempt, + ) + + # Sleep for current delay + await asyncio.sleep(delay) + + # Set retry flag and track that this strategy triggered it + event.retry = True + self._did_trigger_retry = True + + +class NoopRetryStrategy(HookProvider): + """No-op retry strategy that disables automatic retries. + + This strategy can be used when you want to explicitly disable retry behavior + and handle errors directly in your application code. It implements the + HookProvider protocol but does not register any callbacks. + + Example: + ```python + from strands import Agent + from strands.hooks import NoopRetryStrategy + + # Disable automatic retries + agent = Agent(retry_strategy=NoopRetryStrategy()) + ``` + """ + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hooks (no-op implementation). + + This method intentionally does nothing, as this strategy disables retries. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + # Intentionally empty - no callbacks to register + pass diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index fcb530a0d..6f1bfa41b 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -51,10 +51,6 @@ logger = logging.getLogger(__name__) -MAX_ATTEMPTS = 6 -INITIAL_DELAY = 4 -MAX_DELAY = 240 # 4 minutes - def _has_tool_use_in_latest_message(messages: "Messages") -> bool: """Check if the latest message contains any ToolUse content blocks. @@ -315,9 +311,9 @@ async def _handle_model_execution( stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) cycle_trace.add_child(stream_trace) - # Retry loop for handling throttling exceptions - current_delay = INITIAL_DELAY - for attempt in range(MAX_ATTEMPTS): + # Retry loop - actual retry logic is handled by retry_strategy hook + # Hooks control when to stop retrying via the event.retry flag + while True: model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None model_invoke_span = tracer.start_model_invoke_span( messages=agent.messages, @@ -364,10 +360,14 @@ async def _handle_model_execution( # Check if hooks want to retry the model call if after_model_call_event.retry: logger.debug( - "stop_reason=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + "stop_reason=<%s>, retry_requested= | hook requested model retry", stop_reason, - attempt + 1, ) + # Emit EventLoopThrottleEvent for backwards compatibility if ModelRetryStrategy triggered retry + from ..agent.retry import ModelRetryStrategy + + if isinstance(agent.retry_strategy, ModelRetryStrategy) and agent.retry_strategy._did_trigger_retry: + yield EventLoopThrottleEvent(delay=agent.retry_strategy._current_delay) continue # Retry the model call if stop_reason == "max_tokens": @@ -390,31 +390,18 @@ async def _handle_model_execution( # Check if hooks want to retry the model call if after_model_call_event.retry: logger.debug( - "exception=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + "exception=<%s>, retry_requested= | hook requested model retry", type(e).__name__, - attempt + 1, ) - continue # Retry the model call - - if isinstance(e, ModelThrottledException): - if attempt + 1 == MAX_ATTEMPTS: - yield ForceStopEvent(reason=e) - raise e + # Emit EventLoopThrottleEvent for backwards compatibility if ModelRetryStrategy triggered retry + from ..agent.retry import ModelRetryStrategy - logger.debug( - "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " - "| throttling exception encountered " - "| delaying before next retry", - current_delay, - MAX_ATTEMPTS, - attempt + 1, - ) - await asyncio.sleep(current_delay) - current_delay = min(current_delay * 2, MAX_DELAY) + if isinstance(agent.retry_strategy, ModelRetryStrategy) and agent.retry_strategy._did_trigger_retry: + yield EventLoopThrottleEvent(delay=agent.retry_strategy._current_delay) + continue # Retry the model call - yield EventLoopThrottleEvent(delay=current_delay) - else: - raise e + # No retry requested, raise the exception + raise e try: # Add message in trace and mark the end of the stream messages trace diff --git a/tests/strands/agent/conftest.py b/tests/strands/agent/conftest.py new file mode 100644 index 000000000..d3af90dc8 --- /dev/null +++ b/tests/strands/agent/conftest.py @@ -0,0 +1,22 @@ +"""Fixtures for agent tests.""" + +import asyncio +from unittest.mock import AsyncMock + +import pytest + + +@pytest.fixture +def mock_sleep(monkeypatch): + """Mock asyncio.sleep to avoid delays in tests and track sleep calls.""" + sleep_calls = [] + + async def _mock_sleep(delay): + sleep_calls.append(delay) + + mock = AsyncMock(side_effect=_mock_sleep) + monkeypatch.setattr(asyncio, "sleep", mock) + + # Return both the mock and the sleep_calls list for verification + mock.sleep_calls = sleep_calls + return mock diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 7b189a5c6..2cc9a5420 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -513,7 +513,8 @@ async def test_event_loop_cycle_text_response_throttling_early_end( {"event_loop_throttled_delay": 32, **common_props}, {"event_loop_throttled_delay": 64, **common_props}, {"event_loop_throttled_delay": 128, **common_props}, - {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"}, + # Note: force_stop event is no longer emitted with hook-based retry strategy + # The exception is raised after max attempts without emitting force_stop ] assert tru_events == exp_events diff --git a/tests/strands/agent/test_agent_retry.py b/tests/strands/agent/test_agent_retry.py new file mode 100644 index 000000000..d3fdf7939 --- /dev/null +++ b/tests/strands/agent/test_agent_retry.py @@ -0,0 +1,175 @@ +"""Integration tests for Agent retry_strategy parameter.""" + +from unittest.mock import Mock + +import pytest + +from strands import Agent +from strands.agent.retry import ModelRetryStrategy, NoopRetryStrategy +from strands.types.exceptions import ModelThrottledException +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +# Agent Retry Strategy Initialization Tests + + +def test_agent_with_default_retry_strategy(): + """Test that Agent uses ModelRetryStrategy by default when retry_strategy=None.""" + agent = Agent() + + # Should have a retry_strategy + assert hasattr(agent, "retry_strategy") + assert agent.retry_strategy is not None + + # Should be ModelRetryStrategy with default parameters + assert isinstance(agent.retry_strategy, ModelRetryStrategy) + assert agent.retry_strategy._max_attempts == 6 + assert agent.retry_strategy._initial_delay == 4 + assert agent.retry_strategy._max_delay == 240 + + +def test_agent_with_custom_model_retry_strategy(): + """Test Agent initialization with custom ModelRetryStrategy parameters.""" + custom_strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + agent = Agent(retry_strategy=custom_strategy) + + assert agent.retry_strategy is custom_strategy + assert agent.retry_strategy._max_attempts == 3 + assert agent.retry_strategy._initial_delay == 2 + assert agent.retry_strategy._max_delay == 60 + + +def test_agent_with_noop_retry_strategy(): + """Test Agent initialization with NoopRetryStrategy.""" + noop_strategy = NoopRetryStrategy() + agent = Agent(retry_strategy=noop_strategy) + + assert agent.retry_strategy is noop_strategy + assert isinstance(agent.retry_strategy, NoopRetryStrategy) + + +def test_retry_strategy_registered_as_hook(): + """Test that retry_strategy is registered with the hook system.""" + custom_strategy = ModelRetryStrategy(max_attempts=3) + agent = Agent(retry_strategy=custom_strategy) + + # Verify retry strategy callback is registered + from strands.hooks import AfterModelCallEvent + + callbacks = list(agent.hooks.get_callbacks_for(AfterModelCallEvent(agent=agent, exception=None))) + + # Should have at least one callback (from retry strategy) + assert len(callbacks) > 0 + + # Verify one of the callbacks is from the retry strategy + assert any(callback.__self__ is custom_strategy if hasattr(callback, "__self__") else False for callback in callbacks) + + +# Agent Retry Behavior Tests + + +@pytest.mark.asyncio +async def test_agent_retries_with_default_strategy(mock_sleep): + """Test that Agent retries on throttling with default ModelRetryStrategy.""" + # Create a model that fails twice with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success after retries"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have succeeded after retries - just check we got events + assert len(events) > 0 + + # Should have slept twice (for two retries) + assert len(mock_sleep.sleep_calls) == 2 + # First retry: 4 seconds + assert mock_sleep.sleep_calls[0] == 4 + # Second retry: 8 seconds (exponential backoff) + assert mock_sleep.sleep_calls[1] == 8 + + +@pytest.mark.asyncio +async def test_agent_no_retry_with_noop_strategy(): + """Test that Agent does not retry with NoopRetryStrategy.""" + # Create a model that always fails with throttling + model = Mock() + model.stream.side_effect = ModelThrottledException("ThrottlingException") + + agent = Agent(model=model, retry_strategy=NoopRetryStrategy()) + + # Should raise exception immediately without retry + with pytest.raises(ModelThrottledException): + result = agent.stream_async("test prompt") + # Consume the stream to trigger the exception + _ = [event async for event in result] + + +@pytest.mark.asyncio +async def test_agent_respects_max_attempts(mock_sleep): + """Test that Agent respects max_attempts in retry strategy.""" + # Create a model that always fails + model = Mock() + model.stream.side_effect = ModelThrottledException("ThrottlingException") + + # Use custom strategy with max 2 attempts + custom_strategy = ModelRetryStrategy(max_attempts=2, initial_delay=1, max_delay=60) + agent = Agent(model=model, retry_strategy=custom_strategy) + + with pytest.raises(ModelThrottledException): + result = agent.stream_async("test prompt") + _ = [event async for event in result] + + # Should have attempted max_attempts times, which means (max_attempts - 1) sleeps + # Attempt 0: fail, sleep + # Attempt 1: fail, no more attempts + assert len(mock_sleep.sleep_calls) == 1 + + +# Backwards Compatibility Tests + + +@pytest.mark.asyncio +async def test_event_loop_throttle_event_emitted(mock_sleep): + """Test that EventLoopThrottleEvent is still emitted for backwards compatibility.""" + # Create a model that fails once with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have EventLoopThrottleEvent in the stream + throttle_events = [e for e in events if "event_loop_throttled_delay" in e] + assert len(throttle_events) > 0 + + # Should have the correct delay value + assert throttle_events[0]["event_loop_throttled_delay"] > 0 + + +@pytest.mark.asyncio +async def test_no_throttle_event_with_noop_strategy(): + """Test that EventLoopThrottleEvent is not emitted with NoopRetryStrategy.""" + # Create a model that succeeds immediately + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Success"}]}]) + + agent = Agent(model=model, retry_strategy=NoopRetryStrategy()) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should not have any EventLoopThrottleEvent + throttle_events = [e for e in events if "event_loop_throttled_delay" in e] + assert len(throttle_events) == 0 + diff --git a/tests/strands/agent/test_retry.py b/tests/strands/agent/test_retry.py new file mode 100644 index 000000000..48fa8ab68 --- /dev/null +++ b/tests/strands/agent/test_retry.py @@ -0,0 +1,217 @@ +"""Unit tests for retry strategy implementations.""" + +from unittest.mock import Mock + +import pytest + +from strands.hooks import AfterModelCallEvent, HookRegistry +from strands.agent.retry import ModelRetryStrategy, NoopRetryStrategy +from strands.types.exceptions import ModelThrottledException + + +# ModelRetryStrategy Tests + + +def test_model_retry_strategy_init_with_defaults(): + """Test ModelRetryStrategy initialization with default parameters.""" + strategy = ModelRetryStrategy() + assert strategy._max_attempts == 6 + assert strategy._initial_delay == 4 + assert strategy._max_delay == 240 + assert strategy._current_attempt == 0 + assert strategy._calculate_delay() == 4 + + +def test_model_retry_strategy_init_with_custom_parameters(): + """Test ModelRetryStrategy initialization with custom parameters.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + assert strategy._max_attempts == 3 + assert strategy._initial_delay == 2 + assert strategy._max_delay == 60 + assert strategy._current_attempt == 0 + assert strategy._calculate_delay() == 2 + + +def test_model_retry_strategy_register_hooks(): + """Test that ModelRetryStrategy registers AfterModelCallEvent callback.""" + strategy = ModelRetryStrategy() + registry = HookRegistry() + + strategy.register_hooks(registry) + + # Verify callback was registered + assert AfterModelCallEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_retry_on_throttle_exception_first_attempt(mock_sleep): + """Test retry behavior on first ModelThrottledException.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should set retry to True + assert event.retry is True + # Should sleep for initial_delay + assert mock_sleep.sleep_calls == [2] + # Should increment attempt + assert strategy._current_attempt == 1 + assert strategy._calculate_delay() == 4 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_exponential_backoff(mock_sleep): + """Test exponential backoff calculation.""" + strategy = ModelRetryStrategy(max_attempts=5, initial_delay=2, max_delay=16) + mock_agent = Mock() + + # Simulate multiple retries + for _ in range(4): + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event) + assert event.retry is True + + # Verify exponential backoff with max_delay cap + # 2, 4, 8, 16 (capped) + assert mock_sleep.sleep_calls == [2, 4, 8, 16] + # Delay should be capped at max_delay + assert strategy._calculate_delay() == 16 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_after_max_attempts(mock_sleep): + """Test that retry is not set after reaching max_attempts.""" + strategy = ModelRetryStrategy(max_attempts=2, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First attempt + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + + # Second attempt (at max_attempts) + event2 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event2) + # Should NOT retry after reaching max_attempts + assert event2.retry is False + assert strategy._current_attempt == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_non_throttle_exception(): + """Test that retry is not set for non-throttling exceptions.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ValueError("Some other error"), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on non-throttling exceptions + assert event.retry is False + assert strategy._current_attempt == 0 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_success(): + """Test that retry is not set when model call succeeds.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on success + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_retry_strategy_reset_on_success(mock_sleep): + """Test that strategy resets attempt counter on successful call.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First failure + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + + # Success - should reset + event2 = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + await strategy._handle_after_model_call(event2) + assert event2.retry is False + # Should reset to initial state + assert strategy._current_attempt == 0 + assert strategy._calculate_delay() == 2 + + +# NoopRetryStrategy Tests + + +def test_noop_retry_strategy_register_hooks_does_nothing(): + """Test that NoopRetryStrategy does not register any callbacks.""" + strategy = NoopRetryStrategy() + registry = HookRegistry() + + strategy.register_hooks(registry) + + # Verify no callbacks were registered + assert len(registry._registered_callbacks) == 0 + + +@pytest.mark.asyncio +async def test_noop_retry_strategy_no_retry_on_throttle_exception(): + """Test that NoopRetryStrategy does not retry on throttle exceptions.""" + strategy = NoopRetryStrategy() + registry = HookRegistry() + strategy.register_hooks(registry) + + mock_agent = Mock() + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + # Invoke callbacks (should be none registered) + await registry.invoke_callbacks_async(event) + + # event.retry should still be False (default) + assert event.retry is False + diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 6b23bd592..17606c014 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -116,7 +116,13 @@ def tool_stream(tool): @pytest.fixture def hook_registry(): - return HookRegistry() + from strands.agent.retry import ModelRetryStrategy + + registry = HookRegistry() + # Register default retry strategy + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(registry) + return registry @pytest.fixture @@ -133,6 +139,8 @@ def tool_executor(): @pytest.fixture def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry, tool_executor): + from strands.agent.retry import ModelRetryStrategy + mock = unittest.mock.Mock(name="agent") mock.__class__ = Agent mock.config.cache_points = [] @@ -147,6 +155,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.tool_executor = tool_executor mock._interrupt_state = _InterruptState() mock.trace_attributes = {} + mock.retry_strategy = ModelRetryStrategy() return mock @@ -692,7 +701,9 @@ async def test_event_loop_tracing_with_throttling_exception( ] # Mock the time.sleep function to speed up the test - with patch("strands.event_loop.event_loop.asyncio.sleep", new_callable=unittest.mock.AsyncMock): + import asyncio + + with patch.object(asyncio, "sleep", new_callable=unittest.mock.AsyncMock): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -855,15 +866,21 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, # 1st call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 2nd call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 3rd call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 4th call - successful assert next(events) == BeforeModelCallEvent(agent=agent)