diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 273581302..81468728c 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -2,6 +2,7 @@ import asyncio import concurrent.futures +import dataclasses import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass @@ -49,7 +50,17 @@ ) from temporalio.service import RPCError, RPCStatusCode from temporalio.testing import WorkflowEnvironment -from temporalio.worker import Worker +from temporalio.worker import ( + ExecuteNexusOperationCancelInput, + ExecuteNexusOperationStartInput, + Interceptor, + NexusOperationInboundInterceptor, + StartNexusOperationInput, + Worker, + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, + WorkflowOutboundInterceptor, +) from tests.helpers import find_free_port, new_worker from tests.helpers.metrics import PromMetricMatcher from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -109,6 +120,30 @@ class OpOutput: value: str +@dataclass +class HeaderTestOutput: + received_headers: dict[str, str] + + +@dataclass +class HeaderTestCallerWfInput: + headers: dict[str, str] + task_queue: str + + +@dataclass +class CancelHeaderTestCallerWfInput: + workflow_id: str + headers: dict[str, str] + task_queue: str + + +@dataclass +class WorkflowRunHeaderTestCallerWfInput: + headers: dict[str, str] + task_queue: str + + @dataclass class HandlerWfInput: op_input: OpInput @@ -126,6 +161,13 @@ class ServiceInterface: async_operation: nexusrpc.Operation[OpInput, HandlerWfOutput] +@nexusrpc.service +class HeaderTestService: + header_echo_operation: nexusrpc.Operation[None, HeaderTestOutput] + workflow_run_header_operation: nexusrpc.Operation[None, HeaderTestOutput] + cancellable_operation: nexusrpc.Operation[None, str] + + # ----------------------------------------------------------------------------- # Service implementation # @@ -218,6 +260,76 @@ async def async_operation( ) +@workflow.defn +class HeaderEchoWorkflow: + """A workflow that returns the headers it receives as input.""" + + @workflow.run + async def run(self, headers: dict[str, str]) -> HeaderTestOutput: + return HeaderTestOutput(received_headers=headers) + + +class CancellableOperationHandler(OperationHandler[None, str]): + """Operation handler that captures cancel headers.""" + + def __init__(self, cancel_headers_received: list[dict[str, str]]) -> None: + self._cancel_headers_received = cancel_headers_received + + async def start( + self, ctx: StartOperationContext, input: None + ) -> StartOperationResultAsync: + return StartOperationResultAsync("test-token") + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + # Capture cancel headers for test verification + self._cancel_headers_received.append( + { + k: v + for k, v in ctx.headers.items() + if k.startswith("x-custom-") or k.startswith("x-interceptor-") + } + ) + + +@service_handler(service=HeaderTestService) +class HeaderTestServiceImpl: + def __init__(self) -> None: + self.cancel_headers_received: list[dict[str, str]] = [] + + @sync_operation + async def header_echo_operation( + self, ctx: StartOperationContext, _input: None + ) -> HeaderTestOutput: + # Return headers with "x-custom-" or "x-interceptor-" prefix for verification + return HeaderTestOutput( + received_headers={ + k: v + for k, v in ctx.headers.items() + if k.startswith("x-custom-") or k.startswith("x-interceptor-") + } + ) + + @workflow_run_operation + async def workflow_run_header_operation( + self, ctx: WorkflowRunOperationContext, _input: None + ) -> nexus.WorkflowHandle[HeaderTestOutput]: + # Filter headers and pass to backing workflow + filtered_headers = { + k: v + for k, v in ctx.headers.items() + if k.startswith("x-custom-") or k.startswith("x-interceptor-") + } + return await ctx.start_workflow( + HeaderEchoWorkflow.run, + filtered_headers, + id=str(uuid.uuid4()), + ) + + @operation_handler + def cancellable_operation(self) -> OperationHandler[None, str]: + return CancellableOperationHandler(self.cancel_headers_received) + + # ----------------------------------------------------------------------------- # Caller workflow # @@ -417,6 +529,59 @@ async def run( return CallerWfOutput(op_output=OpOutput(value=op_output.value)) +@workflow.defn +class HeaderTestCallerWorkflow: + @workflow.run + async def run(self, input: HeaderTestCallerWfInput) -> HeaderTestOutput: + nexus_client = workflow.create_nexus_client( + service=HeaderTestService, + endpoint=make_nexus_endpoint_name(input.task_queue), + ) + return await nexus_client.execute_operation( + HeaderTestService.header_echo_operation, + None, + headers=input.headers, + ) + + +@workflow.defn +class CancelHeaderTestCallerWorkflow: + """Workflow that starts a cancellable operation and then cancels it.""" + + @workflow.run + async def run(self, input: CancelHeaderTestCallerWfInput) -> None: + nexus_client = workflow.create_nexus_client( + service=HeaderTestService, + endpoint=make_nexus_endpoint_name(input.task_queue), + ) + op_handle = await nexus_client.start_operation( + HeaderTestService.cancellable_operation, + None, + headers=input.headers, + ) + # Request cancellation - this sends cancel headers to the handler + op_handle.cancel() + # Wait briefly to allow cancel request to be processed + await asyncio.sleep(0.1) + + +@workflow.defn +class WorkflowRunHeaderTestCallerWorkflow: + """Workflow that calls a workflow_run_operation and verifies headers.""" + + @workflow.run + async def run(self, input: WorkflowRunHeaderTestCallerWfInput) -> HeaderTestOutput: + nexus_client = workflow.create_nexus_client( + service=HeaderTestService, + endpoint=make_nexus_endpoint_name(input.task_queue), + ) + return await nexus_client.execute_operation( + HeaderTestService.workflow_run_header_operation, + None, + headers=input.headers, + ) + + # ----------------------------------------------------------------------------- # Tests # @@ -497,7 +662,241 @@ async def test_workflow_run_operation_happy_path( # TODO(nexus-preview): cross-namespace tests # TODO(nexus-preview): nexus endpoint pytest fixture? -# TODO(nexus-prerelease): test headers + + +# ----------------------------------------------------------------------------- +# Header tests +# + + +@dataclass +class HeaderModificationRecord: + original_headers: dict[str, str] + modified_headers: dict[str, str] + + +@dataclass +class CancelHeaderRecord: + original_headers: dict[str, str] + modified_headers: dict[str, str] + + +class HeaderModifyingNexusInterceptor(Interceptor): + def __init__(self) -> None: + self.header_records: list[HeaderModificationRecord] = [] + self.cancel_header_records: list[CancelHeaderRecord] = [] + + def intercept_nexus_operation( + self, next: NexusOperationInboundInterceptor + ) -> NexusOperationInboundInterceptor: + return _HeaderModifyingNexusInboundInterceptor(next, self) + + +class _HeaderModifyingNexusInboundInterceptor(NexusOperationInboundInterceptor): + def __init__( + self, + next: NexusOperationInboundInterceptor, + root: HeaderModifyingNexusInterceptor, + ): + super().__init__(next) + self._root = root + + async def execute_nexus_operation_start( + self, input: ExecuteNexusOperationStartInput + ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + import dataclasses + + original_headers = dict(input.ctx.headers) + + # Modify headers: prefix values and add new header + modified_headers = { + k: f"interceptor-modified-{v}" if k.startswith("x-custom-") else v + for k, v in input.ctx.headers.items() + } + modified_headers["x-interceptor-added"] = "interceptor-value" + + self._root.header_records.append( + HeaderModificationRecord( + original_headers=original_headers, + modified_headers=modified_headers, + ) + ) + + input.ctx = dataclasses.replace(input.ctx, headers=modified_headers) + return await super().execute_nexus_operation_start(input) + + async def execute_nexus_operation_cancel( + self, input: ExecuteNexusOperationCancelInput + ) -> None: + import dataclasses + + original_headers = dict(input.ctx.headers) + + # Modify headers: prefix values and add new header + modified_headers = { + k: f"interceptor-modified-{v}" if k.startswith("x-custom-") else v + for k, v in input.ctx.headers.items() + } + modified_headers["x-interceptor-added"] = "cancel-interceptor-value" + + self._root.cancel_header_records.append( + CancelHeaderRecord( + original_headers=original_headers, + modified_headers=modified_headers, + ) + ) + + input.ctx = dataclasses.replace(input.ctx, headers=modified_headers) + return await super().execute_nexus_operation_cancel(input) + + +class HeaderAddingOutboundInterceptor(Interceptor): + """Outbound interceptor that adds a static header to Nexus operation requests.""" + + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> type[WorkflowInboundInterceptor] | None: + return _HeaderAddingWorkflowInboundInterceptor + + +class _HeaderAddingWorkflowInboundInterceptor(WorkflowInboundInterceptor): + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + super().init(_HeaderAddingWorkflowOutboundInterceptor(outbound)) + + +class _HeaderAddingWorkflowOutboundInterceptor(WorkflowOutboundInterceptor): + async def start_nexus_operation( + self, input: StartNexusOperationInput + ) -> workflow.NexusOperationHandle: + existing_headers = dict(input.headers) if input.headers else {} + existing_headers["x-custom-outbound"] = "outbound-value" + input = dataclasses.replace(input, headers=existing_headers) + return await super().start_nexus_operation(input) + + +async def test_start_operation_headers( + client: Client, + env: WorkflowEnvironment, +): + """Test headers from workflow and interceptors are propagated to start operation handler.""" + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + task_queue = str(uuid.uuid4()) + inbound_interceptor = HeaderModifyingNexusInterceptor() + + async with Worker( + client, + nexus_service_handlers=[HeaderTestServiceImpl()], + workflows=[HeaderTestCallerWorkflow], + task_queue=task_queue, + interceptors=[HeaderAddingOutboundInterceptor(), inbound_interceptor], + ): + await create_nexus_endpoint(task_queue, client) + + workflow_headers = {"x-custom-from-workflow": "workflow-value"} + result = await client.execute_workflow( + HeaderTestCallerWorkflow.run, + HeaderTestCallerWfInput(headers=workflow_headers, task_queue=task_queue), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + + # Verify inbound interceptor saw headers from workflow and outbound interceptor + assert len(inbound_interceptor.header_records) == 1 + record = inbound_interceptor.header_records[0] + assert record.original_headers.get("x-custom-from-workflow") == "workflow-value" + assert record.original_headers.get("x-custom-outbound") == "outbound-value" + + # Verify handler received headers modified by inbound interceptor + assert ( + result.received_headers.get("x-custom-from-workflow") + == "interceptor-modified-workflow-value" + ) + assert ( + result.received_headers.get("x-custom-outbound") + == "interceptor-modified-outbound-value" + ) + assert result.received_headers.get("x-interceptor-added") == "interceptor-value" + + +async def test_workflow_run_operation_headers( + client: Client, + env: WorkflowEnvironment, +): + """Test that headers are propagated to @workflow_run_operation handlers.""" + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + task_queue = str(uuid.uuid4()) + test_headers = {"x-custom-workflow-run": "workflow-run-value"} + + async with Worker( + client, + nexus_service_handlers=[HeaderTestServiceImpl()], + workflows=[WorkflowRunHeaderTestCallerWorkflow, HeaderEchoWorkflow], + task_queue=task_queue, + ): + await create_nexus_endpoint(task_queue, client) + result = await client.execute_workflow( + WorkflowRunHeaderTestCallerWorkflow.run, + WorkflowRunHeaderTestCallerWfInput( + headers=test_headers, task_queue=task_queue + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert ( + result.received_headers.get("x-custom-workflow-run") == "workflow-run-value" + ) + + +async def test_cancel_operation_headers( + client: Client, + env: WorkflowEnvironment, +): + """Test headers from workflow and interceptor are propagated to cancel operation handler.""" + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + task_queue = str(uuid.uuid4()) + workflow_id = str(uuid.uuid4()) + inbound_interceptor = HeaderModifyingNexusInterceptor() + service_handler = HeaderTestServiceImpl() + + async with Worker( + client, + nexus_service_handlers=[service_handler], + workflows=[CancelHeaderTestCallerWorkflow], + task_queue=task_queue, + interceptors=[inbound_interceptor], + ): + await create_nexus_endpoint(task_queue, client) + + workflow_headers = {"x-custom-cancel": "cancel-value"} + await client.execute_workflow( + CancelHeaderTestCallerWorkflow.run, + CancelHeaderTestCallerWfInput( + workflow_id=workflow_id, + headers=workflow_headers, + task_queue=task_queue, + ), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + + # Verify inbound interceptor saw cancel headers from workflow + assert len(inbound_interceptor.cancel_header_records) == 1 + record = inbound_interceptor.cancel_header_records[0] + assert record.original_headers.get("x-custom-cancel") == "cancel-value" + + # Verify handler received headers modified by inbound interceptor + assert len(service_handler.cancel_headers_received) == 1 + received = service_handler.cancel_headers_received[0] + assert received.get("x-custom-cancel") == "interceptor-modified-cancel-value" + assert received.get("x-interceptor-added") == "cancel-interceptor-value" + + @pytest.mark.parametrize("exception_in_operation_start", [False, True]) @pytest.mark.parametrize("request_cancel", [False, True]) @pytest.mark.parametrize(