1919 InvokeModelRequestTypeDef ,
2020 InvokeModelResponseTypeDef ,
2121)
22+ from strands .models .bedrock import ModelThrottledException
2223
2324# Configure logger
2425logger = logging .getLogger (__name__ )
2526logger .setLevel (os .environ .get ("LOG_LEVEL" , "INFO" ))
2627
28+ # Default retryable error codes (matched against ClientError codes and exception messages)
29+ DEFAULT_RETRYABLE_ERRORS = {
30+ "ThrottlingException" ,
31+ "throttlingException" ,
32+ "ModelThrottledException" , # Strands wrapper for throttling
33+ "ModelErrorException" ,
34+ "ValidationException" ,
35+ "ServiceQuotaExceededException" ,
36+ "RequestLimitExceeded" ,
37+ "TooManyRequestsException" ,
38+ "ServiceUnavailableException" ,
39+ "serviceUnavailableException" , # lowercase variant from EventStreamError
40+ "RequestTimeout" ,
41+ "RequestTimeoutException" ,
42+ }
43+
44+ # Default retryable exception types (caught by isinstance check)
45+ DEFAULT_RETRYABLE_EXCEPTION_TYPES : tuple [type [Exception ], ...] = (
46+ ModelThrottledException ,
47+ )
48+
2749
2850def async_exponential_backoff_retry [T , ** P ](
2951 max_retries : int = 5 ,
@@ -32,23 +54,13 @@ def async_exponential_backoff_retry[T, **P](
3254 exponential_base : float = 2.0 ,
3355 jitter : float = 0.1 ,
3456 retryable_errors : set [str ] | None = None ,
57+ retryable_exception_types : tuple [type [Exception ], ...] | None = None ,
3558) -> Callable [[Callable [P , Awaitable [T ]]], Callable [P , Awaitable [T ]]]:
36- if not retryable_errors :
37- retryable_errors = set (
38- [
39- "ThrottlingException" ,
40- "throttlingException" ,
41- "ModelErrorException" ,
42- "ValidationException" ,
43- "ServiceQuotaExceededException" ,
44- "RequestLimitExceeded" ,
45- "TooManyRequestsException" ,
46- "ServiceUnavailableException" ,
47- "serviceUnavailableException" , # lowercase variant from EventStreamError
48- "RequestTimeout" ,
49- "RequestTimeoutException" ,
50- ]
51- )
59+ # Use defaults if not provided
60+ if retryable_errors is None :
61+ retryable_errors = DEFAULT_RETRYABLE_ERRORS
62+ if retryable_exception_types is None :
63+ retryable_exception_types = DEFAULT_RETRYABLE_EXCEPTION_TYPES
5264
5365 def decorator (func : Callable [P , Awaitable [T ]]) -> Callable [P , Awaitable [T ]]:
5466 @wraps (func )
@@ -104,7 +116,34 @@ def log_bedrock_invocation_error(error: Exception, attempt_num: int):
104116 await asyncio .sleep (sleep_time )
105117 delay = min (delay * exponential_base , max_delay )
106118 except Exception as e :
107- # Log bedrock invocation details for non-ClientError exceptions too
119+ # Check if this is a retryable exception type (e.g., Strands ModelThrottledException)
120+ is_retryable_type = retryable_exception_types and isinstance (
121+ e , retryable_exception_types
122+ )
123+
124+ # Also check if exception name or message contains retryable error patterns
125+ exception_name = type (e ).__name__
126+ exception_str = str (e )
127+ is_retryable_name = exception_name in retryable_errors or any (
128+ err in exception_str for err in retryable_errors
129+ )
130+
131+ if (
132+ is_retryable_type or is_retryable_name
133+ ) and attempt < max_retries - 1 :
134+ # Log and retry
135+ log_bedrock_invocation_error (e , attempt + 1 )
136+ jitter_value = random .uniform (- jitter , jitter )
137+ sleep_time = max (0.1 , delay * (1 + jitter_value ))
138+ logger .warning (
139+ f"{ exception_name } : { exception_str } encountered in { func .__name__ } . "
140+ f"Retrying in { sleep_time :.2f} seconds. Attempt { attempt + 1 } /{ max_retries } "
141+ )
142+ await asyncio .sleep (sleep_time )
143+ delay = min (delay * exponential_base , max_delay )
144+ continue
145+
146+ # Log bedrock invocation details for non-retryable exceptions
108147 log_bedrock_invocation_error (e , attempt + 1 )
109148 raise
110149
0 commit comments