77from splitio .optional .loaders import asyncio
88
99_LOGGER = logging .getLogger (__name__ )
10- _ASYNC_SLEEP_SECONDS = 0.3
11-
1210
1311class WorkerPool (object ):
1412 """Worker pool class to implement single producer/multiple consumer."""
@@ -141,121 +139,94 @@ def _wait_workers_shutdown(self, event):
141139class WorkerPoolAsync (object ):
142140 """Worker pool async class to implement single producer/multiple consumer."""
143141
142+ _abort = object ()
143+
144144 def __init__ (self , worker_count , worker_func ):
145145 """
146146 Class constructor.
147147
148148 :param worker_count: Number of workers for the pool.
149149 :type worker_func: Function to be executed by the workers whenever a messages is fetched.
150150 """
151+ self ._semaphore = asyncio .Semaphore (worker_count )
152+ self ._queue = asyncio .Queue ()
153+ self ._handler = worker_func
154+ self ._aborted = False
151155 self ._failed = False
152- self ._running = False
153- self ._incoming = asyncio .Queue ()
154- self ._worker_count = worker_count
155- self ._worker_func = worker_func
156- self .current_workers = []
157156
157+ async def _schedule_work (self ):
158+ """wrap the message handler execution."""
159+ while True :
160+ message = await self ._queue .get ()
161+ if message == self ._abort :
162+ self ._aborted = True
163+ return
164+ asyncio .get_running_loop ().create_task (self ._do_work (message ))
165+
166+ async def _do_work (self , message ):
167+ """process a single message."""
168+ try :
169+ await self ._semaphore .acquire () # wait until "there's a free worker"
170+ if self ._aborted : # check in case the pool was shutdown while we were waiting for a worker
171+ return
172+ await self ._handler (message ._message )
173+ except Exception :
174+ _LOGGER .error ("Something went wrong when processing message %s" , message )
175+ _LOGGER .debug ('Original traceback: ' , exc_info = True )
176+ self ._failed = True
177+ message ._complete .set ()
178+ self ._semaphore .release () # signal worker is idle
158179
159180 def start (self ):
160181 """Start the workers."""
161- self ._running = True
162- self ._worker_pool_task = asyncio .get_running_loop ().create_task (self ._wrapper ())
182+ self ._task = asyncio .get_running_loop ().create_task (self ._schedule_work ())
163183
164- async def _safe_run (self , message ):
184+ async def submit_work (self , jobs ):
165185 """
166- Execute the user funcion for a given message without raising exceptions.
167-
168- :param func: User defined function.
169- :type func: callable
170- :param message: Message fetched from the queue.
171- :param message: object
186+ Add a new message to the work-queue.
172187
173- :return True if no everything goes well. False otherwise .
174- :rtype bool
188+ :param message: New message to add .
189+ :type message: object.
175190 """
176- try :
177- await self ._worker_func (message )
178- return True
179- except Exception : # pylint: disable=broad-except
180- _LOGGER .error ("Something went wrong when processing message %s" , message )
181- _LOGGER .error ('Original traceback: ' , exc_info = True )
182- return False
191+ self .jobs = jobs
192+ if len (jobs ) == 1 :
193+ wrapped = TaskCompletionWraper (jobs [0 ])
194+ await self ._queue .put (wrapped )
195+ return wrapped
183196
184- async def _wrapper ( self ):
185- """
186- Fetch message, execute tasks, and acknowledge results.
197+ tasks = [ TaskCompletionWraper ( job ) for job in jobs ]
198+ for w in tasks :
199+ await self . _queue . put ( w )
187200
188- :param worker_number: # (id) of worker whose function will be executed.
189- :type worker_number: int
190- :param func: User defined function.
191- :type func: callable.
192- """
193- self .current_workers = []
194- while self ._running :
195- try :
196- if len (self .current_workers ) == self ._worker_count or self ._incoming .qsize () == 0 :
197- await asyncio .sleep (_ASYNC_SLEEP_SECONDS )
198- self ._check_and_clean_workers ()
199- continue
200- message = await self ._incoming .get ()
201- # For some reason message can be None in python2 implementation of queue.
202- # This method must be both ignored and acknowledged with .task_done()
203- # otherwise .join() will halt.
204- if message is None :
205- _LOGGER .debug ('spurious message received. acking and ignoring.' )
206- continue
201+ return BatchCompletionWrapper (tasks )
207202
208- # If the task is successfully executed, the ack is done AFTERWARDS,
209- # to avoid race conditions on SDK initialization.
210- _LOGGER .debug ("processing message '%s'" , message )
211- self .current_workers .append ([asyncio .get_running_loop ().create_task (self ._safe_run (message )), message ])
203+ async def stop (self , event = None ):
204+ """abort all execution (except currently running handlers)."""
205+ await self ._queue .put (self ._abort )
212206
213- # check tasks status
214- self ._check_and_clean_workers ()
215- except queue .Empty :
216- # No message was fetched, just keep waiting.
217- pass
207+ def pop_failed (self ):
208+ old = self ._failed
209+ self ._failed = False
210+ return old
218211
219- def _check_and_clean_workers (self ):
220- found_running = False
221- for task in self .current_workers :
222- if task [0 ].done ():
223- self .current_workers .remove (task )
224- if not task [0 ].result ():
225- self ._failed = True
226- _LOGGER .error (
227- ("Something went wrong during the execution, "
228- "removing message \" %s\" from queue." ,
229- task [1 ])
230- )
231- else :
232- found_running = True
233- return found_running
234212
235- async def submit_work (self , message ):
236- """
237- Add a new message to the work-queue.
213+ class TaskCompletionWraper :
214+ """Task completion class"""
215+ def __init__ (self , message ):
216+ self ._message = message
217+ self ._complete = asyncio .Event ()
238218
239- :param message: New message to add.
240- :type message: object.
241- """
242- await self ._incoming .put (message )
243- _LOGGER .debug ('queued message %s for processing.' , message )
219+ async def await_completion (self ):
220+ await self ._complete .wait ()
244221
245- async def wait_for_completion (self ):
246- """Block until the work queue is empty."""
247- _LOGGER .debug ('waiting for all messages to be processed.' )
248- if self ._incoming .qsize () > 0 :
249- await self ._incoming .join ()
250- _LOGGER .debug ('all messages processed.' )
251- old = self ._failed
252- self ._failed = False
253- self ._running = False
254- return old
222+ def _mark_as_complete (self ):
223+ self ._complete .set ()
255224
256- async def stop (self , event = None ):
257- """Stop all worker nodes."""
258- await self .wait_for_completion ()
259- while self ._check_and_clean_workers ():
260- await asyncio .sleep (_ASYNC_SLEEP_SECONDS )
261- self ._worker_pool_task .cancel ()
225+
226+ class BatchCompletionWrapper :
227+ """Batch completion class"""
228+ def __init__ (self , tasks ):
229+ self ._tasks = tasks
230+
231+ async def await_completion (self ):
232+ await asyncio .gather (* [task .await_completion () for task in self ._tasks ])
0 commit comments