Skip to content

Commit 040e949

Browse files
authored
Merge pull request #424 from splitio/async-sync-segment-workerpool
Added workerpool and sync.segment async classes
2 parents 8580abb + f79bf8b commit 040e949

File tree

4 files changed

+597
-11
lines changed

4 files changed

+597
-11
lines changed

splitio/sync/segment.py

Lines changed: 187 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,36 @@
1010
from splitio.util.backoff import Backoff
1111
from splitio.optional.loaders import asyncio, aiofiles
1212
from splitio.sync import util
13+
from splitio.optional.loaders import asyncio
1314

1415
_LOGGER = logging.getLogger(__name__)
1516

1617

1718
_ON_DEMAND_FETCH_BACKOFF_BASE = 10 # backoff base starting at 10 seconds
1819
_ON_DEMAND_FETCH_BACKOFF_MAX_WAIT = 60 # don't sleep for more than 1 minute
1920
_ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES = 10
21+
_MAX_WORKERS = 10
2022

2123

2224
class SegmentSynchronizer(object):
23-
def __init__(self, segment_api, split_storage, segment_storage):
25+
def __init__(self, segment_api, feature_flag_storage, segment_storage):
2426
"""
2527
Class constructor.
2628
2729
:param segment_api: API to retrieve segments from backend.
2830
:type segment_api: splitio.api.SegmentApi
2931
30-
:param split_storage: Feature Flag Storage.
31-
:type split_storage: splitio.storage.InMemorySplitStorage
32+
:param feature_flag_storage: Feature Flag Storage.
33+
:type feature_flag_storage: splitio.storage.InMemorySplitStorage
3234
3335
:param segment_storage: Segment storage reference.
3436
:type segment_storage: splitio.storage.SegmentStorage
3537
3638
"""
3739
self._api = segment_api
38-
self._split_storage = split_storage
40+
self._feature_flag_storage = feature_flag_storage
3941
self._segment_storage = segment_storage
40-
self._worker_pool = workerpool.WorkerPool(10, self.synchronize_segment)
42+
self._worker_pool = workerpool.WorkerPool(_MAX_WORKERS, self.synchronize_segment)
4143
self._worker_pool.start()
4244
self._backoff = Backoff(
4345
_ON_DEMAND_FETCH_BACKOFF_BASE,
@@ -48,7 +50,7 @@ def recreate(self):
4850
Create worker_pool on forked processes.
4951
5052
"""
51-
self._worker_pool = workerpool.WorkerPool(10, self.synchronize_segment)
53+
self._worker_pool = workerpool.WorkerPool(_MAX_WORKERS, self.synchronize_segment)
5254
self._worker_pool.start()
5355

5456
def shutdown(self):
@@ -176,7 +178,7 @@ def synchronize_segments(self, segment_names = None, dont_wait = False):
176178
:rtype: bool
177179
"""
178180
if segment_names is None:
179-
segment_names = self._split_storage.get_segment_names()
181+
segment_names = self._feature_flag_storage.get_segment_names()
180182

181183
for segment_name in segment_names:
182184
self._worker_pool.submit_work(segment_name)
@@ -196,6 +198,184 @@ def segment_exist_in_storage(self, segment_name):
196198
"""
197199
return self._segment_storage.get(segment_name) != None
198200

201+
202+
class SegmentSynchronizerAsync(object):
203+
def __init__(self, segment_api, feature_flag_storage, segment_storage):
204+
"""
205+
Class constructor.
206+
207+
:param segment_api: API to retrieve segments from backend.
208+
:type segment_api: splitio.api.SegmentApi
209+
210+
:param feature_flag_storage: Feature Flag Storage.
211+
:type feature_flag_storage: splitio.storage.InMemorySplitStorage
212+
213+
:param segment_storage: Segment storage reference.
214+
:type segment_storage: splitio.storage.SegmentStorage
215+
216+
"""
217+
self._api = segment_api
218+
self._feature_flag_storage = feature_flag_storage
219+
self._segment_storage = segment_storage
220+
self._worker_pool = workerpool.WorkerPoolAsync(_MAX_WORKERS, self.synchronize_segment)
221+
self._worker_pool.start()
222+
self._backoff = Backoff(
223+
_ON_DEMAND_FETCH_BACKOFF_BASE,
224+
_ON_DEMAND_FETCH_BACKOFF_MAX_WAIT)
225+
226+
def recreate(self):
227+
"""
228+
Create worker_pool on forked processes.
229+
230+
"""
231+
self._worker_pool = workerpool.WorkerPoolAsync(_MAX_WORKERS, self.synchronize_segment)
232+
self._worker_pool.start()
233+
234+
async def shutdown(self):
235+
"""
236+
Shutdown worker_pool
237+
238+
"""
239+
await self._worker_pool.stop()
240+
241+
async def _fetch_until(self, segment_name, fetch_options, till=None):
242+
"""
243+
Hit endpoint, update storage and return when since==till.
244+
245+
:param segment_name: Name of the segment to update.
246+
:type segment_name: str
247+
248+
:param fetch_options Fetch options for getting segment definitions.
249+
:type fetch_options splitio.api.FetchOptions
250+
251+
:param till: Passed till from Streaming.
252+
:type till: int
253+
254+
:return: last change number
255+
:rtype: int
256+
"""
257+
while True: # Fetch until since==till
258+
change_number = await self._segment_storage.get_change_number(segment_name)
259+
if change_number is None:
260+
change_number = -1
261+
if till is not None and till < change_number:
262+
# the passed till is less than change_number, no need to perform updates
263+
return change_number
264+
265+
try:
266+
segment_changes = await self._api.fetch_segment(segment_name, change_number,
267+
fetch_options)
268+
except APIException as exc:
269+
_LOGGER.error('Exception raised while fetching segment %s', segment_name)
270+
_LOGGER.debug('Exception information: ', exc_info=True)
271+
raise exc
272+
273+
if change_number == -1: # first time fetching the segment
274+
new_segment = segments.from_raw(segment_changes)
275+
await self._segment_storage.put(new_segment)
276+
else:
277+
await self._segment_storage.update(
278+
segment_name,
279+
segment_changes['added'],
280+
segment_changes['removed'],
281+
segment_changes['till']
282+
)
283+
284+
if segment_changes['till'] == segment_changes['since']:
285+
return segment_changes['till']
286+
287+
async def _attempt_segment_sync(self, segment_name, fetch_options, till=None):
288+
"""
289+
Hit endpoint, update storage and return True if sync is complete.
290+
291+
:param segment_name: Name of the segment to update.
292+
:type segment_name: str
293+
294+
:param fetch_options Fetch options for getting feature flag definitions.
295+
:type fetch_options splitio.api.FetchOptions
296+
297+
:param till: Passed till from Streaming.
298+
:type till: int
299+
300+
:return: Flags to check if it should perform bypass or operation ended
301+
:rtype: bool, int, int
302+
"""
303+
self._backoff.reset()
304+
remaining_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES
305+
while True:
306+
remaining_attempts -= 1
307+
change_number = await self._fetch_until(segment_name, fetch_options, till)
308+
if till is None or till <= change_number:
309+
return True, remaining_attempts, change_number
310+
elif remaining_attempts <= 0:
311+
return False, remaining_attempts, change_number
312+
how_long = self._backoff.get()
313+
await asyncio.sleep(how_long)
314+
315+
async def synchronize_segment(self, segment_name, till=None):
316+
"""
317+
Update a segment from queue
318+
319+
:param segment_name: Name of the segment to update.
320+
:type segment_name: str
321+
322+
:param till: ChangeNumber received.
323+
:type till: int
324+
325+
:return: True if no error occurs. False otherwise.
326+
:rtype: bool
327+
"""
328+
fetch_options = FetchOptions(True) # Set Cache-Control to no-cache
329+
successful_sync, remaining_attempts, change_number = await self._attempt_segment_sync(segment_name, fetch_options, till)
330+
attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts
331+
if successful_sync: # succedeed sync
332+
_LOGGER.debug('Refresh completed in %d attempts.', attempts)
333+
return True
334+
with_cdn_bypass = FetchOptions(True, change_number) # Set flag for bypassing CDN
335+
without_cdn_successful_sync, remaining_attempts, change_number = await self._attempt_segment_sync(segment_name, with_cdn_bypass, till)
336+
without_cdn_attempts = _ON_DEMAND_FETCH_BACKOFF_MAX_RETRIES - remaining_attempts
337+
if without_cdn_successful_sync:
338+
_LOGGER.debug('Refresh completed bypassing the CDN in %d attempts.',
339+
without_cdn_attempts)
340+
return True
341+
_LOGGER.debug('No changes fetched after %d attempts with CDN bypassed.',
342+
without_cdn_attempts)
343+
return False
344+
345+
async def synchronize_segments(self, segment_names = None, dont_wait = False):
346+
"""
347+
Submit all current segments and wait for them to finish depend on dont_wait flag, then set the ready flag.
348+
349+
:param segment_names: Optional, array of segment names to update.
350+
:type segment_name: {str}
351+
352+
:param dont_wait: Optional, instruct the function to not wait for task completion
353+
:type segment_name: boolean
354+
355+
:return: True if no error occurs or dont_wait flag is True. False otherwise.
356+
:rtype: bool
357+
"""
358+
if segment_names is None:
359+
segment_names = await self._feature_flag_storage.get_segment_names()
360+
361+
jobs = await self._worker_pool.submit_work(segment_names)
362+
if (dont_wait):
363+
return True
364+
return await jobs.await_completion()
365+
366+
async def segment_exist_in_storage(self, segment_name):
367+
"""
368+
Check if a segment exists in the storage
369+
370+
:param segment_name: Name of the segment
371+
:type segment_name: str
372+
373+
:return: True if segment exist. False otherwise.
374+
:rtype: bool
375+
"""
376+
return await self._segment_storage.get(segment_name) != None
377+
378+
199379
class LocalSegmentSynchronizerBase(object):
200380
"""Localhost mode segment base synchronizer."""
201381

splitio/tasks/util/workerpool.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from threading import Thread, Event
55
import queue
66

7+
from splitio.optional.loaders import asyncio
78

89
_LOGGER = logging.getLogger(__name__)
910

10-
1111
class WorkerPool(object):
1212
"""Worker pool class to implement single producer/multiple consumer."""
1313

@@ -134,3 +134,95 @@ def _wait_workers_shutdown(self, event):
134134
for worker_event in self._worker_events:
135135
worker_event.wait()
136136
event.set()
137+
138+
139+
class WorkerPoolAsync(object):
140+
"""Worker pool async class to implement single producer/multiple consumer."""
141+
142+
_abort = object()
143+
144+
def __init__(self, worker_count, worker_func):
145+
"""
146+
Class constructor.
147+
148+
:param worker_count: Number of workers for the pool.
149+
:type worker_func: Function to be executed by the workers whenever a messages is fetched.
150+
"""
151+
self._semaphore = asyncio.Semaphore(worker_count)
152+
self._queue = asyncio.Queue()
153+
self._handler = worker_func
154+
self._aborted = False
155+
156+
async def _schedule_work(self):
157+
"""wrap the message handler execution."""
158+
while True:
159+
message = await self._queue.get()
160+
if message == self._abort:
161+
self._aborted = True
162+
return
163+
asyncio.get_running_loop().create_task(self._do_work(message))
164+
165+
async def _do_work(self, message):
166+
"""process a single message."""
167+
try:
168+
await self._semaphore.acquire() # wait until "there's a free worker"
169+
if self._aborted: # check in case the pool was shutdown while we were waiting for a worker
170+
return
171+
await self._handler(message._message)
172+
except Exception:
173+
_LOGGER.error("Something went wrong when processing message %s", message)
174+
_LOGGER.debug('Original traceback: ', exc_info=True)
175+
message._failed = True
176+
message._complete.set()
177+
self._semaphore.release() # signal worker is idle
178+
179+
def start(self):
180+
"""Start the workers."""
181+
self._task = asyncio.get_running_loop().create_task(self._schedule_work())
182+
183+
async def submit_work(self, jobs):
184+
"""
185+
Add a new message to the work-queue.
186+
187+
:param message: New message to add.
188+
:type message: object.
189+
"""
190+
self.jobs = jobs
191+
if len(jobs) == 1:
192+
wrapped = TaskCompletionWraper(jobs[0])
193+
await self._queue.put(wrapped)
194+
return wrapped
195+
196+
tasks = [TaskCompletionWraper(job) for job in jobs]
197+
for w in tasks:
198+
await self._queue.put(w)
199+
200+
return BatchCompletionWrapper(tasks)
201+
202+
async def stop(self, event=None):
203+
"""abort all execution (except currently running handlers)."""
204+
await self._queue.put(self._abort)
205+
206+
207+
class TaskCompletionWraper:
208+
"""Task completion class"""
209+
def __init__(self, message):
210+
self._message = message
211+
self._complete = asyncio.Event()
212+
self._failed = False
213+
214+
async def await_completion(self):
215+
await self._complete.wait()
216+
217+
def _mark_as_complete(self):
218+
self._complete.set()
219+
220+
221+
class BatchCompletionWrapper:
222+
"""Batch completion class"""
223+
def __init__(self, tasks):
224+
self._tasks = tasks
225+
226+
async def await_completion(self):
227+
await asyncio.gather(*[task.await_completion() for task in self._tasks])
228+
return not any(task._failed for task in self._tasks)

0 commit comments

Comments
 (0)