Skip to content

Commit 0448d00

Browse files
author
arch
committed
test code for optical flow
1 parent a686609 commit 0448d00

File tree

3 files changed

+296
-18
lines changed

3 files changed

+296
-18
lines changed

funscript_editor/algorithms/funscriptgenerator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def __init__(self,
9898
self.params = params
9999
self.funscripts = funscripts
100100
self.video_info = FFmpegStream.get_video_info(self.params.video_path)
101-
self.tracking_fps = []
102101
self.score = {
103102
'x': [],
104103
'y': [],
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
from os import path
2+
import cv2
3+
import time
4+
import json
5+
import copy
6+
import logging
7+
8+
import numpy as np
9+
10+
from funscript_editor.data.ffmpegstream import FFmpegStream
11+
from dataclasses import dataclass
12+
from sklearn.decomposition import PCA
13+
from PyQt5 import QtCore
14+
from funscript_editor.algorithms.signal import Signal
15+
from funscript_editor.ui.opencvui import OpenCV_GUI, OpenCV_GUI_Parameters
16+
17+
@dataclass
18+
class OpticalFlowFunscriptGeneratorParameter:
19+
""" Funscript Generator Parameter Dataclass with default values """
20+
video_path: str
21+
projection: str
22+
start_frame: int
23+
end_frame: int = -1 # default is video end (-1)
24+
skip_frames: int = 0
25+
min_trajectory_len: int = 60
26+
feature_detect_interval: int = 10
27+
movement_filter: float = 10.0
28+
29+
30+
class OpticalFlowFunscriptGeneratorThread(QtCore.QThread):
31+
""" Funscript Generator Thread
32+
33+
Args:
34+
params (OpticalFlowFunscriptGeneratorParameter): required parameter for the funscript generator
35+
funscript (dict): the references to the Funscript where we store the predicted actions
36+
"""
37+
38+
def __init__(self,
39+
params: OpticalFlowFunscriptGeneratorParameter,
40+
funscripts: dict):
41+
QtCore.QThread.__init__(self)
42+
self.logger = logging.getLogger(__name__)
43+
self.params = params
44+
self.funscripts = funscripts
45+
self.video_info = FFmpegStream.get_video_info(self.params.video_path)
46+
47+
self.ui = OpenCV_GUI(OpenCV_GUI_Parameters(
48+
video_info = self.video_info,
49+
skip_frames = self.params.skip_frames,
50+
end_frame_number = self.params.end_frame
51+
))
52+
53+
54+
#: completed event with reference to the funscript with the predicted actions, status message and success flag
55+
funscriptCompleted = QtCore.pyqtSignal(dict, str, bool)
56+
57+
58+
class OpticalFlowPyrLK:
59+
60+
def __init__(self, min_trajectory_len, feature_detect_interval):
61+
self.min_trajectory_len = min_trajectory_len
62+
self.feature_detect_interval = feature_detect_interval
63+
self.trajectories = []
64+
self.frame_idx = 0
65+
self.prev_frame_gray = None
66+
self.result = []
67+
68+
self.lk_params = dict(
69+
winSize = (15, 15),
70+
maxLevel = 2,
71+
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03)
72+
)
73+
74+
self.feature_params = dict(
75+
maxCorners = 20,
76+
qualityLevel = 0.3,
77+
minDistance = 10,
78+
blockSize = 7
79+
)
80+
81+
82+
def update(self, frame_roi):
83+
frame_gray = cv2.cvtColor(frame_roi, cv2.COLOR_BGR2GRAY)
84+
if len(self.trajectories) > 0:
85+
p0 = np.float32([trajectory[-1] for trajectory in self.trajectories]).reshape(-1, 1, 2)
86+
p1, _, _ = cv2.calcOpticalFlowPyrLK(self.prev_frame_gray, frame_gray, p0, None, **self.lk_params)
87+
p0r, _, _ = cv2.calcOpticalFlowPyrLK(frame_gray, self.prev_frame_gray, p1, None, **self.lk_params)
88+
d = abs(p0-p0r).reshape(-1, 2).max(-1)
89+
good = d < 1
90+
91+
new_trajectories = []
92+
for trajectory, (x, y), good_flag in zip(self.trajectories, p1.reshape(-1, 2), good):
93+
if not good_flag:
94+
if len (trajectory) > self.min_trajectory_len:
95+
self.result.append({'end': self.frame_idx, 'trajectory': trajectory})
96+
continue
97+
trajectory.append((x, y))
98+
new_trajectories.append(trajectory)
99+
100+
self.trajectories = new_trajectories
101+
102+
103+
if self.frame_idx % self.feature_detect_interval == 0:
104+
mask = np.zeros_like(frame_gray)
105+
mask[:] = 255
106+
p = cv2.goodFeaturesToTrack(frame_gray, mask = mask, **self.feature_params)
107+
if p is not None:
108+
for x, y in np.float32(p).reshape(-1, 2):
109+
if any(abs(t[-1][0] - x) < 3 and abs(t[-1][1] - y) < 3 for t in self.trajectories):
110+
continue
111+
112+
self.trajectories.append([(x, y)])
113+
114+
115+
self.frame_idx += 1
116+
self.prev_frame_gray = frame_gray
117+
118+
return [t[-1] for t in self.trajectories]
119+
120+
121+
def get_result(self):
122+
result = copy.deepcopy(self.result)
123+
for trajectory in self.trajectories:
124+
if len (trajectory) > self.min_trajectory_len:
125+
result.append({'end': self.frame_idx, 'trajectory': trajectory})
126+
127+
return { 'meta': { 'last_idx': self.frame_idx }, 'data': result }
128+
129+
130+
def extract_movement(self, optical_flow_result, metric_idx = 1):
131+
result = []
132+
for r in optical_flow_result['data']:
133+
zero_before = r['end'] - len(r['trajectory'])
134+
zero_after = optical_flow_result['meta']['last_idx'] - r['end']
135+
trajectory_min = min([item[metric_idx] for item in r['trajectory']])
136+
y = [0 for _ in range(zero_before)] + [(r['trajectory'][i][metric_idx] - trajectory_min)**2 for i in range(len(r['trajectory']))] + [0 for _ in range(zero_after)]
137+
if max(y) - min(y) > self.params.movement_filter:
138+
result.append(y)
139+
140+
return result
141+
142+
143+
def get_absolute_framenumber(self, frame_number: int) -> int:
144+
""" Get the absoulte frame number
145+
146+
Args:
147+
frame_number (int): relative frame number
148+
149+
Returns:
150+
int: absolute frame position
151+
"""
152+
return self.params.start_frame + frame_number
153+
154+
155+
def tracking(self) -> str:
156+
""" Tracking function to track the features in the video
157+
158+
Returns:
159+
str: a process status message e.g. 'end of video reached'
160+
"""
161+
first_frame = FFmpegStream.get_frame(self.params.video_path, self.params.start_frame)
162+
163+
projection_config = self.ui.get_video_projection_config(first_frame, self.params.projection)
164+
165+
video = FFmpegStream(
166+
video_path = self.params.video_path,
167+
config = projection_config,
168+
skip_frames = self.params.skip_frames,
169+
start_frame = self.params.start_frame
170+
)
171+
172+
first_frame = video.read()
173+
if first_frame is None:
174+
return "FFmpeg could not extract the first video frame"
175+
176+
roi = self.ui.bbox_selector(
177+
first_frame,
178+
"Select ROI",
179+
)
180+
181+
optical_flow = OpticalFlowFunscriptGeneratorThread.OpticalFlowPyrLK(
182+
min_trajectory_len = self.params.min_trajectory_len,
183+
feature_detect_interval = self.params.feature_detect_interval
184+
)
185+
186+
status = "End of video reached"
187+
frame_num = 1 # first frame is init frame
188+
while video.isOpen():
189+
frame = video.read()
190+
frame_num += (self.params.skip_frames+1)
191+
192+
if frame is None:
193+
status = 'Reach a corrupt video frame' if video.isOpen() else 'End of video reached'
194+
break
195+
196+
if self.params.end_frame > 0 and frame_num + self.params.start_frame >= self.params.end_frame:
197+
status = "Tracking stop at existing action point"
198+
break
199+
200+
frame_roi = frame[roi[1]:roi[1]+roi[3], roi[0]:roi[0]+roi[2], :]
201+
current_features = optical_flow.update(frame_roi)
202+
203+
for f in current_features:
204+
cv2.circle(frame, (int(roi[0]+f[0]), int(roi[1]+f[1])), 3, (0, 0, 255), -1)
205+
206+
key = self.ui.preview(
207+
frame,
208+
frame_num + self.params.start_frame,
209+
texte = ["Press 'q' to stop tracking"],
210+
boxes = [roi],
211+
)
212+
213+
if self.ui.was_key_pressed('q') or key == ord('q'):
214+
status = 'Tracking stopped by user'
215+
break
216+
217+
result = optical_flow.get_result()
218+
result = self.extract_movement(result)
219+
220+
pca = PCA(n_components=2)
221+
principalComponents = pca.fit_transform(np.transpose(np.array(result)))
222+
result = np.transpose(np.array(principalComponents))
223+
224+
result = np.array(result[0]) - np.array(result[1])
225+
226+
signal = Signal(self.video_info.fps)
227+
points = signal.get_local_min_max_points(result)
228+
229+
val = 0
230+
for k in self.funscripts:
231+
for p in points:
232+
self.funscripts[k].add_action(
233+
val,
234+
FFmpegStream.frame_to_millisec(self.get_absolute_framenumber(p * (1+self.params.skip_frames)), self.video_info.fps)
235+
)
236+
val = 0 if val != 0 else 100
237+
238+
return status
239+
240+
241+
def finished(self, status: str, success :bool) -> None:
242+
""" Process necessary steps to complete the predicted funscript
243+
244+
Args:
245+
status (str): a process status/error message
246+
success (bool): True if funscript was generated else False
247+
"""
248+
for metric in self.funscripts.keys():
249+
# we use this flag internaly to determine if we have to invert the score
250+
# ensure not to publish the invertion with our generated funscript
251+
# in this case we will invert our result again, what we dont want
252+
self.funscripts[metric].data["inverted"] = False
253+
self.ui.close()
254+
self.funscriptCompleted.emit(self.funscripts, status, success)
255+
256+
257+
def run(self) -> None:
258+
try:
259+
status = self.tracking()
260+
self.finished(status, True)
261+
except Exception as ex:
262+
self.logger.critical("The program crashed due to a fatal error", exc_info=ex)
263+
self.finished("The program crashed due to a fatal error", False)

funscript_editor/ui/funscript_generator_window.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
from PyQt5 import QtCore, QtGui, QtWidgets
1717

18+
USE_OPTICALFLOW = True
19+
if USE_OPTICALFLOW:
20+
from funscript_editor.algorithms.opticalflow import OpticalFlowFunscriptGeneratorThread, OpticalFlowFunscriptGeneratorParameter
1821

1922
class FunscriptGeneratorWindow(QtWidgets.QMainWindow):
2023
""" Class to Generate a funscript with minimal UI
@@ -127,22 +130,35 @@ def run(self) -> None:
127130
self.__logger.info('settings: %s', str(self.settings))
128131
self.settings['videoType'] = list(filter(lambda x: PROJECTION[x]['name'] == self.settings['videoType'], PROJECTION.keys()))[0]
129132
self.funscripts = {k.replace('inverted', '').strip(): Funscript(self.fps, inverted = "inverted" in k) for k in self.settings['trackingMetrics'].split('+')}
130-
self.funscript_generator = FunscriptGeneratorThread(
131-
FunscriptGeneratorParameter(
132-
video_path = self.video_file,
133-
track_men = 'two' in self.settings['trackingMethod'],
134-
supervised_tracking = 'Supervised' in self.settings['trackingMethod'],
135-
supervised_tracking_is_exit_condition = "stopping" in self.settings['trackingMethod'],
136-
projection = self.settings['videoType'],
137-
start_frame = self.start_frame,
138-
end_frame = self.end_frame,
139-
number_of_trackers = int(self.settings['numberOfTracker']),
140-
points = self.settings['points'].lower().replace(' ', '_'),
141-
additional_points = self.settings['additionalPoints'].lower().replace(' ', '_'),
142-
skip_frames = int(self.settings['processingSpeed']),
143-
top_points_offset = self.settings['topPointOffset'],
144-
bottom_points_offset = self.settings['bottomPointOffset']
145-
),
146-
self.funscripts)
133+
134+
if USE_OPTICALFLOW:
135+
self.funscript_generator = OpticalFlowFunscriptGeneratorThread(
136+
OpticalFlowFunscriptGeneratorParameter(
137+
video_path = self.video_file,
138+
projection = self.settings['videoType'],
139+
start_frame = self.start_frame,
140+
end_frame = self.end_frame,
141+
skip_frames = int(self.settings['processingSpeed'])
142+
),
143+
self.funscripts)
144+
else:
145+
self.funscript_generator = FunscriptGeneratorThread(
146+
FunscriptGeneratorParameter(
147+
video_path = self.video_file,
148+
track_men = 'two' in self.settings['trackingMethod'],
149+
supervised_tracking = 'Supervised' in self.settings['trackingMethod'],
150+
supervised_tracking_is_exit_condition = "stopping" in self.settings['trackingMethod'],
151+
projection = self.settings['videoType'],
152+
start_frame = self.start_frame,
153+
end_frame = self.end_frame,
154+
number_of_trackers = int(self.settings['numberOfTracker']),
155+
points = self.settings['points'].lower().replace(' ', '_'),
156+
additional_points = self.settings['additionalPoints'].lower().replace(' ', '_'),
157+
skip_frames = int(self.settings['processingSpeed']),
158+
top_points_offset = self.settings['topPointOffset'],
159+
bottom_points_offset = self.settings['bottomPointOffset']
160+
),
161+
self.funscripts)
162+
147163
self.funscript_generator.funscriptCompleted.connect(self.__funscript_generated)
148164
self.funscript_generator.start()

0 commit comments

Comments
 (0)