diff --git a/cadence/_internal/activity/_activity_executor.py b/cadence/_internal/activity/_activity_executor.py index 64c3c52..f9efba0 100644 --- a/cadence/_internal/activity/_activity_executor.py +++ b/cadence/_internal/activity/_activity_executor.py @@ -1,11 +1,13 @@ -import asyncio import inspect from concurrent.futures import ThreadPoolExecutor from logging import getLogger from traceback import format_exception from typing import Any, Callable +from google.protobuf.duration import to_timedelta +from google.protobuf.timestamp import to_datetime -from cadence._internal.type_utils import get_fn_parameters +from cadence._internal.activity._context import _Context, _SyncContext +from cadence.activity import ActivityInfo from cadence.api.v1.common_pb2 import Failure from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \ RespondActivityTaskCompletedRequest @@ -19,35 +21,31 @@ def __init__(self, client: Client, task_list: str, identity: str, max_workers: i self._data_converter = client.data_converter self._registry = registry self._identity = identity + self._task_list = task_list self._thread_pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix=f'{task_list}-activity-') async def execute(self, task: PollForActivityTaskResponse): - activity_type = task.activity_type.name - try: - activity_fn = self._registry(activity_type) - except KeyError as e: - _logger.error("Activity type not found.", extra={'activity_type': activity_type}) - await self._report_failure(task, e) - return - - await self._execute_fn(activity_fn, task) - - async def _execute_fn(self, activity_fn: Callable[[Any], Any], task: PollForActivityTaskResponse): try: - type_hints = get_fn_parameters(activity_fn) - params = await self._client.data_converter.from_data(task.input, type_hints) - if inspect.iscoroutinefunction(activity_fn): - result = await activity_fn(*params) - else: - result = await self._invoke_sync_activity(activity_fn, params) + context = self._create_context(task) + result = await context.execute(task.input) await self._report_success(task, result) except Exception as e: await self._report_failure(task, e) - async def _invoke_sync_activity(self, activity_fn: Callable[[Any], Any], params: list[Any]) -> Any: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self._thread_pool, activity_fn, *params) + def _create_context(self, task: PollForActivityTaskResponse) -> _Context: + activity_type = task.activity_type.name + try: + activity_fn = self._registry(activity_type) + except KeyError: + raise KeyError(f"Activity type not found: {activity_type}") from None + + info = self._create_info(task) + + if inspect.iscoroutinefunction(activity_fn): + return _Context(self._client, info, activity_fn) + else: + return _SyncContext(self._client, info, activity_fn, self._thread_pool) async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception): try: @@ -71,6 +69,23 @@ async def _report_success(self, task: PollForActivityTaskResponse, result: Any): except Exception: _logger.exception('Exception reporting activity complete') + def _create_info(self, task: PollForActivityTaskResponse) -> ActivityInfo: + return ActivityInfo( + task_token=task.task_token, + workflow_type=task.workflow_type.name, + workflow_domain=task.workflow_domain, + workflow_id=task.workflow_execution.workflow_id, + workflow_run_id=task.workflow_execution.run_id, + activity_id=task.activity_id, + activity_type=task.activity_type.name, + task_list=self._task_list, + heartbeat_timeout=to_timedelta(task.heartbeat_timeout), + scheduled_timestamp=to_datetime(task.scheduled_time), + started_timestamp=to_datetime(task.started_time), + start_to_close_timeout=to_timedelta(task.start_to_close_timeout), + attempt=task.attempt, + ) + def _to_failure(exception: Exception) -> Failure: stacktrace = "".join(format_exception(exception)) diff --git a/cadence/_internal/activity/_context.py b/cadence/_internal/activity/_context.py new file mode 100644 index 0000000..208b859 --- /dev/null +++ b/cadence/_internal/activity/_context.py @@ -0,0 +1,48 @@ +import asyncio +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Callable, Any + +from cadence import Client +from cadence._internal.type_utils import get_fn_parameters +from cadence.activity import ActivityInfo, ActivityContext +from cadence.api.v1.common_pb2 import Payload + + +class _Context(ActivityContext): + def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any]): + self._client = client + self._info = info + self._activity_fn = activity_fn + + async def execute(self, payload: Payload) -> Any: + params = await self._to_params(payload) + with self._activate(): + return await self._activity_fn(*params) + + async def _to_params(self, payload: Payload) -> list[Any]: + type_hints = get_fn_parameters(self._activity_fn) + return await self._client.data_converter.from_data(payload, type_hints) + + def client(self) -> Client: + return self._client + + def info(self) -> ActivityInfo: + return self._info + +class _SyncContext(_Context): + def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any], executor: ThreadPoolExecutor): + super().__init__(client, info, activity_fn) + self._executor = executor + + async def execute(self, payload: Payload) -> Any: + params = await self._to_params(payload) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._executor, self._run, params) + + def _run(self, args: list[Any]) -> Any: + with self._activate(): + return self._activity_fn(*args) + + def client(self) -> Client: + raise RuntimeError("client is only supported in async activities") + diff --git a/cadence/activity.py b/cadence/activity.py new file mode 100644 index 0000000..0f71fb0 --- /dev/null +++ b/cadence/activity.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from datetime import timedelta, datetime +from typing import Iterator + +from cadence import Client + + +@dataclass(frozen=True) +class ActivityInfo: + task_token: bytes + workflow_type: str + workflow_domain: str + workflow_id: str + workflow_run_id: str + activity_id: str + activity_type: str + task_list: str + heartbeat_timeout: timedelta + scheduled_timestamp: datetime + started_timestamp: datetime + start_to_close_timeout: timedelta + attempt: int + +def client() -> Client: + return ActivityContext.get().client() + +def in_activity() -> bool: + return ActivityContext.is_set() + +def info() -> ActivityInfo: + return ActivityContext.get().info() + + + +class ActivityContext(ABC): + _var: ContextVar['ActivityContext'] = ContextVar("activity") + + @abstractmethod + def info(self) -> ActivityInfo: + ... + + @abstractmethod + def client(self) -> Client: + ... + + @contextmanager + def _activate(self) -> Iterator[None]: + token = ActivityContext._var.set(self) + yield None + ActivityContext._var.reset(token) + + @staticmethod + def is_set() -> bool: + return ActivityContext._var.get(None) is not None + + @staticmethod + def get() -> 'ActivityContext': + return ActivityContext._var.get() diff --git a/tests/cadence/_internal/activity/test_activity_executor.py b/tests/cadence/_internal/activity/test_activity_executor.py index 31e29ce..89b95e2 100644 --- a/tests/cadence/_internal/activity/test_activity_executor.py +++ b/tests/cadence/_internal/activity/test_activity_executor.py @@ -1,11 +1,15 @@ import asyncio +from datetime import timedelta, datetime from unittest.mock import Mock, AsyncMock, PropertyMock import pytest +from google.protobuf.timestamp_pb2 import Timestamp +from google.protobuf.duration import from_timedelta -from cadence import Client +from cadence import activity, Client from cadence._internal.activity import ActivityExecutor -from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure +from cadence.activity import ActivityInfo +from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure, WorkflowType from cadence.api.v1.service_worker_pb2 import RespondActivityTaskCompletedResponse, PollForActivityTaskResponse, \ RespondActivityTaskCompletedRequest, RespondActivityTaskFailedResponse, RespondActivityTaskFailedRequest from cadence.data_converter import DefaultDataConverter @@ -19,7 +23,6 @@ def client() -> Client: return client -@pytest.mark.asyncio async def test_activity_async_success(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) @@ -37,7 +40,6 @@ async def activity_fn(): identity='identity', )) -@pytest.mark.asyncio async def test_activity_async_failure(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) @@ -64,7 +66,6 @@ async def activity_fn(): identity='identity', ) -@pytest.mark.asyncio async def test_activity_args(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) @@ -82,8 +83,6 @@ async def activity_fn(first: str, second: str): identity='identity', )) - -@pytest.mark.asyncio async def test_activity_sync_success(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) @@ -105,7 +104,6 @@ def activity_fn(): identity='identity', )) -@pytest.mark.asyncio async def test_activity_sync_failure(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) @@ -132,7 +130,6 @@ def activity_fn(): identity='identity', ) -@pytest.mark.asyncio async def test_activity_unknown(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) @@ -148,7 +145,7 @@ def registry(name: str): call = worker_stub.RespondActivityTaskFailed.call_args[0][0] - assert 'unknown activity: any' in call.failure.details.decode() + assert 'Activity type not found: any' in call.failure.details.decode() call.failure.details = bytes() assert call == RespondActivityTaskFailedRequest( task_token=b'task_token', @@ -158,9 +155,70 @@ def registry(name: str): identity='identity', ) +async def test_activity_context(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + + async def activity_fn(): + assert fake_info("activity_type") == activity.info() + assert activity.in_activity() + assert activity.client() is not None + return "success" + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + + await executor.execute(fake_task("activity_type", "")) + + worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( + task_token=b'task_token', + result=Payload(data='"success"'.encode()), + identity='identity', + )) + +async def test_activity_context_sync(client): + worker_stub = client.worker_stub + worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + + def activity_fn(): + assert fake_info("activity_type") == activity.info() + assert activity.in_activity() + with pytest.raises(RuntimeError): + activity.client() + return "success" + + executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + + await executor.execute(fake_task("activity_type", "")) + + worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( + task_token=b'task_token', + result=Payload(data='"success"'.encode()), + identity='identity', + )) + + +def fake_info(activity_type: str) -> ActivityInfo: + return ActivityInfo( + task_token=b'task_token', + workflow_domain="workflow_domain", + workflow_id="workflow_id", + workflow_run_id="run_id", + activity_id="activity_id", + activity_type=activity_type, + attempt=1, + workflow_type="workflow_type", + task_list="task_list", + heartbeat_timeout=timedelta(seconds=1), + scheduled_timestamp=datetime(2020, 1, 2 ,3), + started_timestamp=datetime(2020, 1, 2 ,4), + start_to_close_timeout=timedelta(seconds=2), + ) + def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskResponse: return PollForActivityTaskResponse( task_token=b'task_token', + workflow_domain="workflow_domain", + workflow_type=WorkflowType(name="workflow_type"), workflow_execution=WorkflowExecution( workflow_id="workflow_id", run_id="run_id", @@ -168,5 +226,14 @@ def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskRespons activity_id="activity_id", activity_type=ActivityType(name=activity_type), input=Payload(data=input_json.encode()), - attempt=0, - ) \ No newline at end of file + attempt=1, + heartbeat_timeout=from_timedelta(timedelta(seconds=1)), + scheduled_time=from_datetime(datetime(2020, 1, 2, 3)), + started_time=from_datetime(datetime(2020, 1, 2, 4)), + start_to_close_timeout=from_timedelta(timedelta(seconds=2)), + ) + +def from_datetime(time: datetime) -> Timestamp: + t = Timestamp() + t.FromDatetime(time) + return t