Skip to content

Commit 08d5bd0

Browse files
author
arch
committed
refactoring
1 parent 3244f75 commit 08d5bd0

File tree

1 file changed

+62
-127
lines changed

1 file changed

+62
-127
lines changed

funscript_editor/algorithms/funscriptgenerator.py

Lines changed: 62 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from queue import Queue
1111
from pynput.keyboard import Key, Listener
1212
from dataclasses import dataclass
13-
from PyQt5 import QtGui, QtCore, QtWidgets
13+
from PyQt5 import QtCore
1414
from matplotlib.figure import Figure
1515
from datetime import datetime
1616
from scipy.interpolate import interp1d
@@ -25,7 +25,6 @@
2525

2626
import funscript_editor.algorithms.signalprocessing as sp
2727
import numpy as np
28-
import matplotlib.pyplot as plt
2928

3029

3130
@dataclass
@@ -109,7 +108,7 @@ def __init__(self,
109108
logger = logging.getLogger(__name__)
110109

111110

112-
def determine_preview_scaling(self, frame_width, frame_height) -> float:
111+
def determine_preview_scaling(self, frame_width, frame_height) -> None:
113112
""" Determine the scaling for current monitor setup
114113
115114
Args:
@@ -356,28 +355,22 @@ def scale_score(self, status: str, metric : str = 'y') -> None:
356355
status (str): a status/info message to display in the window
357356
metric (str): scale the 'y' or 'x' score
358357
"""
359-
if len(self.score['y']) < 2: return
360-
361-
cap = cv2.VideoCapture(self.params.video_path)
358+
if metric not in self.score.keys():
359+
self.logger.error("key %s is not in score dict", metric)
360+
return
362361

363-
if metric == 'euclideanDistance':
364-
min_frame = np.argmin(np.array(self.score['euclideanDistance'])) + self.params.start_frame
365-
max_frame = np.argmax(np.array(self.score['euclideanDistance'])) + self.params.start_frame
366-
elif metric == 'x':
367-
min_frame = np.argmin(np.array(self.score['x'])) + self.params.start_frame
368-
max_frame = np.argmax(np.array(self.score['x'])) + self.params.start_frame
369-
else:
370-
min_frame = np.argmin(np.array(self.score['y'])) + self.params.start_frame
371-
max_frame = np.argmax(np.array(self.score['y'])) + self.params.start_frame
362+
if len(self.score[metric]) < 2: return
363+
min_frame = np.argmin(np.array(self.score[metric])) + self.params.start_frame
364+
max_frame = np.argmax(np.array(self.score[metric])) + self.params.start_frame
372365

366+
cap = cv2.VideoCapture(self.params.video_path)
373367
cap.set(cv2.CAP_PROP_POS_FRAMES, min_frame)
374-
successMin, imgMin = cap.read()
368+
success_min, imgMin = cap.read()
375369
cap.set(cv2.CAP_PROP_POS_FRAMES, max_frame)
376-
successMax, imgMax = cap.read()
377-
370+
success_max, imgMax = cap.read()
378371
cap.release()
379372

380-
if successMin and successMax:
373+
if success_min and success_max:
381374
if 'vr' in self.params.projection.split('_'):
382375
if 'sbs' in self.params.projection.split('_'):
383376
imgMin = imgMin[:, :int(imgMin.shape[1]/2)]
@@ -390,41 +383,23 @@ def scale_score(self, status: str, metric : str = 'y') -> None:
390383
scale = PROJECTION[self.params.projection]['parameter']['width'] / float(1.75*imgMax.shape[1])
391384
else:
392385
scale = PROJECTION[self.params.projection]['parameter']['height'] / float(1.75*imgMax.shape[0])
386+
393387
imgMin = cv2.resize(imgMin, None, fx=scale, fy=scale)
394388
imgMax = cv2.resize(imgMax, None, fx=scale, fy=scale)
395389

396-
if metric == 'y':
397-
title_min = "Bottom"
398-
elif metric == 'x':
399-
title_min = "Left"
400-
else:
401-
title_min = "Minimum"
402-
403-
if metric == 'y':
404-
title_max = "Top"
405-
elif metric == 'x':
406-
title_max = "Right"
407-
else:
408-
title_max = "Maximum"
409-
410390
(desired_min, desired_max) = self.min_max_selector(
411391
image_min = imgMin,
412392
image_max = imgMax,
413393
info = status,
414-
title_min = title_min,
415-
title_max = title_max
394+
title_min = metric + " Minimum",
395+
title_max = metric + " Maximum"
416396
)
417397
else:
418398
self.logger.warning("Determine min and max failed")
419399
desired_min = 0
420400
desired_max = 99
421401

422-
if metric == 'euclideanDistance':
423-
self.score['euclideanDistance'] = sp.scale_signal(self.score['euclideanDistance'], desired_min, desired_max)
424-
elif metric == 'x':
425-
self.score['x'] = sp.scale_signal(self.score['x'], desired_min, desired_max)
426-
else:
427-
self.score['y'] = sp.scale_signal(self.score['y'], desired_min, desired_max)
402+
self.score[metric] = sp.scale_signal(self.score[metric], desired_min, desired_max)
428403

429404

430405
def plot_y_score(self, name: str, idx_list: list, dpi : int = 300) -> None:
@@ -533,7 +508,7 @@ def get_vr_projection_config(self, image :np.ndarray) -> dict:
533508

534509
preview = self.drawText(preview, "Press 'q' to use current selected region of interest)",
535510
y = 50, color = (255, 0, 0))
536-
preview = self.drawText(preview, "Use 'w', 's' to move up/down to the region of interest",
511+
preview = self.drawText(preview, "VR Projection: Use 'w', 's' to move up/down to the region of interest",
537512
y = 75, color = (0, 255, 0))
538513

539514
cv2.imshow(self.window_name, self.preview_scaling(preview))
@@ -813,76 +788,54 @@ def finished(self, status: str, success :bool) -> None:
813788
self.funscriptCompleted.emit(self.funscript, status, success)
814789

815790

816-
def apply_shift(self, frame_number, position: str) -> int:
791+
def apply_shift(self, frame_number: int, metric: str, position: str) -> int:
817792
""" Apply shift to predicted frame positions
818793
819794
Args:
820-
position (str): is max or min
821-
"""
822-
if self.params.metric == 'euclideanDistance':
823-
shift_a = self.params.shift_top_points
824-
elif self.params.metric == 'x':
825-
shift_a = self.params.shift_right_points
826-
else:
827-
shift_a = self.params.shift_top_points
795+
frame_number (int): relative frame number
796+
metric (str): metric to apply the shift
797+
position (str): keyword ['max', 'min', 'None']
828798
829-
if self.params.metric == 'euclideanDistance':
830-
shift_b = self.params.shift_bottom_points
831-
elif self.params.metric == 'x':
832-
shift_b = self.params.shift_left_points
833-
else:
834-
shift_b = self.params.shift_bottom_points
799+
Returns:
800+
int: real frame position
801+
"""
802+
shift_max = self.params.shift_top_points if metric == 'y' else self.params.shift_right_points
803+
shift_min = self.params.shift_bottom_points if metric == 'y' else self.params.shift_left_points
835804

836-
if position in ['max', 'top', 'right'] :
837-
if frame_number >= -1*shift_a \
838-
and frame_number + shift_a < len(self.score['y']): \
839-
return self.params.start_frame + frame_number + shift_a
805+
if position in ['max'] :
806+
if frame_number >= -1*shift_max \
807+
and frame_number + shift_max < len(self.score[metric]): \
808+
return self.params.start_frame + frame_number + shift_max
840809

841-
if position in ['min', 'bottom', 'left']:
842-
if frame_number >= -1*shift_b \
843-
and frame_number + shift_b < len(self.score['y']): \
844-
return self.params.start_frame + frame_number + shift_b
810+
if position in ['min']:
811+
if frame_number >= -1*shift_min \
812+
and frame_number + shift_min < len(self.score[metric]): \
813+
return self.params.start_frame + frame_number + shift_min
845814

846815
return self.params.start_frame + frame_number
847816

848817

849-
def get_score_with_offset(self, idx_dict) -> list:
818+
def get_score_with_offset(self, idx_dict: dict, metric: str) -> list:
850819
""" Apply the offsets form config file
851820
852821
Args:
853822
idx_dict (dict): the idx dictionary with {'min':[], 'max':[]} idx lists
823+
metric (str): the metric for the score calculation
854824
855825
Returns:
856826
list: score with offset
857827
"""
858-
if self.params.metric == 'euclideanDistance':
859-
offset_a = self.params.top_points_offset
860-
elif self.params.metric == 'x':
861-
offset_a = self.params.right_points_offset
862-
else:
863-
offset_a = self.params.top_points_offset
864-
865-
if self.params.metric == 'euclideanDistance':
866-
offset_b = self.params.bottom_points_offset
867-
elif self.params.metric == 'x':
868-
offset_b = self.params.left_points_offset
869-
else:
870-
offset_b = self.params.bottom_points_offset
871-
872-
if self.params.metric == 'euclideanDistance':
873-
score = copy.deepcopy(self.score['euclideanDistance'])
874-
elif self.params.metric == 'x':
875-
score = copy.deepcopy(self.score['x'])
876-
else:
877-
score = copy.deepcopy(self.score['y'])
828+
offset_max = self.params.top_points_offset if metric == 'y' else self.params.right_points_offset
829+
offset_min = self.params.bottom_points_offset if metric == 'y' else self.params.left_points_offset
878830

831+
score = copy.deepcopy(self.score[metric])
879832
score_min, score_max = min(score), max(score)
880833

881834
for idx in idx_dict['max']:
882-
score[idx] = max(( score_min, min((score_max, score[idx] + offset_a)) ))
835+
score[idx] = max(( score_min, min((score_max, score[idx] + offset_max)) ))
883836

884837
for idx in idx_dict['min']:
885-
score[idx] = max(( score_min, min((score_max, score[idx] + offset_b)) ))
838+
score[idx] = max(( score_min, min((score_max, score[idx] + offset_min)) ))
886839

887840
return score
888841

@@ -908,20 +861,20 @@ def apply_kalman_filter(self) -> None:
908861
self.bboxes['Men'][idx] = (prediction[0], prediction[1], item[2], item[3])
909862

910863

911-
def determin_change_points(self) -> dict:
864+
def determine_change_points(self, metric: str) -> dict:
912865
""" Determine all change points
913866
867+
Args:
868+
metric (str): from which metric you want to have the chainge points
869+
914870
Returns:
915871
dict: all local max and min points in score {'min':[idx1, idx2, ...], 'max':[idx1, idx2, ...]}
916872
"""
917-
self.logger.info("Determine change points")
918-
if self.params.metric == 'euclideanDistance':
919-
idx_dict = sp.get_local_max_and_min_idx(self.score['euclideanDistance'], round(self.video_info.fps))
920-
elif self.params.metric == 'x':
921-
idx_dict = sp.get_local_max_and_min_idx(self.score['x'], round(self.video_info.fps))
922-
else:
923-
idx_dict = sp.get_local_max_and_min_idx(self.score['y'], round(self.video_info.fps))
924-
return idx_dict
873+
self.logger.info("Determine change points for %s", metric)
874+
if metric not in self.score.keys():
875+
self.logger.error("key %s not in score metrics dict", metric)
876+
return dict()
877+
return sp.get_local_max_and_min_idx(self.score[metric], round(self.video_info.fps))
925878

926879

927880
def create_funscript(self, idx_dict: dict) -> None:
@@ -932,50 +885,33 @@ def create_funscript(self, idx_dict: dict) -> None:
932885
{'min':[idx1, idx2, ...], 'max':[idx1, idx2, ...]}
933886
"""
934887
if self.params.raw_output:
935-
if self.params.metric == 'euclideanDistance':
936-
output_score = copy.deepcopy(self.score['euclideanDistance'])
937-
elif self.params.metric == 'x':
938-
output_score = copy.deepcopy(self.score['x'])
939-
else:
940-
output_score = copy.deepcopy(self.score['y'])
941-
888+
output_score = copy.deepcopy(self.score[self.params.metric])
942889
for idx in range(len(output_score)):
943890
self.funscript.add_action(
944891
output_score[idx],
945-
FFmpegStream.frame_to_millisec(self.apply_shift(idx, 'none'), self.video_info.fps)
892+
FFmpegStream.frame_to_millisec(self.apply_shift(idx, self.params.metric, 'None'), self.video_info.fps)
946893
)
947894

948895
else:
949-
output_score = self.get_score_with_offset(idx_dict)
896+
output_score = self.get_score_with_offset(idx_dict, self.params.metric)
950897

951-
if self.params.metric == 'euclideanDistance':
952-
threshold_a = self.params.bottom_threshold
953-
elif self.params.metric == 'x':
954-
threshold_a = self.params.left_threshold
955-
else:
956-
threshold_a = self.params.bottom_threshold
957-
958-
if self.params.metric == 'euclideanDistance':
959-
threshold_b = self.params.top_threshold
960-
elif self.params.metric == 'x':
961-
threshold_b = self.params.right_threshold
962-
else:
963-
threshold_b = self.params.top_threshold
898+
threshold_min = self.params.bottom_threshold if self.params.metric == 'y' else self.params.left_threshold
899+
threshold_max = self.params.top_threshold if self.params.metric == 'y' else self.params.right_threshold
964900

965901
for idx in idx_dict['min']:
966902
self.funscript.add_action(
967903
min(output_score) \
968-
if output_score[idx] < min(output_score) + threshold_a \
904+
if output_score[idx] < min(output_score) + threshold_min \
969905
else round(output_score[idx]),
970-
FFmpegStream.frame_to_millisec(self.apply_shift(idx, 'min'), self.video_info.fps)
906+
FFmpegStream.frame_to_millisec(self.apply_shift(idx, self.params.metric, 'min'), self.video_info.fps)
971907
)
972908

973909
for idx in idx_dict['max']:
974910
self.funscript.add_action(
975911
max(output_score) \
976-
if output_score[idx] > max(output_score) - threshold_b \
912+
if output_score[idx] > max(output_score) - threshold_max \
977913
else round(output_score[idx]),
978-
FFmpegStream.frame_to_millisec(self.apply_shift(idx, 'max'), self.video_info.fps)
914+
FFmpegStream.frame_to_millisec(self.apply_shift(idx, self.params.metric, 'max'), self.video_info.fps)
979915
)
980916

981917

@@ -993,22 +929,21 @@ def run(self) -> None:
993929
if self.params.raw_output:
994930
self.logger.warning("Raw output is enabled!")
995931

996-
# NOTE: score['y'] and score['x'] should have the same number size so it should be enouth to check one score length
997932
with Listener(on_press=self.on_key_press) as _:
998933
status = self.tracking()
999934

1000935
if self.params.use_kalman_filter:
1001936
self.apply_kalman_filter()
1002937

1003-
if len(self.score['y']) >= HYPERPARAMETER['min_frames']:
938+
if len(self.score[self.params.metric]) >= HYPERPARAMETER['min_frames']:
1004939
self.logger.info("Scale score")
1005940
self.scale_score(status, metric=self.params.metric)
1006941

1007-
if len(self.score['y']) < HYPERPARAMETER['min_frames']:
942+
if len(self.score[self.params.metric]) < HYPERPARAMETER['min_frames']:
1008943
self.finished(status + ' -> Tracking time insufficient', False)
1009944
return
1010945

1011-
idx_dict = self.determin_change_points()
946+
idx_dict = self.determine_change_points(self.params.metric)
1012947

1013948
if False:
1014949
idx_list = [x for k in ['min', 'max'] for x in idx_dict[k]]

0 commit comments

Comments
 (0)