|
17 | 17 | ) |
18 | 18 |
|
19 | 19 | from gooddata_flexconnect.function.function import FlexConnectFunction |
| 20 | +from gooddata_flexconnect.function.function_invocation import ( |
| 21 | + CancelInvocation, |
| 22 | + RetryInvocation, |
| 23 | + SubmitInvocation, |
| 24 | + extract_invocation_from_descriptor, |
| 25 | +) |
20 | 26 | from gooddata_flexconnect.function.function_registry import FlexConnectFunctionRegistry |
21 | 27 | from gooddata_flexconnect.function.function_task import FlexConnectFunctionTask |
22 | 28 |
|
@@ -67,48 +73,20 @@ def _create_fun_info(self, fun: type[FlexConnectFunction]) -> pyarrow.flight.Fli |
67 | 73 | total_records=-1, |
68 | 74 | ) |
69 | 75 |
|
70 | | - def _extract_invocation_payload( |
71 | | - self, descriptor: pyarrow.flight.FlightDescriptor |
72 | | - ) -> tuple[str, dict, Optional[tuple[str, ...]]]: |
73 | | - if descriptor.command is None or not len(descriptor.command): |
74 | | - raise ErrorInfo.bad_argument( |
75 | | - "Incorrect FlexConnect function invocation. Flight descriptor must contain command " |
76 | | - "with the invocation payload." |
77 | | - ) |
78 | | - |
79 | | - try: |
80 | | - payload = orjson.loads(descriptor.command) |
81 | | - except Exception: |
82 | | - raise ErrorInfo.bad_argument( |
83 | | - "Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON." |
84 | | - ) |
85 | | - |
86 | | - fun = payload.get("functionName") |
87 | | - if fun is None or not len(fun): |
88 | | - raise ErrorInfo.bad_argument( |
89 | | - "Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'." |
90 | | - ) |
91 | | - |
92 | | - parameters = payload.get("parameters") or {} |
93 | | - columns = parameters.get("columns") |
94 | | - |
95 | | - return fun, parameters, columns |
96 | | - |
97 | 76 | def _prepare_task( |
98 | 77 | self, |
99 | 78 | context: pyarrow.flight.ServerCallContext, |
100 | | - descriptor: pyarrow.flight.FlightDescriptor, |
| 79 | + submit_invocation: SubmitInvocation, |
101 | 80 | ) -> FlexConnectFunctionTask: |
102 | | - fun_name, parameters, columns = self._extract_invocation_payload(descriptor) |
103 | 81 | headers = self.call_info_middleware(context).headers |
104 | | - fun = self._registry.create_function(fun_name) |
| 82 | + fun = self._registry.create_function(submit_invocation.function_name) |
105 | 83 |
|
106 | 84 | return FlexConnectFunctionTask( |
107 | 85 | fun=fun, |
108 | | - parameters=parameters, |
109 | | - columns=columns, |
| 86 | + parameters=submit_invocation.parameters, |
| 87 | + columns=submit_invocation.columns, |
110 | 88 | headers=headers, |
111 | | - cmd=descriptor.command, |
| 89 | + cmd=submit_invocation.command, |
112 | 90 | ) |
113 | 91 |
|
114 | 92 | def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flight.FlightInfo: |
@@ -156,44 +134,40 @@ def get_flight_info( |
156 | 134 | ) -> pyarrow.flight.FlightInfo: |
157 | 135 | structlog.contextvars.bind_contextvars(peer=context.peer()) |
158 | 136 |
|
159 | | - # first, check if the descriptor is a cancel descriptor |
160 | | - if descriptor.command is None or not len(descriptor.command): |
161 | | - raise ErrorInfo.bad_argument( |
162 | | - "Incorrect FlexConnect function invocation. Flight descriptor must contain command " |
163 | | - "with the invocation payload." |
164 | | - ) |
165 | | - |
166 | 137 | task_id: str |
167 | 138 | fun_name: Optional[str] = None |
| 139 | + invocation = extract_invocation_from_descriptor(descriptor) |
168 | 140 |
|
169 | | - if descriptor.command.startswith(b"c:"): |
170 | | - # cancel descriptor: just cancel the given task and raise cancellation exception |
171 | | - task_id = descriptor.command[2:].decode() |
172 | | - self._ctx.task_executor.cancel(task_id) |
| 141 | + if isinstance(invocation, CancelInvocation): |
| 142 | + # cancel the given task and raise cancellation exception |
| 143 | + self._ctx.task_executor.cancel(invocation.task_id) |
173 | 144 | raise ErrorInfo.for_reason( |
174 | 145 | ErrorCode.COMMAND_CANCELLED, "FlexConnect function invocation was cancelled." |
175 | 146 | ).to_cancelled_error() |
176 | | - elif descriptor.command.startswith(b"r:"): |
| 147 | + elif isinstance(invocation, RetryInvocation): |
177 | 148 | # retry descriptor: extract the task_id, do not submit it again and do one polling iteration |
178 | | - task_id = descriptor.command[2:].decode() |
| 149 | + task_id = invocation.task_id |
179 | 150 | # for retries, we also need to check the call deadline for the whole call duration |
180 | 151 | task_timestamp = self._ctx.task_executor.get_task_submitted_timestamp(task_id) |
181 | 152 | if task_timestamp is not None and time.perf_counter() - task_timestamp > self._call_deadline: |
182 | 153 | self._ctx.task_executor.cancel(task_id) |
183 | 154 | raise ErrorInfo.for_reason( |
184 | 155 | ErrorCode.TIMEOUT, f"GetFlightInfo timed out while waiting for task {task_id}." |
185 | 156 | ).to_timeout_error() |
186 | | - else: |
| 157 | + elif isinstance(invocation, SubmitInvocation): |
187 | 158 | # basic first-time submit: submit the task and do one polling iteration. |
188 | 159 | # do not check call deadline to give it a chance to wait for the result at least once |
189 | 160 | try: |
190 | | - task = self._prepare_task(context, descriptor) |
| 161 | + task = self._prepare_task(context, invocation) |
191 | 162 | self._ctx.task_executor.submit(task) |
192 | 163 | task_id = task.task_id |
193 | 164 | fun_name = task.fun_name |
194 | 165 | except Exception: |
195 | 166 | _LOGGER.error("flexconnect_fun_submit_failed", exc_info=True) |
196 | 167 | raise |
| 168 | + else: |
| 169 | + # can be replaced by assert_never when we are on 3.11 |
| 170 | + raise AssertionError |
197 | 171 |
|
198 | 172 | try: |
199 | 173 | task_result = self._ctx.task_executor.wait_for_result(task_id, timeout=self._poll_interval) |
|
0 commit comments