Skip to content

Commit 7019bea

Browse files
author
arch
committed
add prototype for kalman filter
1 parent e321f9a commit 7019bea

File tree

6 files changed

+119
-10
lines changed

6 files changed

+119
-10
lines changed

docs/app/docs/user-guide/config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Config Files:
1919

2020
#### `settings.yaml`
2121

22+
- `use_kalman_filter`: Enable Kalman Filter
2223
- `use_zoom` (bool): Enable or disable an additional step to zoom in the Video before selecting a tracking feature for the Woman or Men.
2324
- `zoom_factor:` (float): Set the desired zoom value which will be used when the zoom function is activated.
2425
- `tracking_direction` (str): Specify the tracking direction. Allowed values are `'x'` and `'y'`.

funscript_editor/algorithms/funscriptgenerator.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from funscript_editor.utils.config import HYPERPARAMETER, SETTINGS, PROJECTION
1919
from datetime import datetime
2020
from funscript_editor.data.ffmpegstream import FFmpegStream, VideoInfo
21+
from funscript_editor.algorithms.kalmanfilter import KalmanFilter2D
2122

2223
import funscript_editor.algorithms.signalprocessing as sp
2324
import numpy as np
@@ -44,6 +45,7 @@ class FunscriptGeneratorParameter:
4445
bottom_threshold: float = float(HYPERPARAMETER['bottom_threshold'])
4546
preview_scaling: float = float(SETTINGS['preview_scaling'])
4647
projection: str = str(SETTINGS['projection']).lower()
48+
use_kalman_filter: bool = SETTINGS['use_kalman_filter']
4749

4850

4951
class FunscriptGenerator(QtCore.QThread):
@@ -403,13 +405,13 @@ def delete_last_tracking_predictions(self, num :int) -> None:
403405
Args:
404406
num (int): number of frames to remove from predicted boxes
405407
"""
406-
if len(self.bboxes['Woman']) <= num-1:
408+
if len(self.bboxes['Woman']) <= num:
407409
self.bboxes['Woman'] = []
408410
self.bboxes['Men'] = []
409411
else:
410-
for i in range(len(self.bboxes['Woman'])-1,len(self.bboxes['Woman'])-num,-1):
411-
del self.bboxes['Woman'][i]
412-
if self.params.track_men: del self.bboxes['Men'][i]
412+
for _ in range(num):
413+
del self.bboxes['Woman'][-1]
414+
if self.params.track_men: del self.bboxes['Men'][-1]
413415

414416

415417
def preview_scaling(self, preview_image :np.ndarray) -> np.ndarray:
@@ -438,7 +440,7 @@ def get_vr_projection_config(self, image :np.ndarray) -> dict:
438440
Returns:
439441
dict: projection config
440442
"""
441-
config = PROJECTION[self.params.projection]
443+
config = copy.deepcopy(PROJECTION[self.params.projection])
442444

443445
self.determine_preview_scaling(config['parameter']['width'], config['parameter']['height'])
444446

@@ -639,7 +641,7 @@ def tracking(self) -> str:
639641

640642
if self.was_key_pressed('q') or cv2.waitKey(1) == ord('q'):
641643
status = 'Tracking stopped by user'
642-
self.delete_last_tracking_predictions(int(self.get_average_tracking_fps()+1)*3)
644+
self.delete_last_tracking_predictions(int((self.get_average_tracking_fps()+1)*2.2))
643645
break
644646

645647
(successWoman, bboxWoman) = trackerWoman.result()
@@ -663,6 +665,7 @@ def tracking(self) -> str:
663665

664666
video.stop()
665667
self.logger.info(status)
668+
self.logger.info('Calculate score')
666669
self.calculate_score()
667670
return status
668671

@@ -762,12 +765,38 @@ def get_score_with_offset(self, idx_dict) -> list:
762765
return score
763766

764767

768+
def apply_kalman_filter(self) -> None:
769+
""" Apply Kalman Filter to the tracking prediction """
770+
if len(self.bboxes['Woman']) < self.video_info.fps: return
771+
772+
# TODO: we should use the center of the tracking box not x0,y0 of the box
773+
774+
self.logger.info("Apply kalman filter")
775+
kalman = KalmanFilter2D(self.video_info.fps)
776+
kalman.init(self.bboxes['Woman'][0][0], self.bboxes['Woman'][0][1])
777+
for idx, item in enumerate(self.bboxes['Woman']):
778+
prediction = kalman.update(item[0], item[1])
779+
self.bboxes['Woman'][idx] = (prediction[0], prediction[1], item[2], item[3])
780+
781+
if self.params.track_men:
782+
kalman = KalmanFilter2D(self.video_info.fps)
783+
kalman.init(self.bboxes['Men'][0][0], self.bboxes['Men'][0][1])
784+
for idx, item in enumerate(self.bboxes['Men']):
785+
prediction = kalman.update(item[0], item[1])
786+
self.bboxes['Men'][idx] = (prediction[0], prediction[1], item[2], item[3])
787+
788+
765789
def run(self) -> None:
766790
""" The Funscript Generator Thread Function """
767791
# NOTE: score['y'] and score['x'] should have the same number size so it should be enouth to check one score length
768792
with Listener(on_press=self.on_key_press) as listener:
769793
status = self.tracking()
794+
795+
if self.params.use_kalman_filter:
796+
self.apply_kalman_filter()
797+
770798
if len(self.score['y']) >= HYPERPARAMETER['min_frames']:
799+
self.logger.info("Scale score")
771800
if self.params.direction != 'x':
772801
self.scale_score(status, direction='y')
773802
else:
@@ -777,6 +806,7 @@ def run(self) -> None:
777806
self.finished(status + ' -> Tracking time insufficient', False)
778807
return
779808

809+
self.logger.info("Determine local max and min")
780810
if self.params.direction != 'x':
781811
idx_dict = sp.get_local_max_and_min_idx(self.score['y'], self.video_info.fps)
782812
else:
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
""" Kalman Filter """
2+
3+
import cv2
4+
5+
import numpy as np
6+
7+
8+
class KalmanFilter2D:
9+
""" Kalman 2D Filter
10+
11+
Args:
12+
fps (float): frames per second
13+
"""
14+
15+
def __init__(self, fps: float):
16+
self.kalman = cv2.KalmanFilter(4, 2)
17+
self.prediction = np.zeros((2, 1), np.float32)
18+
self.prediction_counter = 0
19+
self.fps = fps
20+
21+
22+
def init(self, x: float, y: float) -> None:
23+
""" Initialize the 2D Kalman Filter
24+
25+
Args:
26+
x (float): x measurement
27+
y (float): y measurement
28+
"""
29+
# dt = 1.0 / self.fps
30+
dt = 1.0
31+
32+
self.kalman.measurementMatrix = np.array(
33+
[
34+
[1, 0, 0, 0],
35+
[0, 1, 0, 0]
36+
], np.float32)
37+
38+
self.kalman.transitionMatrix = np.array(
39+
[
40+
[1, 0, dt,0 ],
41+
[0, 1, 0, dt],
42+
[0, 0, 1, 0 ],
43+
[0, 0, 0, 1 ]
44+
], np.float32)
45+
46+
self.kalman.processNoiseCov = np.array(
47+
[
48+
[1, 0, 0, 0],
49+
[0, 1, 0, 0],
50+
[0, 0, 1, 0],
51+
[0, 0, 0, 1]
52+
], np.float32) * 0.03
53+
54+
self.kalman.correct(np.array([np.float32(x-1), np.float32(y-1)], np.float32))
55+
self.prediction = self.kalman.predict()
56+
self.prediction_counter += 1
57+
58+
59+
def update(self, x: float, y: float) -> np.ndarray:
60+
""" Update the 2D Kalman Filter
61+
62+
Args:
63+
x (float): x measurement
64+
y (float): y measurement
65+
66+
Returns:
67+
list: prediction [x', y', vx', vy']
68+
"""
69+
self.kalman.correct(np.array([np.float32(x), np.float32(y)], np.float32))
70+
self.prediction = self.kalman.predict()
71+
self.prediction_counter += 1
72+
# TODO howo to init/fit the OpenCV Kalman Filter?
73+
if self.prediction_counter < self.fps:
74+
return [round(x), round(y), 0.0, 0.0]
75+
else:
76+
return [round(item[0]) if idx < 2 else float(item[0]) for idx, item in enumerate(self.prediction)]

funscript_editor/config/hyperparameter.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ avg_sec_for_local_min_max_extraction: 1.9
1111

1212
# Specify the minimum required frames for the tracking. Wee need this parameter to
1313
# ensure there is at leas two strokes in the tracking result.
14-
min_frames: 120
14+
min_frames: 100
1515

1616
# Shift predicted top points by given frame number. Positive values delay the position
1717
# and negative values result in an earlier position.

funscript_editor/config/settings.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Python Funscript Editor settings
22

3+
# Enable Kalman Filter
4+
use_kalman_filter: False
5+
36
# Enable or disable an additional step to zoom in the Video before selecting a tracking
47
# feature for the Woman or Men.
58
use_zoom: False

funscript_editor/data/ffmpegstream.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def get_projection(
155155
(config['parameter']['height'], config['parameter']['width'], 3)
156156
)
157157

158+
pipe.terminate()
158159
pipe.stdin.close()
159160
pipe.stdout.close()
160-
pipe.terminate()
161161

162162
return projection
163163

@@ -218,7 +218,6 @@ def millisec_to_timestamp(millis :int)->str:
218218
def stop(self) -> None:
219219
""" Stop FFmpeg video stream """
220220
self.stopped = True
221-
self.thread.join()
222221

223222

224223
def read(self) -> np.ndarray:
@@ -305,5 +304,5 @@ def run(self) -> None:
305304

306305
self.stopped = True
307306
self.logger.info('Close FFmpeg Stream')
308-
pipe.stdout.close()
309307
pipe.terminate()
308+
pipe.stdout.close()

0 commit comments

Comments
 (0)