11# (C) 2024 GoodData Corporation
2+ import time
23from collections .abc import Generator
34from typing import Optional
45
1718)
1819
1920from gooddata_flexconnect .function .function import FlexConnectFunction
21+ from gooddata_flexconnect .function .function_invocation import (
22+ CancelInvocation ,
23+ RetryInvocation ,
24+ SubmitInvocation ,
25+ extract_pollable_invocation_from_descriptor ,
26+ extract_submit_invocation_from_descriptor ,
27+ )
2028from gooddata_flexconnect .function .function_registry import FlexConnectFunctionRegistry
2129from gooddata_flexconnect .function .function_task import FlexConnectFunctionTask
2230
2331_LOGGER = structlog .get_logger ("gooddata_flexconnect.rpc" )
2432
33+ POLLING_HEADER_NAME = "x-quiver-pollable"
34+ """
35+ If this header is present on the get flight info call, the polling extension will be used.
36+ Otherwise the basic do get will be used.
37+ """
38+
39+
40+ def _prepare_poll_error (task_id : str ) -> pyarrow .flight .FlightError :
41+ return ErrorInfo .poll (
42+ flight_info = None ,
43+ cancel_descriptor = pyarrow .flight .FlightDescriptor .for_command (f"c:{ task_id } " .encode ()),
44+ retry_descriptor = pyarrow .flight .FlightDescriptor .for_command (f"r:{ task_id } " .encode ()),
45+ )
46+
2547
2648class _FlexConnectServerMethods (FlightServerMethods ):
27- def __init__ (self , ctx : ServerContext , registry : FlexConnectFunctionRegistry , call_deadline_ms : float ) -> None :
49+ def __init__ (
50+ self ,
51+ ctx : ServerContext ,
52+ registry : FlexConnectFunctionRegistry ,
53+ call_deadline_ms : float ,
54+ poll_interval_ms : float ,
55+ ) -> None :
2856 self ._ctx = ctx
2957 self ._registry = registry
3058 self ._call_deadline = call_deadline_ms / 1000
59+ self ._poll_interval = poll_interval_ms / 1000
3160
3261 @staticmethod
3362 def _create_descriptor (fun_name : str , metadata : Optional [dict ]) -> pyarrow .flight .FlightDescriptor :
@@ -52,58 +81,37 @@ def _create_fun_info(self, fun: type[FlexConnectFunction]) -> pyarrow.flight.Fli
5281 total_records = - 1 ,
5382 )
5483
55- def _extract_invocation_payload (
56- self , descriptor : pyarrow .flight .FlightDescriptor
57- ) -> tuple [str , dict , Optional [tuple [str , ...]]]:
58- if descriptor .command is None or not len (descriptor .command ):
59- raise ErrorInfo .bad_argument (
60- "Incorrect FlexConnect function invocation. Flight descriptor must contain command "
61- "with the invocation payload."
62- )
63-
64- try :
65- payload = orjson .loads (descriptor .command )
66- except Exception :
67- raise ErrorInfo .bad_argument (
68- "Incorrect FlexConnect function invocation. The invocation payload is not a valid JSON."
69- )
70-
71- fun = payload .get ("functionName" )
72- if fun is None or not len (fun ):
73- raise ErrorInfo .bad_argument (
74- "Incorrect FlexConnect function invocation. The invocation payload does not specify 'functionName'."
75- )
76-
77- parameters = payload .get ("parameters" ) or {}
78- columns = parameters .get ("columns" )
79-
80- return fun , parameters , columns
81-
8284 def _prepare_task (
8385 self ,
8486 context : pyarrow .flight .ServerCallContext ,
85- descriptor : pyarrow . flight . FlightDescriptor ,
87+ submit_invocation : SubmitInvocation ,
8688 ) -> FlexConnectFunctionTask :
87- fun_name , parameters , columns = self ._extract_invocation_payload (descriptor )
8889 headers = self .call_info_middleware (context ).headers
89- fun = self ._registry .create_function (fun_name )
90+ fun = self ._registry .create_function (submit_invocation . function_name )
9091
9192 return FlexConnectFunctionTask (
9293 fun = fun ,
93- parameters = parameters ,
94- columns = columns ,
94+ parameters = submit_invocation . parameters ,
95+ columns = submit_invocation . columns ,
9596 headers = headers ,
96- cmd = descriptor .command ,
97+ cmd = submit_invocation .command ,
9798 )
9899
99- def _prepare_flight_info (self , task_result : TaskExecutionResult ) -> pyarrow .flight .FlightInfo :
100+ def _prepare_flight_info (
101+ self , task_id : str , task_result : Optional [TaskExecutionResult ]
102+ ) -> pyarrow .flight .FlightInfo :
103+ if task_result is None :
104+ raise ErrorInfo .for_reason (
105+ ErrorCode .BAD_ARGUMENT , f"Task with id '{ task_id } ' does not exist."
106+ ).to_user_error ()
107+
100108 if task_result .error is not None :
101109 raise task_result .error .as_flight_error ()
102110
103111 if task_result .cancelled :
104112 raise ErrorInfo .for_reason (
105113 ErrorCode .COMMAND_CANCELLED ,
106- f"FlexConnect function invocation was cancelled. Invocation task was: '{ task_result . task_id } '." ,
114+ f"FlexConnect function invocation was cancelled. Invocation task was: '{ task_id } '." ,
107115 ).to_server_error ()
108116
109117 result = task_result .result
@@ -114,40 +122,33 @@ def _prepare_flight_info(self, task_result: TaskExecutionResult) -> pyarrow.flig
114122 descriptor = pyarrow .flight .FlightDescriptor .for_command (task_result .cmd ),
115123 endpoints = [
116124 pyarrow .flight .FlightEndpoint (
117- ticket = pyarrow .flight .Ticket (ticket = orjson .dumps ({"task_id" : task_result . task_id })),
125+ ticket = pyarrow .flight .Ticket (ticket = orjson .dumps ({"task_id" : task_id })),
118126 locations = [self ._ctx .location ],
119127 )
120128 ],
121129 total_records = - 1 ,
122130 total_bytes = - 1 ,
123131 )
124132
125- ###################################################################
126- # Implementation of Flight RPC methods
127- ###################################################################
128-
129- def list_flights (
130- self , context : pyarrow .flight .ServerCallContext , criteria : bytes
131- ) -> Generator [pyarrow .flight .FlightInfo , None , None ]:
132- structlog .contextvars .bind_contextvars (peer = context .peer ())
133- _LOGGER .info ("list_flights" , available_funs = self ._registry .function_names )
134-
135- return (self ._create_fun_info (fun ) for fun in self ._registry .functions .values ())
136-
137- def get_flight_info (
133+ def _get_flight_info_no_polling (
138134 self ,
139135 context : pyarrow .flight .ServerCallContext ,
140136 descriptor : pyarrow .flight .FlightDescriptor ,
141137 ) -> pyarrow .flight .FlightInfo :
138+ """
139+ Basic DoGetInfo flow with no polling extension.
140+ This conforms to the mainline Arrow Flight RPC specification.
141+ """
142142 structlog .contextvars .bind_contextvars (peer = context .peer ())
143+ invocation = extract_submit_invocation_from_descriptor (descriptor )
144+
143145 task : Optional [FlexConnectFunctionTask ] = None
144146
145147 try :
146- task = self ._prepare_task (context , descriptor )
148+ task = self ._prepare_task (context , invocation )
147149 self ._ctx .task_executor .submit (task )
148150
149151 try :
150- # XXX: this should be enhanced to implement polling
151152 task_result = self ._ctx .task_executor .wait_for_result (task .task_id , self ._call_deadline )
152153 except TaskWaitTimeoutError :
153154 cancelled = self ._ctx .task_executor .cancel (task .task_id )
@@ -166,15 +167,106 @@ def get_flight_info(
166167 # particular task id finished
167168 assert task_result is not None
168169
169- return self ._prepare_flight_info (task_result )
170+ return self ._prepare_flight_info (task_id = task . task_id , task_result = task_result )
170171 except Exception :
171172 if task is not None :
172- _LOGGER .error ("get_flight_info_failed" , task_id = task .task_id , fun = task .fun_name , exc_info = True )
173+ _LOGGER .error (
174+ "get_flight_info_failed" , task_id = task .task_id , fun = task .fun_name , exc_info = True , polling = False
175+ )
173176 else :
174- _LOGGER .error ("flexconnect_fun_submit_failed" , exc_info = True )
177+ _LOGGER .error ("flexconnect_fun_submit_failed" , exc_info = True , polling = False )
178+ raise
179+
180+ def _get_flight_info_polling (
181+ self ,
182+ context : pyarrow .flight .ServerCallContext ,
183+ descriptor : pyarrow .flight .FlightDescriptor ,
184+ ) -> pyarrow .flight .FlightInfo :
185+ """
186+ DoGetInfo flow with polling extension.
187+ This extends the mainline Arrow Flight RPC specification with polling capabilities using the RetryInfo
188+ encoded into the FlightTimedOutError.extra_info.
189+ Ideally, we would use the mainline PollFlightInfo, but that has yet to be implemented in the PyArrow library.
190+ """
191+ structlog .contextvars .bind_contextvars (peer = context .peer ())
192+ invocation = extract_pollable_invocation_from_descriptor (descriptor )
193+
194+ task_id : str
195+ fun_name : Optional [str ] = None
196+
197+ if isinstance (invocation , CancelInvocation ):
198+ # cancel the given task and raise cancellation exception
199+ if self ._ctx .task_executor .cancel (invocation .task_id ):
200+ raise ErrorInfo .for_reason (
201+ ErrorCode .COMMAND_CANCELLED , "FlexConnect function invocation was cancelled."
202+ ).to_cancelled_error ()
203+ raise ErrorInfo .for_reason (
204+ ErrorCode .COMMAND_CANCEL_NOT_POSSIBLE , "FlexConnect function invocation could not be cancelled."
205+ ).to_cancelled_error ()
206+ elif isinstance (invocation , RetryInvocation ):
207+ # retry descriptor: extract the task_id, do not submit it again and do one polling iteration
208+ task_id = invocation .task_id
209+ elif isinstance (invocation , SubmitInvocation ):
210+ # basic first-time submit: submit the task and do one polling iteration.
211+ # do not check call deadline to give it a chance to wait for the result at least once
212+ try :
213+ task = self ._prepare_task (context , invocation )
214+ self ._ctx .task_executor .submit (task )
215+ task_id = task .task_id
216+ fun_name = task .fun_name
217+ except Exception :
218+ _LOGGER .error ("flexconnect_fun_submit_failed" , exc_info = True , polling = True )
219+ raise
220+ else :
221+ # can be replaced by assert_never when we are on 3.11
222+ raise AssertionError
223+
224+ try :
225+ task_result = self ._ctx .task_executor .wait_for_result (task_id , timeout = self ._poll_interval )
226+ return self ._prepare_flight_info (task_id , task_result )
227+ except TimeoutError :
228+ # first, check the call deadline for the whole call duration
229+ task_timestamp = self ._ctx .task_executor .get_task_submitted_timestamp (task_id )
230+ if task_timestamp is not None and time .perf_counter () - task_timestamp > self ._call_deadline :
231+ self ._ctx .task_executor .cancel (task_id )
232+ raise ErrorInfo .for_reason (
233+ ErrorCode .TIMEOUT , f"GetFlightInfo timed out while waiting for task { task_id } ."
234+ ).to_timeout_error ()
175235
236+ # if the result is not ready, and we still have time, indicate to the client
237+ # how to poll for the results
238+ raise _prepare_poll_error (task_id )
239+ except Exception :
240+ _LOGGER .error ("get_flight_info_failed" , task_id = task_id , fun = fun_name , exc_info = True , polling = True )
176241 raise
177242
243+ ###################################################################
244+ # Implementation of Flight RPC methods
245+ ###################################################################
246+
247+ def list_flights (
248+ self , context : pyarrow .flight .ServerCallContext , criteria : bytes
249+ ) -> Generator [pyarrow .flight .FlightInfo , None , None ]:
250+ structlog .contextvars .bind_contextvars (peer = context .peer ())
251+ _LOGGER .info ("list_flights" , available_funs = self ._registry .function_names )
252+
253+ return (self ._create_fun_info (fun ) for fun in self ._registry .functions .values ())
254+
255+ def get_flight_info (
256+ self ,
257+ context : pyarrow .flight .ServerCallContext ,
258+ descriptor : pyarrow .flight .FlightDescriptor ,
259+ ) -> pyarrow .flight .FlightInfo :
260+ structlog .contextvars .bind_contextvars (peer = context .peer ())
261+
262+ headers = self .call_info_middleware (context ).headers
263+ allow_polling = headers .get (POLLING_HEADER_NAME ) is not None
264+
265+ if allow_polling :
266+ return self ._get_flight_info_polling (context , descriptor )
267+ else :
268+ return self ._get_flight_info_no_polling (context , descriptor )
269+
178270 def do_get (
179271 self ,
180272 context : pyarrow .flight .ServerCallContext ,
@@ -201,7 +293,9 @@ def do_get(
201293_FLEX_CONNECT_CONFIG_SECTION = "flexconnect"
202294_FLEX_CONNECT_FUNCTION_LIST = "functions"
203295_FLEX_CONNECT_CALL_DEADLINE_MS = "call_deadline_ms"
296+ _FLEX_CONNECT_POLLING_INTERVAL_MS = "polling_interval_ms"
204297_DEFAULT_FLEX_CONNECT_CALL_DEADLINE_MS = 180_000
298+ _DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS = 2000
205299
206300
207301def _read_call_deadline_ms (ctx : ServerContext ) -> int :
@@ -223,6 +317,24 @@ def _read_call_deadline_ms(ctx: ServerContext) -> int:
223317 )
224318
225319
320+ def _read_polling_interval_ms (ctx : ServerContext ) -> int :
321+ polling_interval = ctx .settings .get (f"{ _FLEX_CONNECT_CONFIG_SECTION } .{ _FLEX_CONNECT_POLLING_INTERVAL_MS } " )
322+ if polling_interval is None :
323+ return _DEFAULT_FLEX_CONNECT_POLLING_INTERVAL_MS
324+
325+ try :
326+ polling_interval = int (polling_interval )
327+ if polling_interval <= 0 :
328+ raise ValueError ()
329+ return polling_interval
330+ except ValueError :
331+ raise ValueError (
332+ f"Value of { _FLEX_CONNECT_CONFIG_SECTION } .{ _FLEX_CONNECT_POLLING_INTERVAL_MS } must "
333+ f"be a positive number - duration, in milliseconds, that FlexConnect function "
334+ f"waits for the result during one polling iteration."
335+ )
336+
337+
226338@flight_server_methods
227339def create_flexconnect_flight_methods (ctx : ServerContext ) -> FlightServerMethods :
228340 """
@@ -236,8 +348,9 @@ def create_flexconnect_flight_methods(ctx: ServerContext) -> FlightServerMethods
236348 """
237349 modules = list (ctx .settings .get (f"{ _FLEX_CONNECT_CONFIG_SECTION } .{ _FLEX_CONNECT_FUNCTION_LIST } " ) or [])
238350 call_deadline_ms = _read_call_deadline_ms (ctx )
351+ polling_interval_ms = _read_polling_interval_ms (ctx )
239352
240353 _LOGGER .info ("flexconnect_init" , modules = modules )
241354 registry = FlexConnectFunctionRegistry ().load (ctx , modules )
242355
243- return _FlexConnectServerMethods (ctx , registry , call_deadline_ms )
356+ return _FlexConnectServerMethods (ctx , registry , call_deadline_ms , polling_interval_ms )
0 commit comments