Skip to content

Commit de04b17

Browse files
authored
Merge pull request #1043 from no23reason/dho/cq-1124-polling-better
feat: add support for polling in FlexConnect server
2 parents f51941e + 173aa5b commit de04b17

File tree

8 files changed

+507
-63
lines changed

8 files changed

+507
-63
lines changed

gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py

Lines changed: 169 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# (C) 2024 GoodData Corporation
2+
import time
23
from collections.abc import Generator
34
from typing import Optional
45

@@ -17,17 +18,45 @@
1718
)
1819

1920
from gooddata_flexconnect.function.function import FlexConnectFunction
21+
from gooddata_flexconnect.function.function_invocation import (
22+
CancelInvocation,
23+
RetryInvocation,
24+
SubmitInvocation,
25+
extract_pollable_invocation_from_descriptor,
26+
extract_submit_invocation_from_descriptor,
27+
)
2028
from gooddata_flexconnect.function.function_registry import FlexConnectFunctionRegistry
2129
from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask
2230

2331
_LOGGER = structlog.get_logger("gooddata_flexconnect.rpc")
2432

33+
POLLING_HEADER_NAME = "x-quiver-pollable"
34+
"""
35+
If this header is present on the get flight info call, the polling extension will be used.
36+
Otherwise the basic do get will be used.
37+
"""
38+
39+
40+
def _prepare_poll_error(task_id: str) -> pyarrow.flight.FlightError:
41+
return ErrorInfo.poll(
42+
flight_info=None,
43+
cancel_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"c:{task_id}".encode()),
44+
retry_descriptor=pyarrow.flight.FlightDescriptor.for_command(f"r:{task_id}".encode()),
45+
)
46+
2547

2648
class _FlexConnectServerMethods(FlightServerMethods):
27-
def __init__(self, ctx: ServerContext, registry: FlexConnectFunctionRegistry, call_deadline_ms: float) -> None:
49+
def __init__(
50+
self,
51+
ctx: ServerContext,
52+
registry: FlexConnectFunctionRegistry,
53+
call_deadline_ms: float,
54+
poll_interval_ms: float,
55+
) -> None:
2856
self._ctx = ctx
2957
self._registry = registry
3058
self._call_deadline = call_deadline_ms / 1000
59+
self._poll_interval = poll_interval_ms / 1000
3160

3261
@staticmethod
3362
def _create_descriptor(fun_name: str, metadata: Optional[dict]) -> pyarrow.flight.FlightDescriptor:
@@ -52,58 +81,37 @@ def _create_fun_info(self, fun: type[FlexConnectFunction]) -> pyarrow.flight.Fli
5281
total_records=-1,
5382
)
5483

55-
def _extract_invocation_payload(
56-
self, descriptor: pyarrow.flight.FlightDescriptor
57-
) -> tuple[str, dict, Optional[tuple[str, ...]]]:
58-
if descriptor.command is None or not len(descriptor.command):
59-
raise ErrorInfo.bad_argument(
60-
"Incorrect FlexConnect function invocation. Flight descriptor must contain command "
61-
"with the invocation payload."
62-
)
63-
64-
try:
65-
payload = orjson.loads(descriptor.command)
66-
except Exception:
67-
raise ErrorInfo.bad_argument(
68-
"Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON."
69-
)
70-
71-
fun = payload.get("functionName")
72-
if fun is None or not len(fun):
73-
raise ErrorInfo.bad_argument(
74-
"Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'."
75-
)
76-
77-
parameters = payload.get("parameters") or {}
78-
columns = parameters.get("columns")
79-
80-
return fun, parameters, columns
81-
8284
def _prepare_task(
8385
self,
8486
context: pyarrow.flight.ServerCallContext,
85-
descriptor: pyarrow.flight.FlightDescriptor,
87+
submit_invocation: SubmitInvocation,
8688
) -> FlexConnectFunctionTask:
87-
fun_name, parameters, columns = self._extract_invocation_payload(descriptor)
8889
headers = self.call_info_middleware(context).headers
89-
fun = self._registry.create_function(fun_name)
90+
fun = self._registry.create_function(submit_invocation.function_name)
9091

9192
return FlexConnectFunctionTask(
9293
fun=fun,
93-
parameters=parameters,
94-
columns=columns,
94+
parameters=submit_invocation.parameters,
95+
columns=submit_invocation.columns,
9596
headers=headers,
96-
cmd=descriptor.command,
97+
cmd=submit_invocation.command,
9798
)
9899

99-
def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flight.FlightInfo:
100+
def _prepare_flight_info(
101+
self, task_id: str, task_result: Optional[TaskExecutionResult]
102+
) -> pyarrow.flight.FlightInfo:
103+
if task_result is None:
104+
raise ErrorInfo.for_reason(
105+
ErrorCode.BAD_ARGUMENT, f"Task with id '{task_id}' does not exist."
106+
).to_user_error()
107+
100108
if task_result.error is not None:
101109
raise task_result.error.as_flight_error()
102110

103111
if task_result.cancelled:
104112
raise ErrorInfo.for_reason(
105113
ErrorCode.COMMAND_CANCELLED,
106-
f"FlexConnect function invocation was cancelled. Invocation task was: '{task_result.task_id}'.",
114+
f"FlexConnect function invocation was cancelled. Invocation task was: '{task_id}'.",
107115
).to_server_error()
108116

109117
result = task_result.result
@@ -114,40 +122,33 @@ def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flig
114122
descriptor=pyarrow.flight.FlightDescriptor.for_command(task_result.cmd),
115123
endpoints=[
116124
pyarrow.flight.FlightEndpoint(
117-
ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_result.task_id})),
125+
ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_id})),
118126
locations=[self._ctx.location],
119127
)
120128
],
121129
total_records=-1,
122130
total_bytes=-1,
123131
)
124132

125-
###################################################################
126-
# Implementation of Flight RPC methods
127-
###################################################################
128-
129-
def list_flights(
130-
self, context: pyarrow.flight.ServerCallContext, criteria: bytes
131-
) -> Generator[pyarrow.flight.FlightInfo, None, None]:
132-
structlog.contextvars.bind_contextvars(peer=context.peer())
133-
_LOGGER.info("list_flights", available_funs=self._registry.function_names)
134-
135-
return (self._create_fun_info(fun) for fun in self._registry.functions.values())
136-
137-
def get_flight_info(
133+
def _get_flight_info_no_polling(
138134
self,
139135
context: pyarrow.flight.ServerCallContext,
140136
descriptor: pyarrow.flight.FlightDescriptor,
141137
) -> pyarrow.flight.FlightInfo:
138+
"""
139+
Basic DoGetInfo flow with no polling extension.
140+
This conforms to the mainline Arrow Flight RPC specification.
141+
"""
142142
structlog.contextvars.bind_contextvars(peer=context.peer())
143+
invocation = extract_submit_invocation_from_descriptor(descriptor)
144+
143145
task: Optional[FlexConnectFunctionTask] = None
144146

145147
try:
146-
task = self._prepare_task(context, descriptor)
148+
task = self._prepare_task(context, invocation)
147149
self._ctx.task_executor.submit(task)
148150

149151
try:
150-
# XXX: this should be enhanced to implement polling
151152
task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline)
152153
except TaskWaitTimeoutError:
153154
cancelled = self._ctx.task_executor.cancel(task.task_id)
@@ -166,15 +167,106 @@ def get_flight_info(
166167
# particular task id finished
167168
assert task_result is not None
168169

169-
return self._prepare_flight_info(task_result)
170+
return self._prepare_flight_info(task_id=task.task_id, task_result=task_result)
170171
except Exception:
171172
if task is not None:
172-
_LOGGER.error("get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True)
173+
_LOGGER.error(
174+
"get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True, polling=False
175+
)
173176
else:
174-
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True)
177+
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=False)
178+
raise
179+
180+
def _get_flight_info_polling(
181+
self,
182+
context: pyarrow.flight.ServerCallContext,
183+
descriptor: pyarrow.flight.FlightDescriptor,
184+
) -> pyarrow.flight.FlightInfo:
185+
"""
186+
DoGetInfo flow with polling extension.
187+
This extends the mainline Arrow Flight RPC specification with polling capabilities using the RetryInfo
188+
encoded into the FlightTimedOutError.extra_info.
189+
Ideally, we would use the mainline PollFlightInfo, but that has yet to be implemented in the PyArrow library.
190+
"""
191+
structlog.contextvars.bind_contextvars(peer=context.peer())
192+
invocation = extract_pollable_invocation_from_descriptor(descriptor)
193+
194+
task_id: str
195+
fun_name: Optional[str] = None
196+
197+
if isinstance(invocation, CancelInvocation):
198+
# cancel the given task and raise cancellation exception
199+
if self._ctx.task_executor.cancel(invocation.task_id):
200+
raise ErrorInfo.for_reason(
201+
ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled."
202+
).to_cancelled_error()
203+
raise ErrorInfo.for_reason(
204+
ErrorCode.COMMAND_CANCEL_NOT_POSSIBLE, "FlexConnect function invocation could not be cancelled."
205+
).to_cancelled_error()
206+
elif isinstance(invocation, RetryInvocation):
207+
# retry descriptor: extract the task_id, do not submit it again and do one polling iteration
208+
task_id = invocation.task_id
209+
elif isinstance(invocation, SubmitInvocation):
210+
# basic first-time submit: submit the task and do one polling iteration.
211+
# do not check call deadline to give it a chance to wait for the result at least once
212+
try:
213+
task = self._prepare_task(context, invocation)
214+
self._ctx.task_executor.submit(task)
215+
task_id = task.task_id
216+
fun_name = task.fun_name
217+
except Exception:
218+
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=True)
219+
raise
220+
else:
221+
# can be replaced by assert_never when we are on 3.11
222+
raise AssertionError
223+
224+
try:
225+
task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval)
226+
return self._prepare_flight_info(task_id, task_result)
227+
except TimeoutError:
228+
# first, check the call deadline for the whole call duration
229+
task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id)
230+
if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline:
231+
self._ctx.task_executor.cancel(task_id)
232+
raise ErrorInfo.for_reason(
233+
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}."
234+
).to_timeout_error()
175235

236+
# if the result is not ready, and we still have time, indicate to the client
237+
# how to poll for the results
238+
raise _prepare_poll_error(task_id)
239+
except Exception:
240+
_LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True, polling=True)
176241
raise
177242

243+
###################################################################
244+
# Implementation of Flight RPC methods
245+
###################################################################
246+
247+
def list_flights(
248+
self, context: pyarrow.flight.ServerCallContext, criteria: bytes
249+
) -> Generator[pyarrow.flight.FlightInfo, None, None]:
250+
structlog.contextvars.bind_contextvars(peer=context.peer())
251+
_LOGGER.info("list_flights", available_funs=self._registry.function_names)
252+
253+
return (self._create_fun_info(fun) for fun in self._registry.functions.values())
254+
255+
def get_flight_info(
256+
self,
257+
context: pyarrow.flight.ServerCallContext,
258+
descriptor: pyarrow.flight.FlightDescriptor,
259+
) -> pyarrow.flight.FlightInfo:
260+
structlog.contextvars.bind_contextvars(peer=context.peer())
261+
262+
headers = self.call_info_middleware(context).headers
263+
allow_polling = headers.get(POLLING_HEADER_NAME) is not None
264+
265+
if allow_polling:
266+
return self._get_flight_info_polling(context, descriptor)
267+
else:
268+
return self._get_flight_info_no_polling(context, descriptor)
269+
178270
def do_get(
179271
self,
180272
context: pyarrow.flight.ServerCallContext,
@@ -201,7 +293,9 @@ def do_get(
201293
_FLEX_CONNECT_CONFIG_SECTION = "flexconnect"
202294
_FLEX_CONNECT_FUNCTION_LIST = "functions"
203295
_FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms"
296+
_FLEX_CONNECT_POLLING_INTERVAL_MS = "polling_interval_ms"
204297
_DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000
298+
_DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS = 2000
205299

206300

207301
def _read_call_deadline_ms(ctx: ServerContext) -> int:
@@ -223,6 +317,24 @@ def _read_call_deadline_ms(ctx: ServerContext) -> int:
223317
)
224318

225319

320+
def _read_polling_interval_ms(ctx: ServerContext) -> int:
321+
polling_interval = ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS}")
322+
if polling_interval is None:
323+
return _DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS
324+
325+
try:
326+
polling_interval = int(polling_interval)
327+
if polling_interval <= 0:
328+
raise ValueError()
329+
return polling_interval
330+
except ValueError:
331+
raise ValueError(
332+
f"Value of {_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_POLLING_INTERVAL_MS} must "
333+
f"be a positive number - duration, in milliseconds, that FlexConnect function "
334+
f"waits for the result during one polling iteration."
335+
)
336+
337+
226338
@flight_server_methods
227339
def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods:
228340
"""
@@ -236,8 +348,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods
236348
"""
237349
modules = list(ctx.settings.get(f"{_FLEX_CONNECT_CONFIG_SECTION}.{_FLEX_CONNECT_FUNCTION_LIST}") or [])
238350
call_deadline_ms = _read_call_deadline_ms(ctx)
351+
polling_interval_ms = _read_polling_interval_ms(ctx)
239352

240353
_LOGGER.info("flexconnect_init", modules=modules)
241354
registry = FlexConnectFunctionRegistry().load(ctx, modules)
242355

243-
return _FlexConnectServerMethods(ctx, registry, call_deadline_ms)
356+
return _FlexConnectServerMethods(ctx, registry, call_deadline_ms, polling_interval_ms)

0 commit comments

Comments
 (0)