Skip to content

Commit c3cdaf2

Browse files
committed
Added refactored workerpool and updated sync.segment
1 parent a79e72e commit c3cdaf2

File tree

4 files changed

+96
-109
lines changed

4 files changed

+96
-109
lines changed

splitio/sync/segment.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,11 @@ async def synchronize_segments(self, segment_names = None, dont_wait = False):
357357
if segment_names is None:
358358
segment_names = await self._feature_flag_storage.get_segment_names()
359359

360-
for segment_name in segment_names:
361-
await self._worker_pool.submit_work(segment_name)
360+
jobs = await self._worker_pool.submit_work(segment_names)
362361
if (dont_wait):
363362
return True
364-
await asyncio.sleep(.5)
365-
return not await self._worker_pool.wait_for_completion()
363+
await jobs.await_completion()
364+
return not self._worker_pool.pop_failed()
366365

367366
async def segment_exist_in_storage(self, segment_name):
368367
"""

splitio/tasks/util/workerpool.py

Lines changed: 66 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from splitio.optional.loaders import asyncio
88

99
_LOGGER = logging.getLogger(__name__)
10-
_ASYNC_SLEEP_SECONDS = 0.3
11-
1210

1311
class WorkerPool(object):
1412
"""Worker pool class to implement single producer/multiple consumer."""
@@ -141,121 +139,94 @@ def _wait_workers_shutdown(self, event):
141139
class WorkerPoolAsync(object):
142140
"""Worker pool async class to implement single producer/multiple consumer."""
143141

142+
_abort = object()
143+
144144
def __init__(self, worker_count, worker_func):
145145
"""
146146
Class constructor.
147147
148148
:param worker_count: Number of workers for the pool.
149149
:type worker_func: Function to be executed by the workers whenever a messages is fetched.
150150
"""
151+
self._semaphore = asyncio.Semaphore(worker_count)
152+
self._queue = asyncio.Queue()
153+
self._handler = worker_func
154+
self._aborted = False
151155
self._failed = False
152-
self._running = False
153-
self._incoming = asyncio.Queue()
154-
self._worker_count = worker_count
155-
self._worker_func = worker_func
156-
self.current_workers = []
157156

157+
async def _schedule_work(self):
158+
"""wrap the message handler execution."""
159+
while True:
160+
message = await self._queue.get()
161+
if message == self._abort:
162+
self._aborted = True
163+
return
164+
asyncio.get_running_loop().create_task(self._do_work(message))
165+
166+
async def _do_work(self, message):
167+
"""process a single message."""
168+
try:
169+
await self._semaphore.acquire() # wait until "there's a free worker"
170+
if self._aborted: # check in case the pool was shutdown while we were waiting for a worker
171+
return
172+
await self._handler(message._message)
173+
except Exception:
174+
_LOGGER.error("Something went wrong when processing message %s", message)
175+
_LOGGER.debug('Original traceback: ', exc_info=True)
176+
self._failed = True
177+
message._complete.set()
178+
self._semaphore.release() # signal worker is idle
158179

159180
def start(self):
160181
"""Start the workers."""
161-
self._running = True
162-
self._worker_pool_task = asyncio.get_running_loop().create_task(self._wrapper())
182+
self._task = asyncio.get_running_loop().create_task(self._schedule_work())
163183

164-
async def _safe_run(self, message):
184+
async def submit_work(self, jobs):
165185
"""
166-
Execute the user funcion for a given message without raising exceptions.
167-
168-
:param func: User defined function.
169-
:type func: callable
170-
:param message: Message fetched from the queue.
171-
:param message: object
186+
Add a new message to the work-queue.
172187
173-
:return True if no everything goes well. False otherwise.
174-
:rtype bool
188+
:param message: New message to add.
189+
:type message: object.
175190
"""
176-
try:
177-
await self._worker_func(message)
178-
return True
179-
except Exception: # pylint: disable=broad-except
180-
_LOGGER.error("Something went wrong when processing message %s", message)
181-
_LOGGER.error('Original traceback: ', exc_info=True)
182-
return False
191+
self.jobs = jobs
192+
if len(jobs) == 1:
193+
wrapped = TaskCompletionWraper(jobs[0])
194+
await self._queue.put(wrapped)
195+
return wrapped
183196

184-
async def _wrapper(self):
185-
"""
186-
Fetch message, execute tasks, and acknowledge results.
197+
tasks = [TaskCompletionWraper(job) for job in jobs]
198+
for w in tasks:
199+
await self._queue.put(w)
187200

188-
:param worker_number: # (id) of worker whose function will be executed.
189-
:type worker_number: int
190-
:param func: User defined function.
191-
:type func: callable.
192-
"""
193-
self.current_workers = []
194-
while self._running:
195-
try:
196-
if len(self.current_workers) == self._worker_count or self._incoming.qsize() == 0:
197-
await asyncio.sleep(_ASYNC_SLEEP_SECONDS)
198-
self._check_and_clean_workers()
199-
continue
200-
message = await self._incoming.get()
201-
# For some reason message can be None in python2 implementation of queue.
202-
# This method must be both ignored and acknowledged with .task_done()
203-
# otherwise .join() will halt.
204-
if message is None:
205-
_LOGGER.debug('spurious message received. acking and ignoring.')
206-
continue
201+
return BatchCompletionWrapper(tasks)
207202

208-
# If the task is successfully executed, the ack is done AFTERWARDS,
209-
# to avoid race conditions on SDK initialization.
210-
_LOGGER.debug("processing message '%s'", message)
211-
self.current_workers.append([asyncio.get_running_loop().create_task(self._safe_run(message)), message])
203+
async def stop(self, event=None):
204+
"""abort all execution (except currently running handlers)."""
205+
await self._queue.put(self._abort)
212206

213-
# check tasks status
214-
self._check_and_clean_workers()
215-
except queue.Empty:
216-
# No message was fetched, just keep waiting.
217-
pass
207+
def pop_failed(self):
208+
old = self._failed
209+
self._failed = False
210+
return old
218211

219-
def _check_and_clean_workers(self):
220-
found_running = False
221-
for task in self.current_workers:
222-
if task[0].done():
223-
self.current_workers.remove(task)
224-
if not task[0].result():
225-
self._failed = True
226-
_LOGGER.error(
227-
("Something went wrong during the execution, "
228-
"removing message \"%s\" from queue.",
229-
task[1])
230-
)
231-
else:
232-
found_running = True
233-
return found_running
234212

235-
async def submit_work(self, message):
236-
"""
237-
Add a new message to the work-queue.
213+
class TaskCompletionWraper:
214+
"""Task completion class"""
215+
def __init__(self, message):
216+
self._message = message
217+
self._complete = asyncio.Event()
238218

239-
:param message: New message to add.
240-
:type message: object.
241-
"""
242-
await self._incoming.put(message)
243-
_LOGGER.debug('queued message %s for processing.', message)
219+
async def await_completion(self):
220+
await self._complete.wait()
244221

245-
async def wait_for_completion(self):
246-
"""Block until the work queue is empty."""
247-
_LOGGER.debug('waiting for all messages to be processed.')
248-
if self._incoming.qsize() > 0:
249-
await self._incoming.join()
250-
_LOGGER.debug('all messages processed.')
251-
old = self._failed
252-
self._failed = False
253-
self._running = False
254-
return old
222+
def _mark_as_complete(self):
223+
self._complete.set()
255224

256-
async def stop(self, event=None):
257-
"""Stop all worker nodes."""
258-
await self.wait_for_completion()
259-
while self._check_and_clean_workers():
260-
await asyncio.sleep(_ASYNC_SLEEP_SECONDS)
261-
self._worker_pool_task.cancel()
225+
226+
class BatchCompletionWrapper:
227+
"""Batch completion class"""
228+
def __init__(self, tasks):
229+
self._tasks = tasks
230+
231+
async def await_completion(self):
232+
await asyncio.gather(*[task.await_completion() for task in self._tasks])

tests/sync/test_segments_synchronizer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,22 @@ async def get_segment_names():
202202
split_storage.get_segment_names = get_segment_names
203203

204204
storage = mocker.Mock(spec=SegmentStorage)
205-
async def get_change_number():
205+
async def get_change_number(*args):
206206
return -1
207207
storage.get_change_number = get_change_number
208208

209+
async def put(*args):
210+
pass
211+
storage.put = put
212+
209213
api = mocker.Mock()
210-
async def run(x):
214+
async def run(*args):
211215
raise APIException("something broke")
212216
api.fetch_segment = run
213217

214218
segments_synchronizer = SegmentSynchronizerAsync(api, split_storage, storage)
215219
assert not await segments_synchronizer.synchronize_segments()
220+
await segments_synchronizer.shutdown()
216221

217222
@pytest.mark.asyncio
218223
async def test_synchronize_segments(self, mocker):
@@ -295,6 +300,8 @@ async def fetch_segment_mock(segment_name, change_number, fetch_options):
295300
assert segment.name in segments_to_validate
296301
segments_to_validate.remove(segment.name)
297302

303+
await segments_synchronizer.shutdown()
304+
298305
@pytest.mark.asyncio
299306
async def test_synchronize_segment(self, mocker):
300307
"""Test particular segment update."""
@@ -339,6 +346,8 @@ async def fetch_segment_mock(segment_name, change_number, fetch_options):
339346
assert (self.segment[0], self.change[0], self.options[0]) == ('segmentA', -1, FetchOptions(True))
340347
assert (self.segment[1], self.change[1], self.options[1]) == ('segmentA', 123, FetchOptions(True))
341348

349+
await segments_synchronizer.shutdown()
350+
342351
@pytest.mark.asyncio
343352
async def test_synchronize_segment_cdn(self, mocker):
344353
"""Test particular segment update cdn bypass."""
@@ -401,14 +410,18 @@ async def fetch_segment_mock(segment_name, change_number, fetch_options):
401410
await segments_synchronizer.synchronize_segment('segmentA', 12345)
402411
assert (self.segment[7], self.change[7], self.options[7]) == ('segmentA', 12345, FetchOptions(True, 1234))
403412
assert len(self.segment) == 8 # 2 ok + BACKOFF(2 since==till + 2 re-attempts) + CDN(2 since==till)
413+
await segments_synchronizer.shutdown()
404414

405415
@pytest.mark.asyncio
406416
async def test_recreate(self, mocker):
407417
"""Test recreate logic."""
408418
segments_synchronizer = SegmentSynchronizerAsync(mocker.Mock(), mocker.Mock(), mocker.Mock())
409419
current_pool = segments_synchronizer._worker_pool
420+
await segments_synchronizer.shutdown()
410421
segments_synchronizer.recreate()
422+
411423
assert segments_synchronizer._worker_pool != current_pool
424+
await segments_synchronizer.shutdown()
412425

413426

414427
class LocalSegmentsSynchronizerTests(object):

tests/tasks/util/test_workerpool.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,16 @@ async def worker_func(num):
8989

9090
wpool = workerpool.WorkerPoolAsync(10, worker_func)
9191
wpool.start()
92+
jobs = []
9293
for num in range(0, 11):
93-
await wpool.submit_work(str(num))
94+
jobs.append(str(num))
9495

95-
await asyncio.sleep(1)
96+
task = await wpool.submit_work(jobs)
97+
await task.await_completion()
9698
await wpool.stop()
97-
assert wpool._running == False
9899
for num in range(0, 11):
99100
assert str(num) in calls
101+
assert not wpool.pop_failed()
100102

101103
@pytest.mark.asyncio
102104
async def test_fail_in_msg_doesnt_break(self):
@@ -114,9 +116,10 @@ async def do_work(self, work):
114116
wpool = workerpool.WorkerPoolAsync(50, worker.do_work)
115117
wpool.start()
116118
for num in range(0, 100):
117-
await wpool.submit_work(str(num))
119+
await wpool.submit_work([str(num)])
118120
await asyncio.sleep(1)
119121
await wpool.stop()
122+
assert wpool.pop_failed()
120123

121124
for num in range(0, 100):
122125
if num != 55:
@@ -138,9 +141,10 @@ async def do_work(self, work):
138141
worker = Worker()
139142
wpool = workerpool.WorkerPoolAsync(50, worker.do_work)
140143
wpool.start()
144+
jobs = []
141145
for num in range(0, 100):
142-
await wpool.submit_work(str(num))
143-
144-
await asyncio.sleep(1)
145-
await wpool.wait_for_completion()
146+
jobs.append(str(num))
147+
task = await wpool.submit_work(jobs)
148+
await task.await_completion()
149+
await wpool.stop()
146150
assert len(worker.worked) == 100

0 commit comments

Comments
 (0)