Skip to content

Commit 869fe09

Browse files
author
arch
committed
improve optical flow
1 parent 8a534ec commit 869fe09

File tree

2 files changed

+154
-33
lines changed

2 files changed

+154
-33
lines changed

funscript_editor/algorithms/opticalflow.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
from funscript_editor.data.ffmpegstream import FFmpegStream
1111
from dataclasses import dataclass
12-
from sklearn.decomposition import PCA
1312
from PyQt5 import QtCore
1413
from funscript_editor.algorithms.signal import Signal
1514
from funscript_editor.ui.opencvui import OpenCV_GUI, OpenCV_GUI_Parameters
15+
from funscript_editor.algorithms.ppca import PPCA
1616

1717
import matplotlib.pyplot as plt
1818

@@ -24,7 +24,7 @@ class OpticalFlowFunscriptGeneratorParameter:
2424
start_frame: int
2525
end_frame: int = -1 # default is video end (-1)
2626
skip_frames: int = 0
27-
min_trajectory_len: int = 50
27+
min_trajectory_len: int = 40
2828
feature_detect_interval: int = 10
2929
movement_filter: float = 10.0
3030

@@ -59,9 +59,10 @@ def __init__(self,
5959

6060
class OpticalFlowPyrLK:
6161

62-
def __init__(self, min_trajectory_len, feature_detect_interval):
62+
def __init__(self, min_trajectory_len, feature_detect_interval, feature_area):
6363
self.min_trajectory_len = min_trajectory_len
6464
self.feature_detect_interval = feature_detect_interval
65+
self.feature_area = feature_area
6566
self.trajectories = []
6667
self.frame_idx = 0
6768
self.prev_frame_gray = None
@@ -94,6 +95,7 @@ def update(self, frame_roi):
9495
for trajectory, (x, y), good_flag in zip(self.trajectories, p1.reshape(-1, 2), good):
9596
if not good_flag:
9697
if len (trajectory) > self.min_trajectory_len:
98+
# print('add trajectorie from', self.frame_idx - len(trajectory), 'to', self.frame_idx)
9799
self.result.append({'end': self.frame_idx, 'trajectory': trajectory})
98100
continue
99101
trajectory.append((x, y))
@@ -103,11 +105,14 @@ def update(self, frame_roi):
103105

104106

105107
if len(self.trajectories) == 0 or self.frame_idx % self.feature_detect_interval == 0:
106-
mask = np.zeros_like(frame_gray)
108+
seach_img = frame_gray[self.feature_area[1]:self.feature_area[1]+self.feature_area[3], self.feature_area[0]:self.feature_area[0]+self.feature_area[2]]
109+
mask = np.zeros_like(seach_img)
107110
mask[:] = 255
108-
p = cv2.goodFeaturesToTrack(frame_gray, mask = mask, **self.feature_params)
111+
p = cv2.goodFeaturesToTrack(seach_img, mask = mask, **self.feature_params)
109112
if p is not None:
110113
for x, y in np.float32(p).reshape(-1, 2):
114+
x += self.feature_area[0]
115+
y += self.feature_area[1]
111116
if any(abs(t[-1][0] - x) < 3 and abs(t[-1][1] - y) < 3 for t in self.trajectories):
112117
continue
113118

@@ -135,8 +140,9 @@ def extract_movement(self, optical_flow_result, metric_idx = 1, filter_static_po
135140
zero_before = r['end'] - len(r['trajectory'])
136141
zero_after = optical_flow_result['meta']['last_idx'] - r['end']
137142
trajectory_min = min([item[metric_idx] for item in r['trajectory']])
138-
y = [0 for _ in range(zero_before)] + [(r['trajectory'][i][metric_idx] - trajectory_min)**2 for i in range(len(r['trajectory']))] + [0 for _ in range(zero_after)]
139-
if not filter_static_points or (max(y) - min(y)) > self.params.movement_filter:
143+
action = [(r['trajectory'][i][metric_idx] - trajectory_min) for i in range(len(r['trajectory']))]
144+
y = [None for _ in range(zero_before)] + action + [None for _ in range(zero_after)]
145+
if not filter_static_points or (max(action) - min(action)) > self.params.movement_filter:
140146
result.append(y)
141147

142148
return result
@@ -145,13 +151,16 @@ def extract_movement(self, optical_flow_result, metric_idx = 1, filter_static_po
145151
def get_absolute_framenumber(self, frame_number: int) -> int:
146152
""" Get the absoulte frame number
147153
154+
Note:
155+
We have an offset of 1 because we use the first frame for init
156+
148157
Args:
149158
frame_number (int): relative frame number
150159
151160
Returns:
152161
int: absolute frame position
153162
"""
154-
return self.params.start_frame + frame_number
163+
return self.params.start_frame + frame_number + 1
155164

156165

157166
def tracking(self) -> str:
@@ -175,14 +184,43 @@ def tracking(self) -> str:
175184
if first_frame is None:
176185
return "FFmpeg could not extract the first video frame"
177186

178-
roi = self.ui.bbox_selector(
179-
first_frame,
187+
preview_frame = copy.copy(first_frame)
188+
search_roi = self.ui.bbox_selector(
189+
preview_frame,
180190
"Select observe area of an single person",
181191
)
182192

193+
preview_frame = self.ui.draw_box_to_image(
194+
preview_frame,
195+
search_roi,
196+
color=(0,255,0)
197+
)
198+
199+
while True:
200+
feature_roi = self.ui.bbox_selector(
201+
preview_frame,
202+
"Select feature area inside the observe area",
203+
)
204+
205+
if feature_roi[0] > search_roi[0] \
206+
and feature_roi[1] > search_roi[1] \
207+
and feature_roi[0] + feature_roi[2] < search_roi[0] + search_roi[2] \
208+
and feature_roi[1] + feature_roi[3] < search_roi[1] + search_roi[3]:
209+
break
210+
211+
self.logger.warning("Invalid feature")
212+
213+
feature_roi = [
214+
feature_roi[0] - search_roi[0],
215+
feature_roi[1] - search_roi[1],
216+
feature_roi[2],
217+
feature_roi[3]
218+
]
219+
183220
optical_flow = OpticalFlowFunscriptGeneratorThread.OpticalFlowPyrLK(
184221
min_trajectory_len = self.params.min_trajectory_len,
185-
feature_detect_interval = self.params.feature_detect_interval
222+
feature_detect_interval = self.params.feature_detect_interval,
223+
feature_area = feature_roi
186224
)
187225

188226
status = "End of video reached"
@@ -199,48 +237,31 @@ def tracking(self) -> str:
199237
status = "Tracking stop at existing action point"
200238
break
201239

202-
frame_roi = frame[roi[1]:roi[1]+roi[3], roi[0]:roi[0]+roi[2], :]
240+
frame_roi = frame[search_roi[1]:search_roi[1]+search_roi[3], search_roi[0]:search_roi[0]+search_roi[2], :]
203241
current_features = optical_flow.update(frame_roi)
204242

205243
for f in current_features:
206-
cv2.circle(frame, (int(roi[0]+f[0]), int(roi[1]+f[1])), 3, (0, 0, 255), -1)
244+
cv2.circle(frame, (int(search_roi[0]+f[0]), int(search_roi[1]+f[1])), 3, (0, 0, 255), -1)
207245

208246
key = self.ui.preview(
209247
frame,
210248
frame_num + self.params.start_frame,
211249
texte = ["Press 'q' to stop tracking"],
212-
boxes = [roi],
250+
boxes = [search_roi],
213251
)
214252

215253
if self.ui.was_key_pressed('q') or key == ord('q'):
216254
status = 'Tracking stopped by user'
217255
break
218256

219257
result = optical_flow.get_result()
220-
221-
# for filter_static_points in [True, False]:
222-
# test = self.extract_movement(result, filter_static_points=filter_static_points)
223-
# for i in [1, 2, 3, 4]:
224-
# pca = PCA(n_components=i)
225-
# principalComponents = pca.fit_transform(np.transpose(np.array(test)))
226-
# test_result = np.array(principalComponents)
227-
# plt.plot(test_result)
228-
# plt.savefig('debug_{}_{}.png'.format(filter_static_points, i), dpi=400)
229-
# plt.close()
230-
231258
result = self.extract_movement(result)
232259

233-
pca = PCA(n_components=1)
234-
principalComponents = pca.fit_transform(np.transpose(np.array(result)))
235-
result = [x[0] for x in principalComponents]
236-
237-
# option for pca 2 with two moving persons:
238-
# result = np.transpose(np.array(principalComponents))
239-
# result = np.array(result[0]) - np.array(result[1])
260+
_, _, _, principalComponents, _ = PPCA(np.transpose(np.array(result, dtype=float)), d=1)
261+
result = [x[0] for x in principalComponents.tolist()]
240262

241263
signal = Signal(self.video_info.fps)
242264
points = signal.get_local_min_max_points(result)
243-
# points = signal.get_direction_changes(result, filter_len=4)
244265
categorized_points = signal.categorize_points(result, points)
245266

246267
for k in self.funscripts:
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Python implemention of PPCA-EM for data with missing values
4+
Adapted from MATLAB implemention from J.J. VerBeek
5+
"""
6+
7+
import numpy as np
8+
from numpy import shape, isnan, nanmean, average, log, cov
9+
from numpy.matlib import repmat
10+
from numpy.random import normal
11+
from numpy.linalg import inv, det, eig
12+
from numpy import identity as eye
13+
from numpy import trace as tr
14+
from scipy.linalg import orth
15+
16+
17+
def PPCA(Y, d):
18+
"""
19+
Implements probabilistic PCA for data with missing values,
20+
using a factorizing distribution over hidden states and hidden observations.
21+
22+
Args:
23+
Y (np.ndarray) input numpy ndarray of data vectors (N by D)
24+
d (int): dimension of latent space
25+
26+
Returns:
27+
C (D by d): C*C' + I*ss is covariance model, C has scaled principal directions as cols
28+
ss (float): isotropic variance outside subspace
29+
M (D by 1): data mean
30+
X (N by d): expected states
31+
Ye (N by D): expected complete observations (differs from Y if data is missing)
32+
33+
Based on MATLAB code from J.J. VerBeek, 2006. http://lear.inrialpes.fr/~verbeek
34+
"""
35+
N, D = shape(Y) # N observations in D dimensions (i.e. D is number of features, N is samples)
36+
threshold = 1E-4 # minimal relative change in objective function to continue
37+
hidden = isnan(Y)
38+
missing = hidden.sum()
39+
40+
if missing > 0:
41+
M = nanmean(Y, axis=0)
42+
else:
43+
M = average(Y, axis=0)
44+
45+
Ye = Y - repmat(M, N, 1)
46+
47+
if missing > 0:
48+
Ye[hidden] = 0
49+
50+
# initialize
51+
C = normal(loc=0.0, scale=1.0, size=(D, d))
52+
CtC = C.T @ C
53+
X = Ye @ C @ inv(CtC)
54+
recon = X @ C.T
55+
recon[hidden] = 0
56+
ss = np.sum((recon - Ye) ** 2) / (N * D - missing)
57+
58+
count = 1
59+
old = np.inf
60+
61+
# EM Iterations
62+
while (count):
63+
Sx = inv(eye(d) + CtC / ss) # E-step, covariances
64+
ss_old = ss
65+
if missing > 0:
66+
proj = X @ C.T
67+
Ye[hidden] = proj[hidden]
68+
69+
X = Ye @ C @ Sx / ss # E-step: expected values
70+
71+
SumXtX = X.T @ X # M-step
72+
C = Ye.T @ X @ (SumXtX + N * Sx).T @ inv(((SumXtX + N * Sx) @ (SumXtX + N * Sx).T))
73+
CtC = C.T @ C
74+
ss = (np.sum((X @ C.T - Ye) ** 2) + N * np.sum(CtC * Sx) + missing * ss_old) / (N * D)
75+
# transform Sx determinant into numpy longdouble in order to deal with high dimensionality
76+
Sx_det = np.min(Sx).astype(np.longdouble) ** shape(Sx)[0] * det(Sx / np.min(Sx))
77+
objective = N * D + N * (D * log(ss) + tr(Sx) - log(Sx_det)) + tr(SumXtX) - missing * log(ss_old)
78+
79+
rel_ch = np.abs(1 - objective / old)
80+
old = objective
81+
82+
count = count + 1
83+
if rel_ch < threshold and count > 5:
84+
count = 0
85+
86+
C = orth(C)
87+
covM = cov((Ye @ C).T)
88+
if d == 1:
89+
covM = [[covM]]
90+
vals, vecs = eig(covM)
91+
ordr = np.argsort(vals)[::-1]
92+
vecs = vecs[:, ordr]
93+
94+
C = C @ vecs
95+
X = Ye @ C
96+
97+
# add data mean to expected complete data
98+
Ye = Ye + repmat(M, N, 1)
99+
100+
return C, ss, M, X, Ye

0 commit comments

Comments
 (0)