Skip to content

Commit 7f901fa

Browse files
author
arch
committed
use mp queue
1 parent 55770b9 commit 7f901fa

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

funscript_editor/algorithms/funscriptgenerator.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class FunscriptGeneratorParameter:
7676
max_threshold: float = float(HYPERPARAMETER['max_threshold'])
7777

7878

79-
def merge_score(item: list, number_of_trackers: int) -> list:
79+
def merge_score(item: list, number_of_trackers: int, return_queue: mp.Queue) -> None:
8080
""" Merge score for given number of trackers
8181
8282
Note:
@@ -92,14 +92,14 @@ def merge_score(item: list, number_of_trackers: int) -> list:
9292
list: merged score
9393
"""
9494
if number_of_trackers == 1:
95-
return item[0] if len(item) > 0 else []
95+
return_queue.put(item[0] if len(item) > 0 else [])
9696
else:
9797
max_frame_number = max([len(item[i]) for i in range(number_of_trackers)])
9898
arr = np.ma.empty((max_frame_number,number_of_trackers))
9999
arr.mask = True
100100
for tracker_number in range(number_of_trackers):
101101
arr[:item[tracker_number].shape[0],tracker_number] = item[tracker_number]
102-
return list(filter(None.__ne__, arr.mean(axis=1).tolist()))
102+
return_queue.put(list(filter(None.__ne__, arr.mean(axis=1).tolist())))
103103

104104

105105
class FunscriptGeneratorThread(QtCore.QThread):
@@ -414,14 +414,15 @@ def calculate_score(self, bboxes) -> None:
414414
score['y'][tracker_number] = np.array([max([x[1] for x in bboxes['Woman'][tracker_number]]) - w[1] for w in bboxes['Woman'][tracker_number]])
415415

416416
self.logger.info("Merge Scores")
417-
pool, handle = {}, {}
417+
pool, queue = {}, {}
418418
for metric in score.keys():
419-
pool[metric] = mp.Pool(1)
420-
handle[metric] = pool[metric].starmap_async(merge_score, [(score[metric], self.params.number_of_trackers)])
419+
queue[metric] = mp.Queue()
420+
pool[metric] = mp.Process(target=merge_score, args=(score[metric], self.params.number_of_trackers, queue[metric], ))
421+
pool[metric].start()
421422

422423
for metric in score.keys():
423-
score[metric] = handle[metric].get()[0]
424-
pool[metric].close()
424+
pool[metric].join()
425+
score[metric] = queue[metric].get()
425426

426427
self.logger.info("Scale Score to 0 - 100")
427428
for metric in score.keys():

0 commit comments

Comments
 (0)