Skip to content

Commit 64aad67

Browse files
committed
feat: make the polling extension opt-in
By default, the FlexConnect will conform to the Arrow Flight RPC spec. However, if an opt-in header is present, it will use the polling extension used by GoodData. This allows for things like query cancellation. Ideally, we would use the PollFlightInfo from the Arrow Flight RPC but unfortunately it is not yet available in PyArrow. JIRA: CQ-1124 risk: low
1 parent 2541c61 commit 64aad67

File tree

3 files changed

+139
-48
lines changed

3 files changed

+139
-48
lines changed

gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
FlightServerMethods,
1414
ServerContext,
1515
TaskExecutionResult,
16+
TaskWaitTimeoutError,
1617
flight_server_methods,
1718
)
1819

@@ -21,13 +22,20 @@
2122
CancelInvocation,
2223
RetryInvocation,
2324
SubmitInvocation,
24-
extract_invocation_from_descriptor,
25+
extract_pollable_invocation_from_descriptor,
26+
extract_submit_invocation_from_descriptor,
2527
)
2628
from gooddata_flexconnect.function.function_registry import FlexConnectFunctionRegistry
2729
from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask
2830

2931
_LOGGER = structlog.get_logger("gooddata_flexconnect.rpc")
3032

33+
POLLING_HEADER_NAME = "x-quiver-pollable"
34+
"""
35+
If this header is present on the get flight info call, the polling extension will be used.
36+
Otherwise the basic do get will be used.
37+
"""
38+
3139

3240
def _prepare_poll_error(task_id: str) -> pyarrow.flight.FlightError:
3341
return ErrorInfo.poll(
@@ -122,28 +130,69 @@ def _prepare_flight_info(
122130
total_bytes=-1,
123131
)
124132

125-
###################################################################
126-
# Implementation of Flight RPC methods
127-
###################################################################
128-
129-
def list_flights(
130-
self, context: pyarrow.flight.ServerCallContext, criteria: bytes
131-
) -> Generator[pyarrow.flight.FlightInfo, None, None]:
133+
def _get_flight_info_no_polling(
134+
self,
135+
context: pyarrow.flight.ServerCallContext,
136+
descriptor: pyarrow.flight.FlightDescriptor,
137+
) -> pyarrow.flight.FlightInfo:
138+
"""
139+
Basic DoGetInfo flow with no polling extension.
140+
This conforms to the mainline Arrow Flight RPC specification.
141+
"""
132142
structlog.contextvars.bind_contextvars(peer=context.peer())
133-
_LOGGER.info("list_flights", available_funs=self._registry.function_names)
143+
invocation = extract_submit_invocation_from_descriptor(descriptor)
134144

135-
return (self._create_fun_info(fun) for fun in self._registry.functions.values())
145+
task: Optional[FlexConnectFunctionTask] = None
136146

137-
def get_flight_info(
147+
try:
148+
task = self._prepare_task(context, invocation)
149+
self._ctx.task_executor.submit(task)
150+
151+
try:
152+
task_result = self._ctx.task_executor.wait_for_result(task.task_id, self._call_deadline)
153+
except TaskWaitTimeoutError:
154+
cancelled = self._ctx.task_executor.cancel(task.task_id)
155+
_LOGGER.warning(
156+
"flexconnect_fun_call_timeout", task_id=task.task_id, fun=task.fun_name, cancelled=cancelled
157+
)
158+
159+
raise ErrorInfo.for_reason(
160+
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task.task_id}."
161+
).to_timeout_error()
162+
163+
# if this bombs then there must be something really wrong because the task
164+
# was clearly submitted and code was waiting for its completion. this invariant
165+
# should not happen in this particular code path. The None return value may
166+
# be applicable one day when polling is in use and a request comes to check whether
167+
# particular task id finished
168+
assert task_result is not None
169+
170+
return self._prepare_flight_info(task_id=task.task_id, task_result=task_result)
171+
except Exception:
172+
if task is not None:
173+
_LOGGER.error(
174+
"get_flight_info_failed", task_id=task.task_id, fun=task.fun_name, exc_info=True, polling=False
175+
)
176+
else:
177+
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=False)
178+
raise
179+
180+
def _get_flight_info_polling(
138181
self,
139182
context: pyarrow.flight.ServerCallContext,
140183
descriptor: pyarrow.flight.FlightDescriptor,
141184
) -> pyarrow.flight.FlightInfo:
185+
"""
186+
DoGetInfo flow with polling extension.
187+
This extends the mainline Arrow Flight RPC specification with polling capabilities using the RetryInfo
188+
encoded into the FlightTimedOutError.extra_info.
189+
Ideally, we would use the mainline PollFlightInfo, but that has yet to be implemented in the PyArrow library.
190+
"""
142191
structlog.contextvars.bind_contextvars(peer=context.peer())
192+
invocation = extract_pollable_invocation_from_descriptor(descriptor)
143193

144194
task_id: str
145195
fun_name: Optional[str] = None
146-
invocation = extract_invocation_from_descriptor(descriptor)
147196

148197
if isinstance(invocation, CancelInvocation):
149198
# cancel the given task and raise cancellation exception
@@ -166,7 +215,7 @@ def get_flight_info(
166215
task_id = task.task_id
167216
fun_name = task.fun_name
168217
except Exception:
169-
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True)
218+
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True, polling=True)
170219
raise
171220
else:
172221
# can be replaced by assert_never when we are on 3.11
@@ -188,9 +237,36 @@ def get_flight_info(
188237
# how to poll for the results
189238
raise _prepare_poll_error(task_id)
190239
except Exception:
191-
_LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True)
240+
_LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True, polling=True)
192241
raise
193242

243+
###################################################################
244+
# Implementation of Flight RPC methods
245+
###################################################################
246+
247+
def list_flights(
248+
self, context: pyarrow.flight.ServerCallContext, criteria: bytes
249+
) -> Generator[pyarrow.flight.FlightInfo, None, None]:
250+
structlog.contextvars.bind_contextvars(peer=context.peer())
251+
_LOGGER.info("list_flights", available_funs=self._registry.function_names)
252+
253+
return (self._create_fun_info(fun) for fun in self._registry.functions.values())
254+
255+
def get_flight_info(
256+
self,
257+
context: pyarrow.flight.ServerCallContext,
258+
descriptor: pyarrow.flight.FlightDescriptor,
259+
) -> pyarrow.flight.FlightInfo:
260+
structlog.contextvars.bind_contextvars(peer=context.peer())
261+
262+
headers = self.call_info_middleware(context).headers
263+
allow_polling = headers.get(POLLING_HEADER_NAME) is not None
264+
265+
if allow_polling:
266+
return self._get_flight_info_polling(context, descriptor)
267+
else:
268+
return self._get_flight_info_no_polling(context, descriptor)
269+
194270
def do_get(
195271
self,
196272
context: pyarrow.flight.ServerCallContext,

gooddata-flexconnect/gooddata_flexconnect/function/function_invocation.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,42 +53,51 @@ class SubmitInvocation:
5353
"""
5454

5555

56-
def extract_invocation_from_descriptor(
56+
def extract_submit_invocation_from_descriptor(descriptor: pyarrow.flight.FlightDescriptor) -> SubmitInvocation:
57+
"""
58+
Given a flight descriptor, extract the invocation information from it.
59+
Do not allow the polling-related variants.
60+
"""
61+
try:
62+
payload = orjson.loads(descriptor.command)
63+
except Exception:
64+
raise ErrorInfo.bad_argument(
65+
"Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON."
66+
)
67+
68+
function_name = payload.get("functionName")
69+
if function_name is None or not len(function_name):
70+
raise ErrorInfo.bad_argument(
71+
"Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'."
72+
)
73+
74+
parameters = payload.get("parameters") or {}
75+
columns = parameters.get("columns")
76+
77+
return SubmitInvocation(
78+
function_name=function_name, parameters=parameters, columns=columns, command=descriptor.command
79+
)
80+
81+
82+
def extract_pollable_invocation_from_descriptor(
5783
descriptor: pyarrow.flight.FlightDescriptor,
5884
) -> Union[RetryInvocation, CancelInvocation, SubmitInvocation]:
5985
"""
6086
Given a flight descriptor, extract the invocation information from it.
87+
Allow also the polling-related variants.
6188
"""
62-
6389
if descriptor.command is None or not len(descriptor.command):
6490
raise ErrorInfo.bad_argument(
6591
"Incorrect FlexConnect function invocation. Flight descriptor must contain command "
6692
"with the invocation payload."
6793
)
6894

95+
# we are in the polling-enabled realm: try parsing the retry and cancel descriptors first
6996
if descriptor.command.startswith(b"c:"):
7097
task_id = descriptor.command[2:].decode()
7198
return CancelInvocation(task_id)
7299
elif descriptor.command.startswith(b"r:"):
73100
task_id = descriptor.command[2:].decode()
74101
return RetryInvocation(task_id)
75-
else:
76-
try:
77-
payload = orjson.loads(descriptor.command)
78-
except Exception:
79-
raise ErrorInfo.bad_argument(
80-
"Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON."
81-
)
82-
83-
function_name = payload.get("functionName")
84-
if function_name is None or not len(function_name):
85-
raise ErrorInfo.bad_argument(
86-
"Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'."
87-
)
88-
89-
parameters = payload.get("parameters") or {}
90-
columns = parameters.get("columns")
91-
92-
return SubmitInvocation(
93-
function_name=function_name, parameters=parameters, columns=columns, command=descriptor.command
94-
)
102+
103+
return extract_submit_invocation_from_descriptor(descriptor)

gooddata-flexconnect/tests/server/test_flexconnect_server.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33
import orjson
44
import pyarrow.flight
55
import pytest
6+
from gooddata_flexconnect.function.flight_methods import POLLING_HEADER_NAME
67
from gooddata_flight_server import ErrorCode, ErrorInfo, RetryInfo
78

89
from tests.assert_error_info import assert_error_code
910
from tests.server.conftest import flexconnect_server
1011

1112

13+
@pytest.fixture
14+
def call_options_with_polling():
15+
return pyarrow.flight.FlightCallOptions(headers=[(POLLING_HEADER_NAME.encode(), b"true")])
16+
17+
1218
def test_basic_function():
1319
"""
1420
This function should return immediately when called, no polling necessary.
@@ -94,7 +100,7 @@ def test_basic_function_tls(tls_ca_cert):
94100
assert data.column_names == ["col1", "col2", "col3"]
95101

96102

97-
def test_function_with_polling():
103+
def test_function_with_polling(call_options_with_polling):
98104
"""
99105
Flight RPC implementation that invokes FlexConnect can return a polling info.
100106
@@ -114,7 +120,7 @@ def test_function_with_polling():
114120
# the function is set to sleep a bit longer than the polling interval,
115121
# so the first iteration returns retry info in the exception
116122
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
117-
c.get_flight_info(descriptor)
123+
c.get_flight_info(descriptor, call_options_with_polling)
118124

119125
assert e.value is not None
120126
assert_error_code(ErrorCode.POLL, e.value)
@@ -124,21 +130,21 @@ def test_function_with_polling():
124130

125131
# use the retry info to poll again for the result,
126132
# now it should be ready and returned normally
127-
info = c.get_flight_info(retry_info.retry_descriptor)
133+
info = c.get_flight_info(retry_info.retry_descriptor, call_options_with_polling)
128134
data: pyarrow.Table = c.do_get(info.endpoints[0].ticket).read_all()
129135

130136
assert len(data) == 3
131137
assert data.column_names == ["col1", "col2", "col3"]
132138

133139
# also check that trying to cancel already completed task results in cancelled with correct code
134140
with pytest.raises(pyarrow.flight.FlightCancelledError) as e:
135-
c.get_flight_info(retry_info.cancel_descriptor)
141+
c.get_flight_info(retry_info.cancel_descriptor, call_options_with_polling)
136142

137143
assert e.value is not None
138144
assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value)
139145

140146

141-
def test_function_with_call_deadline():
147+
def test_function_with_call_deadline(call_options_with_polling):
142148
"""
143149
Flight RPC implementation that invokes FlexConnect can be setup with
144150
deadline for the invocation duration (done by GetFlightInfo).
@@ -163,7 +169,7 @@ def test_function_with_call_deadline():
163169

164170
# the initial submit returns polling info
165171
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
166-
c.get_flight_info(descriptor)
172+
c.get_flight_info(descriptor, call_options_with_polling)
167173

168174
assert e.value is not None
169175
assert_error_code(ErrorCode.POLL, e.value)
@@ -173,19 +179,19 @@ def test_function_with_call_deadline():
173179

174180
# the next poll still returns polling info
175181
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
176-
c.get_flight_info(retry_info.retry_descriptor)
182+
c.get_flight_info(retry_info.retry_descriptor, call_options_with_polling)
177183

178184
assert e.value is not None
179185
assert_error_code(ErrorCode.POLL, e.value)
180186

181187
# the third one reaches the deadline so the Timeout code is returned instead
182188
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
183-
c.get_flight_info(retry_info.retry_descriptor)
189+
c.get_flight_info(retry_info.retry_descriptor, call_options_with_polling)
184190

185191
assert_error_code(ErrorCode.TIMEOUT, e.value)
186192

187193

188-
def test_function_with_cancellation():
194+
def test_function_with_cancellation(call_options_with_polling):
189195
"""
190196
Run a long-running function and cancel it after one poll iteration.
191197
"""
@@ -202,7 +208,7 @@ def test_function_with_cancellation():
202208

203209
# the initial submit returns polling info
204210
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
205-
c.get_flight_info(descriptor)
211+
c.get_flight_info(descriptor, call_options_with_polling)
206212

207213
assert e.value is not None
208214
assert_error_code(ErrorCode.POLL, e.value)
@@ -212,14 +218,14 @@ def test_function_with_cancellation():
212218

213219
# use the poll info to cancel the task
214220
with pytest.raises(pyarrow.flight.FlightCancelledError) as e:
215-
c.get_flight_info(retry_info.cancel_descriptor)
221+
c.get_flight_info(retry_info.cancel_descriptor, call_options_with_polling)
216222

217223
assert e.value is not None
218224
assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value)
219225

220226
# even multiple cancellations return the same error
221227
with pytest.raises(pyarrow.flight.FlightCancelledError) as e:
222-
c.get_flight_info(retry_info.cancel_descriptor)
228+
c.get_flight_info(retry_info.cancel_descriptor, call_options_with_polling)
223229

224230
assert e.value is not None
225231
assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value)

0 commit comments

Comments
 (0)