Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cadence/_internal/activity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@


from ._activity_executor import (
ActivityExecutor
)

__all__ = [
"ActivityExecutor",
]
81 changes: 81 additions & 0 deletions cadence/_internal/activity/_activity_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import asyncio
import inspect
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from traceback import format_exception
from typing import Any, Callable

from cadence._internal.type_utils import get_fn_parameters
from cadence.api.v1.common_pb2 import Failure
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \
RespondActivityTaskCompletedRequest
from cadence.client import Client

_logger = getLogger(__name__)

class ActivityExecutor:
def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], Callable]):
self._client = client
self._data_converter = client.data_converter
self._registry = registry
self._identity = identity
self._thread_pool = ThreadPoolExecutor(max_workers=max_workers,
thread_name_prefix=f'{task_list}-activity-')

async def execute(self, task: PollForActivityTaskResponse):
activity_type = task.activity_type.name
try:
activity_fn = self._registry(activity_type)
except KeyError as e:
_logger.error("Activity type not found.", extra={'activity_type': activity_type})
await self._report_failure(task, e)
return

await self._execute_fn(activity_fn, task)

async def _execute_fn(self, activity_fn: Callable[[Any], Any], task: PollForActivityTaskResponse):
try:
type_hints = get_fn_parameters(activity_fn)
params = await self._client.data_converter.from_data(task.input, type_hints)
if inspect.iscoroutinefunction(activity_fn):
result = await activity_fn(*params)
else:
result = await self._invoke_sync_activity(activity_fn, params)
await self._report_success(task, result)
except Exception as e:
await self._report_failure(task, e)

async def _invoke_sync_activity(self, activity_fn: Callable[[Any], Any], params: list[Any]) -> Any:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._thread_pool, activity_fn, *params)

async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception):
try:
await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest(
task_token=task.task_token,
failure=_to_failure(error),
identity=self._identity,
))
except Exception:
_logger.exception('Exception reporting activity failure')

async def _report_success(self, task: PollForActivityTaskResponse, result: Any):
as_payload = await self._data_converter.to_data(result)

try:
await self._client.worker_stub.RespondActivityTaskCompleted(RespondActivityTaskCompletedRequest(
task_token=task.task_token,
result=as_payload,
identity=self._identity,
))
except Exception:
_logger.exception('Exception reporting activity complete')


def _to_failure(exception: Exception) -> Failure:
stacktrace = "".join(format_exception(exception))

return Failure(
reason=type(exception).__name__,
details=stacktrace.encode(),
)
19 changes: 19 additions & 0 deletions cadence/_internal/type_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from inspect import signature, Parameter
from typing import Callable, List, Type, get_type_hints

def get_fn_parameters(fn: Callable) -> List[Type | None]:
args = signature(fn).parameters
hints = get_type_hints(fn)
result = []
for name, param in args.items():
if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
type_hint = hints.get(name, None)
result.append(type_hint)

return result

def validate_fn_parameters(fn: Callable) -> None:
args = signature(fn).parameters
for name, param in args.items():
if param.kind not in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid")
6 changes: 6 additions & 0 deletions cadence/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
from grpc.aio import Channel

from cadence.data_converter import DataConverter


class ClientOptions(TypedDict, total=False):
domain: str
identity: str
data_converter: DataConverter

class Client:
def __init__(self, channel: Channel, options: ClientOptions) -> None:
Expand All @@ -17,6 +20,9 @@ def __init__(self, channel: Channel, options: ClientOptions) -> None:
self._options = options
self._identity = options["identity"] if "identity" in options else f"{os.getpid()}@{socket.gethostname()}"

@property
def data_converter(self) -> DataConverter:
return self._options["data_converter"]

@property
def domain(self) -> str:
Expand Down
15 changes: 10 additions & 5 deletions cadence/data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class DataConverter(Protocol):

@abstractmethod
async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]:
async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]:
raise NotImplementedError()

@abstractmethod
Expand All @@ -23,7 +23,10 @@ def __init__(self) -> None:
self._fallback_decoder = JSONDecoder(strict=False)


async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]:
async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]:
if not payload.data:
return DefaultDataConverter._convert_into([], type_hints)

if len(type_hints) > 1:
payload_str = payload.data.decode()
# Handle payloads from the Go client, which are a series of json objects rather than a json array
Expand All @@ -37,7 +40,7 @@ async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]
return DefaultDataConverter._convert_into([as_value], type_hints)


def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) -> List[Any]:
def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type | None]) -> List[Any]:
results: List[Any] = []
start, end = 0, len(payload)
while start < end and len(results) < len(type_hints):
Expand All @@ -49,10 +52,12 @@ def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) ->
return DefaultDataConverter._convert_into(results, type_hints)

@staticmethod
def _convert_into(values: List[Any], type_hints: List[Type]) -> List[Any]:
def _convert_into(values: List[Any], type_hints: List[Type | None]) -> List[Any]:
results: List[Any] = []
for i, type_hint in enumerate(type_hints):
if i < len(values):
if not type_hint:
value = values[i]
elif i < len(values):
value = convert(values[i], type_hint)
else:
value = DefaultDataConverter._get_default(type_hint)
Expand Down
4 changes: 2 additions & 2 deletions cadence/sample/client_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from cadence.client import Client, ClientOptions
from cadence._internal.rpc.metadata import MetadataInterceptor
from cadence.worker import Worker
from cadence.worker import Worker, Registry


async def main():
Expand All @@ -15,7 +15,7 @@ async def main():
metadata["rpc-caller"] = "nate"
async with insecure_channel("localhost:7833", interceptors=[MetadataInterceptor(metadata)]) as channel:
client = Client(channel, ClientOptions(domain="foo"))
worker = Worker(client, "task_list")
worker = Worker(client, "task_list", Registry())
await worker.run()

if __name__ == '__main__':
Expand Down
18 changes: 8 additions & 10 deletions cadence/worker/_activity.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import asyncio
from typing import Optional

from cadence.api.v1.common_pb2 import Failure
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, PollForActivityTaskRequest, \
RespondActivityTaskFailedRequest
from cadence._internal.activity import ActivityExecutor
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, PollForActivityTaskRequest
from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind
from cadence.client import Client
from cadence.worker._registry import Registry
from cadence.worker._types import WorkerOptions, _LONG_POLL_TIMEOUT
from cadence.worker._poller import Poller


class ActivityWorker:
def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None:
def __init__(self, client: Client, task_list: str, registry: Registry, options: WorkerOptions) -> None:
self._client = client
self._task_list = task_list
self._identity = options["identity"]
permits = asyncio.Semaphore(options["max_concurrent_activity_execution_size"])
max_concurrent = options["max_concurrent_activity_execution_size"]
permits = asyncio.Semaphore(max_concurrent)
self._executor = ActivityExecutor(self._client, self._task_list, options["identity"], max_concurrent, registry.get_activity)
self._poller = Poller[PollForActivityTaskResponse](options["activity_task_pollers"], permits, self._poll, self._execute)
# TODO: Local dispatch, local activities, actually running activities, etc

Expand All @@ -35,9 +37,5 @@ async def _poll(self) -> Optional[PollForActivityTaskResponse]:
return None

async def _execute(self, task: PollForActivityTaskResponse) -> None:
await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest(
task_token=task.task_token,
identity=self._identity,
failure=Failure(reason='error', details=b'not implemented'),
))
await self._executor.execute(task)

2 changes: 2 additions & 0 deletions cadence/worker/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
from typing import Callable, Dict, Optional, Unpack, TypedDict
from cadence._internal.type_utils import validate_fn_parameters


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -107,6 +108,7 @@ def activity(
options = RegisterActivityOptions(**kwargs)

def decorator(f: Callable) -> Callable:
validate_fn_parameters(f)
activity_name = options.get('name') or f.__name__

if activity_name in self._activities:
Expand Down
5 changes: 3 additions & 2 deletions cadence/worker/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
from typing import Unpack, cast

from cadence.client import Client
from cadence.worker._registry import Registry
from cadence.worker._activity import ActivityWorker
from cadence.worker._decision import DecisionWorker
from cadence.worker._types import WorkerOptions, _DEFAULT_WORKER_OPTIONS


class Worker:

def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]) -> None:
def __init__(self, client: Client, task_list: str, registry: Registry, **kwargs: Unpack[WorkerOptions]) -> None:
self._client = client
self._task_list = task_list

options = WorkerOptions(**kwargs)
_validate_and_copy_defaults(client, task_list, options)
self._options = options
self._activity_worker = ActivityWorker(client, task_list, options)
self._activity_worker = ActivityWorker(client, task_list, registry, options)
self._decision_worker = DecisionWorker(client, task_list, options)


Expand Down
Loading