From 3de2ba9996e54c7600a42fdc17406ba9e0e92a9a Mon Sep 17 00:00:00 2001 From: Nate Mortensen Date: Fri, 5 Sep 2025 14:36:32 -0700 Subject: [PATCH] Add RPC Retries via interceptor Add RPC retries via an interceptor, with exponential backoff matching both the Go and Java clients. The approach here differs from both however in two main ways: 1. Adding these via interceptor makes it implicit, while both the Go and Java client require it to be explicit. 2. The specific requests to retry are based on GRPC error codes, rather than explicitly listing non-retryable errors and retrying everything by default. This seems like a more sustainable approach, since nearly every error type is non-retryable. A newly introduced error type would require a client update to mark it non-retryable before it could safely be used. Any time the python client doesn't recognize an error it gets mapped to just CadenceError, so new errors can safely be added. --- cadence/_internal/rpc/retry.py | 87 +++++++++++ cadence/_internal/rpc/yarpc.py | 20 +-- cadence/client.py | 4 +- tests/cadence/_internal/rpc/test_retry.py | 167 ++++++++++++++++++++++ 4 files changed, 260 insertions(+), 18 deletions(-) create mode 100644 cadence/_internal/rpc/retry.py create mode 100644 tests/cadence/_internal/rpc/test_retry.py diff --git a/cadence/_internal/rpc/retry.py b/cadence/_internal/rpc/retry.py new file mode 100644 index 0000000..7e1f280 --- /dev/null +++ b/cadence/_internal/rpc/retry.py @@ -0,0 +1,87 @@ +import asyncio +from dataclasses import dataclass +from typing import Callable, Any + +from grpc import StatusCode +from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails + +from cadence.error import CadenceError, EntityNotExistsError + +RETRYABLE_CODES = { + StatusCode.INTERNAL, + StatusCode.RESOURCE_EXHAUSTED, + StatusCode.ABORTED, + StatusCode.UNAVAILABLE +} + +# No expiration interval, use the GRPC timeout value instead +@dataclass +class ExponentialRetryPolicy: + initial_interval: float + backoff_coefficient: float + max_interval: float + max_attempts: float + + def next_delay(self, attempts: int, elapsed: float, expiration: float) -> float | None: + if elapsed >= expiration: + return None + if self.max_attempts != 0 and attempts >= self.max_attempts: + return None + + backoff = min(self.initial_interval * pow(self.backoff_coefficient, attempts-1), self.max_interval) + if (elapsed + backoff) >= expiration: + return None + + return backoff + +DEFAULT_RETRY_POLICY = ExponentialRetryPolicy(initial_interval=0.02, backoff_coefficient=1.2, max_interval=6, max_attempts=0) +GET_WORKFLOW_HISTORY = b'/uber.cadence.api.v1.WorkflowAPI/GetWorkflowExecutionHistory' + +class RetryInterceptor(UnaryUnaryClientInterceptor): + def __init__(self, retry_policy: ExponentialRetryPolicy = DEFAULT_RETRY_POLICY): + super().__init__() + self._retry_policy = retry_policy + + async def intercept_unary_unary( + self, + continuation: Callable[[ClientCallDetails, Any], Any], + client_call_details: ClientCallDetails, + request: Any + ) -> Any: + loop = asyncio.get_running_loop() + expiration_interval = client_call_details.timeout + start_time = loop.time() + deadline = start_time + expiration_interval + + attempts = 0 + while True: + remaining = deadline - loop.time() + # Namedtuple methods start with an underscore to avoid conflicts and aren't actually private + # noinspection PyProtectedMember + call_details = client_call_details._replace(timeout=remaining) + rpc_call = await continuation(call_details, request) + try: + # Return the result directly if success. GRPC will wrap it back into a UnaryUnaryCall + return await rpc_call + except CadenceError as e: + err = e + + attempts += 1 + elapsed = loop.time() - start_time + backoff = self._retry_policy.next_delay(attempts, elapsed, expiration_interval) + if not is_retryable(err, client_call_details) or backoff is None: + break + + await asyncio.sleep(backoff) + + # On policy expiration, return the most recent UnaryUnaryCall. It has the error we want + return rpc_call + + + +def is_retryable(err: CadenceError, call_details: ClientCallDetails) -> bool: + # Handle requests to the passive side, matching the Go and Java Clients + if call_details.method == GET_WORKFLOW_HISTORY and isinstance(err, EntityNotExistsError): + return err.active_cluster is not None and err.current_cluster is not None and err.active_cluster != err.current_cluster + + return err.code in RETRYABLE_CODES diff --git a/cadence/_internal/rpc/yarpc.py b/cadence/_internal/rpc/yarpc.py index 1c2cdbb..42f7994 100644 --- a/cadence/_internal/rpc/yarpc.py +++ b/cadence/_internal/rpc/yarpc.py @@ -1,18 +1,9 @@ -import collections from typing import Any, Callable from grpc.aio import Metadata from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails -class _ClientCallDetails( - collections.namedtuple( - "_ClientCallDetails", ("method", "timeout", "metadata", "credentials", "wait_for_ready") - ), - ClientCallDetails, -): - pass - SERVICE_KEY = "rpc-service" CALLER_KEY = "rpc-caller" ENCODING_KEY = "rpc-encoding" @@ -42,11 +33,6 @@ def _replace_details(self, client_call_details: ClientCallDetails) -> ClientCall else: metadata += self._metadata - return _ClientCallDetails( - method=client_call_details.method, - # YARPC seems to require a TTL value - timeout=client_call_details.timeout or 60.0, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=client_call_details.wait_for_ready, - ) + # Namedtuple methods start with an underscore to avoid conflicts and aren't actually private + # noinspection PyProtectedMember + return client_call_details._replace(metadata=metadata, timeout=client_call_details.timeout or 60.0) diff --git a/cadence/client.py b/cadence/client.py index 3a0c5e7..8294da5 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -5,6 +5,7 @@ from grpc import ChannelCredentials, Compression from cadence._internal.rpc.error import CadenceErrorInterceptor +from cadence._internal.rpc.retry import RetryInterceptor from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor from cadence.api.v1.service_domain_pb2_grpc import DomainAPIStub from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub @@ -91,8 +92,9 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: def _create_channel(options: ClientOptions) -> Channel: interceptors = list(options["interceptors"]) - interceptors.append(CadenceErrorInterceptor()) interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"])) + interceptors.append(RetryInterceptor()) + interceptors.append(CadenceErrorInterceptor()) if options["credentials"]: return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors) diff --git a/tests/cadence/_internal/rpc/test_retry.py b/tests/cadence/_internal/rpc/test_retry.py new file mode 100644 index 0000000..1874341 --- /dev/null +++ b/tests/cadence/_internal/rpc/test_retry.py @@ -0,0 +1,167 @@ +from concurrent import futures +from typing import Tuple, Type + +import pytest +from google.protobuf import any_pb2 +from google.rpc import status_pb2, code_pb2 +from grpc import server +from grpc.aio import insecure_channel +from grpc_status.rpc_status import to_status + +from cadence._internal.rpc.error import CadenceErrorInterceptor +from cadence.api.v1 import error_pb2, service_workflow_pb2_grpc + +from cadence._internal.rpc.retry import ExponentialRetryPolicy, RetryInterceptor +from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionResponse, \ + DescribeWorkflowExecutionRequest, GetWorkflowExecutionHistoryRequest +from cadence.error import CadenceError, FeatureNotEnabledError, EntityNotExistsError + +simple_policy = ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=2, max_interval=10, max_attempts=6) + +@pytest.mark.parametrize( + "policy,params,expected", + [ + pytest.param( + simple_policy, (1, 0.0, 100.0), 1, id="happy path" + ), + pytest.param( + simple_policy, (2, 0.0, 100.0), 2, id="second attempt" + ), + pytest.param( + simple_policy, (3, 0.0, 100.0), 4, id="third attempt" + ), + pytest.param( + simple_policy, (5, 0.0, 100.0), 10, id="capped by max_interval" + ), + pytest.param( + simple_policy, (6, 0.0, 100.0), None, id="out of attempts" + ), + pytest.param( + simple_policy, (1, 100.0, 100.0), None, id="timeout" + ), + pytest.param( + simple_policy, (1, 99.0, 100.0), None, id="backoff causes timeout" + ), + pytest.param( + ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=1, max_interval=10, max_attempts=0), (100, 0.0, 100.0), 1, id="unlimited retries" + ), + ] +) +def test_next_delay(policy: ExponentialRetryPolicy, params: Tuple[int, float, float], expected: float | None): + assert policy.next_delay(*params) == expected + + +class FakeService(service_workflow_pb2_grpc.WorkflowAPIServicer): + def __init__(self) -> None: + super().__init__() + self.port = None + self.counter = 0 + + # Retryable only because it's GetWorkflowExecutionHistory + def GetWorkflowExecutionHistory(self, request: GetWorkflowExecutionHistoryRequest, context): + self.counter += 1 + + detail = any_pb2.Any() + detail.Pack(error_pb2.EntityNotExistsError(current_cluster=request.domain, active_cluster="active")) + status_proto = status_pb2.Status( + code=code_pb2.NOT_FOUND, + message="message", + details=[detail], + ) + context.abort_with_status(to_status(status_proto)) + # Unreachable + + + # Not retryable + def DescribeWorkflowExecution(self, request: DescribeWorkflowExecutionRequest, context): + self.counter += 1 + + if request.domain == "success": + return DescribeWorkflowExecutionResponse() + elif request.domain == "retryable": + code = code_pb2.RESOURCE_EXHAUSTED + elif request.domain == "maybe later": + if self.counter >= 3: + return DescribeWorkflowExecutionResponse() + + code = code_pb2.RESOURCE_EXHAUSTED + else: + code = code_pb2.PERMISSION_DENIED + + detail = any_pb2.Any() + detail.Pack(error_pb2.FeatureNotEnabledError(feature_flag="the flag")) + status_proto = status_pb2.Status( + code=code, + message="message", + details=[detail], + ) + context.abort_with_status(to_status(status_proto)) + # Unreachable + + +@pytest.fixture(scope="module") +def fake_service(): + fake = FakeService() + sync_server = server(futures.ThreadPoolExecutor(max_workers=1)) + service_workflow_pb2_grpc.add_WorkflowAPIServicer_to_server(fake, sync_server) + fake.port = sync_server.add_insecure_port("[::]:0") + sync_server.start() + yield fake + sync_server.stop(grace=None) + +TEST_POLICY = ExponentialRetryPolicy(initial_interval=0, backoff_coefficient=0, max_interval=10, max_attempts=10) + +@pytest.mark.usefixtures("fake_service") +@pytest.mark.parametrize( + "case,expected_calls,expected_err", + [ + pytest.param( + "success", 1, None, id="happy path" + ), + pytest.param( + "maybe later", 3, None, id="retries then success" + ), + pytest.param( + "not retryable", 1, FeatureNotEnabledError, id="not retryable" + ), + pytest.param( + "retryable", TEST_POLICY.max_attempts, FeatureNotEnabledError, id="retries exhausted" + ), + + ] +) +@pytest.mark.asyncio +async def test_retryable_error(fake_service, case: str, expected_calls: int, expected_err: Type[CadenceError]): + fake_service.counter = 0 + async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel: + stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel) + if expected_err: + with pytest.raises(expected_err): + await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10) + else: + await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10) + + assert fake_service.counter == expected_calls + +@pytest.mark.usefixtures("fake_service") +@pytest.mark.parametrize( + "case,expected_calls", + [ + pytest.param( + "active", 1, id="not retryable" + ), + pytest.param( + "not active", TEST_POLICY.max_attempts, id="retries exhausted" + ), + + ] +) +@pytest.mark.asyncio +async def test_workflow_history(fake_service, case: str, expected_calls: int): + fake_service.counter = 0 + async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel: + stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel) + with pytest.raises(EntityNotExistsError): + await stub.GetWorkflowExecutionHistory(GetWorkflowExecutionHistoryRequest(domain=case), timeout=10) + + assert fake_service.counter == expected_calls \ No newline at end of file