Skip to content

Commit d4e6312

Browse files
committed
improve retry mechanism
1 parent bb78c71 commit d4e6312

File tree

1 file changed

+56
-17
lines changed

1 file changed

+56
-17
lines changed

lib/idp_common_pkg/idp_common/utils/bedrock_utils.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,33 @@
1919
InvokeModelRequestTypeDef,
2020
InvokeModelResponseTypeDef,
2121
)
22+
from strands.models.bedrock import ModelThrottledException
2223

2324
# Configure logger
2425
logger = logging.getLogger(__name__)
2526
logger.setLevel(os.environ.get("LOG_LEVEL", "INFO"))
2627

28+
# Default retryable error codes (matched against ClientError codes and exception messages)
29+
DEFAULT_RETRYABLE_ERRORS = {
30+
"ThrottlingException",
31+
"throttlingException",
32+
"ModelThrottledException", # Strands wrapper for throttling
33+
"ModelErrorException",
34+
"ValidationException",
35+
"ServiceQuotaExceededException",
36+
"RequestLimitExceeded",
37+
"TooManyRequestsException",
38+
"ServiceUnavailableException",
39+
"serviceUnavailableException", # lowercase variant from EventStreamError
40+
"RequestTimeout",
41+
"RequestTimeoutException",
42+
}
43+
44+
# Default retryable exception types (caught by isinstance check)
45+
DEFAULT_RETRYABLE_EXCEPTION_TYPES: tuple[type[Exception], ...] = (
46+
ModelThrottledException,
47+
)
48+
2749

2850
def async_exponential_backoff_retry[T, **P](
2951
max_retries: int = 5,
@@ -32,23 +54,13 @@ def async_exponential_backoff_retry[T, **P](
3254
exponential_base: float = 2.0,
3355
jitter: float = 0.1,
3456
retryable_errors: set[str] | None = None,
57+
retryable_exception_types: tuple[type[Exception], ...] | None = None,
3558
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
36-
if not retryable_errors:
37-
retryable_errors = set(
38-
[
39-
"ThrottlingException",
40-
"throttlingException",
41-
"ModelErrorException",
42-
"ValidationException",
43-
"ServiceQuotaExceededException",
44-
"RequestLimitExceeded",
45-
"TooManyRequestsException",
46-
"ServiceUnavailableException",
47-
"serviceUnavailableException", # lowercase variant from EventStreamError
48-
"RequestTimeout",
49-
"RequestTimeoutException",
50-
]
51-
)
59+
# Use defaults if not provided
60+
if retryable_errors is None:
61+
retryable_errors = DEFAULT_RETRYABLE_ERRORS
62+
if retryable_exception_types is None:
63+
retryable_exception_types = DEFAULT_RETRYABLE_EXCEPTION_TYPES
5264

5365
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
5466
@wraps(func)
@@ -104,7 +116,34 @@ def log_bedrock_invocation_error(error: Exception, attempt_num: int):
104116
await asyncio.sleep(sleep_time)
105117
delay = min(delay * exponential_base, max_delay)
106118
except Exception as e:
107-
# Log bedrock invocation details for non-ClientError exceptions too
119+
# Check if this is a retryable exception type (e.g., Strands ModelThrottledException)
120+
is_retryable_type = retryable_exception_types and isinstance(
121+
e, retryable_exception_types
122+
)
123+
124+
# Also check if exception name or message contains retryable error patterns
125+
exception_name = type(e).__name__
126+
exception_str = str(e)
127+
is_retryable_name = exception_name in retryable_errors or any(
128+
err in exception_str for err in retryable_errors
129+
)
130+
131+
if (
132+
is_retryable_type or is_retryable_name
133+
) and attempt < max_retries - 1:
134+
# Log and retry
135+
log_bedrock_invocation_error(e, attempt + 1)
136+
jitter_value = random.uniform(-jitter, jitter)
137+
sleep_time = max(0.1, delay * (1 + jitter_value))
138+
logger.warning(
139+
f"{exception_name}: {exception_str} encountered in {func.__name__}. "
140+
f"Retrying in {sleep_time:.2f} seconds. Attempt {attempt + 1}/{max_retries}"
141+
)
142+
await asyncio.sleep(sleep_time)
143+
delay = min(delay * exponential_base, max_delay)
144+
continue
145+
146+
# Log bedrock invocation details for non-retryable exceptions
108147
log_bedrock_invocation_error(e, attempt + 1)
109148
raise
110149

0 commit comments

Comments
 (0)