From f58c4ef224737c0cf27fe3cdf2808deca9eda7ee Mon Sep 17 00:00:00 2001 From: Dan Homola Date: Wed, 21 May 2025 12:45:58 +0200 Subject: [PATCH 1/4] feat: add support for polling in FlexConnect server Add a polling mechanism to the FlexConnect functions so that long-running tasks can be polled for and canceled. The TaskExecutor now supports returning a timestamp of when a particular task was submitted. This is to keep track of the call deadline breaches. JIRA: CQ-1124 risk: low --- .../function/flight_methods.py | 106 +++++++++++++----- gooddata-flexconnect/tests/server/conftest.py | 3 +- .../tests/server/funs/fun3.py | 2 +- .../tests/server/funs/fun4.py | 39 +++++++ .../tests/server/test_flexconnect_server.py | 90 ++++++++++++++- .../tasks/task_executor.py | 9 ++ .../tasks/thread_task_executor.py | 8 ++ 7 files changed, 226 insertions(+), 31 deletions(-) create mode 100644 gooddata-flexconnect/tests/server/funs/fun4.py diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py index 67e11de9c..849d4b404 100644 --- a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py +++ b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py @@ -1,4 +1,5 @@ # (C) 2024 GoodData Corporation +import time from collections.abc import Generator from typing import Optional @@ -12,7 +13,6 @@ FlightServerMethods, ServerContext, TaskExecutionResult, - TaskWaitTimeoutError, flight_server_methods, ) @@ -23,11 +23,26 @@ _LOGGER = structlog.get_logger("gooddata_flexconnect.rpc") +def _prepare_poll_error(task_id: str) -> pyarrow.flight.FlightError: + return ErrorInfo.poll( + flight_info=None, + cancel_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"c:{task_id}".encode()), + retry_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"r:{task_id}".encode()), + ) + + class _FlexConnectServerMethods(FlightServerMethods): - def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry, call_deadline_ms: float) -> None: + def __init__( + self, + ctx: ServerContext, + registry: FlexConnectFunctionRegistry, + call_deadline_ms: float, + poll_interval_ms: float, + ) -> None: self._ctx = ctx self._registry = registry self._call_deadline = call_deadline_ms / 1000 + self._poll_interval = poll_interval_ms / 1000 @staticmethod def _create_descriptor(fun_name: str, metadata: Optional[dict]) -> pyarrow.flight.FlightDescriptor: @@ -140,39 +155,53 @@ def get_flight_info( descriptor: pyarrow.flight.FlightDescriptor, ) -> pyarrow.flight.FlightInfo: structlog.contextvars.bind_contextvars(peer=context.peer()) - task: Optional[FlexConnectFunctionTask] = None - try: - task = self._prepare_task(context, descriptor) - self._ctx.task_executor.submit(task) + # first, check if the descriptor is a cancel descriptor + if descriptor.command is None or not len(descriptor.command): + raise ErrorInfo.bad_argument( + "Incorrect FlexConnect function invocation. Flight descriptor must contain command " + "with the invocation payload." + ) - try: - # XXX: this should be enhanced to implement polling - task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline) - except TaskWaitTimeoutError: - cancelled = self._ctx.task_executor.cancel(task.task_id) - _LOGGER.warning( - "flexconnect_fun_call_timeout", task_id=task.task_id, fun=task.fun_name, cancelled=cancelled - ) + task_id: str + fun_name: Optional[str] = None + if descriptor.command.startswith(b"c:"): + # cancel descriptor: just cancel the given task and raise cancellation exception + task_id = descriptor.command[2:].decode() + self._ctx.task_executor.cancel(task_id) + raise ErrorInfo.for_reason( + ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled." + ).to_cancelled_error() + elif descriptor.command.startswith(b"r:"): + # retry descriptor: extract the task_id, do not submit it again and do one polling iteration + task_id = descriptor.command[2:].decode() + # for retries, we also need to check the call deadline for the whole call duration + task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id) + if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline: + self._ctx.task_executor.cancel(task_id) raise ErrorInfo.for_reason( - ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task.task_id}." + ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}." ).to_timeout_error() + else: + # basic first-time submit: submit the task and do one polling iteration. + # do not check call deadline to give it a chance to wait for the result at least once + try: + task = self._prepare_task(context, descriptor) + self._ctx.task_executor.submit(task) + task_id = task.task_id + fun_name = task.fun_name + except Exception: + _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True) + raise - # if this bombs then there must be something really wrong because the task - # was clearly submitted and code was waiting for its completion. this invariant - # should not happen in this particular code path. The None return value may - # be applicable one day when polling is in use and a request comes to check whether - # particular task id finished - assert task_result is not None - + try: + task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval) return self._prepare_flight_info(task_result) + except TimeoutError: + raise _prepare_poll_error(task_id) except Exception: - if task is not None: - _LOGGER.error("get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True) - else: - _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True) - + _LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True) raise def do_get( @@ -201,7 +230,9 @@ def do_get( _FLEX_CONNECT_CONFIG_SECTION = "flexconnect" _FLEX_CONNECT_FUNCTION_LIST = "functions" _FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms" +_FLEX_CONNECT_POLLING_INTERVAL_MS = "polling_interval_ms" _DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000 +_DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS = 2000 def _read_call_deadline_ms(ctx: ServerContext) -> int: @@ -223,6 +254,24 @@ def _read_call_deadline_ms(ctx: ServerContext) -> int: ) +def _read_polling_interval_ms(ctx: ServerContext) -> int: + polling_interval = ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS}") + if polling_interval is None: + return _DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS + + try: + polling_interval = int(polling_interval) + if polling_interval <= 0: + raise ValueError() + return polling_interval + except ValueError: + raise ValueError( + f"Value of {_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS} must " + f"be a positive number - duration, in milliseconds, that FlexConnect function " + f"waits for the result during one polling iteration." + ) + + @flight_server_methods def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods: """ @@ -236,8 +285,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods """ modules = list(ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_FUNCTION_LIST}") or []) call_deadline_ms = _read_call_deadline_ms(ctx) + polling_interval_ms = _read_polling_interval_ms(ctx) _LOGGER.info("flexconnect_init", modules=modules) registry = FlexConnectFunctionRegistry().load(ctx, modules) - return _FlexConnectServerMethods(ctx, registry, call_deadline_ms) + return _FlexConnectServerMethods(ctx, registry, call_deadline_ms, polling_interval_ms) diff --git a/gooddata-flexconnect/tests/server/conftest.py b/gooddata-flexconnect/tests/server/conftest.py index 8e87f3ed6..fc4632948 100644 --- a/gooddata-flexconnect/tests/server/conftest.py +++ b/gooddata-flexconnect/tests/server/conftest.py @@ -80,7 +80,8 @@ def flexconnect_server( funs = f"[{funs}]" os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = funs - os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "500" + os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "1000" + os.environ["GOODDATA_FLIGHT_FLEXCONNECT__POLLING_INTERVAL_MS"] = "500" with server(create_flexconnect_flight_methods, tls, mtls) as s: yield s diff --git a/gooddata-flexconnect/tests/server/funs/fun3.py b/gooddata-flexconnect/tests/server/funs/fun3.py index 42ac365ef..e3ac73c31 100644 --- a/gooddata-flexconnect/tests/server/funs/fun3.py +++ b/gooddata-flexconnect/tests/server/funs/fun3.py @@ -27,7 +27,7 @@ def call( ) -> ArrowData: # sleep is intentionally setup to be longer than the deadline for # the function invocation (see conftest.py // flexconnect_server fixture) - time.sleep(1) + time.sleep(1.5) return pyarrow.table( data={ diff --git a/gooddata-flexconnect/tests/server/funs/fun4.py b/gooddata-flexconnect/tests/server/funs/fun4.py new file mode 100644 index 000000000..b42faae77 --- /dev/null +++ b/gooddata-flexconnect/tests/server/funs/fun4.py @@ -0,0 +1,39 @@ +# (C) 2024 GoodData Corporation +import time +from typing import Optional + +import pyarrow +from gooddata_flexconnect.function.function import FlexConnectFunction +from gooddata_flight_server import ArrowData + +_DATA: Optional[pyarrow.Table] = None + + +class _PollableFun(FlexConnectFunction): + Name = "PollableFun" + Schema = pyarrow.schema( + fields=[ + pyarrow.field("col1", pyarrow.int64()), + pyarrow.field("col2", pyarrow.string()), + pyarrow.field("col3", pyarrow.bool_()), + ] + ) + + def call( + self, + parameters: dict, + columns: tuple[str, ...], + headers: dict[str, list[str]], + ) -> ArrowData: + # sleep is intentionally setup to be longer than one polling interval + # (see conftest.py // flexconnect_server fixture) + time.sleep(0.7) + + return pyarrow.table( + data={ + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + "col3": [True, False, True], + }, + schema=self.Schema, + ) diff --git a/gooddata-flexconnect/tests/server/test_flexconnect_server.py b/gooddata-flexconnect/tests/server/test_flexconnect_server.py index 42fc3ffdb..5a41e5b1d 100644 --- a/gooddata-flexconnect/tests/server/test_flexconnect_server.py +++ b/gooddata-flexconnect/tests/server/test_flexconnect_server.py @@ -1,8 +1,9 @@ # (C) 2024 GoodData Corporation + import orjson import pyarrow.flight import pytest -from gooddata_flight_server import ErrorCode +from gooddata_flight_server import ErrorCode, ErrorInfo, RetryInfo from tests.assert_error_info import assert_error_code from tests.server.conftest import flexconnect_server @@ -89,6 +90,43 @@ def test_basic_function_tls(tls_ca_cert): assert data.column_names == ["col1", "col2", "col3"] +def test_function_with_polling(): + """ + Flight RPC implementation that invokes FlexConnect can return a polling info. + + This way, the client can poll for results that take longer to complete. + """ + with flexconnect_server(["tests.server.funs.fun4"]) as s: + c = pyarrow.flight.FlightClient(s.location) + descriptor = pyarrow.flight.FlightDescriptor.for_command( + orjson.dumps( + { + "functionName": "PollableFun", + "parameters": {"test1": 1, "test2": 2, "test3": 3}, + } + ) + ) + + # the function is set to sleep a bit longer than the polling interval, + # so the first iteration returns retry info in the exception + with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: + c.get_flight_info(descriptor) + + assert e.value is not None + assert_error_code(ErrorCode.POLL, e.value) + + error_info = ErrorInfo.from_bytes(e.value.extra_info) + retry_info = RetryInfo.from_bytes(error_info.body) + + # use the retry info to poll again for the result, + # now it should be ready and returned normally + info = c.get_flight_info(retry_info.retry_descriptor) + data: pyarrow.Table = c.do_get(info.endpoints[0].ticket).read_all() + + assert len(data) == 3 + assert data.column_names == ["col1", "col2", "col3"] + + def test_function_with_call_deadline(): """ Flight RPC implementation that invokes FlexConnect can be setup with @@ -115,4 +153,54 @@ def test_function_with_call_deadline(): with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: c.get_flight_info(descriptor) + assert e.value is not None + assert_error_code(ErrorCode.POLL, e.value) + + error_info = ErrorInfo.from_bytes(e.value.extra_info) + retry_info = RetryInfo.from_bytes(error_info.body) + + # poll twice to reach the call deadline + with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: + c.get_flight_info(retry_info.retry_descriptor) + + assert e.value is not None + assert_error_code(ErrorCode.POLL, e.value) + + error_info = ErrorInfo.from_bytes(e.value.extra_info) + retry_info = RetryInfo.from_bytes(error_info.body) + + with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: + c.get_flight_info(retry_info.retry_descriptor) + + # and then ensure the timeout error is returned assert_error_code(ErrorCode.TIMEOUT, e.value) + + +def test_function_with_cancelation(): + """ + Run a long-running function and cancel it after one poll iteration. + """ + with flexconnect_server(["tests.server.funs.fun3"]) as s: + c = pyarrow.flight.FlightClient(s.location) + descriptor = pyarrow.flight.FlightDescriptor.for_command( + orjson.dumps( + { + "functionName": "LongRunningFun", + "parameters": {"test1": 1, "test2": 2, "test3": 3}, + } + ) + ) + + with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: + c.get_flight_info(descriptor) + + assert e.value is not None + assert_error_code(ErrorCode.POLL, e.value) + + error_info = ErrorInfo.from_bytes(e.value.extra_info) + retry_info = RetryInfo.from_bytes(error_info.body) + + with pytest.raises(pyarrow.flight.FlightCancelledError) as e: + c.get_flight_info(retry_info.cancel_descriptor) + + assert e.value is not None diff --git a/gooddata-flight-server/gooddata_flight_server/tasks/task_executor.py b/gooddata-flight-server/gooddata_flight_server/tasks/task_executor.py index 9c2b76d7b..58ef99fd1 100644 --- a/gooddata-flight-server/gooddata_flight_server/tasks/task_executor.py +++ b/gooddata-flight-server/gooddata_flight_server/tasks/task_executor.py @@ -54,6 +54,15 @@ def submit( """ raise NotImplementedError + @abc.abstractmethod + def get_task_submitted_timestamp(self, task_id: str) -> Optional[float]: + """ + Returns the timestamp of when the task with the given id was submitted. + :param task_id: task id to get the timestamp for + :return: Timestamp in seconds since epoch of when the task was submitted or None if there is no such task + """ + raise NotImplementedError + @abc.abstractmethod def wait_for_result(self, task_id: str, timeout: Optional[float] = None) -> Optional[TaskExecutionResult]: """ diff --git a/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py b/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py index f5cfe408b..343978367 100644 --- a/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py +++ b/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py @@ -565,6 +565,14 @@ def submit( execution.start() self._metrics.queue_size.set(self._queue_size) + def get_task_submitted_timestamp(self, task_id: str) -> Optional[float]: + with self._task_lock: + execution = self._executions.get(task_id) + + if execution is not None: + return execution.stats.created + return None + def wait_for_result(self, task_id: str, timeout: Optional[float] = None) -> Optional[TaskExecutionResult]: with self._task_lock: execution = self._executions.get(task_id) From e64c5e51f2e18d15e98d4fb1963839c63ebda517 Mon Sep 17 00:00:00 2001 From: Dan Homola Date: Wed, 21 May 2025 14:48:59 +0200 Subject: [PATCH 2/4] refactor: extract the payload parsing Encapsulate the different invocation types so that we can change the actual representation later if needed. JIRA: CQ-1124 risk: low --- .../function/flight_methods.py | 70 +++++--------- .../function/function_invocation.py | 94 +++++++++++++++++++ 2 files changed, 116 insertions(+), 48 deletions(-) create mode 100644 gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py index 849d4b404..04859da68 100644 --- a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py +++ b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py @@ -17,6 +17,12 @@ ) from gooddata_flexconnect.function.function import FlexConnectFunction +from gooddata_flexconnect.function.function_invocation import ( + CancelInvocation, + RetryInvocation, + SubmitInvocation, + extract_invocation_from_descriptor, +) from gooddata_flexconnect.function.function_registry import FlexConnectFunctionRegistry from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask @@ -67,48 +73,20 @@ def _create_fun_info(self, fun: type[FlexConnectFunction]) -> pyarrow.flight.Fli total_records=-1, ) - def _extract_invocation_payload( - self, descriptor: pyarrow.flight.FlightDescriptor - ) -> tuple[str, dict, Optional[tuple[str, ...]]]: - if descriptor.command is None or not len(descriptor.command): - raise ErrorInfo.bad_argument( - "Incorrect FlexConnect function invocation. Flight descriptor must contain command " - "with the invocation payload." - ) - - try: - payload = orjson.loads(descriptor.command) - except Exception: - raise ErrorInfo.bad_argument( - "Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON." - ) - - fun = payload.get("functionName") - if fun is None or not len(fun): - raise ErrorInfo.bad_argument( - "Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'." - ) - - parameters = payload.get("parameters") or {} - columns = parameters.get("columns") - - return fun, parameters, columns - def _prepare_task( self, context: pyarrow.flight.ServerCallContext, - descriptor: pyarrow.flight.FlightDescriptor, + submit_invocation: SubmitInvocation, ) -> FlexConnectFunctionTask: - fun_name, parameters, columns = self._extract_invocation_payload(descriptor) headers = self.call_info_middleware(context).headers - fun = self._registry.create_function(fun_name) + fun = self._registry.create_function(submit_invocation.function_name) return FlexConnectFunctionTask( fun=fun, - parameters=parameters, - columns=columns, + parameters=submit_invocation.parameters, + columns=submit_invocation.columns, headers=headers, - cmd=descriptor.command, + cmd=submit_invocation.command, ) def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flight.FlightInfo: @@ -156,26 +134,19 @@ def get_flight_info( ) -> pyarrow.flight.FlightInfo: structlog.contextvars.bind_contextvars(peer=context.peer()) - # first, check if the descriptor is a cancel descriptor - if descriptor.command is None or not len(descriptor.command): - raise ErrorInfo.bad_argument( - "Incorrect FlexConnect function invocation. Flight descriptor must contain command " - "with the invocation payload." - ) - task_id: str fun_name: Optional[str] = None + invocation = extract_invocation_from_descriptor(descriptor) - if descriptor.command.startswith(b"c:"): - # cancel descriptor: just cancel the given task and raise cancellation exception - task_id = descriptor.command[2:].decode() - self._ctx.task_executor.cancel(task_id) + if isinstance(invocation, CancelInvocation): + # cancel the given task and raise cancellation exception + self._ctx.task_executor.cancel(invocation.task_id) raise ErrorInfo.for_reason( ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled." ).to_cancelled_error() - elif descriptor.command.startswith(b"r:"): + elif isinstance(invocation, RetryInvocation): # retry descriptor: extract the task_id, do not submit it again and do one polling iteration - task_id = descriptor.command[2:].decode() + task_id = invocation.task_id # for retries, we also need to check the call deadline for the whole call duration task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id) if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline: @@ -183,17 +154,20 @@ def get_flight_info( raise ErrorInfo.for_reason( ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}." ).to_timeout_error() - else: + elif isinstance(invocation, SubmitInvocation): # basic first-time submit: submit the task and do one polling iteration. # do not check call deadline to give it a chance to wait for the result at least once try: - task = self._prepare_task(context, descriptor) + task = self._prepare_task(context, invocation) self._ctx.task_executor.submit(task) task_id = task.task_id fun_name = task.fun_name except Exception: _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True) raise + else: + # can be replaced by assert_never when we are on 3.11 + raise AssertionError try: task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval) diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py b/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py new file mode 100644 index 000000000..dadf79c9d --- /dev/null +++ b/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py @@ -0,0 +1,94 @@ +# (C) 2025 GoodData Corporation +from dataclasses import dataclass +from typing import Optional, Union + +import orjson +import pyarrow.flight +from gooddata_flight_server import ErrorInfo + + +@dataclass(frozen=True) +class RetryInvocation: + """ + Indicates that the getting the results of the given task should be retried. + """ + + task_id: str + + +@dataclass(frozen=True) +class CancelInvocation: + """ + Indicates that the given task should be cancelled. + """ + + task_id: str + + +@dataclass(frozen=True) +class SubmitInvocation: + """ + Indicates that the given task should be submitted for processing. + """ + + command: bytes + """ + The raw command that was sent to the Flight Server. + """ + + function_name: str + """ + The name of the FlexConnect function to invoke. + """ + + parameters: dict + """ + Parameters to pass to the FlexConnect function. + """ + + columns: Optional[tuple[str, ...]] + """ + Columns to get from the FlexConnect function result. + This may be used for column trimming by the function: the function must return at least those columns. + """ + + +def extract_invocation_from_descriptor( + descriptor: pyarrow.flight.FlightDescriptor, +) -> Union[RetryInvocation, CancelInvocation, SubmitInvocation]: + """ + Given a flight descriptor, extract the invocation information from it. + """ + + if descriptor.command is None or not len(descriptor.command): + raise ErrorInfo.bad_argument( + "Incorrect FlexConnect function invocation. Flight descriptor must contain command " + "with the invocation payload." + ) + + if descriptor.command.startswith(b"c:"): + task_id = descriptor.command[2:].decode() + return CancelInvocation(task_id) + elif descriptor.command.startswith(b"r:"): + task_id = descriptor.command[2:].decode() + return RetryInvocation(task_id) + else: + try: + payload = orjson.loads(descriptor.command) + except Exception: + raise ErrorInfo.bad_argument( + "Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON." + ) + + function_name = payload.get("functionName") + if function_name is None or not len(function_name): + raise ErrorInfo.bad_argument( + "Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'." + ) + + parameters = payload.get("parameters") or {} + columns = parameters.get("columns") + + return SubmitInvocation( + function_name=function_name, parameters=parameters, columns=columns, command=descriptor.command + ) From ce2fdb1a141c95a281881cb9d3313cbd041dab4a Mon Sep 17 00:00:00 2001 From: Dan Homola Date: Thu, 22 May 2025 13:23:33 +0200 Subject: [PATCH 3/4] fix: handle cancel return values in FlexConnect When cancelling, be more transparent about whether the cancellation actually happened. Also handle missing tasks: that now yields BAD_ARGUMENT exceptions. JIRA: CQ-1124 risk: low --- .../function/flight_methods.py | 39 ++++++++++++------- gooddata-flexconnect/tests/server/conftest.py | 2 +- .../tests/server/funs/fun3.py | 2 +- .../tests/server/test_flexconnect_server.py | 31 ++++++++++++--- .../tasks/thread_task_executor.py | 4 +- 5 files changed, 55 insertions(+), 23 deletions(-) diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py index 04859da68..b575ab12e 100644 --- a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py +++ b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py @@ -89,14 +89,21 @@ def _prepare_task( cmd=submit_invocation.command, ) - def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flight.FlightInfo: + def _prepare_flight_info( + self, task_id: str, task_result: Optional[TaskExecutionResult] + ) -> pyarrow.flight.FlightInfo: + if task_result is None: + raise ErrorInfo.for_reason( + ErrorCode.BAD_ARGUMENT, f"Task with id '{task_id}' does not exist." + ).to_user_error() + if task_result.error is not None: raise task_result.error.as_flight_error() if task_result.cancelled: raise ErrorInfo.for_reason( ErrorCode.COMMAND_CANCELLED, - f"FlexConnect function invocation was cancelled. Invocation task was: '{task_result.task_id}'.", + f"FlexConnect function invocation was cancelled. Invocation task was: '{task_id}'.", ).to_server_error() result = task_result.result @@ -107,7 +114,7 @@ def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flig descriptor=pyarrow.flight.FlightDescriptor.for_command(task_result.cmd), endpoints=[ pyarrow.flight.FlightEndpoint( - ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_result.task_id})), + ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_id})), locations=[self._ctx.location], ) ], @@ -140,20 +147,16 @@ def get_flight_info( if isinstance(invocation, CancelInvocation): # cancel the given task and raise cancellation exception - self._ctx.task_executor.cancel(invocation.task_id) + if self._ctx.task_executor.cancel(invocation.task_id): + raise ErrorInfo.for_reason( + ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled." + ).to_cancelled_error() raise ErrorInfo.for_reason( - ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled." + ErrorCode.COMMAND_CANCEL_NOT_POSSIBLE, "FlexConnect function invocation could not be cancelled." ).to_cancelled_error() elif isinstance(invocation, RetryInvocation): # retry descriptor: extract the task_id, do not submit it again and do one polling iteration task_id = invocation.task_id - # for retries, we also need to check the call deadline for the whole call duration - task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id) - if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline: - self._ctx.task_executor.cancel(task_id) - raise ErrorInfo.for_reason( - ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}." - ).to_timeout_error() elif isinstance(invocation, SubmitInvocation): # basic first-time submit: submit the task and do one polling iteration. # do not check call deadline to give it a chance to wait for the result at least once @@ -171,8 +174,18 @@ def get_flight_info( try: task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval) - return self._prepare_flight_info(task_result) + return self._prepare_flight_info(task_id, task_result) except TimeoutError: + # first, check the call deadline for the whole call duration + task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id) + if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline: + self._ctx.task_executor.cancel(task_id) + raise ErrorInfo.for_reason( + ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}." + ).to_timeout_error() + + # if the result is not ready, and we still have time, indicate to the client + # how to poll for the results raise _prepare_poll_error(task_id) except Exception: _LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True) diff --git a/gooddata-flexconnect/tests/server/conftest.py b/gooddata-flexconnect/tests/server/conftest.py index fc4632948..a84cbe9c0 100644 --- a/gooddata-flexconnect/tests/server/conftest.py +++ b/gooddata-flexconnect/tests/server/conftest.py @@ -80,7 +80,7 @@ def flexconnect_server( funs = f"[{funs}]" os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = funs - os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "1000" + os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "1200" os.environ["GOODDATA_FLIGHT_FLEXCONNECT__POLLING_INTERVAL_MS"] = "500" with server(create_flexconnect_flight_methods, tls, mtls) as s: diff --git a/gooddata-flexconnect/tests/server/funs/fun3.py b/gooddata-flexconnect/tests/server/funs/fun3.py index e3ac73c31..6ec64496b 100644 --- a/gooddata-flexconnect/tests/server/funs/fun3.py +++ b/gooddata-flexconnect/tests/server/funs/fun3.py @@ -27,7 +27,7 @@ def call( ) -> ArrowData: # sleep is intentionally setup to be longer than the deadline for # the function invocation (see conftest.py // flexconnect_server fixture) - time.sleep(1.5) + time.sleep(2) return pyarrow.table( data={ diff --git a/gooddata-flexconnect/tests/server/test_flexconnect_server.py b/gooddata-flexconnect/tests/server/test_flexconnect_server.py index 5a41e5b1d..f3007b710 100644 --- a/gooddata-flexconnect/tests/server/test_flexconnect_server.py +++ b/gooddata-flexconnect/tests/server/test_flexconnect_server.py @@ -10,6 +10,10 @@ def test_basic_function(): + """ + This function should return immediately when called, no polling necessary. + :return: + """ with flexconnect_server(["tests.server.funs.fun1"]) as s: c = pyarrow.flight.FlightClient(s.location) fun_infos = list(c.list_flights()) @@ -126,6 +130,13 @@ def test_function_with_polling(): assert len(data) == 3 assert data.column_names == ["col1", "col2", "col3"] + # also check that trying to cancel already completed task results in cancelled with correct code + with pytest.raises(pyarrow.flight.FlightCancelledError) as e: + c.get_flight_info(retry_info.cancel_descriptor) + + assert e.value is not None + assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value) + def test_function_with_call_deadline(): """ @@ -150,6 +161,7 @@ def test_function_with_call_deadline(): ) ) + # the initial submit returns polling info with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: c.get_flight_info(descriptor) @@ -159,24 +171,21 @@ def test_function_with_call_deadline(): error_info = ErrorInfo.from_bytes(e.value.extra_info) retry_info = RetryInfo.from_bytes(error_info.body) - # poll twice to reach the call deadline + # the next poll still returns polling info with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: c.get_flight_info(retry_info.retry_descriptor) assert e.value is not None assert_error_code(ErrorCode.POLL, e.value) - error_info = ErrorInfo.from_bytes(e.value.extra_info) - retry_info = RetryInfo.from_bytes(error_info.body) - + # the third one reaches the deadline so the Timeout code is returned instead with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: c.get_flight_info(retry_info.retry_descriptor) - # and then ensure the timeout error is returned assert_error_code(ErrorCode.TIMEOUT, e.value) -def test_function_with_cancelation(): +def test_function_with_cancellation(): """ Run a long-running function and cancel it after one poll iteration. """ @@ -191,6 +200,7 @@ def test_function_with_cancelation(): ) ) + # the initial submit returns polling info with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: c.get_flight_info(descriptor) @@ -200,7 +210,16 @@ def test_function_with_cancelation(): error_info = ErrorInfo.from_bytes(e.value.extra_info) retry_info = RetryInfo.from_bytes(error_info.body) + # use the poll info to cancel the task + with pytest.raises(pyarrow.flight.FlightCancelledError) as e: + c.get_flight_info(retry_info.cancel_descriptor) + + assert e.value is not None + assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value) + + # even multiple cancellations return the same error with pytest.raises(pyarrow.flight.FlightCancelledError) as e: c.get_flight_info(retry_info.cancel_descriptor) assert e.value is not None + assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value) diff --git a/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py b/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py index 343978367..6005c2354 100644 --- a/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py +++ b/gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py @@ -605,8 +605,8 @@ def cancel(self, task_id: str) -> bool: return True if execution is None: - # the task was not and is not running - cancel not possible - return False + # the task was not and is not running - cancel not necessary + return True return execution.cancel() From 173aa5b331c68b79b1b2dc9132d72f6ff1184e39 Mon Sep 17 00:00:00 2001 From: Dan Homola Date: Tue, 27 May 2025 13:31:15 +0200 Subject: [PATCH 4/4] feat: make the polling extension opt-in By default, the FlexConnect will conform to the Arrow Flight RPC spec. However, if an opt-in header is present, it will use the polling extension used by GoodData. This allows for things like query cancellation. Ideally, we would use the PollFlightInfo from the Arrow Flight RPC but unfortunately it is not yet available in PyArrow. JIRA: CQ-1124 risk: low --- .../function/flight_methods.py | 104 +++++++++++++++--- .../function/function_invocation.py | 53 +++++---- .../tests/server/test_flexconnect_server.py | 92 +++++++++++++--- 3 files changed, 199 insertions(+), 50 deletions(-) diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py index b575ab12e..863f992b3 100644 --- a/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py +++ b/gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py @@ -13,6 +13,7 @@ FlightServerMethods, ServerContext, TaskExecutionResult, + TaskWaitTimeoutError, flight_server_methods, ) @@ -21,13 +22,20 @@ CancelInvocation, RetryInvocation, SubmitInvocation, - extract_invocation_from_descriptor, + extract_pollable_invocation_from_descriptor, + extract_submit_invocation_from_descriptor, ) from gooddata_flexconnect.function.function_registry import FlexConnectFunctionRegistry from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask _LOGGER = structlog.get_logger("gooddata_flexconnect.rpc") +POLLING_HEADER_NAME = "x-quiver-pollable" +""" +If this header is present on the get flight info call, the polling extension will be used. +Otherwise the basic do get will be used. +""" + def _prepare_poll_error(task_id: str) -> pyarrow.flight.FlightError: return ErrorInfo.poll( @@ -122,28 +130,69 @@ def _prepare_flight_info( total_bytes=-1, ) - ################################################################### - # Implementation of Flight RPC methods - ################################################################### - - def list_flights( - self, context: pyarrow.flight.ServerCallContext, criteria: bytes - ) -> Generator[pyarrow.flight.FlightInfo, None, None]: + def _get_flight_info_no_polling( + self, + context: pyarrow.flight.ServerCallContext, + descriptor: pyarrow.flight.FlightDescriptor, + ) -> pyarrow.flight.FlightInfo: + """ + Basic DoGetInfo flow with no polling extension. + This conforms to the mainline Arrow Flight RPC specification. + """ structlog.contextvars.bind_contextvars(peer=context.peer()) - _LOGGER.info("list_flights", available_funs=self._registry.function_names) + invocation = extract_submit_invocation_from_descriptor(descriptor) - return (self._create_fun_info(fun) for fun in self._registry.functions.values()) + task: Optional[FlexConnectFunctionTask] = None - def get_flight_info( + try: + task = self._prepare_task(context, invocation) + self._ctx.task_executor.submit(task) + + try: + task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline) + except TaskWaitTimeoutError: + cancelled = self._ctx.task_executor.cancel(task.task_id) + _LOGGER.warning( + "flexconnect_fun_call_timeout", task_id=task.task_id, fun=task.fun_name, cancelled=cancelled + ) + + raise ErrorInfo.for_reason( + ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task.task_id}." + ).to_timeout_error() + + # if this bombs then there must be something really wrong because the task + # was clearly submitted and code was waiting for its completion. this invariant + # should not happen in this particular code path. The None return value may + # be applicable one day when polling is in use and a request comes to check whether + # particular task id finished + assert task_result is not None + + return self._prepare_flight_info(task_id=task.task_id, task_result=task_result) + except Exception: + if task is not None: + _LOGGER.error( + "get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True, polling=False + ) + else: + _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=False) + raise + + def _get_flight_info_polling( self, context: pyarrow.flight.ServerCallContext, descriptor: pyarrow.flight.FlightDescriptor, ) -> pyarrow.flight.FlightInfo: + """ + DoGetInfo flow with polling extension. + This extends the mainline Arrow Flight RPC specification with polling capabilities using the RetryInfo + encoded into the FlightTimedOutError.extra_info. + Ideally, we would use the mainline PollFlightInfo, but that has yet to be implemented in the PyArrow library. + """ structlog.contextvars.bind_contextvars(peer=context.peer()) + invocation = extract_pollable_invocation_from_descriptor(descriptor) task_id: str fun_name: Optional[str] = None - invocation = extract_invocation_from_descriptor(descriptor) if isinstance(invocation, CancelInvocation): # cancel the given task and raise cancellation exception @@ -166,7 +215,7 @@ def get_flight_info( task_id = task.task_id fun_name = task.fun_name except Exception: - _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True) + _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=True) raise else: # can be replaced by assert_never when we are on 3.11 @@ -188,9 +237,36 @@ def get_flight_info( # how to poll for the results raise _prepare_poll_error(task_id) except Exception: - _LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True) + _LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True, polling=True) raise + ################################################################### + # Implementation of Flight RPC methods + ################################################################### + + def list_flights( + self, context: pyarrow.flight.ServerCallContext, criteria: bytes + ) -> Generator[pyarrow.flight.FlightInfo, None, None]: + structlog.contextvars.bind_contextvars(peer=context.peer()) + _LOGGER.info("list_flights", available_funs=self._registry.function_names) + + return (self._create_fun_info(fun) for fun in self._registry.functions.values()) + + def get_flight_info( + self, + context: pyarrow.flight.ServerCallContext, + descriptor: pyarrow.flight.FlightDescriptor, + ) -> pyarrow.flight.FlightInfo: + structlog.contextvars.bind_contextvars(peer=context.peer()) + + headers = self.call_info_middleware(context).headers + allow_polling = headers.get(POLLING_HEADER_NAME) is not None + + if allow_polling: + return self._get_flight_info_polling(context, descriptor) + else: + return self._get_flight_info_no_polling(context, descriptor) + def do_get( self, context: pyarrow.flight.ServerCallContext, diff --git a/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py b/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py index dadf79c9d..8647bf0ac 100644 --- a/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py +++ b/gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py @@ -53,42 +53,51 @@ class SubmitInvocation: """ -def extract_invocation_from_descriptor( +def extract_submit_invocation_from_descriptor(descriptor: pyarrow.flight.FlightDescriptor) -> SubmitInvocation: + """ + Given a flight descriptor, extract the invocation information from it. + Do not allow the polling-related variants. + """ + try: + payload = orjson.loads(descriptor.command) + except Exception: + raise ErrorInfo.bad_argument( + "Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON." + ) + + function_name = payload.get("functionName") + if function_name is None or not len(function_name): + raise ErrorInfo.bad_argument( + "Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'." + ) + + parameters = payload.get("parameters") or {} + columns = parameters.get("columns") + + return SubmitInvocation( + function_name=function_name, parameters=parameters, columns=columns, command=descriptor.command + ) + + +def extract_pollable_invocation_from_descriptor( descriptor: pyarrow.flight.FlightDescriptor, ) -> Union[RetryInvocation, CancelInvocation, SubmitInvocation]: """ Given a flight descriptor, extract the invocation information from it. + Allow also the polling-related variants. """ - if descriptor.command is None or not len(descriptor.command): raise ErrorInfo.bad_argument( "Incorrect FlexConnect function invocation. Flight descriptor must contain command " "with the invocation payload." ) + # we are in the polling-enabled realm: try parsing the retry and cancel descriptors first if descriptor.command.startswith(b"c:"): task_id = descriptor.command[2:].decode() return CancelInvocation(task_id) elif descriptor.command.startswith(b"r:"): task_id = descriptor.command[2:].decode() return RetryInvocation(task_id) - else: - try: - payload = orjson.loads(descriptor.command) - except Exception: - raise ErrorInfo.bad_argument( - "Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON." - ) - - function_name = payload.get("functionName") - if function_name is None or not len(function_name): - raise ErrorInfo.bad_argument( - "Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'." - ) - - parameters = payload.get("parameters") or {} - columns = parameters.get("columns") - - return SubmitInvocation( - function_name=function_name, parameters=parameters, columns=columns, command=descriptor.command - ) + + return extract_submit_invocation_from_descriptor(descriptor) diff --git a/gooddata-flexconnect/tests/server/test_flexconnect_server.py b/gooddata-flexconnect/tests/server/test_flexconnect_server.py index f3007b710..831370adf 100644 --- a/gooddata-flexconnect/tests/server/test_flexconnect_server.py +++ b/gooddata-flexconnect/tests/server/test_flexconnect_server.py @@ -3,16 +3,21 @@ import orjson import pyarrow.flight import pytest +from gooddata_flexconnect.function.flight_methods import POLLING_HEADER_NAME from gooddata_flight_server import ErrorCode, ErrorInfo, RetryInfo from tests.assert_error_info import assert_error_code from tests.server.conftest import flexconnect_server +@pytest.fixture +def call_options_with_polling(): + return pyarrow.flight.FlightCallOptions(headers=[(POLLING_HEADER_NAME.encode(), b"true")]) + + def test_basic_function(): """ - This function should return immediately when called, no polling necessary. - :return: + This function should return immediately when called, no polling allowed. """ with flexconnect_server(["tests.server.funs.fun1"]) as s: c = pyarrow.flight.FlightClient(s.location) @@ -94,7 +99,66 @@ def test_basic_function_tls(tls_ca_cert): assert data.column_names == ["col1", "col2", "col3"] -def test_function_with_polling(): +def test_cancellable_function(call_options_with_polling): + """ + This function should return immediately when called, no polling necessary even if enabled. + """ + with flexconnect_server(["tests.server.funs.fun1"]) as s: + c = pyarrow.flight.FlightClient(s.location) + fun_infos = list(c.list_flights()) + assert len(fun_infos) == 1 + fun_info: pyarrow.flight.FlightInfo = fun_infos[0] + + assert fun_info.schema.names == ["col1", "col2", "col3"] + assert fun_info.descriptor.command is not None + assert len(fun_info.descriptor.command) + cmd = orjson.loads(fun_info.descriptor.command) + assert cmd["functionName"] == "SimpleFun1" + + descriptor = pyarrow.flight.FlightDescriptor.for_command( + orjson.dumps( + { + "functionName": "SimpleFun1", + "parameters": {"test1": 1, "test2": 2, "test3": 3}, + } + ) + ) + info = c.get_flight_info(descriptor, call_options_with_polling) + data: pyarrow.Table = c.do_get(info.endpoints[0].ticket).read_all() + + assert len(data) == 3 + assert data.column_names == ["col1", "col2", "col3"] + + +def test_cancellable_function_tls(tls_ca_cert, call_options_with_polling): + with flexconnect_server(["tests.server.funs.fun1"], tls=True) as s: + c = pyarrow.flight.FlightClient(s.location, tls_root_certs=tls_ca_cert) + fun_infos = list(c.list_flights()) + assert len(fun_infos) == 1 + fun_info: pyarrow.flight.FlightInfo = fun_infos[0] + + assert fun_info.schema.names == ["col1", "col2", "col3"] + assert fun_info.descriptor.command is not None + assert len(fun_info.descriptor.command) + cmd = orjson.loads(fun_info.descriptor.command) + assert cmd["functionName"] == "SimpleFun1" + + descriptor = pyarrow.flight.FlightDescriptor.for_command( + orjson.dumps( + { + "functionName": "SimpleFun1", + "parameters": {"test1": 1, "test2": 2, "test3": 3}, + } + ) + ) + info = c.get_flight_info(descriptor, call_options_with_polling) + data: pyarrow.Table = c.do_get(info.endpoints[0].ticket).read_all() + + assert len(data) == 3 + assert data.column_names == ["col1", "col2", "col3"] + + +def test_cancellable_function_with_polling(call_options_with_polling): """ Flight RPC implementation that invokes FlexConnect can return a polling info. @@ -114,7 +178,7 @@ def test_function_with_polling(): # the function is set to sleep a bit longer than the polling interval, # so the first iteration returns retry info in the exception with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: - c.get_flight_info(descriptor) + c.get_flight_info(descriptor, call_options_with_polling) assert e.value is not None assert_error_code(ErrorCode.POLL, e.value) @@ -124,7 +188,7 @@ def test_function_with_polling(): # use the retry info to poll again for the result, # now it should be ready and returned normally - info = c.get_flight_info(retry_info.retry_descriptor) + info = c.get_flight_info(retry_info.retry_descriptor, call_options_with_polling) data: pyarrow.Table = c.do_get(info.endpoints[0].ticket).read_all() assert len(data) == 3 @@ -132,13 +196,13 @@ def test_function_with_polling(): # also check that trying to cancel already completed task results in cancelled with correct code with pytest.raises(pyarrow.flight.FlightCancelledError) as e: - c.get_flight_info(retry_info.cancel_descriptor) + c.get_flight_info(retry_info.cancel_descriptor, call_options_with_polling) assert e.value is not None assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value) -def test_function_with_call_deadline(): +def test_cancellable_function_with_call_deadline(call_options_with_polling): """ Flight RPC implementation that invokes FlexConnect can be setup with deadline for the invocation duration (done by GetFlightInfo). @@ -163,7 +227,7 @@ def test_function_with_call_deadline(): # the initial submit returns polling info with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: - c.get_flight_info(descriptor) + c.get_flight_info(descriptor, call_options_with_polling) assert e.value is not None assert_error_code(ErrorCode.POLL, e.value) @@ -173,19 +237,19 @@ def test_function_with_call_deadline(): # the next poll still returns polling info with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: - c.get_flight_info(retry_info.retry_descriptor) + c.get_flight_info(retry_info.retry_descriptor, call_options_with_polling) assert e.value is not None assert_error_code(ErrorCode.POLL, e.value) # the third one reaches the deadline so the Timeout code is returned instead with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: - c.get_flight_info(retry_info.retry_descriptor) + c.get_flight_info(retry_info.retry_descriptor, call_options_with_polling) assert_error_code(ErrorCode.TIMEOUT, e.value) -def test_function_with_cancellation(): +def test_cancellable_function_with_cancellation(call_options_with_polling): """ Run a long-running function and cancel it after one poll iteration. """ @@ -202,7 +266,7 @@ def test_function_with_cancellation(): # the initial submit returns polling info with pytest.raises(pyarrow.flight.FlightTimedOutError) as e: - c.get_flight_info(descriptor) + c.get_flight_info(descriptor, call_options_with_polling) assert e.value is not None assert_error_code(ErrorCode.POLL, e.value) @@ -212,14 +276,14 @@ def test_function_with_cancellation(): # use the poll info to cancel the task with pytest.raises(pyarrow.flight.FlightCancelledError) as e: - c.get_flight_info(retry_info.cancel_descriptor) + c.get_flight_info(retry_info.cancel_descriptor, call_options_with_polling) assert e.value is not None assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value) # even multiple cancellations return the same error with pytest.raises(pyarrow.flight.FlightCancelledError) as e: - c.get_flight_info(retry_info.cancel_descriptor) + c.get_flight_info(retry_info.cancel_descriptor, call_options_with_polling) assert e.value is not None assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value)