Skip to content

Commit 0a40f24

Browse files
committed
refactor: extract the payload parsing
Encapsulate the different invocation types so that we can change the actual representation later if needed. JIRA: CQ-1124 risk: low
1 parent 3705c59 commit 0a40f24

File tree

2 files changed

+116
-48
lines changed

2 files changed

+116
-48
lines changed

gooddata-flexconnect/gooddata_flexconnect/function/flight_methods.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
)
1818

1919
from gooddata_flexconnect.function.function import FlexConnectFunction
20+
from gooddata_flexconnect.function.function_invocation import (
21+
CancelInvocation,
22+
RetryInvocation,
23+
SubmitInvocation,
24+
extract_invocation_from_descriptor,
25+
)
2026
from gooddata_flexconnect.function.function_registry import FlexConnectFunctionRegistry
2127
from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask
2228

@@ -67,48 +73,20 @@ def _create_fun_info(self, fun: type[FlexConnectFunction]) -> pyarrow.flight.Fli
6773
total_records=-1,
6874
)
6975

70-
def _extract_invocation_payload(
71-
self, descriptor: pyarrow.flight.FlightDescriptor
72-
) -> tuple[str, dict, Optional[tuple[str, ...]]]:
73-
if descriptor.command is None or not len(descriptor.command):
74-
raise ErrorInfo.bad_argument(
75-
"Incorrect FlexConnect function invocation. Flight descriptor must contain command "
76-
"with the invocation payload."
77-
)
78-
79-
try:
80-
payload = orjson.loads(descriptor.command)
81-
except Exception:
82-
raise ErrorInfo.bad_argument(
83-
"Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON."
84-
)
85-
86-
fun = payload.get("functionName")
87-
if fun is None or not len(fun):
88-
raise ErrorInfo.bad_argument(
89-
"Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'."
90-
)
91-
92-
parameters = payload.get("parameters") or {}
93-
columns = parameters.get("columns")
94-
95-
return fun, parameters, columns
96-
9776
def _prepare_task(
9877
self,
9978
context: pyarrow.flight.ServerCallContext,
100-
descriptor: pyarrow.flight.FlightDescriptor,
79+
submit_invocation: SubmitInvocation,
10180
) -> FlexConnectFunctionTask:
102-
fun_name, parameters, columns = self._extract_invocation_payload(descriptor)
10381
headers = self.call_info_middleware(context).headers
104-
fun = self._registry.create_function(fun_name)
82+
fun = self._registry.create_function(submit_invocation.function_name)
10583

10684
return FlexConnectFunctionTask(
10785
fun=fun,
108-
parameters=parameters,
109-
columns=columns,
86+
parameters=submit_invocation.parameters,
87+
columns=submit_invocation.columns,
11088
headers=headers,
111-
cmd=descriptor.command,
89+
cmd=submit_invocation.command,
11290
)
11391

11492
def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flight.FlightInfo:
@@ -156,44 +134,40 @@ def get_flight_info(
156134
) -> pyarrow.flight.FlightInfo:
157135
structlog.contextvars.bind_contextvars(peer=context.peer())
158136

159-
# first, check if the descriptor is a cancel descriptor
160-
if descriptor.command is None or not len(descriptor.command):
161-
raise ErrorInfo.bad_argument(
162-
"Incorrect FlexConnect function invocation. Flight descriptor must contain command "
163-
"with the invocation payload."
164-
)
165-
166137
task_id: str
167138
fun_name: Optional[str] = None
139+
invocation = extract_invocation_from_descriptor(descriptor)
168140

169-
if descriptor.command.startswith(b"c:"):
170-
# cancel descriptor: just cancel the given task and raise cancellation exception
171-
task_id = descriptor.command[2:].decode()
172-
self._ctx.task_executor.cancel(task_id)
141+
if isinstance(invocation, CancelInvocation):
142+
# cancel the given task and raise cancellation exception
143+
self._ctx.task_executor.cancel(invocation.task_id)
173144
raise ErrorInfo.for_reason(
174145
ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled."
175146
).to_cancelled_error()
176-
elif descriptor.command.startswith(b"r:"):
147+
elif isinstance(invocation, RetryInvocation):
177148
# retry descriptor: extract the task_id, do not submit it again and do one polling iteration
178-
task_id = descriptor.command[2:].decode()
149+
task_id = invocation.task_id
179150
# for retries, we also need to check the call deadline for the whole call duration
180151
task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id)
181152
if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline:
182153
self._ctx.task_executor.cancel(task_id)
183154
raise ErrorInfo.for_reason(
184155
ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}."
185156
).to_timeout_error()
186-
else:
157+
elif isinstance(invocation, SubmitInvocation):
187158
# basic first-time submit: submit the task and do one polling iteration.
188159
# do not check call deadline to give it a chance to wait for the result at least once
189160
try:
190-
task = self._prepare_task(context, descriptor)
161+
task = self._prepare_task(context, invocation)
191162
self._ctx.task_executor.submit(task)
192163
task_id = task.task_id
193164
fun_name = task.fun_name
194165
except Exception:
195166
_LOGGER.error("flexconnect_fun_submit_failed", exc_info=True)
196167
raise
168+
else:
169+
# can be replaced by assert_never when we are on 3.11
170+
raise AssertionError
197171

198172
try:
199173
task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# (C) 2025 GoodData Corporation
2+
from dataclasses import dataclass
3+
from typing import Optional, Union
4+
5+
import orjson
6+
import pyarrow.flight
7+
from gooddata_flight_server import ErrorInfo
8+
9+
10+
@dataclass(frozen=True)
11+
class RetryInvocation:
12+
"""
13+
Indicates that the getting the results of the given task should be retried.
14+
"""
15+
16+
task_id: str
17+
18+
19+
@dataclass(frozen=True)
20+
class CancelInvocation:
21+
"""
22+
Indicates that the given task should be cancelled.
23+
"""
24+
25+
task_id: str
26+
27+
28+
@dataclass(frozen=True)
29+
class SubmitInvocation:
30+
"""
31+
Indicates that the given task should be submitted for processing.
32+
"""
33+
34+
command: bytes
35+
"""
36+
The raw command that was sent to the Flight Server.
37+
"""
38+
39+
function_name: str
40+
"""
41+
The name of the FlexConnect function to invoke.
42+
"""
43+
44+
parameters: dict
45+
"""
46+
Parameters to pass to the FlexConnect function.
47+
"""
48+
49+
columns: Optional[tuple[str, ...]]
50+
"""
51+
Columns to get from the FlexConnect function result.
52+
This may be used for column trimming by the function: the function must return at least those columns.
53+
"""
54+
55+
56+
def extract_invocation_from_descriptor(
57+
descriptor: pyarrow.flight.FlightDescriptor,
58+
) -> Union[RetryInvocation, CancelInvocation, SubmitInvocation]:
59+
"""
60+
Given a flight descriptor, extract the invocation information from it.
61+
"""
62+
63+
if descriptor.command is None or not len(descriptor.command):
64+
raise ErrorInfo.bad_argument(
65+
"Incorrect FlexConnect function invocation. Flight descriptor must contain command "
66+
"with the invocation payload."
67+
)
68+
69+
if descriptor.command.startswith(b"c:"):
70+
task_id = descriptor.command[2:].decode()
71+
return CancelInvocation(task_id)
72+
elif descriptor.command.startswith(b"r:"):
73+
task_id = descriptor.command[2:].decode()
74+
return RetryInvocation(task_id)
75+
else:
76+
try:
77+
payload = orjson.loads(descriptor.command)
78+
except Exception:
79+
raise ErrorInfo.bad_argument(
80+
"Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON."
81+
)
82+
83+
function_name = payload.get("functionName")
84+
if function_name is None or not len(function_name):
85+
raise ErrorInfo.bad_argument(
86+
"Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'."
87+
)
88+
89+
parameters = payload.get("parameters") or {}
90+
columns = parameters.get("columns")
91+
92+
return SubmitInvocation(
93+
function_name=function_name, parameters=parameters, columns=columns, command=descriptor.command
94+
)

0 commit comments

Comments
 (0)