Skip to content

Commit 6677bce

Browse files
author
arch
committed
add more sliders
1 parent 19431f5 commit 6677bce

File tree

3 files changed

+107
-84
lines changed

3 files changed

+107
-84
lines changed

funscript_editor/algorithms/signal.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@
1313

1414
@dataclass
1515
class SignalParameter:
16-
local_min_max_filter_len: int = int(HYPERPARAMETER['signal']['local_max_min_filter_len'])
16+
additional_points_merge_time_threshold_in_ms: float
17+
additional_points_merge_distance_threshold: float
18+
high_second_derivative_points_threshold: float
19+
distance_minimization_threshold: float
20+
local_min_max_filter_len: int
21+
direction_change_filter_len: int
1722
avg_sec_for_local_min_max_extraction: float = float(HYPERPARAMETER['signal']['avg_sec_for_local_min_max_extraction'])
18-
additional_points_merge_time_threshold_in_ms: float = float(HYPERPARAMETER['signal']['additional_points_merge_time_threshold_in_ms'])
19-
additional_points_merge_distance_threshold: float = float(HYPERPARAMETER['signal']['additional_points_merge_distance_threshold'])
20-
distance_minimization_threshold: float = float(HYPERPARAMETER['signal']['distance_minimization_threshold'])
21-
high_second_derivative_points_threshold: float = float(HYPERPARAMETER['signal']['high_second_derivative_points_threshold'])
22-
direction_change_filter_len: int = int(HYPERPARAMETER['signal']['direction_change_filter_len'])
2323
min_evenly_intermediate_interframes: int = int(HYPERPARAMETER['signal']['min_evenly_intermediate_interframes'])
2424

2525

2626
class Signal:
2727

28-
def __init__(self, fps):
29-
self.params = SignalParameter()
28+
def __init__(self, params, fps):
29+
self.params = params
3030
self.fps = fps
3131
self.logger = logging.getLogger(__name__)
3232

funscript_editor/config/hyperparameter.yaml

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,14 @@
99
min_frames: 90
1010

1111
# reaction time of the user to stop the tracking when scene changed or tracking box shifts
12-
user_reaction_time_in_milliseconds: 1000
12+
user_reaction_time_in_milliseconds: 800
1313

1414

1515
# Signal Processing Hyperparameter
1616
signal:
17-
18-
# Specify the window size for the calculation of smothed tracking signal for the local min and max search.
19-
# For slow actions in your video, i recommend to set this to '3' to reduce the false positive predictions.
20-
local_max_min_filter_len: 1
21-
2217
# Specify the window size for the calculation of the reference value for the local min and max search.
2318
avg_sec_for_local_min_max_extraction: 2.0
2419

25-
# threshold value to predict an distance minimization point candidate
26-
distance_minimization_threshold: 16.0
27-
28-
# threshold value tor get additional points by comparing second derivative with the rolling standard deviation and given threshold
29-
high_second_derivative_points_threshold: 1.2
30-
31-
# filter length to detect a direction change
32-
direction_change_filter_len: 3
33-
34-
# threshold value in milliseconds to merge additional points
35-
additional_points_merge_time_threshold_in_ms: 60
36-
37-
# threshold value to merge additional points
38-
additional_points_merge_distance_threshold: 8.0
39-
4020
# min interframes without an additional datapoint for the evenly intermediate algorithm
4121
min_evenly_intermediate_interframes: 2
4222

funscript_editor/ui/postprocessing.py

Lines changed: 98 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
simplify_coords_vw_idx,
55
)
66

7+
import funscript_editor.utils.logging as logging
78
import copy
89
import numpy as np
910
import pyqtgraph as pg
1011

11-
from funscript_editor.algorithms.signal import Signal
12+
from funscript_editor.algorithms.signal import Signal,SignalParameter
1213
from funscript_editor.ui.cut_tracking_result import Slider
1314

1415
class QHLine(QtWidgets.QFrame):
@@ -20,6 +21,7 @@ def __init__(self):
2021
class PostprocessingWidget(QtWidgets.QWidget):
2122
def __init__(self, metric, raw_score, video_info, parent=None):
2223
super(QtWidgets.QWidget, self).__init__(parent=parent)
24+
self.logger = logging.getLogger(__name__)
2325
pg.setConfigOption("background","w")
2426
self.verticalLayout = QtWidgets.QVBoxLayout(self)
2527

@@ -94,6 +96,9 @@ def add_custom_tab(self):
9496
self.tabs_content[tab_name]["widgets"]["points"].addItems(["Local Min Max", "Direction Changed"])
9597
self.tabs_content[tab_name]["widgets"]["points"].currentIndexChanged.connect(self.update_plot)
9698

99+
self.tabs_content[tab_name]["widgets"]["filterLen"] = Slider("Filter Len", 10, 2)
100+
self.tabs_content[tab_name]["widgets"]["filterLen"].slider.valueChanged.connect(self.update_plot)
101+
97102
self.tabs_content[tab_name]["widgets"]["high_second_derivate"] = QtWidgets.QCheckBox("High Second Derivate")
98103
self.tabs_content[tab_name]["widgets"]["high_second_derivate"].stateChanged.connect(self.update_plot)
99104

@@ -106,6 +111,18 @@ def add_custom_tab(self):
106111
self.tabs_content[tab_name]["widgets"]["runs"] = Slider("Max additional Points", 8, 2)
107112
self.tabs_content[tab_name]["widgets"]["runs"].slider.valueChanged.connect(self.update_plot)
108113

114+
self.tabs_content[tab_name]["widgets"]["mergeThresholdMs"] = Slider("Merge Threshold Time in ms", 1000, 60)
115+
self.tabs_content[tab_name]["widgets"]["mergeThresholdMs"].slider.valueChanged.connect(self.update_plot)
116+
117+
self.tabs_content[tab_name]["widgets"]["mergeThresholdDistance"] = Slider("Merge Threshold Distance", 100, 12)
118+
self.tabs_content[tab_name]["widgets"]["mergeThresholdDistance"].slider.valueChanged.connect(self.update_plot)
119+
120+
self.tabs_content[tab_name]["widgets"]["highSecondDerivateThreshold"] = Slider("Threshold", 100, 12)
121+
self.tabs_content[tab_name]["widgets"]["highSecondDerivateThreshold"].slider.valueChanged.connect(self.update_plot)
122+
123+
self.tabs_content[tab_name]["widgets"]["distanzMinimizationThreshold"] = Slider("Threshold", 100, 16)
124+
self.tabs_content[tab_name]["widgets"]["distanzMinimizationThreshold"].slider.valueChanged.connect(self.update_plot)
125+
109126
self.tabs_content[tab_name]["widgets"]["lower"] = Slider("Lower Offset", 100, 0)
110127
self.tabs_content[tab_name]["widgets"]["lower"].slider.valueChanged.connect(self.update_plot)
111128

@@ -114,12 +131,20 @@ def add_custom_tab(self):
114131

115132
self.tabs_content[tab_name]["main"].layout.addWidget(QtWidgets.QLabel("Points:"))
116133
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["points"])
134+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["filterLen"])
117135
self.tabs_content[tab_name]["main"].layout.addWidget(QHLine())
118136
self.tabs_content[tab_name]["main"].layout.addWidget(QtWidgets.QLabel("Additinal Points:"))
137+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["runs"])
138+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["mergeThresholdMs"])
139+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["mergeThresholdDistance"])
140+
self.tabs_content[tab_name]["main"].layout.addWidget(QHLine())
119141
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["high_second_derivate"])
142+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["highSecondDerivateThreshold"])
143+
self.tabs_content[tab_name]["main"].layout.addWidget(QHLine())
120144
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["distance_minimization"])
145+
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["distanzMinimizationThreshold"])
146+
self.tabs_content[tab_name]["main"].layout.addWidget(QHLine())
121147
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["evenly_intermediate"])
122-
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["runs"])
123148
self.tabs_content[tab_name]["main"].layout.addWidget(QHLine())
124149
self.tabs_content[tab_name]["main"].layout.addWidget(QtWidgets.QLabel("Offset:"))
125150
self.tabs_content[tab_name]["main"].layout.addWidget(self.tabs_content[tab_name]["widgets"]["lower"])
@@ -142,57 +167,75 @@ def confirm(self):
142167
def update_plot(self):
143168
current_tab_name = self.get_current_tab_name()
144169

145-
if current_tab_name == "Ramer–Douglas–Peucker":
146-
self.result_idx = simplify_coords_idx(self.raw_score_np, float(self.tabs_content[current_tab_name]["widgets"]["epsilon"].x) / 10.0)
147-
self.result_val = [val for idx,val in enumerate(self.raw_score) if idx in self.result_idx]
148-
self.curve_result.setData(self.result_idx, self.result_val)
149-
return
150-
151-
if current_tab_name == "Visvalingam-Whyatt":
152-
self.result_idx = simplify_coords_vw_idx(self.raw_score_np, float(self.tabs_content[current_tab_name]["widgets"]["epsilon"].x) / 1.0)
153-
self.result_val = [val for idx,val in enumerate(self.raw_score) if idx in self.result_idx]
154-
self.curve_result.setData(self.result_idx, self.result_val)
155-
return
156-
157-
if current_tab_name == "Custom":
158-
base_algo = self.tabs_content[current_tab_name]["widgets"]["points"].currentText()
159-
runs = self.tabs_content[current_tab_name]["widgets"]["runs"].x
160-
offset_lower = self.tabs_content[current_tab_name]["widgets"]["lower"].x
161-
offset_upper = self.tabs_content[current_tab_name]["widgets"]["upper"].x
162-
163-
base_point_algorithm = Signal.BasePointAlgorithm.local_min_max
164-
if base_algo == 'Direction Changed':
165-
base_point_algorithm = Signal.BasePointAlgorithm.direction_changes
166-
167-
additional_points_algorithms = []
168-
if self.tabs_content[current_tab_name]["widgets"]["high_second_derivate"].isChecked():
169-
additional_points_algorithms.append(Signal.AdditionalPointAlgorithm.high_second_derivative)
170-
171-
if self.tabs_content[current_tab_name]["widgets"]["distance_minimization"].isChecked():
172-
additional_points_algorithms.append(Signal.AdditionalPointAlgorithm.distance_minimization)
173-
174-
if self.tabs_content[current_tab_name]["widgets"]["evenly_intermediate"].isChecked():
175-
additional_points_algorithms.append(Signal.AdditionalPointAlgorithm.evenly_intermediate)
176-
177-
signal = Signal(self.video_info.fps)
178-
self.result_idx = signal.decimate(
179-
self.raw_score,
180-
base_point_algorithm,
181-
additional_points_algorithms,
182-
additional_points_repetitions = runs
183-
)
184-
185-
categorized = signal.categorize_points(self.raw_score, self.result_idx)
186-
187-
score = copy.deepcopy(self.raw_score)
188-
score_min, score_max = min(score), max(score)
189-
190-
for idx in categorized['upper']:
191-
score[idx] = max(( score_min, min((score_max, score[idx] + offset_upper)) ))
192-
193-
for idx in categorized['lower']:
194-
score[idx] = max(( score_min, min((score_max, score[idx] - offset_lower)) ))
170+
try:
171+
if current_tab_name == "Ramer–Douglas–Peucker":
172+
self.result_idx = simplify_coords_idx(self.raw_score_np, float(self.tabs_content[current_tab_name]["widgets"]["epsilon"].x) / 10.0)
173+
self.result_val = [val for idx,val in enumerate(self.raw_score) if idx in self.result_idx]
174+
self.curve_result.setData(self.result_idx, self.result_val)
175+
return
176+
177+
if current_tab_name == "Visvalingam-Whyatt":
178+
self.result_idx = simplify_coords_vw_idx(self.raw_score_np, float(self.tabs_content[current_tab_name]["widgets"]["epsilon"].x) / 1.0)
179+
self.result_val = [val for idx,val in enumerate(self.raw_score) if idx in self.result_idx]
180+
self.curve_result.setData(self.result_idx, self.result_val)
181+
return
182+
183+
if current_tab_name == "Custom":
184+
base_algo = self.tabs_content[current_tab_name]["widgets"]["points"].currentText()
185+
runs = self.tabs_content[current_tab_name]["widgets"]["runs"].x
186+
offset_lower = self.tabs_content[current_tab_name]["widgets"]["lower"].x
187+
offset_upper = self.tabs_content[current_tab_name]["widgets"]["upper"].x
188+
mergeThresholdMs = self.tabs_content[current_tab_name]["widgets"]["mergeThresholdMs"].x
189+
mergeThresholdDistance = self.tabs_content[current_tab_name]["widgets"]["mergeThresholdDistance"].x
190+
highSecondDerivateThreshold = self.tabs_content[current_tab_name]["widgets"]["highSecondDerivateThreshold"].x / 10.0
191+
distanzMinimizationThreshold = self.tabs_content[current_tab_name]["widgets"]["distanzMinimizationThreshold"].x
192+
filterLen = self.tabs_content[current_tab_name]["widgets"]["filterLen"].x + 1
193+
194+
base_point_algorithm = Signal.BasePointAlgorithm.local_min_max
195+
if base_algo == 'Direction Changed':
196+
base_point_algorithm = Signal.BasePointAlgorithm.direction_changes
197+
198+
additional_points_algorithms = []
199+
if self.tabs_content[current_tab_name]["widgets"]["high_second_derivate"].isChecked():
200+
additional_points_algorithms.append(Signal.AdditionalPointAlgorithm.high_second_derivative)
201+
202+
if self.tabs_content[current_tab_name]["widgets"]["distance_minimization"].isChecked():
203+
additional_points_algorithms.append(Signal.AdditionalPointAlgorithm.distance_minimization)
204+
205+
if self.tabs_content[current_tab_name]["widgets"]["evenly_intermediate"].isChecked():
206+
additional_points_algorithms.append(Signal.AdditionalPointAlgorithm.evenly_intermediate)
207+
208+
signal = Signal(SignalParameter(
209+
additional_points_merge_time_threshold_in_ms = mergeThresholdMs,
210+
additional_points_merge_distance_threshold = mergeThresholdDistance,
211+
high_second_derivative_points_threshold = highSecondDerivateThreshold,
212+
distance_minimization_threshold = distanzMinimizationThreshold,
213+
local_min_max_filter_len = filterLen,
214+
direction_change_filter_len = filterLen
215+
), self.video_info.fps
216+
)
217+
218+
self.result_idx = signal.decimate(
219+
self.raw_score,
220+
base_point_algorithm,
221+
additional_points_algorithms,
222+
additional_points_repetitions = runs
223+
)
224+
225+
categorized = signal.categorize_points(self.raw_score, self.result_idx)
226+
227+
score = copy.deepcopy(self.raw_score)
228+
score_min, score_max = min(score), max(score)
229+
230+
for idx in categorized['upper']:
231+
score[idx] = max(( score_min, min((score_max, score[idx] + offset_upper)) ))
232+
233+
for idx in categorized['lower']:
234+
score[idx] = max(( score_min, min((score_max, score[idx] - offset_lower)) ))
235+
236+
self.result_val = [val for idx,val in enumerate(score) if idx in self.result_idx]
237+
self.curve_result.setData(self.result_idx, self.result_val)
238+
return
239+
except Exception as ex:
240+
self.logger.critical("Invalid Values in Postprocessing Widget", exc_info=ex)
195241

196-
self.result_val = [val for idx,val in enumerate(score) if idx in self.result_idx]
197-
self.curve_result.setData(self.result_idx, self.result_val)
198-
return

0 commit comments

Comments
 (0)