Skip to content

Commit 62247eb

Browse files
author
arch
committed
improve prediction for fast action
1 parent f828508 commit 62247eb

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

funscript_editor/algorithms/signal.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
@dataclass
1414
class SignalParameter:
15+
local_min_max_filter_len: int = int(HYPERPARAMETER['signal']['local_max_min_filter_len'])
1516
avg_sec_for_local_min_max_extraction: float = float(HYPERPARAMETER['signal']['avg_sec_for_local_min_max_extraction'])
1617
additional_points_merge_time_threshold_in_ms: float = float(HYPERPARAMETER['signal']['additional_points_merge_time_threshold_in_ms'])
1718
additional_points_merge_distance_threshold: float = float(HYPERPARAMETER['signal']['additional_points_merge_distance_threshold'])
@@ -296,7 +297,7 @@ def get_edge_points(self, signal: list, base_points: list, threshold: float = 25
296297

297298

298299

299-
def get_local_min_max_points(self, signal: list, filter_len: int = 3) -> list:
300+
def get_local_min_max_points(self, signal: list, filter_len: int = 1) -> list:
300301
""" Get the local max and min positions in given signal
301302
302303
Args:
@@ -306,6 +307,10 @@ def get_local_min_max_points(self, signal: list, filter_len: int = 3) -> list:
306307
Returns:
307308
list: with local max and min indexes
308309
"""
310+
filter_len = max((1, filter_len))
311+
if filter_len % 2 == 0:
312+
filter_len += 1
313+
309314
avg = Signal.moving_average(signal, w=round(self.fps * self.params.avg_sec_for_local_min_max_extraction))
310315
smothed_signal = Signal.moving_average(signal, w=filter_len)
311316
points, tmp_min_start_idx, tmp_min_end_idx, tmp_max_start_idx, tmp_max_end_idx = [], -1, -1, -1, -1
@@ -532,7 +537,7 @@ def decimate(self,
532537
if base_point_algorithm == self.BasePointAlgorithm.direction_changes:
533538
decimated_indexes = self.get_direction_changes(signal, filter_len = self.params.direction_change_filter_len)
534539
elif base_point_algorithm == self.BasePointAlgorithm.local_min_max:
535-
decimated_indexes = self.get_local_min_max_points(signal)
540+
decimated_indexes = self.get_local_min_max_points(signal, filter_len = self.params.local_min_max_filter_len)
536541
else:
537542
raise NotImplementedError("Selected Base Point Algorithm is not implemented")
538543

funscript_editor/config/hyperparameter.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ user_reaction_time_in_milliseconds: 1000
1414

1515
# Signal Processing Hyperparameter
1616
signal:
17+
18+
# Specify the window size for the calculation of smothed tracking signal for the local min and max search.
19+
# For slow actions in you video i recommend to set this to '3' to reduce the false positive predictions.
20+
local_max_min_filter_len: 1
21+
1722
# Specify the window size for the calculation of the reference value for the local min and max search.
1823
avg_sec_for_local_min_max_extraction: 2.0
1924

0 commit comments

Comments
 (0)