1818from funscript_editor .utils .config import HYPERPARAMETER , SETTINGS , PROJECTION
1919from datetime import datetime
2020from funscript_editor .data .ffmpegstream import FFmpegStream , VideoInfo
21+ from funscript_editor .algorithms .kalmanfilter import KalmanFilter2D
2122
2223import funscript_editor .algorithms .signalprocessing as sp
2324import numpy as np
@@ -44,6 +45,7 @@ class FunscriptGeneratorParameter:
4445 bottom_threshold : float = float (HYPERPARAMETER ['bottom_threshold' ])
4546 preview_scaling : float = float (SETTINGS ['preview_scaling' ])
4647 projection : str = str (SETTINGS ['projection' ]).lower ()
48+ use_kalman_filter : bool = SETTINGS ['use_kalman_filter' ]
4749
4850
4951class FunscriptGenerator (QtCore .QThread ):
@@ -403,13 +405,13 @@ def delete_last_tracking_predictions(self, num :int) -> None:
403405 Args:
404406 num (int): number of frames to remove from predicted boxes
405407 """
406- if len (self .bboxes ['Woman' ]) <= num - 1 :
408+ if len (self .bboxes ['Woman' ]) <= num :
407409 self .bboxes ['Woman' ] = []
408410 self .bboxes ['Men' ] = []
409411 else :
410- for i in range (len ( self . bboxes [ 'Woman' ]) - 1 , len ( self . bboxes [ 'Woman' ]) - num , - 1 ):
411- del self .bboxes ['Woman' ][i ]
412- if self .params .track_men : del self .bboxes ['Men' ][i ]
412+ for _ in range (num ):
413+ del self .bboxes ['Woman' ][- 1 ]
414+ if self .params .track_men : del self .bboxes ['Men' ][- 1 ]
413415
414416
415417 def preview_scaling (self , preview_image :np .ndarray ) -> np .ndarray :
@@ -438,7 +440,7 @@ def get_vr_projection_config(self, image :np.ndarray) -> dict:
438440 Returns:
439441 dict: projection config
440442 """
441- config = PROJECTION [self .params .projection ]
443+ config = copy . deepcopy ( PROJECTION [self .params .projection ])
442444
443445 self .determine_preview_scaling (config ['parameter' ]['width' ], config ['parameter' ]['height' ])
444446
@@ -639,7 +641,7 @@ def tracking(self) -> str:
639641
640642 if self .was_key_pressed ('q' ) or cv2 .waitKey (1 ) == ord ('q' ):
641643 status = 'Tracking stopped by user'
642- self .delete_last_tracking_predictions (int (self .get_average_tracking_fps ()+ 1 )* 3 )
644+ self .delete_last_tracking_predictions (int (( self .get_average_tracking_fps ()+ 1 )* 2.2 ) )
643645 break
644646
645647 (successWoman , bboxWoman ) = trackerWoman .result ()
@@ -663,6 +665,7 @@ def tracking(self) -> str:
663665
664666 video .stop ()
665667 self .logger .info (status )
668+ self .logger .info ('Calculate score' )
666669 self .calculate_score ()
667670 return status
668671
@@ -762,12 +765,38 @@ def get_score_with_offset(self, idx_dict) -> list:
762765 return score
763766
764767
768+ def apply_kalman_filter (self ) -> None :
769+ """ Apply Kalman Filter to the tracking prediction """
770+ if len (self .bboxes ['Woman' ]) < self .video_info .fps : return
771+
772+ # TODO: we should use the center of the tracking box not x0,y0 of the box
773+
774+ self .logger .info ("Apply kalman filter" )
775+ kalman = KalmanFilter2D (self .video_info .fps )
776+ kalman .init (self .bboxes ['Woman' ][0 ][0 ], self .bboxes ['Woman' ][0 ][1 ])
777+ for idx , item in enumerate (self .bboxes ['Woman' ]):
778+ prediction = kalman .update (item [0 ], item [1 ])
779+ self .bboxes ['Woman' ][idx ] = (prediction [0 ], prediction [1 ], item [2 ], item [3 ])
780+
781+ if self .params .track_men :
782+ kalman = KalmanFilter2D (self .video_info .fps )
783+ kalman .init (self .bboxes ['Men' ][0 ][0 ], self .bboxes ['Men' ][0 ][1 ])
784+ for idx , item in enumerate (self .bboxes ['Men' ]):
785+ prediction = kalman .update (item [0 ], item [1 ])
786+ self .bboxes ['Men' ][idx ] = (prediction [0 ], prediction [1 ], item [2 ], item [3 ])
787+
788+
765789 def run (self ) -> None :
766790 """ The Funscript Generator Thread Function """
767791 # NOTE: score['y'] and score['x'] should have the same number size so it should be enouth to check one score length
768792 with Listener (on_press = self .on_key_press ) as listener :
769793 status = self .tracking ()
794+
795+ if self .params .use_kalman_filter :
796+ self .apply_kalman_filter ()
797+
770798 if len (self .score ['y' ]) >= HYPERPARAMETER ['min_frames' ]:
799+ self .logger .info ("Scale score" )
771800 if self .params .direction != 'x' :
772801 self .scale_score (status , direction = 'y' )
773802 else :
@@ -777,6 +806,7 @@ def run(self) -> None:
777806 self .finished (status + ' -> Tracking time insufficient' , False )
778807 return
779808
809+ self .logger .info ("Determine local max and min" )
780810 if self .params .direction != 'x' :
781811 idx_dict = sp .get_local_max_and_min_idx (self .score ['y' ], self .video_info .fps )
782812 else :
0 commit comments