From c61af8ca984039ade894b081dc967bc7d3822af3 Mon Sep 17 00:00:00 2001 From: Nate Mortensen Date: Wed, 27 Aug 2025 15:58:34 -0700 Subject: [PATCH] Add initial ActivityExecutor Create a basic activity executor capable of running functions as activities. Async functions are run as part of the main event loop, while sync ones are run in a threaadpool executor. Type hints are expected for any parameters the activity takes, with non-typed parameters receiving the deserialized json value as a standard python type (dict, list, string, etc.). Plenty of work remains for full activity support, such as maintaining an activity context for heartbeating. --- cadence/_internal/activity/__init__.py | 9 + .../_internal/activity/_activity_executor.py | 81 +++++++++ cadence/_internal/type_utils.py | 19 ++ cadence/client.py | 6 + cadence/data_converter.py | 15 +- cadence/sample/client_example.py | 4 +- cadence/worker/_activity.py | 18 +- cadence/worker/_registry.py | 2 + cadence/worker/_worker.py | 5 +- .../activity/test_activity_executor.py | 172 ++++++++++++++++++ tests/cadence/_internal/test_type_utils.py | 70 +++++++ tests/cadence/data_converter_test.py | 13 ++ tests/cadence/worker/test_worker.py | 4 +- 13 files changed, 397 insertions(+), 21 deletions(-) create mode 100644 cadence/_internal/activity/__init__.py create mode 100644 cadence/_internal/activity/_activity_executor.py create mode 100644 cadence/_internal/type_utils.py create mode 100644 tests/cadence/_internal/activity/test_activity_executor.py create mode 100644 tests/cadence/_internal/test_type_utils.py diff --git a/cadence/_internal/activity/__init__.py b/cadence/_internal/activity/__init__.py new file mode 100644 index 0000000..073d53c --- /dev/null +++ b/cadence/_internal/activity/__init__.py @@ -0,0 +1,9 @@ + + +from ._activity_executor import ( + ActivityExecutor +) + +__all__ = [ + "ActivityExecutor", +] diff --git a/cadence/_internal/activity/_activity_executor.py b/cadence/_internal/activity/_activity_executor.py new file mode 100644 index 0000000..64c3c52 --- /dev/null +++ b/cadence/_internal/activity/_activity_executor.py @@ -0,0 +1,81 @@ +import asyncio +import inspect +from concurrent.futures import ThreadPoolExecutor +from logging import getLogger +from traceback import format_exception +from typing import Any, Callable + +from cadence._internal.type_utils import get_fn_parameters +from cadence.api.v1.common_pb2 import Failure +from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \ + RespondActivityTaskCompletedRequest +from cadence.client import Client + +_logger = getLogger(__name__) + +class ActivityExecutor: + def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], Callable]): + self._client = client + self._data_converter = client.data_converter + self._registry = registry + self._identity = identity + self._thread_pool = ThreadPoolExecutor(max_workers=max_workers, + thread_name_prefix=f'{task_list}-activity-') + + async def execute(self, task: PollForActivityTaskResponse): + activity_type = task.activity_type.name + try: + activity_fn = self._registry(activity_type) + except KeyError as e: + _logger.error("Activity type not found.", extra={'activity_type': activity_type}) + await self._report_failure(task, e) + return + + await self._execute_fn(activity_fn, task) + + async def _execute_fn(self, activity_fn: Callable[[Any], Any], task: PollForActivityTaskResponse): + try: + type_hints = get_fn_parameters(activity_fn) + params = await self._client.data_converter.from_data(task.input, type_hints) + if inspect.iscoroutinefunction(activity_fn): + result = await activity_fn(*params) + else: + result = await self._invoke_sync_activity(activity_fn, params) + await self._report_success(task, result) + except Exception as e: + await self._report_failure(task, e) + + async def _invoke_sync_activity(self, activity_fn: Callable[[Any], Any], params: list[Any]) -> Any: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._thread_pool, activity_fn, *params) + + async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception): + try: + await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest( + task_token=task.task_token, + failure=_to_failure(error), + identity=self._identity, + )) + except Exception: + _logger.exception('Exception reporting activity failure') + + async def _report_success(self, task: PollForActivityTaskResponse, result: Any): + as_payload = await self._data_converter.to_data(result) + + try: + await self._client.worker_stub.RespondActivityTaskCompleted(RespondActivityTaskCompletedRequest( + task_token=task.task_token, + result=as_payload, + identity=self._identity, + )) + except Exception: + _logger.exception('Exception reporting activity complete') + + +def _to_failure(exception: Exception) -> Failure: + stacktrace = "".join(format_exception(exception)) + + return Failure( + reason=type(exception).__name__, + details=stacktrace.encode(), + ) diff --git a/cadence/_internal/type_utils.py b/cadence/_internal/type_utils.py new file mode 100644 index 0000000..84fd07c --- /dev/null +++ b/cadence/_internal/type_utils.py @@ -0,0 +1,19 @@ +from inspect import signature, Parameter +from typing import Callable, List, Type, get_type_hints + +def get_fn_parameters(fn: Callable) -> List[Type | None]: + args = signature(fn).parameters + hints = get_type_hints(fn) + result = [] + for name, param in args.items(): + if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): + type_hint = hints.get(name, None) + result.append(type_hint) + + return result + +def validate_fn_parameters(fn: Callable) -> None: + args = signature(fn).parameters + for name, param in args.items(): + if param.kind not in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): + raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid") \ No newline at end of file diff --git a/cadence/client.py b/cadence/client.py index bfca66b..ef5f542 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -5,10 +5,13 @@ from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub from grpc.aio import Channel +from cadence.data_converter import DataConverter + class ClientOptions(TypedDict, total=False): domain: str identity: str + data_converter: DataConverter class Client: def __init__(self, channel: Channel, options: ClientOptions) -> None: @@ -17,6 +20,9 @@ def __init__(self, channel: Channel, options: ClientOptions) -> None: self._options = options self._identity = options["identity"] if "identity" in options else f"{os.getpid()}@{socket.gethostname()}" + @property + def data_converter(self) -> DataConverter: + return self._options["data_converter"] @property def domain(self) -> str: diff --git a/cadence/data_converter.py b/cadence/data_converter.py index ca54712..e88680f 100644 --- a/cadence/data_converter.py +++ b/cadence/data_converter.py @@ -9,7 +9,7 @@ class DataConverter(Protocol): @abstractmethod - async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]: + async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]: raise NotImplementedError() @abstractmethod @@ -23,7 +23,10 @@ def __init__(self) -> None: self._fallback_decoder = JSONDecoder(strict=False) - async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]: + async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]: + if not payload.data: + return DefaultDataConverter._convert_into([], type_hints) + if len(type_hints) > 1: payload_str = payload.data.decode() # Handle payloads from the Go client, which are a series of json objects rather than a json array @@ -37,7 +40,7 @@ async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any] return DefaultDataConverter._convert_into([as_value], type_hints) - def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) -> List[Any]: + def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type | None]) -> List[Any]: results: List[Any] = [] start, end = 0, len(payload) while start < end and len(results) < len(type_hints): @@ -49,10 +52,12 @@ def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) -> return DefaultDataConverter._convert_into(results, type_hints) @staticmethod - def _convert_into(values: List[Any], type_hints: List[Type]) -> List[Any]: + def _convert_into(values: List[Any], type_hints: List[Type | None]) -> List[Any]: results: List[Any] = [] for i, type_hint in enumerate(type_hints): - if i < len(values): + if not type_hint: + value = values[i] + elif i < len(values): value = convert(values[i], type_hint) else: value = DefaultDataConverter._get_default(type_hint) diff --git a/cadence/sample/client_example.py b/cadence/sample/client_example.py index 64b9be2..556691c 100644 --- a/cadence/sample/client_example.py +++ b/cadence/sample/client_example.py @@ -4,7 +4,7 @@ from cadence.client import Client, ClientOptions from cadence._internal.rpc.metadata import MetadataInterceptor -from cadence.worker import Worker +from cadence.worker import Worker, Registry async def main(): @@ -15,7 +15,7 @@ async def main(): metadata["rpc-caller"] = "nate" async with insecure_channel("localhost:7833", interceptors=[MetadataInterceptor(metadata)]) as channel: client = Client(channel, ClientOptions(domain="foo")) - worker = Worker(client, "task_list") + worker = Worker(client, "task_list", Registry()) await worker.run() if __name__ == '__main__': diff --git a/cadence/worker/_activity.py b/cadence/worker/_activity.py index ec1ef7e..2e7591f 100644 --- a/cadence/worker/_activity.py +++ b/cadence/worker/_activity.py @@ -1,21 +1,23 @@ import asyncio from typing import Optional -from cadence.api.v1.common_pb2 import Failure -from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, PollForActivityTaskRequest, \ - RespondActivityTaskFailedRequest +from cadence._internal.activity import ActivityExecutor +from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, PollForActivityTaskRequest from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind from cadence.client import Client +from cadence.worker._registry import Registry from cadence.worker._types import WorkerOptions, _LONG_POLL_TIMEOUT from cadence.worker._poller import Poller class ActivityWorker: - def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None: + def __init__(self, client: Client, task_list: str, registry: Registry, options: WorkerOptions) -> None: self._client = client self._task_list = task_list self._identity = options["identity"] - permits = asyncio.Semaphore(options["max_concurrent_activity_execution_size"]) + max_concurrent = options["max_concurrent_activity_execution_size"] + permits = asyncio.Semaphore(max_concurrent) + self._executor = ActivityExecutor(self._client, self._task_list, options["identity"], max_concurrent, registry.get_activity) self._poller = Poller[PollForActivityTaskResponse](options["activity_task_pollers"], permits, self._poll, self._execute) # TODO: Local dispatch, local activities, actually running activities, etc @@ -35,9 +37,5 @@ async def _poll(self) -> Optional[PollForActivityTaskResponse]: return None async def _execute(self, task: PollForActivityTaskResponse) -> None: - await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest( - task_token=task.task_token, - identity=self._identity, - failure=Failure(reason='error', details=b'not implemented'), - )) + await self._executor.execute(task) diff --git a/cadence/worker/_registry.py b/cadence/worker/_registry.py index 4ba0972..1f5d03f 100644 --- a/cadence/worker/_registry.py +++ b/cadence/worker/_registry.py @@ -8,6 +8,7 @@ import logging from typing import Callable, Dict, Optional, Unpack, TypedDict +from cadence._internal.type_utils import validate_fn_parameters logger = logging.getLogger(__name__) @@ -107,6 +108,7 @@ def activity( options = RegisterActivityOptions(**kwargs) def decorator(f: Callable) -> Callable: + validate_fn_parameters(f) activity_name = options.get('name') or f.__name__ if activity_name in self._activities: diff --git a/cadence/worker/_worker.py b/cadence/worker/_worker.py index 8d8932a..70ce364 100644 --- a/cadence/worker/_worker.py +++ b/cadence/worker/_worker.py @@ -3,6 +3,7 @@ from typing import Unpack, cast from cadence.client import Client +from cadence.worker._registry import Registry from cadence.worker._activity import ActivityWorker from cadence.worker._decision import DecisionWorker from cadence.worker._types import WorkerOptions, _DEFAULT_WORKER_OPTIONS @@ -10,14 +11,14 @@ class Worker: - def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]) -> None: + def __init__(self, client: Client, task_list: str, registry: Registry, **kwargs: Unpack[WorkerOptions]) -> None: self._client = client self._task_list = task_list options = WorkerOptions(**kwargs) _validate_and_copy_defaults(client, task_list, options) self._options = options - self._activity_worker = ActivityWorker(client, task_list, options) + self._activity_worker = ActivityWorker(client, task_list, registry, options) self._decision_worker = DecisionWorker(client, task_list, options) diff --git a/tests/cadence/_internal/activity/test_activity_executor.py b/tests/cadence/_internal/activity/test_activity_executor.py new file mode 100644 index 0000000..31e29ce --- /dev/null +++ b/tests/cadence/_internal/activity/test_activity_executor.py @@ -0,0 +1,172 @@ +import asyncio +from unittest.mock import Mock, AsyncMock, PropertyMock + +import pytest + +from cadence import Client +from cadence._internal.activity import ActivityExecutor +from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure +from cadence.api.v1.service_worker_pb2 import RespondActivityTaskCompletedResponse, PollForActivityTaskResponse, \ + RespondActivityTaskCompletedRequest, RespondActivityTaskFailedResponse, RespondActivityTaskFailedRequest +from cadence.data_converter import DefaultDataConverter + + +@pytest.fixture +def client() -> Client: + client = Mock(spec=Client) + client.worker_stub = AsyncMock() + type(client).data_converter = PropertyMock(return_value=DefaultDataConverter()) + return client + + +@pytest.mark.asyncio +async def test_activity_async_success(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + + async def activity_fn(): + return "success" + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + + await executor.execute(fake_task("any", "")) + + worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( + task_token=b'task_token', + result=Payload(data='"success"'.encode()), + identity='identity', + )) + +@pytest.mark.asyncio +async def test_activity_async_failure(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) + + async def activity_fn(): + raise KeyError("failure") + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + + await executor.execute(fake_task("any", "")) + + worker_stub.RespondActivityTaskFailed.assert_called_once() + + call = worker_stub.RespondActivityTaskFailed.call_args[0][0] + + # Confirm it's a stacktrace, then clear it since it is different on every machine + assert 'raise KeyError("failure")' in call.failure.details.decode() + call.failure.details = bytes() + assert call == RespondActivityTaskFailedRequest( + task_token=b'task_token', + failure=Failure( + reason="KeyError", + ), + identity='identity', + ) + +@pytest.mark.asyncio +async def test_activity_args(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + + async def activity_fn(first: str, second: str): + return " ".join([first, second]) + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + + await executor.execute(fake_task("any", '["hello", "world"]')) + + worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( + task_token=b'task_token', + result=Payload(data='"hello world"'.encode()), + identity='identity', + )) + + +@pytest.mark.asyncio +async def test_activity_sync_success(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + + def activity_fn(): + try: + asyncio.get_running_loop() + except RuntimeError: + return "success" + raise RuntimeError("expected to be running outside of the event loop") + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + + await executor.execute(fake_task("any", "")) + + worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( + task_token=b'task_token', + result=Payload(data='"success"'.encode()), + identity='identity', + )) + +@pytest.mark.asyncio +async def test_activity_sync_failure(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) + + def activity_fn(): + raise KeyError("failure") + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + + await executor.execute(fake_task("any", "")) + + worker_stub.RespondActivityTaskFailed.assert_called_once() + + call = worker_stub.RespondActivityTaskFailed.call_args[0][0] + + # Confirm it's a stacktrace, then clear it since it is different on every machine + assert 'raise KeyError("failure")' in call.failure.details.decode() + call.failure.details = bytes() + assert call == RespondActivityTaskFailedRequest( + task_token=b'task_token', + failure=Failure( + reason="KeyError", + ), + identity='identity', + ) + +@pytest.mark.asyncio +async def test_activity_unknown(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) + + def registry(name: str): + raise KeyError(f"unknown activity: {name}") + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, registry) + + await executor.execute(fake_task("any", "")) + + worker_stub.RespondActivityTaskFailed.assert_called_once() + + call = worker_stub.RespondActivityTaskFailed.call_args[0][0] + + assert 'unknown activity: any' in call.failure.details.decode() + call.failure.details = bytes() + assert call == RespondActivityTaskFailedRequest( + task_token=b'task_token', + failure=Failure( + reason="KeyError", + ), + identity='identity', + ) + +def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskResponse: + return PollForActivityTaskResponse( + task_token=b'task_token', + workflow_execution=WorkflowExecution( + workflow_id="workflow_id", + run_id="run_id", + ), + activity_id="activity_id", + activity_type=ActivityType(name=activity_type), + input=Payload(data=input_json.encode()), + attempt=0, + ) \ No newline at end of file diff --git a/tests/cadence/_internal/test_type_utils.py b/tests/cadence/_internal/test_type_utils.py new file mode 100644 index 0000000..9e35e81 --- /dev/null +++ b/tests/cadence/_internal/test_type_utils.py @@ -0,0 +1,70 @@ +from typing import Callable, Type + +import pytest + +from cadence._internal.type_utils import get_fn_parameters, validate_fn_parameters + + +def _single_param(name: str): + ... + +def _multiple_param(name: str, other: 'str'): + ... + +def _with_args(name:str, *args): + ... + +def _with_kwargs(name:str, **kwargs): + ... + +def _strictly_positional(name: str, other: str, *args, **kwargs): + ... + +def _keyword_only(*args, foo: str): + ... + + +@pytest.mark.parametrize( + "fn,expected", + [ + pytest.param( + _single_param, [str], id="single param" + ), + pytest.param( + _multiple_param, [str, str], id="multiple param" + ), + pytest.param( + _strictly_positional, [str, str], id="strictly positional" + ), + pytest.param( + _keyword_only, [], id="keyword only" + ), + ] +) +def test_get_fn_parameters(fn: Callable, expected: list[Type]): + params = get_fn_parameters(fn) + assert params == expected + +@pytest.mark.parametrize( + "fn,expected", + [ + pytest.param( + _single_param, None, id="single param" + ), + pytest.param( + _multiple_param, None, id="multiple param" + ), + pytest.param( + _with_args, ValueError, id="with args" + ), + pytest.param( + _with_kwargs, ValueError, id="with kwargs" + ), + ] +) +def test_validate_fn_parameters(fn: Callable, expected: Type[Exception]): + if expected: + with pytest.raises(expected): + validate_fn_parameters(fn) + else: + validate_fn_parameters(fn) \ No newline at end of file diff --git a/tests/cadence/data_converter_test.py b/tests/cadence/data_converter_test.py index 88aecd3..cbc7ba6 100644 --- a/tests/cadence/data_converter_test.py +++ b/tests/cadence/data_converter_test.py @@ -60,6 +60,19 @@ class _TestDataClass: '[{"foo": "bar"},{"bar": 100},["hello"],"world"]', [_TestDataClass, _TestDataClass, list[str], str], [_TestDataClass(foo="bar"), _TestDataClass(bar=100), ["hello"], "world"], id="json array mix" ), + pytest.param( + "", [], [], id="no input expected" + ), + pytest.param( + "", [str], [None], id="no input unexpected" + ), + pytest.param( + '["hello world", {"foo":"bar"}, 7]', [None, None, None], ["hello world", {"foo":"bar"}, 7], id="no type hints" + ), + pytest.param( + '"hello" "world" "goodbye"', [str, str], ["hello", "world"], + id="extra content" + ), ] ) @pytest.mark.asyncio diff --git a/tests/cadence/worker/test_worker.py b/tests/cadence/worker/test_worker.py index 5a3667e..951ae6d 100644 --- a/tests/cadence/worker/test_worker.py +++ b/tests/cadence/worker/test_worker.py @@ -7,7 +7,7 @@ from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskRequest, PollForActivityTaskRequest from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind from cadence.client import Client -from cadence.worker import Worker +from cadence.worker import Worker, Registry @pytest.mark.asyncio @@ -29,7 +29,7 @@ async def poll(_, timeout=0.0): type(client).domain = PropertyMock(return_value="domain") type(client).identity = PropertyMock(return_value="identity") - worker = Worker(client, "task_list", activity_task_pollers=1, decision_task_pollers=1, identity="identity") + worker = Worker(client, "task_list", Registry(), activity_task_pollers=1, decision_task_pollers=1, identity="identity") task = asyncio.create_task(worker.run())