Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 169 additions & 56 deletions gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# (C) 2024 GoodData Corporation
import time
from collections.abc import Generator
from typing import Optional

Expand All @@ -17,17 +18,45 @@
)

from gooddata_flexconnect.function.function import FlexConnectFunction
from gooddata_flexconnect.function.function_invocation import (
CancelInvocation,
RetryInvocation,
SubmitInvocation,
extract_pollable_invocation_from_descriptor,
extract_submit_invocation_from_descriptor,
)
from gooddata_flexconnect.function.function_registry import FlexConnectFunctionRegistry
from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask

_LOGGER = structlog.get_logger("gooddata_flexconnect.rpc")

POLLING_HEADER_NAME = "x-quiver-pollable"
"""
If this header is present on the get flight info call, the polling extension will be used.
Otherwise the basic do get will be used.
"""


def _prepare_poll_error(task_id: str) -> pyarrow.flight.FlightError:
return ErrorInfo.poll(
flight_info=None,
cancel_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"c:{task_id}".encode()),
retry_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"r:{task_id}".encode()),
)


class _FlexConnectServerMethods(FlightServerMethods):
def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry, call_deadline_ms: float) -> None:
def __init__(
self,
ctx: ServerContext,
registry: FlexConnectFunctionRegistry,
call_deadline_ms: float,
poll_interval_ms: float,
) -> None:
self._ctx = ctx
self._registry = registry
self._call_deadline = call_deadline_ms / 1000
self._poll_interval = poll_interval_ms / 1000

@staticmethod
def _create_descriptor(fun_name: str, metadata: Optional[dict]) -> pyarrow.flight.FlightDescriptor:
Expand All @@ -52,58 +81,37 @@ def _create_fun_info(self, fun: type[FlexConnectFunction]) -> pyarrow.flight.Fli
total_records=-1,
)

def _extract_invocation_payload(
self, descriptor: pyarrow.flight.FlightDescriptor
) -> tuple[str, dict, Optional[tuple[str, ...]]]:
if descriptor.command is None or not len(descriptor.command):
raise ErrorInfo.bad_argument(
"Incorrect FlexConnect function invocation. Flight descriptor must contain command "
"with the invocation payload."
)

try:
payload = orjson.loads(descriptor.command)
except Exception:
raise ErrorInfo.bad_argument(
"Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON."
)

fun = payload.get("functionName")
if fun is None or not len(fun):
raise ErrorInfo.bad_argument(
"Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'."
)

parameters = payload.get("parameters") or {}
columns = parameters.get("columns")

return fun, parameters, columns

def _prepare_task(
self,
context: pyarrow.flight.ServerCallContext,
descriptor: pyarrow.flight.FlightDescriptor,
submit_invocation: SubmitInvocation,
) -> FlexConnectFunctionTask:
fun_name, parameters, columns = self._extract_invocation_payload(descriptor)
headers = self.call_info_middleware(context).headers
fun = self._registry.create_function(fun_name)
fun = self._registry.create_function(submit_invocation.function_name)

return FlexConnectFunctionTask(
fun=fun,
parameters=parameters,
columns=columns,
parameters=submit_invocation.parameters,
columns=submit_invocation.columns,
headers=headers,
cmd=descriptor.command,
cmd=submit_invocation.command,
)

def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flight.FlightInfo:
def _prepare_flight_info(
self, task_id: str, task_result: Optional[TaskExecutionResult]
) -> pyarrow.flight.FlightInfo:
if task_result is None:
raise ErrorInfo.for_reason(
ErrorCode.BAD_ARGUMENT, f"Task with id '{task_id}' does not exist."
).to_user_error()

if task_result.error is not None:
raise task_result.error.as_flight_error()

if task_result.cancelled:
raise ErrorInfo.for_reason(
ErrorCode.COMMAND_CANCELLED,
f"FlexConnect function invocation was cancelled. Invocation task was: '{task_result.task_id}'.",
f"FlexConnect function invocation was cancelled. Invocation task was: '{task_id}'.",
).to_server_error()

result = task_result.result
Expand All @@ -114,40 +122,33 @@ def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flig
descriptor=pyarrow.flight.FlightDescriptor.for_command(task_result.cmd),
endpoints=[
pyarrow.flight.FlightEndpoint(
ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_result.task_id})),
ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_id})),
locations=[self._ctx.location],
)
],
total_records=-1,
total_bytes=-1,
)

###################################################################
# Implementation of Flight RPC methods
###################################################################

def list_flights(
self, context: pyarrow.flight.ServerCallContext, criteria: bytes
) -> Generator[pyarrow.flight.FlightInfo, None, None]:
structlog.contextvars.bind_contextvars(peer=context.peer())
_LOGGER.info("list_flights", available_funs=self._registry.function_names)

return (self._create_fun_info(fun) for fun in self._registry.functions.values())

def get_flight_info(
def _get_flight_info_no_polling(
self,
context: pyarrow.flight.ServerCallContext,
descriptor: pyarrow.flight.FlightDescriptor,
) -> pyarrow.flight.FlightInfo:
"""
Basic DoGetInfo flow with no polling extension.
This conforms to the mainline Arrow Flight RPC specification.
"""
structlog.contextvars.bind_contextvars(peer=context.peer())
invocation = extract_submit_invocation_from_descriptor(descriptor)

task: Optional[FlexConnectFunctionTask] = None

try:
task = self._prepare_task(context, descriptor)
task = self._prepare_task(context, invocation)
self._ctx.task_executor.submit(task)

try:
# XXX: this should be enhanced to implement polling
task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline)
except TaskWaitTimeoutError:
cancelled = self._ctx.task_executor.cancel(task.task_id)
Expand All @@ -166,15 +167,106 @@ def get_flight_info(
# particular task id finished
assert task_result is not None

return self._prepare_flight_info(task_result)
return self._prepare_flight_info(task_id=task.task_id, task_result=task_result)
except Exception:
if task is not None:
_LOGGER.error("get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True)
_LOGGER.error(
"get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True, polling=False
)
else:
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True)
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=False)
raise

def _get_flight_info_polling(
self,
context: pyarrow.flight.ServerCallContext,
descriptor: pyarrow.flight.FlightDescriptor,
) -> pyarrow.flight.FlightInfo:
"""
DoGetInfo flow with polling extension.
This extends the mainline Arrow Flight RPC specification with polling capabilities using the RetryInfo
encoded into the FlightTimedOutError.extra_info.
Ideally, we would use the mainline PollFlightInfo, but that has yet to be implemented in the PyArrow library.
"""
structlog.contextvars.bind_contextvars(peer=context.peer())
invocation = extract_pollable_invocation_from_descriptor(descriptor)

task_id: str
fun_name: Optional[str] = None

if isinstance(invocation, CancelInvocation):
# cancel the given task and raise cancellation exception
if self._ctx.task_executor.cancel(invocation.task_id):
raise ErrorInfo.for_reason(
ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled."
).to_cancelled_error()
raise ErrorInfo.for_reason(
ErrorCode.COMMAND_CANCEL_NOT_POSSIBLE, "FlexConnect function invocation could not be cancelled."
).to_cancelled_error()
elif isinstance(invocation, RetryInvocation):
# retry descriptor: extract the task_id, do not submit it again and do one polling iteration
task_id = invocation.task_id
elif isinstance(invocation, SubmitInvocation):
# basic first-time submit: submit the task and do one polling iteration.
# do not check call deadline to give it a chance to wait for the result at least once
try:
task = self._prepare_task(context, invocation)
self._ctx.task_executor.submit(task)
task_id = task.task_id
fun_name = task.fun_name
except Exception:
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=True)
raise
else:
# can be replaced by assert_never when we are on 3.11
raise AssertionError

try:
task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval)
return self._prepare_flight_info(task_id, task_result)
except TimeoutError:
# first, check the call deadline for the whole call duration
task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id)
if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline:
self._ctx.task_executor.cancel(task_id)
raise ErrorInfo.for_reason(
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}."
).to_timeout_error()

# if the result is not ready, and we still have time, indicate to the client
# how to poll for the results
raise _prepare_poll_error(task_id)
except Exception:
_LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True, polling=True)
raise

###################################################################
# Implementation of Flight RPC methods
###################################################################

def list_flights(
self, context: pyarrow.flight.ServerCallContext, criteria: bytes
) -> Generator[pyarrow.flight.FlightInfo, None, None]:
structlog.contextvars.bind_contextvars(peer=context.peer())
_LOGGER.info("list_flights", available_funs=self._registry.function_names)

return (self._create_fun_info(fun) for fun in self._registry.functions.values())

def get_flight_info(
self,
context: pyarrow.flight.ServerCallContext,
descriptor: pyarrow.flight.FlightDescriptor,
) -> pyarrow.flight.FlightInfo:
structlog.contextvars.bind_contextvars(peer=context.peer())

headers = self.call_info_middleware(context).headers
allow_polling = headers.get(POLLING_HEADER_NAME) is not None

if allow_polling:
return self._get_flight_info_polling(context, descriptor)
else:
return self._get_flight_info_no_polling(context, descriptor)

def do_get(
self,
context: pyarrow.flight.ServerCallContext,
Expand All @@ -201,7 +293,9 @@ def do_get(
_FLEX_CONNECT_CONFIG_SECTION = "flexconnect"
_FLEX_CONNECT_FUNCTION_LIST = "functions"
_FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms"
_FLEX_CONNECT_POLLING_INTERVAL_MS = "polling_interval_ms"
_DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000
_DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS = 2000


def _read_call_deadline_ms(ctx: ServerContext) -> int:
Expand All @@ -223,6 +317,24 @@ def _read_call_deadline_ms(ctx: ServerContext) -> int:
)


def _read_polling_interval_ms(ctx: ServerContext) -> int:
polling_interval = ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS}")
if polling_interval is None:
return _DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS

try:
polling_interval = int(polling_interval)
if polling_interval <= 0:
raise ValueError()
return polling_interval
except ValueError:
raise ValueError(
f"Value of {_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS} must "
f"be a positive number - duration, in milliseconds, that FlexConnect function "
f"waits for the result during one polling iteration."
)


@flight_server_methods
def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods:
"""
Expand All @@ -236,8 +348,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods
"""
modules = list(ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_FUNCTION_LIST}") or [])
call_deadline_ms = _read_call_deadline_ms(ctx)
polling_interval_ms = _read_polling_interval_ms(ctx)

_LOGGER.info("flexconnect_init", modules=modules)
registry = FlexConnectFunctionRegistry().load(ctx, modules)

return _FlexConnectServerMethods(ctx, registry, call_deadline_ms)
return _FlexConnectServerMethods(ctx, registry, call_deadline_ms, polling_interval_ms)
Loading
Loading