From c9866a0bafb90dad745e9a24232da39d0ad4cdb8 Mon Sep 17 00:00:00 2001 From: Nate Mortensen Date: Tue, 9 Sep 2025 08:42:06 -0700 Subject: [PATCH] Add ActivityContext User code needs to be able to retrieve information about the running activity, and in the future will be able to interact with it (heartbeating). Create the initial structure for providing it. --- .../_internal/activity/_activity_executor.py | 59 +++++++----- cadence/_internal/activity/_context.py | 48 ++++++++++ cadence/activity.py | 61 +++++++++++++ .../activity/test_activity_executor.py | 91 ++++++++++++++++--- 4 files changed, 225 insertions(+), 34 deletions(-) create mode 100644 cadence/_internal/activity/_context.py create mode 100644 cadence/activity.py diff --git a/cadence/_internal/activity/_activity_executor.py b/cadence/_internal/activity/_activity_executor.py index 64c3c52..f9efba0 100644 --- a/cadence/_internal/activity/_activity_executor.py +++ b/cadence/_internal/activity/_activity_executor.py @@ -1,11 +1,13 @@ -import asyncio import inspect from concurrent.futures import ThreadPoolExecutor from logging import getLogger from traceback import format_exception from typing import Any, Callable +from google.protobuf.duration import to_timedelta +from google.protobuf.timestamp import to_datetime -from cadence._internal.type_utils import get_fn_parameters +from cadence._internal.activity._context import _Context, _SyncContext +from cadence.activity import ActivityInfo from cadence.api.v1.common_pb2 import Failure from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \ RespondActivityTaskCompletedRequest @@ -19,35 +21,31 @@ def __init__(self, client: Client, task_list: str, identity: str, max_workers: i self._data_converter = client.data_converter self._registry = registry self._identity = identity + self._task_list = task_list self._thread_pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix=f'{task_list}-activity-') async def execute(self, task: PollForActivityTaskResponse): - activity_type = task.activity_type.name - try: - activity_fn = self._registry(activity_type) - except KeyError as e: - _logger.error("Activity type not found.", extra={'activity_type': activity_type}) - await self._report_failure(task, e) - return - - await self._execute_fn(activity_fn, task) - - async def _execute_fn(self, activity_fn: Callable[[Any], Any], task: PollForActivityTaskResponse): try: - type_hints = get_fn_parameters(activity_fn) - params = await self._client.data_converter.from_data(task.input, type_hints) - if inspect.iscoroutinefunction(activity_fn): - result = await activity_fn(*params) - else: - result = await self._invoke_sync_activity(activity_fn, params) + context = self._create_context(task) + result = await context.execute(task.input) await self._report_success(task, result) except Exception as e: await self._report_failure(task, e) - async def _invoke_sync_activity(self, activity_fn: Callable[[Any], Any], params: list[Any]) -> Any: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self._thread_pool, activity_fn, *params) + def _create_context(self, task: PollForActivityTaskResponse) -> _Context: + activity_type = task.activity_type.name + try: + activity_fn = self._registry(activity_type) + except KeyError: + raise KeyError(f"Activity type not found: {activity_type}") from None + + info = self._create_info(task) + + if inspect.iscoroutinefunction(activity_fn): + return _Context(self._client, info, activity_fn) + else: + return _SyncContext(self._client, info, activity_fn, self._thread_pool) async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception): try: @@ -71,6 +69,23 @@ async def _report_success(self, task: PollForActivityTaskResponse, result: Any): except Exception: _logger.exception('Exception reporting activity complete') + def _create_info(self, task: PollForActivityTaskResponse) -> ActivityInfo: + return ActivityInfo( + task_token=task.task_token, + workflow_type=task.workflow_type.name, + workflow_domain=task.workflow_domain, + workflow_id=task.workflow_execution.workflow_id, + workflow_run_id=task.workflow_execution.run_id, + activity_id=task.activity_id, + activity_type=task.activity_type.name, + task_list=self._task_list, + heartbeat_timeout=to_timedelta(task.heartbeat_timeout), + scheduled_timestamp=to_datetime(task.scheduled_time), + started_timestamp=to_datetime(task.started_time), + start_to_close_timeout=to_timedelta(task.start_to_close_timeout), + attempt=task.attempt, + ) + def _to_failure(exception: Exception) -> Failure: stacktrace = "".join(format_exception(exception)) diff --git a/cadence/_internal/activity/_context.py b/cadence/_internal/activity/_context.py new file mode 100644 index 0000000..208b859 --- /dev/null +++ b/cadence/_internal/activity/_context.py @@ -0,0 +1,48 @@ +import asyncio +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Callable, Any + +from cadence import Client +from cadence._internal.type_utils import get_fn_parameters +from cadence.activity import ActivityInfo, ActivityContext +from cadence.api.v1.common_pb2 import Payload + + +class _Context(ActivityContext): + def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any]): + self._client = client + self._info = info + self._activity_fn = activity_fn + + async def execute(self, payload: Payload) -> Any: + params = await self._to_params(payload) + with self._activate(): + return await self._activity_fn(*params) + + async def _to_params(self, payload: Payload) -> list[Any]: + type_hints = get_fn_parameters(self._activity_fn) + return await self._client.data_converter.from_data(payload, type_hints) + + def client(self) -> Client: + return self._client + + def info(self) -> ActivityInfo: + return self._info + +class _SyncContext(_Context): + def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any], executor: ThreadPoolExecutor): + super().__init__(client, info, activity_fn) + self._executor = executor + + async def execute(self, payload: Payload) -> Any: + params = await self._to_params(payload) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._executor, self._run, params) + + def _run(self, args: list[Any]) -> Any: + with self._activate(): + return self._activity_fn(*args) + + def client(self) -> Client: + raise RuntimeError("client is only supported in async activities") + diff --git a/cadence/activity.py b/cadence/activity.py new file mode 100644 index 0000000..0f71fb0 --- /dev/null +++ b/cadence/activity.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from datetime import timedelta, datetime +from typing import Iterator + +from cadence import Client + + +@dataclass(frozen=True) +class ActivityInfo: + task_token: bytes + workflow_type: str + workflow_domain: str + workflow_id: str + workflow_run_id: str + activity_id: str + activity_type: str + task_list: str + heartbeat_timeout: timedelta + scheduled_timestamp: datetime + started_timestamp: datetime + start_to_close_timeout: timedelta + attempt: int + +def client() -> Client: + return ActivityContext.get().client() + +def in_activity() -> bool: + return ActivityContext.is_set() + +def info() -> ActivityInfo: + return ActivityContext.get().info() + + + +class ActivityContext(ABC): + _var: ContextVar['ActivityContext'] = ContextVar("activity") + + @abstractmethod + def info(self) -> ActivityInfo: + ... + + @abstractmethod + def client(self) -> Client: + ... + + @contextmanager + def _activate(self) -> Iterator[None]: + token = ActivityContext._var.set(self) + yield None + ActivityContext._var.reset(token) + + @staticmethod + def is_set() -> bool: + return ActivityContext._var.get(None) is not None + + @staticmethod + def get() -> 'ActivityContext': + return ActivityContext._var.get() diff --git a/tests/cadence/_internal/activity/test_activity_executor.py b/tests/cadence/_internal/activity/test_activity_executor.py index 31e29ce..89b95e2 100644 --- a/tests/cadence/_internal/activity/test_activity_executor.py +++ b/tests/cadence/_internal/activity/test_activity_executor.py @@ -1,11 +1,15 @@ import asyncio +from datetime import timedelta, datetime from unittest.mock import Mock, AsyncMock, PropertyMock import pytest +from google.protobuf.timestamp_pb2 import Timestamp +from google.protobuf.duration import from_timedelta -from cadence import Client +from cadence import activity, Client from cadence._internal.activity import ActivityExecutor -from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure +from cadence.activity import ActivityInfo +from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure, WorkflowType from cadence.api.v1.service_worker_pb2 import RespondActivityTaskCompletedResponse, PollForActivityTaskResponse, \ RespondActivityTaskCompletedRequest, RespondActivityTaskFailedResponse, RespondActivityTaskFailedRequest from cadence.data_converter import DefaultDataConverter @@ -19,7 +23,6 @@ def client() -> Client: return client -@pytest.mark.asyncio async def test_activity_async_success(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) @@ -37,7 +40,6 @@ async def activity_fn(): identity='identity', )) -@pytest.mark.asyncio async def test_activity_async_failure(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) @@ -64,7 +66,6 @@ async def activity_fn(): identity='identity', ) -@pytest.mark.asyncio async def test_activity_args(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) @@ -82,8 +83,6 @@ async def activity_fn(first: str, second: str): identity='identity', )) - -@pytest.mark.asyncio async def test_activity_sync_success(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) @@ -105,7 +104,6 @@ def activity_fn(): identity='identity', )) -@pytest.mark.asyncio async def test_activity_sync_failure(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) @@ -132,7 +130,6 @@ def activity_fn(): identity='identity', ) -@pytest.mark.asyncio async def test_activity_unknown(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) @@ -148,7 +145,7 @@ def registry(name: str): call = worker_stub.RespondActivityTaskFailed.call_args[0][0] - assert 'unknown activity: any' in call.failure.details.decode() + assert 'Activity type not found: any' in call.failure.details.decode() call.failure.details = bytes() assert call == RespondActivityTaskFailedRequest( task_token=b'task_token', @@ -158,9 +155,70 @@ def registry(name: str): identity='identity', ) +async def test_activity_context(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + + async def activity_fn(): + assert fake_info("activity_type") == activity.info() + assert activity.in_activity() + assert activity.client() is not None + return "success" + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + + await executor.execute(fake_task("activity_type", "")) + + worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( + task_token=b'task_token', + result=Payload(data='"success"'.encode()), + identity='identity', + )) + +async def test_activity_context_sync(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + + def activity_fn(): + assert fake_info("activity_type") == activity.info() + assert activity.in_activity() + with pytest.raises(RuntimeError): + activity.client() + return "success" + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + + await executor.execute(fake_task("activity_type", "")) + + worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( + task_token=b'task_token', + result=Payload(data='"success"'.encode()), + identity='identity', + )) + + +def fake_info(activity_type: str) -> ActivityInfo: + return ActivityInfo( + task_token=b'task_token', + workflow_domain="workflow_domain", + workflow_id="workflow_id", + workflow_run_id="run_id", + activity_id="activity_id", + activity_type=activity_type, + attempt=1, + workflow_type="workflow_type", + task_list="task_list", + heartbeat_timeout=timedelta(seconds=1), + scheduled_timestamp=datetime(2020, 1, 2 ,3), + started_timestamp=datetime(2020, 1, 2 ,4), + start_to_close_timeout=timedelta(seconds=2), + ) + def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskResponse: return PollForActivityTaskResponse( task_token=b'task_token', + workflow_domain="workflow_domain", + workflow_type=WorkflowType(name="workflow_type"), workflow_execution=WorkflowExecution( workflow_id="workflow_id", run_id="run_id", @@ -168,5 +226,14 @@ def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskRespons activity_id="activity_id", activity_type=ActivityType(name=activity_type), input=Payload(data=input_json.encode()), - attempt=0, - ) \ No newline at end of file + attempt=1, + heartbeat_timeout=from_timedelta(timedelta(seconds=1)), + scheduled_time=from_datetime(datetime(2020, 1, 2, 3)), + started_time=from_datetime(datetime(2020, 1, 2, 4)), + start_to_close_timeout=from_timedelta(timedelta(seconds=2)), + ) + +def from_datetime(time: datetime) -> Timestamp: + t = Timestamp() + t.FromDatetime(time) + return t