diff --git a/cadence/_internal/rpc/retry.py b/cadence/_internal/rpc/retry.py new file mode 100644 index 0000000..7e1f280 --- /dev/null +++ b/cadence/_internal/rpc/retry.py @@ -0,0 +1,87 @@ +import asyncio +from dataclasses import dataclass +from typing import Callable, Any + +from grpc import StatusCode +from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails + +from cadence.error import CadenceError, EntityNotExistsError + +RETRYABLE_CODES = { + StatusCode.INTERNAL, + StatusCode.RESOURCE_EXHAUSTED, + StatusCode.ABORTED, + StatusCode.UNAVAILABLE +} + +# No expiration interval, use the GRPC timeout value instead +@dataclass +class ExponentialRetryPolicy: + initial_interval: float + backoff_coefficient: float + max_interval: float + max_attempts: float + + def next_delay(self, attempts: int, elapsed: float, expiration: float) -> float | None: + if elapsed >= expiration: + return None + if self.max_attempts != 0 and attempts >= self.max_attempts: + return None + + backoff = min(self.initial_interval * pow(self.backoff_coefficient, attempts-1), self.max_interval) + if (elapsed + backoff) >= expiration: + return None + + return backoff + +DEFAULT_RETRY_POLICY = ExponentialRetryPolicy(initial_interval=0.02, backoff_coefficient=1.2, max_interval=6, max_attempts=0) +GET_WORKFLOW_HISTORY = b'/uber.cadence.api.v1.WorkflowAPI/GetWorkflowExecutionHistory' + +class RetryInterceptor(UnaryUnaryClientInterceptor): + def __init__(self, retry_policy: ExponentialRetryPolicy = DEFAULT_RETRY_POLICY): + super().__init__() + self._retry_policy = retry_policy + + async def intercept_unary_unary( + self, + continuation: Callable[[ClientCallDetails, Any], Any], + client_call_details: ClientCallDetails, + request: Any + ) -> Any: + loop = asyncio.get_running_loop() + expiration_interval = client_call_details.timeout + start_time = loop.time() + deadline = start_time + expiration_interval + + attempts = 0 + while True: + remaining = deadline - loop.time() + # Namedtuple methods start with an underscore to avoid conflicts and aren't actually private + # noinspection PyProtectedMember + call_details = client_call_details._replace(timeout=remaining) + rpc_call = await continuation(call_details, request) + try: + # Return the result directly if success. GRPC will wrap it back into a UnaryUnaryCall + return await rpc_call + except CadenceError as e: + err = e + + attempts += 1 + elapsed = loop.time() - start_time + backoff = self._retry_policy.next_delay(attempts, elapsed, expiration_interval) + if not is_retryable(err, client_call_details) or backoff is None: + break + + await asyncio.sleep(backoff) + + # On policy expiration, return the most recent UnaryUnaryCall. It has the error we want + return rpc_call + + + +def is_retryable(err: CadenceError, call_details: ClientCallDetails) -> bool: + # Handle requests to the passive side, matching the Go and Java Clients + if call_details.method == GET_WORKFLOW_HISTORY and isinstance(err, EntityNotExistsError): + return err.active_cluster is not None and err.current_cluster is not None and err.active_cluster != err.current_cluster + + return err.code in RETRYABLE_CODES diff --git a/cadence/_internal/rpc/yarpc.py b/cadence/_internal/rpc/yarpc.py index 1c2cdbb..42f7994 100644 --- a/cadence/_internal/rpc/yarpc.py +++ b/cadence/_internal/rpc/yarpc.py @@ -1,18 +1,9 @@ -import collections from typing import Any, Callable from grpc.aio import Metadata from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails -class _ClientCallDetails( - collections.namedtuple( - "_ClientCallDetails", ("method", "timeout", "metadata", "credentials", "wait_for_ready") - ), - ClientCallDetails, -): - pass - SERVICE_KEY = "rpc-service" CALLER_KEY = "rpc-caller" ENCODING_KEY = "rpc-encoding" @@ -42,11 +33,6 @@ def _replace_details(self, client_call_details: ClientCallDetails) -> ClientCall else: metadata += self._metadata - return _ClientCallDetails( - method=client_call_details.method, - # YARPC seems to require a TTL value - timeout=client_call_details.timeout or 60.0, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=client_call_details.wait_for_ready, - ) + # Namedtuple methods start with an underscore to avoid conflicts and aren't actually private + # noinspection PyProtectedMember + return client_call_details._replace(metadata=metadata, timeout=client_call_details.timeout or 60.0) diff --git a/cadence/client.py b/cadence/client.py index 3a0c5e7..8294da5 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -5,6 +5,7 @@ from grpc import ChannelCredentials, Compression from cadence._internal.rpc.error import CadenceErrorInterceptor +from cadence._internal.rpc.retry import RetryInterceptor from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor from cadence.api.v1.service_domain_pb2_grpc import DomainAPIStub from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub @@ -91,8 +92,9 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: def _create_channel(options: ClientOptions) -> Channel: interceptors = list(options["interceptors"]) - interceptors.append(CadenceErrorInterceptor()) interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"])) + interceptors.append(RetryInterceptor()) + interceptors.append(CadenceErrorInterceptor()) if options["credentials"]: return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors) diff --git a/tests/cadence/_internal/rpc/test_retry.py b/tests/cadence/_internal/rpc/test_retry.py new file mode 100644 index 0000000..1874341 --- /dev/null +++ b/tests/cadence/_internal/rpc/test_retry.py @@ -0,0 +1,167 @@ +from concurrent import futures +from typing import Tuple, Type + +import pytest +from google.protobuf import any_pb2 +from google.rpc import status_pb2, code_pb2 +from grpc import server +from grpc.aio import insecure_channel +from grpc_status.rpc_status import to_status + +from cadence._internal.rpc.error import CadenceErrorInterceptor +from cadence.api.v1 import error_pb2, service_workflow_pb2_grpc + +from cadence._internal.rpc.retry import ExponentialRetryPolicy, RetryInterceptor +from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionResponse, \ + DescribeWorkflowExecutionRequest, GetWorkflowExecutionHistoryRequest +from cadence.error import CadenceError, FeatureNotEnabledError, EntityNotExistsError + +simple_policy = ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=2, max_interval=10, max_attempts=6) + +@pytest.mark.parametrize( + "policy,params,expected", + [ + pytest.param( + simple_policy, (1, 0.0, 100.0), 1, id="happy path" + ), + pytest.param( + simple_policy, (2, 0.0, 100.0), 2, id="second attempt" + ), + pytest.param( + simple_policy, (3, 0.0, 100.0), 4, id="third attempt" + ), + pytest.param( + simple_policy, (5, 0.0, 100.0), 10, id="capped by max_interval" + ), + pytest.param( + simple_policy, (6, 0.0, 100.0), None, id="out of attempts" + ), + pytest.param( + simple_policy, (1, 100.0, 100.0), None, id="timeout" + ), + pytest.param( + simple_policy, (1, 99.0, 100.0), None, id="backoff causes timeout" + ), + pytest.param( + ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=1, max_interval=10, max_attempts=0), (100, 0.0, 100.0), 1, id="unlimited retries" + ), + ] +) +def test_next_delay(policy: ExponentialRetryPolicy, params: Tuple[int, float, float], expected: float | None): + assert policy.next_delay(*params) == expected + + +class FakeService(service_workflow_pb2_grpc.WorkflowAPIServicer): + def __init__(self) -> None: + super().__init__() + self.port = None + self.counter = 0 + + # Retryable only because it's GetWorkflowExecutionHistory + def GetWorkflowExecutionHistory(self, request: GetWorkflowExecutionHistoryRequest, context): + self.counter += 1 + + detail = any_pb2.Any() + detail.Pack(error_pb2.EntityNotExistsError(current_cluster=request.domain, active_cluster="active")) + status_proto = status_pb2.Status( + code=code_pb2.NOT_FOUND, + message="message", + details=[detail], + ) + context.abort_with_status(to_status(status_proto)) + # Unreachable + + + # Not retryable + def DescribeWorkflowExecution(self, request: DescribeWorkflowExecutionRequest, context): + self.counter += 1 + + if request.domain == "success": + return DescribeWorkflowExecutionResponse() + elif request.domain == "retryable": + code = code_pb2.RESOURCE_EXHAUSTED + elif request.domain == "maybe later": + if self.counter >= 3: + return DescribeWorkflowExecutionResponse() + + code = code_pb2.RESOURCE_EXHAUSTED + else: + code = code_pb2.PERMISSION_DENIED + + detail = any_pb2.Any() + detail.Pack(error_pb2.FeatureNotEnabledError(feature_flag="the flag")) + status_proto = status_pb2.Status( + code=code, + message="message", + details=[detail], + ) + context.abort_with_status(to_status(status_proto)) + # Unreachable + + +@pytest.fixture(scope="module") +def fake_service(): + fake = FakeService() + sync_server = server(futures.ThreadPoolExecutor(max_workers=1)) + service_workflow_pb2_grpc.add_WorkflowAPIServicer_to_server(fake, sync_server) + fake.port = sync_server.add_insecure_port("[::]:0") + sync_server.start() + yield fake + sync_server.stop(grace=None) + +TEST_POLICY = ExponentialRetryPolicy(initial_interval=0, backoff_coefficient=0, max_interval=10, max_attempts=10) + +@pytest.mark.usefixtures("fake_service") +@pytest.mark.parametrize( + "case,expected_calls,expected_err", + [ + pytest.param( + "success", 1, None, id="happy path" + ), + pytest.param( + "maybe later", 3, None, id="retries then success" + ), + pytest.param( + "not retryable", 1, FeatureNotEnabledError, id="not retryable" + ), + pytest.param( + "retryable", TEST_POLICY.max_attempts, FeatureNotEnabledError, id="retries exhausted" + ), + + ] +) +@pytest.mark.asyncio +async def test_retryable_error(fake_service, case: str, expected_calls: int, expected_err: Type[CadenceError]): + fake_service.counter = 0 + async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel: + stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel) + if expected_err: + with pytest.raises(expected_err): + await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10) + else: + await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10) + + assert fake_service.counter == expected_calls + +@pytest.mark.usefixtures("fake_service") +@pytest.mark.parametrize( + "case,expected_calls", + [ + pytest.param( + "active", 1, id="not retryable" + ), + pytest.param( + "not active", TEST_POLICY.max_attempts, id="retries exhausted" + ), + + ] +) +@pytest.mark.asyncio +async def test_workflow_history(fake_service, case: str, expected_calls: int): + fake_service.counter = 0 + async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel: + stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel) + with pytest.raises(EntityNotExistsError): + await stub.GetWorkflowExecutionHistory(GetWorkflowExecutionHistoryRequest(domain=case), timeout=10) + + assert fake_service.counter == expected_calls \ No newline at end of file