Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions cadence/_internal/rpc/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import asyncio
from dataclasses import dataclass
from typing import Callable, Any

from grpc import StatusCode
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails

from cadence.error import CadenceError, EntityNotExistsError

RETRYABLE_CODES = {
StatusCode.INTERNAL,
StatusCode.RESOURCE_EXHAUSTED,
StatusCode.ABORTED,
StatusCode.UNAVAILABLE
}

# No expiration interval, use the GRPC timeout value instead
@dataclass
class ExponentialRetryPolicy:
initial_interval: float
backoff_coefficient: float
max_interval: float
max_attempts: float

def next_delay(self, attempts: int, elapsed: float, expiration: float) -> float | None:
if elapsed >= expiration:
return None
if self.max_attempts != 0 and attempts >= self.max_attempts:
return None

backoff = min(self.initial_interval * pow(self.backoff_coefficient, attempts-1), self.max_interval)
if (elapsed + backoff) >= expiration:
return None

return backoff

DEFAULT_RETRY_POLICY = ExponentialRetryPolicy(initial_interval=0.02, backoff_coefficient=1.2, max_interval=6, max_attempts=0)
GET_WORKFLOW_HISTORY = b'/uber.cadence.api.v1.WorkflowAPI/GetWorkflowExecutionHistory'

class RetryInterceptor(UnaryUnaryClientInterceptor):
def __init__(self, retry_policy: ExponentialRetryPolicy = DEFAULT_RETRY_POLICY):
super().__init__()
self._retry_policy = retry_policy

async def intercept_unary_unary(
self,
continuation: Callable[[ClientCallDetails, Any], Any],
client_call_details: ClientCallDetails,
request: Any
) -> Any:
loop = asyncio.get_running_loop()
expiration_interval = client_call_details.timeout
start_time = loop.time()
deadline = start_time + expiration_interval

attempts = 0
while True:
remaining = deadline - loop.time()
# Namedtuple methods start with an underscore to avoid conflicts and aren't actually private
# noinspection PyProtectedMember
call_details = client_call_details._replace(timeout=remaining)
rpc_call = await continuation(call_details, request)
try:
# Return the result directly if success. GRPC will wrap it back into a UnaryUnaryCall
return await rpc_call
except CadenceError as e:
err = e

attempts += 1
elapsed = loop.time() - start_time
backoff = self._retry_policy.next_delay(attempts, elapsed, expiration_interval)
if not is_retryable(err, client_call_details) or backoff is None:
break

await asyncio.sleep(backoff)

# On policy expiration, return the most recent UnaryUnaryCall. It has the error we want
return rpc_call



def is_retryable(err: CadenceError, call_details: ClientCallDetails) -> bool:
# Handle requests to the passive side, matching the Go and Java Clients
if call_details.method == GET_WORKFLOW_HISTORY and isinstance(err, EntityNotExistsError):
return err.active_cluster is not None and err.current_cluster is not None and err.active_cluster != err.current_cluster

return err.code in RETRYABLE_CODES
20 changes: 3 additions & 17 deletions cadence/_internal/rpc/yarpc.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
import collections
from typing import Any, Callable

from grpc.aio import Metadata
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails


class _ClientCallDetails(
collections.namedtuple(
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials", "wait_for_ready")
),
ClientCallDetails,
):
pass

SERVICE_KEY = "rpc-service"
CALLER_KEY = "rpc-caller"
ENCODING_KEY = "rpc-encoding"
Expand Down Expand Up @@ -42,11 +33,6 @@ def _replace_details(self, client_call_details: ClientCallDetails) -> ClientCall
else:
metadata += self._metadata

return _ClientCallDetails(
method=client_call_details.method,
# YARPC seems to require a TTL value
timeout=client_call_details.timeout or 60.0,
metadata=metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready,
)
# Namedtuple methods start with an underscore to avoid conflicts and aren't actually private
# noinspection PyProtectedMember
return client_call_details._replace(metadata=metadata, timeout=client_call_details.timeout or 60.0)
4 changes: 3 additions & 1 deletion cadence/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from grpc import ChannelCredentials, Compression

from cadence._internal.rpc.error import CadenceErrorInterceptor
from cadence._internal.rpc.retry import RetryInterceptor
from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor
from cadence.api.v1.service_domain_pb2_grpc import DomainAPIStub
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
Expand Down Expand Up @@ -91,8 +92,9 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:

def _create_channel(options: ClientOptions) -> Channel:
interceptors = list(options["interceptors"])
interceptors.append(CadenceErrorInterceptor())
interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"]))
interceptors.append(RetryInterceptor())
interceptors.append(CadenceErrorInterceptor())

if options["credentials"]:
return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors)
Expand Down
167 changes: 167 additions & 0 deletions tests/cadence/_internal/rpc/test_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from concurrent import futures
from typing import Tuple, Type

import pytest
from google.protobuf import any_pb2
from google.rpc import status_pb2, code_pb2
from grpc import server
from grpc.aio import insecure_channel
from grpc_status.rpc_status import to_status

from cadence._internal.rpc.error import CadenceErrorInterceptor
from cadence.api.v1 import error_pb2, service_workflow_pb2_grpc

from cadence._internal.rpc.retry import ExponentialRetryPolicy, RetryInterceptor
from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionResponse, \
DescribeWorkflowExecutionRequest, GetWorkflowExecutionHistoryRequest
from cadence.error import CadenceError, FeatureNotEnabledError, EntityNotExistsError

simple_policy = ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=2, max_interval=10, max_attempts=6)

@pytest.mark.parametrize(
"policy,params,expected",
[
pytest.param(
simple_policy, (1, 0.0, 100.0), 1, id="happy path"
),
pytest.param(
simple_policy, (2, 0.0, 100.0), 2, id="second attempt"
),
pytest.param(
simple_policy, (3, 0.0, 100.0), 4, id="third attempt"
),
pytest.param(
simple_policy, (5, 0.0, 100.0), 10, id="capped by max_interval"
),
pytest.param(
simple_policy, (6, 0.0, 100.0), None, id="out of attempts"
),
pytest.param(
simple_policy, (1, 100.0, 100.0), None, id="timeout"
),
pytest.param(
simple_policy, (1, 99.0, 100.0), None, id="backoff causes timeout"
),
pytest.param(
ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=1, max_interval=10, max_attempts=0), (100, 0.0, 100.0), 1, id="unlimited retries"
),
]
)
def test_next_delay(policy: ExponentialRetryPolicy, params: Tuple[int, float, float], expected: float | None):
assert policy.next_delay(*params) == expected


class FakeService(service_workflow_pb2_grpc.WorkflowAPIServicer):
def __init__(self) -> None:
super().__init__()
self.port = None
self.counter = 0

# Retryable only because it's GetWorkflowExecutionHistory
def GetWorkflowExecutionHistory(self, request: GetWorkflowExecutionHistoryRequest, context):
self.counter += 1

detail = any_pb2.Any()
detail.Pack(error_pb2.EntityNotExistsError(current_cluster=request.domain, active_cluster="active"))
status_proto = status_pb2.Status(
code=code_pb2.NOT_FOUND,
message="message",
details=[detail],
)
context.abort_with_status(to_status(status_proto))
# Unreachable


# Not retryable
def DescribeWorkflowExecution(self, request: DescribeWorkflowExecutionRequest, context):
self.counter += 1

if request.domain == "success":
return DescribeWorkflowExecutionResponse()
elif request.domain == "retryable":
code = code_pb2.RESOURCE_EXHAUSTED
elif request.domain == "maybe later":
if self.counter >= 3:
return DescribeWorkflowExecutionResponse()

code = code_pb2.RESOURCE_EXHAUSTED
else:
code = code_pb2.PERMISSION_DENIED

detail = any_pb2.Any()
detail.Pack(error_pb2.FeatureNotEnabledError(feature_flag="the flag"))
status_proto = status_pb2.Status(
code=code,
message="message",
details=[detail],
)
context.abort_with_status(to_status(status_proto))
# Unreachable


@pytest.fixture(scope="module")
def fake_service():
fake = FakeService()
sync_server = server(futures.ThreadPoolExecutor(max_workers=1))
service_workflow_pb2_grpc.add_WorkflowAPIServicer_to_server(fake, sync_server)
fake.port = sync_server.add_insecure_port("[::]:0")
sync_server.start()
yield fake
sync_server.stop(grace=None)

TEST_POLICY = ExponentialRetryPolicy(initial_interval=0, backoff_coefficient=0, max_interval=10, max_attempts=10)

@pytest.mark.usefixtures("fake_service")
@pytest.mark.parametrize(
"case,expected_calls,expected_err",
[
pytest.param(
"success", 1, None, id="happy path"
),
pytest.param(
"maybe later", 3, None, id="retries then success"
),
pytest.param(
"not retryable", 1, FeatureNotEnabledError, id="not retryable"
),
pytest.param(
"retryable", TEST_POLICY.max_attempts, FeatureNotEnabledError, id="retries exhausted"
),

]
)
@pytest.mark.asyncio
async def test_retryable_error(fake_service, case: str, expected_calls: int, expected_err: Type[CadenceError]):
fake_service.counter = 0
async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel:
stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel)
if expected_err:
with pytest.raises(expected_err):
await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10)
else:
await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10)

assert fake_service.counter == expected_calls

@pytest.mark.usefixtures("fake_service")
@pytest.mark.parametrize(
"case,expected_calls",
[
pytest.param(
"active", 1, id="not retryable"
),
pytest.param(
"not active", TEST_POLICY.max_attempts, id="retries exhausted"
),

]
)
@pytest.mark.asyncio
async def test_workflow_history(fake_service, case: str, expected_calls: int):
fake_service.counter = 0
async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel:
stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel)
with pytest.raises(EntityNotExistsError):
await stub.GetWorkflowExecutionHistory(GetWorkflowExecutionHistoryRequest(domain=case), timeout=10)

assert fake_service.counter == expected_calls