Skip to content

Commit fb48dd5

Browse files
committed
fix: handle cancel return values in FlexConnect
When cancelling, be more transparent about whether the cancellation actually happened. Also handle missing tasks: that now yields BAD_ARGUMENT exceptions. JIRA: CQ-1124 risk: low
1 parent 0a40f24 commit fb48dd5

File tree

5 files changed

+55
-23
lines changed

5 files changed

+55
-23
lines changed

gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,21 @@ def _prepare_task(
8989
cmd=submit_invocation.command,
9090
)
9191

92-
def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flight.FlightInfo:
92+
def _prepare_flight_info(
93+
self, task_id: str, task_result: Optional[TaskExecutionResult]
94+
) -> pyarrow.flight.FlightInfo:
95+
if task_result is None:
96+
raise ErrorInfo.for_reason(
97+
ErrorCode.BAD_ARGUMENT, f"Task with id '{task_id}' does not exist."
98+
).to_user_error()
99+
93100
if task_result.error is not None:
94101
raise task_result.error.as_flight_error()
95102

96103
if task_result.cancelled:
97104
raise ErrorInfo.for_reason(
98105
ErrorCode.COMMAND_CANCELLED,
99-
f"FlexConnect function invocation was cancelled. Invocation task was: '{task_result.task_id}'.",
106+
f"FlexConnect function invocation was cancelled. Invocation task was: '{task_id}'.",
100107
).to_server_error()
101108

102109
result = task_result.result
@@ -107,7 +114,7 @@ def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flig
107114
descriptor=pyarrow.flight.FlightDescriptor.for_command(task_result.cmd),
108115
endpoints=[
109116
pyarrow.flight.FlightEndpoint(
110-
ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_result.task_id})),
117+
ticket=pyarrow.flight.Ticket(ticket=orjson.dumps({"task_id": task_id})),
111118
locations=[self._ctx.location],
112119
)
113120
],
@@ -140,20 +147,16 @@ def get_flight_info(
140147

141148
if isinstance(invocation, CancelInvocation):
142149
# cancel the given task and raise cancellation exception
143-
self._ctx.task_executor.cancel(invocation.task_id)
150+
if self._ctx.task_executor.cancel(invocation.task_id):
151+
raise ErrorInfo.for_reason(
152+
ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled."
153+
).to_cancelled_error()
144154
raise ErrorInfo.for_reason(
145-
ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled."
155+
ErrorCode.COMMAND_CANCEL_NOT_POSSIBLE, "FlexConnect function invocation could not be cancelled."
146156
).to_cancelled_error()
147157
elif isinstance(invocation, RetryInvocation):
148158
# retry descriptor: extract the task_id, do not submit it again and do one polling iteration
149159
task_id = invocation.task_id
150-
# for retries, we also need to check the call deadline for the whole call duration
151-
task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id)
152-
if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline:
153-
self._ctx.task_executor.cancel(task_id)
154-
raise ErrorInfo.for_reason(
155-
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}."
156-
).to_timeout_error()
157160
elif isinstance(invocation, SubmitInvocation):
158161
# basic first-time submit: submit the task and do one polling iteration.
159162
# do not check call deadline to give it a chance to wait for the result at least once
@@ -171,8 +174,18 @@ def get_flight_info(
171174

172175
try:
173176
task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval)
174-
return self._prepare_flight_info(task_result)
177+
return self._prepare_flight_info(task_id, task_result)
175178
except TimeoutError:
179+
# first, check the call deadline for the whole call duration
180+
task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id)
181+
if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline:
182+
self._ctx.task_executor.cancel(task_id)
183+
raise ErrorInfo.for_reason(
184+
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}."
185+
).to_timeout_error()
186+
187+
# if the result is not ready, and we still have time, indicate to the client
188+
# how to poll for the results
176189
raise _prepare_poll_error(task_id)
177190
except Exception:
178191
_LOGGER.error("get_flight_info_failed", task_id=task_id, fun=fun_name, exc_info=True)

gooddata-flexconnect/tests/server/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def flexconnect_server(
8080
funs = f"[{funs}]"
8181

8282
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__FUNCTIONS"] = funs
83-
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "1000"
83+
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__CALL_DEADLINE_MS"] = "1200"
8484
os.environ["GOODDATA_FLIGHT_FLEXCONNECT__POLLING_INTERVAL_MS"] = "500"
8585

8686
with server(create_flexconnect_flight_methods, tls, mtls) as s:

gooddata-flexconnect/tests/server/funs/fun3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def call(
2727
) -> ArrowData:
2828
# sleep is intentionally setup to be longer than the deadline for
2929
# the function invocation (see conftest.py // flexconnect_server fixture)
30-
time.sleep(1.5)
30+
time.sleep(2)
3131

3232
return pyarrow.table(
3333
data={

gooddata-flexconnect/tests/server/test_flexconnect_server.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111

1212
def test_basic_function():
13+
"""
14+
This function should return immediately when called, no polling necessary.
15+
:return:
16+
"""
1317
with flexconnect_server(["tests.server.funs.fun1"]) as s:
1418
c = pyarrow.flight.FlightClient(s.location)
1519
fun_infos = list(c.list_flights())
@@ -126,6 +130,13 @@ def test_function_with_polling():
126130
assert len(data) == 3
127131
assert data.column_names == ["col1", "col2", "col3"]
128132

133+
# also check that trying to cancel already completed task results in cancelled with correct code
134+
with pytest.raises(pyarrow.flight.FlightCancelledError) as e:
135+
c.get_flight_info(retry_info.cancel_descriptor)
136+
137+
assert e.value is not None
138+
assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value)
139+
129140

130141
def test_function_with_call_deadline():
131142
"""
@@ -150,6 +161,7 @@ def test_function_with_call_deadline():
150161
)
151162
)
152163

164+
# the initial submit returns polling info
153165
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
154166
c.get_flight_info(descriptor)
155167

@@ -159,24 +171,21 @@ def test_function_with_call_deadline():
159171
error_info = ErrorInfo.from_bytes(e.value.extra_info)
160172
retry_info = RetryInfo.from_bytes(error_info.body)
161173

162-
# poll twice to reach the call deadline
174+
# the next poll still returns polling info
163175
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
164176
c.get_flight_info(retry_info.retry_descriptor)
165177

166178
assert e.value is not None
167179
assert_error_code(ErrorCode.POLL, e.value)
168180

169-
error_info = ErrorInfo.from_bytes(e.value.extra_info)
170-
retry_info = RetryInfo.from_bytes(error_info.body)
171-
181+
# the third one reaches the deadline so the Timeout code is returned instead
172182
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
173183
c.get_flight_info(retry_info.retry_descriptor)
174184

175-
# and then ensure the timeout error is returned
176185
assert_error_code(ErrorCode.TIMEOUT, e.value)
177186

178187

179-
def test_function_with_cancelation():
188+
def test_function_with_cancellation():
180189
"""
181190
Run a long-running function and cancel it after one poll iteration.
182191
"""
@@ -191,6 +200,7 @@ def test_function_with_cancelation():
191200
)
192201
)
193202

203+
# the initial submit returns polling info
194204
with pytest.raises(pyarrow.flight.FlightTimedOutError) as e:
195205
c.get_flight_info(descriptor)
196206

@@ -200,7 +210,16 @@ def test_function_with_cancelation():
200210
error_info = ErrorInfo.from_bytes(e.value.extra_info)
201211
retry_info = RetryInfo.from_bytes(error_info.body)
202212

213+
# use the poll info to cancel the task
214+
with pytest.raises(pyarrow.flight.FlightCancelledError) as e:
215+
c.get_flight_info(retry_info.cancel_descriptor)
216+
217+
assert e.value is not None
218+
assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value)
219+
220+
# even multiple cancellations return the same error
203221
with pytest.raises(pyarrow.flight.FlightCancelledError) as e:
204222
c.get_flight_info(retry_info.cancel_descriptor)
205223

206224
assert e.value is not None
225+
assert_error_code(ErrorCode.COMMAND_CANCELLED, e.value)

gooddata-flight-server/gooddata_flight_server/tasks/thread_task_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,8 +605,8 @@ def cancel(self, task_id: str) -> bool:
605605
return True
606606

607607
if execution is None:
608-
# the task was not and is not running - cancel not possible
609-
return False
608+
# the task was not and is not running - cancel not necessary
609+
return True
610610

611611
return execution.cancel()
612612

0 commit comments

Comments
 (0)