Skip to content

Commit e752f4e

Browse files
author
arch
committed
add custom post processing tab
1 parent 28c6dc2 commit e752f4e

File tree

5 files changed

+74
-14
lines changed

5 files changed

+74
-14
lines changed

funscript_editor/algorithms/signal.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class SignalParameter:
2020
distance_minimization_threshold: float = float(HYPERPARAMETER['signal']['distance_minimization_threshold'])
2121
high_second_derivative_points_threshold: float = float(HYPERPARAMETER['signal']['high_second_derivative_points_threshold'])
2222
direction_change_filter_len: int = int(HYPERPARAMETER['signal']['direction_change_filter_len'])
23-
additional_points_repetitions: int = int(HYPERPARAMETER['signal']['additional_points_repetitions'])
2423
min_evenly_intermediate_interframes: int = int(HYPERPARAMETER['signal']['min_evenly_intermediate_interframes'])
2524

2625

@@ -583,13 +582,15 @@ def apply_manual_shift(self, point_group: dict, max_idx: int, shift: dict = {'mi
583582
def decimate(self,
584583
signal: list,
585584
base_point_algorithm: BasePointAlgorithm,
586-
additional_points_algorithms: List[AdditionalPointAlgorithm]) -> list:
585+
additional_points_algorithms: List[AdditionalPointAlgorithm],
586+
additional_points_repetitions: int = 2) -> list:
587587
""" Compute the decimated signal with given algorithms
588588
589589
Args:
590590
signal (list): raw signal
591591
base_point_algorithm (BasePointAlgorithm): algorithm to determine the base points
592592
additional_points_algorithms (List[AdditionalPointAlgorithm]): list with algorithms to determine additional points
593+
additional_points_repetitions: number of runs for the additional points algorithm (max number of points that will be insert between 2 base points)
593594
594595
Returns:
595596
list: indexes for decimated signal
@@ -601,7 +602,7 @@ def decimate(self,
601602
else:
602603
raise NotImplementedError("Selected Base Point Algorithm is not implemented")
603604

604-
for run_idx in range(self.params.additional_points_repetitions):
605+
for run_idx in range(additional_points_repetitions):
605606
self.logger.info("Run Additional Points Algorithms #%d", run_idx+1)
606607
len_before_merge = len(decimated_indexes)
607608
for algo in additional_points_algorithms:

funscript_editor/config/hyperparameter.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ signal:
4040
# min interframes without an additional datapoint for the evenly intermediate algorithm
4141
min_evenly_intermediate_interframes: 2
4242

43-
# number of runs for the additional points algorithm (max number of points that will be insert between 2 base points)
44-
additional_points_repetitions: 2
45-
4643

4744
# Scene Detector Hyperparameter
4845
scene_detector:

funscript_editor/config/settings.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ notification_sound: 'off'
2626
# - 'THRESHOLD': Detects fades in video.
2727
scene_detector: 'CSV'
2828

29-
# Output an position for each frame. This option disable post processing!
30-
raw_output: False
31-
3229
# Force dark ui theme
3330
dark_theme: False
3431

funscript_editor/ui/funscript_generator_window.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def __next_postprocessing(self, last_metric, idx_keep, val_keep):
189189
continue
190190

191191
if found_last:
192-
self.postprocessing_widget = PostprocessingWidget(metric, self.score[metric])
192+
self.postprocessing_widget = PostprocessingWidget(metric, self.score[metric], self.video_info)
193193
self.postprocessing_widget.postprocessingCompleted.connect(self.__next_postprocessing)
194194
self.postprocessing_widget.show()
195195
return
@@ -205,7 +205,7 @@ def __funscript_generated(self, funscripts, msg, success) -> None:
205205
if len(funscripts) > 1:
206206
self.__logger.warning("Multiaxis output for build-in UI is not implemented")
207207
for item in funscripts[first_metric].get_actions():
208-
self.output_file.add_action(item['pos'], item['at'], SETTINGS['raw_output'])
208+
self.output_file.add_action(item['pos'], item['at'], True)
209209
self.funscriptCompleted.emit(self.output_file, msg, success)
210210
else:
211211
os.makedirs(os.path.dirname(self.output_file), exist_ok=True)

funscript_editor/ui/postprocessing.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66

77
import numpy as np
88
import pyqtgraph as pg
9+
10+
from funscript_editor.algorithms.signal import Signal
911
from funscript_editor.ui.cut_tracking_result import Slider
1012

13+
1114
class PostprocessingWidget(QtWidgets.QWidget):
12-
def __init__(self, metric, raw_score, parent=None):
15+
def __init__(self, metric, raw_score, video_info, parent=None):
1316
super(QtWidgets.QWidget, self).__init__(parent=parent)
1417
pg.setConfigOption("background","w")
1518
self.verticalLayout = QtWidgets.QVBoxLayout(self)
1619

20+
self.video_info = video_info
1721
self.metric = metric
1822
self.raw_score_idx = [x for x in range(len(raw_score))]
1923
self.raw_score = raw_score
@@ -25,6 +29,7 @@ def __init__(self, metric, raw_score, parent=None):
2529
self.tabs_content = {}
2630
self.add_rdp_tab()
2731
self.add_vw_tab()
32+
self.add_custom_tab()
2833
self.verticalLayout.addWidget(self.tabs)
2934
self.tabs.currentChanged.connect(self.update_plot)
3035

@@ -67,13 +72,44 @@ def add_vw_tab(self):
6772
tab_name = "Visvalingam-Whyatt"
6873
self.tabs_content[tab_name] = {"main": QtWidgets.QWidget(), "widgets": {}}
6974
self.tabs_content[tab_name]["main"].layout = QtWidgets.QVBoxLayout(self)
70-
self.tabs_content[tab_name]["widgets"]["epsilon"] = Slider("Epsilon", 100, 10)
75+
self.tabs_content[tab_name]["widgets"]["epsilon"] = Slider("Epsilon", 200, 50)
7176
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["epsilon"])
7277
self.tabs_content[tab_name]["widgets"]["epsilon"].slider.valueChanged.connect(self.update_plot)
7378
self.tabs_content[tab_name]["main"].setLayout(self.tabs_content[tab_name]["main"].layout)
7479
self.tabs.addTab(self.tabs_content[tab_name]["main"], tab_name)
7580

7681

82+
def add_custom_tab(self):
83+
tab_name = "Custom"
84+
self.tabs_content[tab_name] = {"main": QtWidgets.QWidget(), "widgets": {}}
85+
self.tabs_content[tab_name]["main"].layout = QtWidgets.QVBoxLayout(self)
86+
87+
self.tabs_content[tab_name]["widgets"]["points"] = QtWidgets.QComboBox()
88+
self.tabs_content[tab_name]["widgets"]["points"].addItems(["Local Min Max", "Direction Changed"])
89+
self.tabs_content[tab_name]["widgets"]["points"].currentIndexChanged.connect(self.update_plot)
90+
91+
self.tabs_content[tab_name]["widgets"]["high_second_derivate"] = QtWidgets.QCheckBox("High Second Derivate")
92+
self.tabs_content[tab_name]["widgets"]["high_second_derivate"].stateChanged.connect(self.update_plot)
93+
94+
self.tabs_content[tab_name]["widgets"]["distance_minimization"] = QtWidgets.QCheckBox("Distance Minimization")
95+
self.tabs_content[tab_name]["widgets"]["distance_minimization"].stateChanged.connect(self.update_plot)
96+
97+
self.tabs_content[tab_name]["widgets"]["evenly_intermediate"] = QtWidgets.QCheckBox("Evenly Intermediate")
98+
self.tabs_content[tab_name]["widgets"]["evenly_intermediate"].stateChanged.connect(self.update_plot)
99+
100+
self.tabs_content[tab_name]["widgets"]["runs"] = Slider("Additionl Points Algorthm Runs", 8, 2)
101+
self.tabs_content[tab_name]["widgets"]["runs"].slider.valueChanged.connect(self.update_plot)
102+
103+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["points"])
104+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["high_second_derivate"])
105+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["distance_minimization"])
106+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["evenly_intermediate"])
107+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["runs"])
108+
109+
self.tabs_content[tab_name]["main"].setLayout(self.tabs_content[tab_name]["main"].layout)
110+
self.tabs.addTab(self.tabs_content[tab_name]["main"], tab_name)
111+
112+
77113
def get_current_tab_name(self) -> str:
78114
return self.tabs.tabText(self.tabs.currentIndex())
79115

@@ -94,7 +130,36 @@ def update_plot(self):
94130
return
95131

96132
if current_tab_name == "Visvalingam-Whyatt":
97-
self.result_idx = simplify_coords_vw_idx(self.raw_score_np, float(self.tabs_content[current_tab_name]["widgets"]["epsilon"].x) / 10.0)
133+
self.result_idx = simplify_coords_vw_idx(self.raw_score_np, float(self.tabs_content[current_tab_name]["widgets"]["epsilon"].x) / 1.0)
134+
self.result_val = [val for idx,val in enumerate(self.raw_score) if idx in self.result_idx]
135+
self.curve_result.setData(self.result_idx, self.result_val)
136+
return
137+
138+
if current_tab_name == "Custom":
139+
base_algo = self.tabs_content[current_tab_name]["widgets"]["points"].currentText()
140+
runs = self.tabs_content[current_tab_name]["widgets"]["runs"].x
141+
142+
base_point_algorithm = Signal.BasePointAlgorithm.local_min_max
143+
if base_algo == 'Direction Changed':
144+
base_point_algorithm = Signal.BasePointAlgorithm.direction_changes
145+
146+
additional_points_algorithms = []
147+
if self.tabs_content[current_tab_name]["widgets"]["high_second_derivate"].isChecked():
148+
additional_points_algorithms.append(Signal.AdditionalPointAlgorithm.high_second_derivative)
149+
150+
if self.tabs_content[current_tab_name]["widgets"]["distance_minimization"].isChecked():
151+
additional_points_algorithms.append(Signal.AdditionalPointAlgorithm.distance_minimization)
152+
153+
if self.tabs_content[current_tab_name]["widgets"]["evenly_intermediate"].isChecked():
154+
additional_points_algorithms.append(Signal.AdditionalPointAlgorithm.evenly_intermediate)
155+
156+
signal = Signal(self.video_info.fps)
157+
self.result_idx = signal.decimate(
158+
self.raw_score,
159+
base_point_algorithm,
160+
additional_points_algorithms,
161+
additional_points_repetitions = runs
162+
)
98163
self.result_val = [val for idx,val in enumerate(self.raw_score) if idx in self.result_idx]
99164
self.curve_result.setData(self.result_idx, self.result_val)
100165
return

0 commit comments

Comments
 (0)