diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py index 67e11de9c..863f992b3 100644 --- a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py +++ b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py @@ -1,4 +1,5 @@ # (C) 2024 GoodData Corporation +import time from collections.abc import Generator from typing import Optional @@ -17,17 +18,45 @@ ) from gooddata_flexconnect.function.function import FlexConnectFunction +from gooddata_flexconnect.function.function_invocation import ( + CancelInvocation, + RetryInvocation, + SubmitInvocation, + extract_pollable_invocation_from_descriptor, + extract_submit_invocation_from_descriptor, +) from gooddata_flexconnect.function.function_registry import FlexConnectFunctionRegistry from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask _LOGGER = structlog.get_logger("gooddata_flexconnect.rpc") +POLLING_HEADER_NAME = "x-quiver-pollable" +""" +If this header is present on the get flight info call, the polling extension will be used. +Otherwise the basic do get will be used. +""" + + +def _prepare_poll_error(task_id: str) -> pyarrow.flight.FlightError: + return ErrorInfo.poll( + flight_info=None, + cancel_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"c:{task_id}".encode()), + retry_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"r:{task_id}".encode()), + ) + class _FlexConnectServerMethods(FlightServerMethods): - def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry, call_deadline_ms: float) -> None: + def __init__( + self, + ctx: ServerContext, + registry: FlexConnectFunctionRegistry, + call_deadline_ms: float, + poll_interval_ms: float, + ) -> None: self._ctx = ctx self._registry = registry self._call_deadline = call_deadline_ms / 1000 + self._poll_interval = poll_interval_ms / 1000 @staticmethod def _create_descriptor(fun_name: str, metadata: Optional[dict]) -> pyarrow.flight.FlightDescriptor: @@ -52,58 +81,37 @@ def _create_fun_info(self, fun: type[FlexConnectFunction]) -> pyarrow.flight.Fli total_records=-1, ) - def _extract_invocation_payload( - self, descriptor: pyarrow.flight.FlightDescriptor - ) -> tuple[str, dict, Optional[tuple[str, ...]]]: - if descriptor.command is None or not len(descriptor.command): - raise ErrorInfo.bad_argument( - "Incorrect FlexConnect function invocation. Flight descriptor must contain command " - "with the invocation payload." - ) - - try: - payload = orjson.loads(descriptor.command) - except Exception: - raise ErrorInfo.bad_argument( - "Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON." - ) - - fun = payload.get("functionName") - if fun is None or not len(fun): - raise ErrorInfo.bad_argument( - "Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'." - ) - - parameters = payload.get("parameters") or {} - columns = parameters.get("columns") - - return fun, parameters, columns - def _prepare_task( self, context: pyarrow.flight.ServerCallContext, - descriptor: pyarrow.flight.FlightDescriptor, + submit_invocation: SubmitInvocation, ) -> FlexConnectFunctionTask: - fun_name, parameters, columns = self._extract_invocation_payload(descriptor) headers = self.call_info_middleware(context).headers - fun = self._registry.create_function(fun_name) + fun = self._registry.create_function(submit_invocation.function_name) return FlexConnectFunctionTask( fun=fun, - parameters=parameters, - columns=columns, + parameters=submit_invocation.parameters, + columns=submit_invocation.columns, headers=headers, - cmd=descriptor.command, + cmd=submit_invocation.command, ) - def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flight.FlightInfo: + def _prepare_flight_info( + self, task_id: str, task_result: Optional[TaskExecutionResult] + ) -> pyarrow.flight.FlightInfo: + if task_result is None: + raise ErrorInfo.for_reason( + ErrorCode.BAD_ARGUMENT, f"Task with id '{task_id}' does not exist." + ).to_user_error() + if task_result.error is not None: raise task_result.error.as_flight_error() if task_result.cancelled: raise ErrorInfo.for_reason( ErrorCode.COMMAND_CANCELLED, - f"FlexConnect function invocation was cancelled. Invocation task was: '{task_result.task_id}'.", + f"FlexConnect function invocation was cancelled. Invocation task was: '{task_id}'.", ).to_server_error() result = task_result.result @@ -114,7 +122,7 @@ def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flig descriptor=pyarrow.flight.FlightDescriptor.for_command(task_result.cmd), endpoints=[ pyarrow.flight.FlightEndpoint( - ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_result.task_id})), + ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_id})), locations=[self._ctx.location], ) ], @@ -122,32 +130,25 @@ def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flig total_bytes=-1, ) - ################################################################### - # Implementation of Flight RPC methods - ################################################################### - - def list_flights( - self, context: pyarrow.flight.ServerCallContext, criteria: bytes - ) -> Generator[pyarrow.flight.FlightInfo, None, None]: - structlog.contextvars.bind_contextvars(peer=context.peer()) - _LOGGER.info("list_flights", available_funs=self._registry.function_names) - - return (self._create_fun_info(fun) for fun in self._registry.functions.values()) - - def get_flight_info( + def _get_flight_info_no_polling( self, context: pyarrow.flight.ServerCallContext, descriptor: pyarrow.flight.FlightDescriptor, ) -> pyarrow.flight.FlightInfo: + """ + Basic DoGetInfo flow with no polling extension. + This conforms to the mainline Arrow Flight RPC specification. + """ structlog.contextvars.bind_contextvars(peer=context.peer()) + invocation = extract_submit_invocation_from_descriptor(descriptor) + task: Optional[FlexConnectFunctionTask] = None try: - task = self._prepare_task(context, descriptor) + task = self._prepare_task(context, invocation) self._ctx.task_executor.submit(task) try: - # XXX: this should be enhanced to implement polling task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline) except TaskWaitTimeoutError: cancelled = self._ctx.task_executor.cancel(task.task_id) @@ -166,15 +167,106 @@ def get_flight_info( # particular task id finished assert task_result is not None - return self._prepare_flight_info(task_result) + return self._prepare_flight_info(task_id=task.task_id, task_result=task_result) except Exception: if task is not None: - _LOGGER.error("get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True) + _LOGGER.error( + "get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True, polling=False + ) else: - _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True) + _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=False) + raise + + def _get_flight_info_polling( + self, + context: pyarrow.flight.ServerCallContext, + descriptor: pyarrow.flight.FlightDescriptor, + ) -> pyarrow.flight.FlightInfo: + """ + DoGetInfo flow with polling extension. + This extends the mainline Arrow Flight RPC specification with polling capabilities using the RetryInfo + encoded into the FlightTimedOutError.extra_info. + Ideally, we would use the mainline PollFlightInfo, but that has yet to be implemented in the PyArrow library. + """ + structlog.contextvars.bind_contextvars(peer=context.peer()) + invocation = extract_pollable_invocation_from_descriptor(descriptor) + + task_id: str + fun_name: Optional[str] = None + + if isinstance(invocation, CancelInvocation): + # cancel the given task and raise cancellation exception + if self._ctx.task_executor.cancel(invocation.task_id): + raise ErrorInfo.for_reason( + ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled." + ).to_cancelled_error() + raise ErrorInfo.for_reason( + ErrorCode.COMMAND_CANCEL_NOT_POSSIBLE, "FlexConnect function invocation could not be cancelled." + ).to_cancelled_error() + elif isinstance(invocation, RetryInvocation): + # retry descriptor: extract the task_id, do not submit it again and do one polling iteration + task_id = invocation.task_id + elif isinstance(invocation, SubmitInvocation): + # basic first-time submit: submit the task and do one polling iteration. + # do not check call deadline to give it a chance to wait for the result at least once + try: + task = self._prepare_task(context, invocation) + self._ctx.task_executor.submit(task) + task_id = task.task_id + fun_name = task.fun_name + except Exception: + _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=True) + raise + else: + # can be replaced by assert_never when we are on 3.11 + raise AssertionError + + try: + task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval) + return self._prepare_flight_info(task_id, task_result) + except TimeoutError: + # first, check the call deadline for the whole call duration + task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id) + if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline: + self._ctx.task_executor.cancel(task_id) + raise ErrorInfo.for_reason( + ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}." + ).to_timeout_error() + # if the result is not ready, and we still have time, indicate to the client + # how to poll for the results + raise _prepare_poll_error(task_id) + except Exception: + _LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True, polling=True) raise + ################################################################### + # Implementation of Flight RPC methods + ################################################################### + + def list_flights( + self, context: pyarrow.flight.ServerCallContext, criteria: bytes + ) -> Generator[pyarrow.flight.FlightInfo, None, None]: + structlog.contextvars.bind_contextvars(peer=context.peer()) + _LOGGER.info("list_flights", available_funs=self._registry.function_names) + + return (self._create_fun_info(fun) for fun in self._registry.functions.values()) + + def get_flight_info( + self, + context: pyarrow.flight.ServerCallContext, + descriptor: pyarrow.flight.FlightDescriptor, + ) -> pyarrow.flight.FlightInfo: + structlog.contextvars.bind_contextvars(peer=context.peer()) + + headers = self.call_info_middleware(context).headers + allow_polling = headers.get(POLLING_HEADER_NAME) is not None + + if allow_polling: + return self._get_flight_info_polling(context, descriptor) + else: + return self._get_flight_info_no_polling(context, descriptor) + def do_get( self, context: pyarrow.flight.ServerCallContext, @@ -201,7 +293,9 @@ def do_get( _FLEX_CONNECT_CONFIG_SECTION = "flexconnect" _FLEX_CONNECT_FUNCTION_LIST = "functions" _FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms" +_FLEX_CONNECT_POLLING_INTERVAL_MS = "polling_interval_ms" _DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000 +_DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS = 2000 def _read_call_deadline_ms(ctx: ServerContext) -> int: @@ -223,6 +317,24 @@ def _read_call_deadline_ms(ctx: ServerContext) -> int: ) +def _read_polling_interval_ms(ctx: ServerContext) -> int: + polling_interval = ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS}") + if polling_interval is None: + return _DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS + + try: + polling_interval = int(polling_interval) + if polling_interval <= 0: + raise ValueError() + return polling_interval + except ValueError: + raise ValueError( + f"Value of {_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS} must " + f"be a positive number - duration, in milliseconds, that FlexConnect function " + f"waits for the result during one polling iteration." + ) + + @flight_server_methods def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods: """ @@ -236,8 +348,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods """ modules = list(ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_FUNCTION_LIST}") or []) call_deadline_ms = _read_call_deadline_ms(ctx) + polling_interval_ms = _read_polling_interval_ms(ctx) _LOGGER.info("flexconnect_init", modules=modules) registry = FlexConnectFunctionRegistry().load(ctx, modules) - return _FlexConnectServerMethods(ctx, registry, call_deadline_ms) + return _FlexConnectServerMethods(ctx, registry, call_deadline_ms, polling_interval_ms) diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py b/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py new file mode 100644 index 000000000..8647bf0ac --- /dev/null +++ b/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py @@ -0,0 +1,103 @@ +# (C) 2025 GoodData Corporation +from dataclasses import dataclass +from typing import Optional, Union + +import orjson +import pyarrow.flight +from gooddata_flight_server import ErrorInfo + + +@dataclass(frozen=True) +class RetryInvocation: + """ + Indicates that the getting the results of the given task should be retried. + """ + + task_id: str + + +@dataclass(frozen=True) +class CancelInvocation: + """ + Indicates that the given task should be cancelled. + """ + + task_id: str + + +@dataclass(frozen=True) +class SubmitInvocation: + """ + Indicates that the given task should be submitted for processing. + """ + + command: bytes + """ + The raw command that was sent to the Flight Server. + """ + + function_name: str + """ + The name of the FlexConnect function to invoke. + """ + + parameters: dict + """ + Parameters to pass to the FlexConnect function. + """ + + columns: Optional[tuple[str, ...]] + """ + Columns to get from the FlexConnect function result. + This may be used for column trimming by the function: the function must return at least those columns. + """ + + +def extract_submit_invocation_from_descriptor(descriptor: pyarrow.flight.FlightDescriptor) -> SubmitInvocation: + """ + Given a flight descriptor, extract the invocation information from it. + Do not allow the polling-related variants. + """ + try: + payload = orjson.loads(descriptor.command) + except Exception: + raise ErrorInfo.bad_argument( + "Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON." + ) + + function_name = payload.get("functionName") + if function_name is None or not len(function_name): + raise ErrorInfo.bad_argument( + "Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'." + ) + + parameters = payload.get("parameters") or {} + columns = parameters.get("columns") + + return SubmitInvocation( + function_name=function_name, parameters=parameters, columns=columns, command=descriptor.command + ) + + +def extract_pollable_invocation_from_descriptor( + descriptor: pyarrow.flight.FlightDescriptor, +) -> Union[RetryInvocation, CancelInvocation, SubmitInvocation]: + """ + Given a flight descriptor, extract the invocation information from it. + Allow also the polling-related variants. + """ + if descriptor.command is None or not len(descriptor.command): + raise ErrorInfo.bad_argument( + "Incorrect FlexConnect function invocation. Flight descriptor must contain command " + "with the invocation payload." + ) + + # we are in the polling-enabled realm: try parsing the retry and cancel descriptors first + if descriptor.command.startswith(b"c:"): + task_id = descriptor.command[2:].decode() + return CancelInvocation(task_id) + elif descriptor.command.startswith(b"r:"): + task_id = descriptor.command[2:].decode() + return RetryInvocation(task_id) + + return extract_submit_invocation_from_descriptor(descriptor) diff --git a/gooddata-flexconnect/tests/server/conftest.py b/gooddata-flexconnect/tests/server/conftest.py index 8e87f3ed6..a84cbe9c0 100644 --- a/gooddata-flexconnect/tests/server/conftest.py +++ b/gooddata-flexconnect/tests/server/conftest.py @@ -80,7 +80,8 @@ def flexconnect_server( funs = f"[{funs}]" os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = funs - os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "500" + os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "1200" + os.environ["GOODDATA_FLIGHT_FLEXCONNECT__POLLING_INTERVAL_MS"] = "500" with server(create_flexconnect_flight_methods, tls, mtls) as s: yield s diff --git a/gooddata-flexconnect/tests/server/funs/fun3.py b/gooddata-flexconnect/tests/server/funs/fun3.py index 42ac365ef..6ec64496b 100644 --- a/gooddata-flexconnect/tests/server/funs/fun3.py +++ b/gooddata-flexconnect/tests/server/funs/fun3.py @@ -27,7 +27,7 @@ def call( ) -> ArrowData: # sleep is intentionally setup to be longer than the deadline for # the function invocation (see conftest.py // flexconnect_server fixture) - time.sleep(1) + time.sleep(2) return pyarrow.table( data={ diff --git a/gooddata-flexconnect/tests/server/funs/fun4.py b/gooddata-flexconnect/tests/server/funs/fun4.py new file mode 100644 index 000000000..b42faae77 --- /dev/null +++ b/gooddata-flexconnect/tests/server/funs/fun4.py @@ -0,0 +1,39 @@ +# (C) 2024 GoodData Corporation +import time +from typing import Optional + +import pyarrow +from gooddata_flexconnect.function.function import FlexConnectFunction +from gooddata_flight_server import ArrowData + +_DATA: Optional[pyarrow.Table] = None + + +class _PollableFun(FlexConnectFunction): + Name = "PollableFun" + Schema = pyarrow.schema( + fields=[ + pyarrow.field("col1", pyarrow.int64()), + pyarrow.field("col2", pyarrow.string()), + pyarrow.field("col3", pyarrow.bool_()), + ] + ) + + def call( + self, + parameters: dict, + columns: tuple[str, ...], + headers: dict[str, list[str]], + ) -> ArrowData: + # sleep is intentionally setup to be longer than one polling interval + # (see conftest.py // flexconnect_server fixture) + time.sleep(0.7) + + return pyarrow.table( + data={ + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + "col3": [True, False, True], + }, + schema=self.Schema, + ) diff --git a/gooddata-flexconnect/tests/server/test_flexconnect_server.py b/gooddata-flexconnect/tests/server/test_flexconnect_server.py index 42fc3ffdb..831370adf 100644 --- a/gooddata-flexconnect/tests/server/test_flexconnect_server.py +++ b/gooddata-flexconnect/tests/server/test_flexconnect_server.py @@ -1,14 +1,24 @@ # (C) 2024 GoodData Corporation + import orjson import pyarrow.flight import pytest -from gooddata_flight_server import ErrorCode +from gooddata_flexconnect.function.flight_methods import POLLING_HEADER_NAME +from gooddata_flight_server import ErrorCode, ErrorInfo, RetryInfo from tests.assert_error_info import assert_error_code from tests.server.conftest import flexconnect_server +@pytest.fixture +def call_options_with_polling(): + return pyarrow.flight.FlightCallOptions(headers=[(POLLING_HEADER_NAME.encode(), b"true")]) + + def test_basic_function(): + """ + This function should return immediately when called, no polling allowed. + """ with flexconnect_server(["tests.server.funs.fun1"]) as s: c = pyarrow.flight.FlightClient(s.location) fun_infos = list(c.list_flights()) @@ -89,7 +99,110 @@ def test_basic_function_tls(tls_ca_cert): assert data.column_names == ["col1", "col2", "col3"] -def test_function_with_call_deadline(): +def test_cancellable_function(call_options_with_polling): + """ + This function should return immediately when called, no polling necessary even if enabled. + """ + with flexconnect_server(["tests.server.funs.fun1"]) as s: + c = pyarrow.flight.FlightClient(s.location) + fun_infos = list(c.list_flights()) + assert len(fun_infos) == 1 + fun_info: pyarrow.flight.FlightInfo = fun_infos[0] + + assert fun_info.schema.names == ["col1", "col2", "col3"] + assert fun_info.descriptor.command is not None + assert len(fun_info.descriptor.command) + cmd = orjson.loads(fun_info.descriptor.command) + assert cmd["functionName"] == "SimpleFun1" + + descriptor = pyarrow.flight.FlightDescriptor.for_command( + orjson.dumps( + { + "functionName": "SimpleFun1", + "parameters": {"test1": 1, "test2": 2, "test3": 3}, + } + ) + ) + info = c.get_flight_info(descriptor, call_options_with_polling) + data: pyarrow.Table = c.do_get(info.endpoints[0].ticket).read_all() + + assert len(data) == 3 + assert data.column_names == ["col1", "col2", "col3"] + + +def test_cancellable_function_tls(tls_ca_cert, call_options_with_polling): + with flexconnect_server(["tests.server.funs.fun1"], tls=True) as s: + c = pyarrow.flight.FlightClient(s.location, tls_root_certs=tls_ca_cert) + fun_infos = list(c.list_flights()) + assert len(fun_infos) == 1 + fun_info: pyarrow.flight.FlightInfo = fun_infos[0] + + assert fun_info.schema.names == ["col1", "col2", "col3"] + assert fun_info.descriptor.command is not None + assert len(fun_info.descriptor.command) + cmd = orjson.loads(fun_info.descriptor.command) + assert cmd["functionName"] == "SimpleFun1" + + descriptor = pyarrow.flight.FlightDescriptor.for_command( + orjson.dumps( + { + "functionName": "SimpleFun1", + "parameters": {"test1": 1, "test2": 2, "test3": 3}, + } + ) + ) + info = c.get_flight_info(descriptor, call_options_with_polling) + data: pyarrow.Table = c.do_get(info.endpoints[0].ticket).read_all() + + assert len(data) == 3 + assert data.column_names == ["col1", "col2", "col3"] + + +def test_cancellable_function_with_polling(call_options_with_polling): + """ + Flight RPC implementation that invokes FlexConnect can return a polling info. + + This way, the client can poll for results that take longer to complete. + """ + with flexconnect_server(["tests.server.funs.fun4"]) as s: + c = pyarrow.flight.FlightClient(s.location) + descriptor = pyarrow.flight.FlightDescriptor.for_command( + orjson.dumps( + { + "functionName": "PollableFun", + "parameters": {"test1": 1, "test2": 2, "test3": 3}, + } + ) + ) + + # the function is set to sleep a bit longer than the polling interval, + # so the first iteration returns retry info in the exception + with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: + c.get_flight_info(descriptor, call_options_with_polling) + + assert e.value is not None + assert_error_code(ErrorCode.POLL, e.value) + + error_info = ErrorInfo.from_bytes(e.value.extra_info) + retry_info = RetryInfo.from_bytes(error_info.body) + + # use the retry info to poll again for the result, + # now it should be ready and returned normally + info = c.get_flight_info(retry_info.retry_descriptor, call_options_with_polling) + data: pyarrow.Table = c.do_get(info.endpoints[0].ticket).read_all() + + assert len(data) == 3 + assert data.column_names == ["col1", "col2", "col3"] + + # also check that trying to cancel already completed task results in cancelled with correct code + with pytest.raises(pyarrow.flight.FlightCancelledError) as e: + c.get_flight_info(retry_info.cancel_descriptor, call_options_with_polling) + + assert e.value is not None + assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value) + + +def test_cancellable_function_with_call_deadline(call_options_with_polling): """ Flight RPC implementation that invokes FlexConnect can be setup with deadline for the invocation duration (done by GetFlightInfo). @@ -112,7 +225,65 @@ def test_function_with_call_deadline(): ) ) + # the initial submit returns polling info + with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: + c.get_flight_info(descriptor, call_options_with_polling) + + assert e.value is not None + assert_error_code(ErrorCode.POLL, e.value) + + error_info = ErrorInfo.from_bytes(e.value.extra_info) + retry_info = RetryInfo.from_bytes(error_info.body) + + # the next poll still returns polling info with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: - c.get_flight_info(descriptor) + c.get_flight_info(retry_info.retry_descriptor, call_options_with_polling) + + assert e.value is not None + assert_error_code(ErrorCode.POLL, e.value) + + # the third one reaches the deadline so the Timeout code is returned instead + with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: + c.get_flight_info(retry_info.retry_descriptor, call_options_with_polling) assert_error_code(ErrorCode.TIMEOUT, e.value) + + +def test_cancellable_function_with_cancellation(call_options_with_polling): + """ + Run a long-running function and cancel it after one poll iteration. + """ + with flexconnect_server(["tests.server.funs.fun3"]) as s: + c = pyarrow.flight.FlightClient(s.location) + descriptor = pyarrow.flight.FlightDescriptor.for_command( + orjson.dumps( + { + "functionName": "LongRunningFun", + "parameters": {"test1": 1, "test2": 2, "test3": 3}, + } + ) + ) + + # the initial submit returns polling info + with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: + c.get_flight_info(descriptor, call_options_with_polling) + + assert e.value is not None + assert_error_code(ErrorCode.POLL, e.value) + + error_info = ErrorInfo.from_bytes(e.value.extra_info) + retry_info = RetryInfo.from_bytes(error_info.body) + + # use the poll info to cancel the task + with pytest.raises(pyarrow.flight.FlightCancelledError) as e: + c.get_flight_info(retry_info.cancel_descriptor, call_options_with_polling) + + assert e.value is not None + assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value) + + # even multiple cancellations return the same error + with pytest.raises(pyarrow.flight.FlightCancelledError) as e: + c.get_flight_info(retry_info.cancel_descriptor, call_options_with_polling) + + assert e.value is not None + assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value) diff --git a/gooddata-flight-server/gooddata_flight_server/tasks/task_executor.py b/gooddata-flight-server/gooddata_flight_server/tasks/task_executor.py index 9c2b76d7b..58ef99fd1 100644 --- a/gooddata-flight-server/gooddata_flight_server/tasks/task_executor.py +++ b/gooddata-flight-server/gooddata_flight_server/tasks/task_executor.py @@ -54,6 +54,15 @@ def submit( """ raise NotImplementedError + @abc.abstractmethod + def get_task_submitted_timestamp(self, task_id: str) -> Optional[float]: + """ + Returns the timestamp of when the task with the given id was submitted. + :param task_id: task id to get the timestamp for + :return: Timestamp in seconds since epoch of when the task was submitted or None if there is no such task + """ + raise NotImplementedError + @abc.abstractmethod def wait_for_result(self, task_id: str, timeout: Optional[float] = None) -> Optional[TaskExecutionResult]: """ diff --git a/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py b/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py index f5cfe408b..6005c2354 100644 --- a/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py +++ b/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py @@ -565,6 +565,14 @@ def submit( execution.start() self._metrics.queue_size.set(self._queue_size) + def get_task_submitted_timestamp(self, task_id: str) -> Optional[float]: + with self._task_lock: + execution = self._executions.get(task_id) + + if execution is not None: + return execution.stats.created + return None + def wait_for_result(self, task_id: str, timeout: Optional[float] = None) -> Optional[TaskExecutionResult]: with self._task_lock: execution = self._executions.get(task_id) @@ -597,8 +605,8 @@ def cancel(self, task_id: str) -> bool: return True if execution is None: - # the task was not and is not running - cancel not possible - return False + # the task was not and is not running - cancel not necessary + return True return execution.cancel()