diff --git a/cadence/_internal/rpc/error.py b/cadence/_internal/rpc/error.py new file mode 100644 index 0000000..d2dbd14 --- /dev/null +++ b/cadence/_internal/rpc/error.py @@ -0,0 +1,121 @@ +from typing import Callable, Any, Optional, Generator, TypeVar + +import grpc +from google.rpc.status_pb2 import Status # type: ignore +from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails, AioRpcError, UnaryUnaryCall, Metadata +from grpc_status.rpc_status import from_call # type: ignore + +from cadence.api.v1 import error_pb2 +from cadence import error + + +RequestType = TypeVar("RequestType") +ResponseType = TypeVar("ResponseType") +DoneCallbackType = Callable[[Any], None] + + +# A UnaryUnaryCall is an awaitable type returned by GRPC's aio support. +# We need to take the UnaryUnaryCall we receive and return one that remaps the exception. +# It doesn't have any functions to compose operations together, so our only option is to wrap it. +# If the interceptor directly throws an exception other than AioRpcError it breaks GRPC +class CadenceErrorUnaryUnaryCall(UnaryUnaryCall[RequestType, ResponseType]): + + def __init__(self, wrapped: UnaryUnaryCall[RequestType, ResponseType]): + super().__init__() + self._wrapped = wrapped + + def __await__(self) -> Generator[Any, None, ResponseType]: + try: + response = yield from self._wrapped.__await__() # type: ResponseType + return response + except AioRpcError as e: + raise map_error(e) + + async def initial_metadata(self) -> Metadata: + return await self._wrapped.initial_metadata() + + async def trailing_metadata(self) -> Metadata: + return await self._wrapped.trailing_metadata() + + async def code(self) -> grpc.StatusCode: + return await self._wrapped.code() + + async def details(self) -> str: + return await self._wrapped.details() # type: ignore + + async def wait_for_connection(self) -> None: + await self._wrapped.wait_for_connection() + + def cancelled(self) -> bool: + return self._wrapped.cancelled() # type: ignore + + def done(self) -> bool: + return self._wrapped.done() # type: ignore + + def time_remaining(self) -> Optional[float]: + return self._wrapped.time_remaining() # type: ignore + + def cancel(self) -> bool: + return self._wrapped.cancel() # type: ignore + + def add_done_callback(self, callback: DoneCallbackType) -> None: + self._wrapped.add_done_callback(callback) + + +class CadenceErrorInterceptor(UnaryUnaryClientInterceptor): + + async def intercept_unary_unary( + self, + continuation: Callable[[ClientCallDetails, Any], Any], + client_call_details: ClientCallDetails, + request: Any + ) -> Any: + rpc_call = await continuation(client_call_details, request) + return CadenceErrorUnaryUnaryCall(rpc_call) + + + + +def map_error(e: AioRpcError) -> error.CadenceError: + status: Status | None = from_call(e) + if not status or not status.details: + return error.CadenceError(e.details(), e.code()) + + details = status.details[0] + if details.Is(error_pb2.WorkflowExecutionAlreadyStartedError.DESCRIPTOR): + already_started = error_pb2.WorkflowExecutionAlreadyStartedError() + details.Unpack(already_started) + return error.WorkflowExecutionAlreadyStartedError(e.details(), e.code(), already_started.start_request_id, already_started.run_id) + elif details.Is(error_pb2.EntityNotExistsError.DESCRIPTOR): + not_exists = error_pb2.EntityNotExistsError() + details.Unpack(not_exists) + return error.EntityNotExistsError(e.details(), e.code(), not_exists.current_cluster, not_exists.active_cluster, list(not_exists.active_clusters)) + elif details.Is(error_pb2.WorkflowExecutionAlreadyCompletedError.DESCRIPTOR): + return error.WorkflowExecutionAlreadyCompletedError(e.details(), e.code()) + elif details.Is(error_pb2.DomainNotActiveError.DESCRIPTOR): + not_active = error_pb2.DomainNotActiveError() + details.Unpack(not_active) + return error.DomainNotActiveError(e.details(), e.code(), not_active.domain, not_active.current_cluster, not_active.active_cluster, list(not_active.active_clusters)) + elif details.Is(error_pb2.ClientVersionNotSupportedError.DESCRIPTOR): + not_supported = error_pb2.ClientVersionNotSupportedError() + details.Unpack(not_supported) + return error.ClientVersionNotSupportedError(e.details(), e.code(), not_supported.feature_version, not_supported.client_impl, not_supported.supported_versions) + elif details.Is(error_pb2.FeatureNotEnabledError.DESCRIPTOR): + not_enabled = error_pb2.FeatureNotEnabledError() + details.Unpack(not_enabled) + return error.FeatureNotEnabledError(e.details(), e.code(), not_enabled.feature_flag) + elif details.Is(error_pb2.CancellationAlreadyRequestedError.DESCRIPTOR): + return error.CancellationAlreadyRequestedError(e.details(), e.code()) + elif details.Is(error_pb2.DomainAlreadyExistsError.DESCRIPTOR): + return error.DomainAlreadyExistsError(e.details(), e.code()) + elif details.Is(error_pb2.LimitExceededError.DESCRIPTOR): + return error.LimitExceededError(e.details(), e.code()) + elif details.Is(error_pb2.QueryFailedError.DESCRIPTOR): + return error.QueryFailedError(e.details(), e.code()) + elif details.Is(error_pb2.ServiceBusyError.DESCRIPTOR): + service_busy = error_pb2.ServiceBusyError() + details.Unpack(service_busy) + return error.ServiceBusyError(e.details(), e.code(), service_busy.reason) + else: + return error.CadenceError(e.details(), e.code()) + diff --git a/cadence/client.py b/cadence/client.py index 7feb242..3f085dd 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -4,6 +4,7 @@ from grpc import ChannelCredentials, Compression +from cadence._internal.rpc.error import CadenceErrorInterceptor from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel @@ -75,6 +76,7 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: def _create_channel(options: ClientOptions) -> Channel: interceptors = list(options["interceptors"]) + interceptors.append(CadenceErrorInterceptor()) interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"])) if options["credentials"]: diff --git a/cadence/error.py b/cadence/error.py new file mode 100644 index 0000000..a7ea5fd --- /dev/null +++ b/cadence/error.py @@ -0,0 +1,65 @@ +import grpc + + +class CadenceError(Exception): + + def __init__(self, message: str, code: grpc.StatusCode, *args): + super().__init__(message, code, *args) + self.code = code + pass + + +class WorkflowExecutionAlreadyStartedError(CadenceError): + + def __init__(self, message: str, code: grpc.StatusCode, start_request_id: str, run_id: str) -> None: + super().__init__(message, code, start_request_id, run_id) + self.start_request_id = start_request_id + self.run_id = run_id + +class EntityNotExistsError(CadenceError): + + def __init__(self, message: str, code: grpc.StatusCode, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None: + super().__init__(message, code, current_cluster, active_cluster, active_clusters) + self.current_cluster = current_cluster + self.active_cluster = active_cluster + self.active_clusters = active_clusters + +class WorkflowExecutionAlreadyCompletedError(CadenceError): + pass + +class DomainNotActiveError(CadenceError): + def __init__(self, message: str, code: grpc.StatusCode, domain: str, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None: + super().__init__(message, code, domain, current_cluster, active_cluster, active_clusters) + self.domain = domain + self.current_cluster = current_cluster + self.active_cluster = active_cluster + self.active_clusters = active_clusters + +class ClientVersionNotSupportedError(CadenceError): + def __init__(self, message: str, code: grpc.StatusCode, feature_version: str, client_impl: str, supported_versions: str) -> None: + super().__init__(message, code, feature_version, client_impl, supported_versions) + self.feature_version = feature_version + self.client_impl = client_impl + self.supported_versions = supported_versions + +class FeatureNotEnabledError(CadenceError): + def __init__(self, message: str, code: grpc.StatusCode, feature_flag: str) -> None: + super().__init__(message, code, feature_flag) + self.feature_flag = feature_flag + +class CancellationAlreadyRequestedError(CadenceError): + pass + +class DomainAlreadyExistsError(CadenceError): + pass + +class LimitExceededError(CadenceError): + pass + +class QueryFailedError(CadenceError): + pass + +class ServiceBusyError(CadenceError): + def __init__(self, message: str, code: grpc.StatusCode, reason: str) -> None: + super().__init__(message, code, reason) + self.reason = reason diff --git a/pyproject.toml b/pyproject.toml index de3ecc2..9b3a27e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ requires-python = ">=3.11,<3.14" dependencies = [ "grpcio==1.71.2", + "grpcio-status>=1.71.2", "msgspec>=0.19.0", "protobuf==5.29.1", "typing-extensions>=4.0.0", diff --git a/tests/cadence/_internal/rpc/test_error.py b/tests/cadence/_internal/rpc/test_error.py new file mode 100644 index 0000000..8ca0c3e --- /dev/null +++ b/tests/cadence/_internal/rpc/test_error.py @@ -0,0 +1,122 @@ +from concurrent import futures + +import pytest +from google.protobuf import any_pb2 +from google.rpc import code_pb2, status_pb2 +from grpc import Status, StatusCode, server +from grpc.aio import insecure_channel +from grpc_status.rpc_status import to_status + +from cadence._internal.rpc.error import CadenceErrorInterceptor +from cadence.api.v1 import error_pb2, service_meta_pb2_grpc +from cadence import error +from google.protobuf.message import Message + +from cadence.api.v1.service_meta_pb2 import HealthRequest, HealthResponse +from cadence.error import CadenceError + + +class FakeService(service_meta_pb2_grpc.MetaAPIServicer): + def __init__(self) -> None: + super().__init__() + self.status: Status | None = None + self.port: int | None = None + + def Health(self, request, context): + if temp := self.status: + self.status = None + context.abort_with_status(temp) + return HealthResponse(ok=True) + + +@pytest.fixture(scope="module") +def fake_service(): + fake = FakeService() + sync_server = server(futures.ThreadPoolExecutor(max_workers=1)) + service_meta_pb2_grpc.add_MetaAPIServicer_to_server(fake, sync_server) + fake.port = sync_server.add_insecure_port("[::]:0") + sync_server.start() + yield fake + sync_server.stop(grace=None) + +@pytest.mark.usefixtures("fake_service") +@pytest.mark.parametrize( + "err,expected", + [ + pytest.param(None, None,id="no error"), + pytest.param( + error_pb2.WorkflowExecutionAlreadyStartedError(start_request_id="start_request", run_id="run_id"), + error.WorkflowExecutionAlreadyStartedError(message="message", code=StatusCode.INVALID_ARGUMENT, start_request_id="start_request", run_id="run_id"), + id="WorkflowExecutionAlreadyStartedError"), + pytest.param( + error_pb2.EntityNotExistsError(current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), + error.EntityNotExistsError(message="message", code=StatusCode.INVALID_ARGUMENT, current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), + id="EntityNotExistsError"), + pytest.param( + error_pb2.WorkflowExecutionAlreadyCompletedError(), + error.WorkflowExecutionAlreadyCompletedError(message="message", code=StatusCode.INVALID_ARGUMENT), + id="WorkflowExecutionAlreadyCompletedError"), + pytest.param( + error_pb2.DomainNotActiveError(domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), + error.DomainNotActiveError(message="message", code=StatusCode.INVALID_ARGUMENT, domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), + id="DomainNotActiveError"), + pytest.param( + error_pb2.ClientVersionNotSupportedError(feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"), + error.ClientVersionNotSupportedError(message="message", code=StatusCode.INVALID_ARGUMENT, feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"), + id="ClientVersionNotSupportedError"), + pytest.param( + error_pb2.FeatureNotEnabledError(feature_flag="feature_flag"), + error.FeatureNotEnabledError(message="message", code=StatusCode.INVALID_ARGUMENT,feature_flag="feature_flag"), + id="FeatureNotEnabledError"), + pytest.param( + error_pb2.CancellationAlreadyRequestedError(), + error.CancellationAlreadyRequestedError(message="message", code=StatusCode.INVALID_ARGUMENT), + id="CancellationAlreadyRequestedError"), + pytest.param( + error_pb2.DomainAlreadyExistsError(), + error.DomainAlreadyExistsError(message="message", code=StatusCode.INVALID_ARGUMENT), + id="DomainAlreadyExistsError"), + pytest.param( + error_pb2.LimitExceededError(), + error.LimitExceededError(message="message", code=StatusCode.INVALID_ARGUMENT), + id="LimitExceededError"), + pytest.param( + error_pb2.QueryFailedError(), + error.QueryFailedError(message="message", code=StatusCode.INVALID_ARGUMENT), + id="QueryFailedError"), + pytest.param( + error_pb2.ServiceBusyError(reason="reason"), + error.ServiceBusyError(message="message", code=StatusCode.INVALID_ARGUMENT, reason="reason"), + id="ServiceBusyError"), + pytest.param( + to_status(status_pb2.Status(code=code_pb2.PERMISSION_DENIED, message="no permission")), + error.CadenceError(message="no permission", code=StatusCode.PERMISSION_DENIED), + id="unknown error type"), + ] +) +@pytest.mark.asyncio +async def test_map_error(fake_service, err: Message | Status, expected: CadenceError): + async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[CadenceErrorInterceptor()]) as channel: + stub = service_meta_pb2_grpc.MetaAPIStub(channel) + if expected is None: + response = await stub.Health(HealthRequest(), timeout=1) + assert response == HealthResponse(ok=True) + else: + if isinstance(err, Message): + fake_service.status = details_to_status(err) + else: + fake_service.status = err + with pytest.raises(type(expected)) as exc_info: + await stub.Health(HealthRequest(), timeout=1) + assert exc_info.value.args == expected.args + +def details_to_status(message: Message) -> Status: + detail = any_pb2.Any() + detail.Pack(message) + status_proto = status_pb2.Status( + code=code_pb2.INVALID_ARGUMENT, + message="message", + details=[detail], + ) + return to_status(status_proto) + diff --git a/uv.lock b/uv.lock index cf85812..90d9b36 100644 --- a/uv.lock +++ b/uv.lock @@ -153,6 +153,7 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "grpcio" }, + { name = "grpcio-status" }, { name = "msgspec" }, { name = "protobuf" }, { name = "typing-extensions" }, @@ -186,6 +187,7 @@ requires-dist = [ { name = "black", marker = "extra == 'dev'", specifier = ">=23.0.0" }, { name = "flake8", marker = "extra == 'dev'", specifier = ">=6.0.0" }, { name = "grpcio", specifier = "==1.71.2" }, + { name = "grpcio-status", specifier = ">=1.71.2" }, { name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.71.2" }, { name = "isort", marker = "extra == 'dev'", specifier = ">=5.12.0" }, { name = "msgspec", specifier = ">=0.19.0" }, @@ -466,6 +468,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/45/b82e3c16be2182bff01179db177fe144d58b5dc787a7d4492c6ed8b9317f/frozenlist-1.7.0-py3-none-any.whl", hash = "sha256:9a5af342e34f7e97caf8c995864c7a396418ae2859cc6fdf1b1073020d516a7e", size = 13106, upload-time = "2025-06-09T23:02:34.204Z" }, ] +[[package]] +name = "googleapis-common-protos" +version = "1.70.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/24/33db22342cf4a2ea27c9955e6713140fedd51e8b141b5ce5260897020f1a/googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257", size = 145903, upload-time = "2025-04-14T10:17:02.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, +] + [[package]] name = "grpcio" version = "1.71.2" @@ -504,6 +518,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/63/8de0b14892c07aad98f61bf140ff95c5f51086e058ae6da40599d60f0a04/grpcio-1.71.2-cp313-cp313-win_amd64.whl", hash = "sha256:54a9bdd5f94ce1512e3cc37f2f84a776cfbaa07222764129ebc2a54f803ebd70", size = 4211232, upload-time = "2025-06-28T04:19:30.95Z" }, ] +[[package]] +name = "grpcio-status" +version = "1.71.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "grpcio" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/d1/b6e9877fedae3add1afdeae1f89d1927d296da9cf977eca0eb08fb8a460e/grpcio_status-1.71.2.tar.gz", hash = "sha256:c7a97e176df71cdc2c179cd1847d7fc86cca5832ad12e9798d7fed6b7a1aab50", size = 13677, upload-time = "2025-06-28T04:24:05.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/58/317b0134129b556a93a3b0afe00ee675b5657f0155509e22fcb853bafe2d/grpcio_status-1.71.2-py3-none-any.whl", hash = "sha256:803c98cb6a8b7dc6dbb785b1111aed739f241ab5e9da0bba96888aa74704cfd3", size = 14424, upload-time = "2025-06-28T04:23:42.136Z" }, +] + [[package]] name = "grpcio-tools" version = "1.71.2"