|
9 | 9 | import sys |
10 | 10 | import numpy as np |
11 | 11 | import pyqtgraph as pg |
| 12 | +from scipy.signal import savgol_filter, find_peaks, find_peaks_cwt |
12 | 13 |
|
13 | 14 | from funscript_editor.algorithms.signal import Signal,SignalParameter |
14 | 15 | from funscript_editor.ui.cut_tracking_result import Slider |
@@ -36,9 +37,12 @@ def __init__(self, metric, raw_score, video_info, parent=None): |
36 | 37 |
|
37 | 38 | self.tabs = QtWidgets.QTabWidget() |
38 | 39 | self.tabs_content = {} |
| 40 | + |
39 | 41 | self.add_rdp_tab() |
40 | 42 | self.add_vw_tab() |
41 | 43 | self.add_custom_tab() |
| 44 | + self.add_auto_tab() |
| 45 | + |
42 | 46 | self.verticalLayout.addWidget(self.tabs) |
43 | 47 | self.tabs.currentChanged.connect(self.update_plot) |
44 | 48 |
|
@@ -155,6 +159,14 @@ def add_custom_tab(self): |
155 | 159 | self.tabs_content[tab_name]["main"].setLayout(self.tabs_content[tab_name]["main"].layout) |
156 | 160 | self.tabs.addTab(self.tabs_content[tab_name]["main"], tab_name) |
157 | 161 |
|
| 162 | + def add_auto_tab(self): |
| 163 | + tab_name = "Developer" |
| 164 | + self.tabs_content[tab_name] = {"main": QtWidgets.QWidget(), "widgets": {}} |
| 165 | + self.tabs_content[tab_name]["main"].layout = QtWidgets.QVBoxLayout(self) |
| 166 | + |
| 167 | + self.tabs_content[tab_name]["main"].setLayout(self.tabs_content[tab_name]["main"].layout) |
| 168 | + self.tabs.addTab(self.tabs_content[tab_name]["main"], tab_name) |
| 169 | + |
158 | 170 |
|
159 | 171 | def get_current_tab_name(self) -> str: |
160 | 172 | return self.tabs.tabText(self.tabs.currentIndex()) |
@@ -245,6 +257,32 @@ def update_plot(self): |
245 | 257 | self.result_val = [val for idx,val in enumerate(score) if idx in self.result_idx] |
246 | 258 | self.curve_result.setData(self.result_idx, self.result_val) |
247 | 259 | return |
| 260 | + |
| 261 | + if current_tab_name == "Developer": |
| 262 | + smothed_score = savgol_filter(self.raw_score, 5, 2) |
| 263 | + |
| 264 | + max_idx, _ = find_peaks(smothed_score) |
| 265 | + min_idx, _ = find_peaks([100.0-1.0 * x for x in smothed_score]) |
| 266 | + |
| 267 | + d1 = savgol_filter(np.diff(smothed_score, 1).tolist(), 5, 2) |
| 268 | + d2 = savgol_filter(np.diff(d1, 1).tolist(), 5, 2) |
| 269 | + |
| 270 | + d2_max_idx, _ = find_peaks(d2) |
| 271 | + d2_min_idx, _ = find_peaks([-1.0*x for x in d2]) |
| 272 | + |
| 273 | + print("min_idx", min_idx) |
| 274 | + print("max_idx", max_idx) |
| 275 | + print("d2_min_idx", d2_min_idx) |
| 276 | + print("d2_max_idx", d2_max_idx) |
| 277 | + |
| 278 | + all_idx = list(max_idx) + list(min_idx) + list(d2_max_idx) + list(d2_min_idx) |
| 279 | + |
| 280 | + self.result_idx = list(set(all_idx)) |
| 281 | + self.result_idx.sort() |
| 282 | + |
| 283 | + self.result_val = [val for idx,val in enumerate(smothed_score) if idx in self.result_idx] |
| 284 | + self.curve_result.setData(self.result_idx, self.result_val) |
| 285 | + |
248 | 286 | except Exception as ex: |
249 | 287 | self.logger.critical("Invalid Values in Postprocessing Widget", exc_info=ex) |
250 | 288 |
|
0 commit comments