Skip to content

Commit 87159f1

Browse files
authored
Merge pull request #1 from RaivoKoot/add-demo
New demo for just loading single N-frame clips per video
2 parents 356d82b + 0e3155e commit 87159f1

File tree

3 files changed

+100
-50
lines changed

3 files changed

+100
-50
lines changed

README.md

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ for image in frames:
5757
- [1. Requirements](#1-requirements)
5858
- [2. Custom Dataset](#2-custom-dataset)
5959
- [3. Video Frame Sampling Method](#3-video-frame-sampling-method)
60-
- [4. Using VideoFrameDataset for Training](#4-using-videoframedataset-for-training)
61-
- [5. Conclusion](#5-conclusion)
62-
- [6. Upcoming Features](#6-upcoming-features)
63-
- [7. Acknowledgements](#6-acknowledgements)
60+
- [4. Alternate Video Frame Sampling Methods](#4-alternate-video-frame-sampling-methods)
61+
- [5. Using VideoFrameDataset for Training](#5-using-videoframedataset-for-training)
62+
- [6. Conclusion](#6-conclusion)
63+
- [7. Upcoming Features](#7-upcoming-features)
64+
- [8. Acknowledgements](#8-acknowledgements)
6465

6566
### 1. Requirements
6667
```
@@ -119,33 +120,38 @@ the `imagefile_template` parameter as "img_{:05d}.jpg", is all that it takes to
119120

120121
### 3. Video Frame Sampling Method
121122
When loading a video, only a number of its frames are loaded. They are chosen in the following way:
122-
1. The frame indices [1,N] are divided into NUM_SEGMENTS even segments. From each segment, FRAMES_PER_SEGMENT consecutive indices are chosen at random.
123+
1. The frame indices [1,N] are divided into NUM_SEGMENTS even segments. From each segment, a random start-index is sampled from which FRAMES_PER_SEGMENT consecutive indices are loaded.
123124
This results in NUM_SEGMENTS*FRAMES_PER_SEGMENT chosen indices, whose frames are loaded as PIL images and put into a list and returned when calling
124-
`dataset[i]`.
125-
125+
`dataset[i]`.
126126
![alt text](https://github.com/RaivoKoot/images/blob/main/Sparse_Temporal_Sampling.jpg "Sparse-Temporal-Sampling-Strategy")
127+
128+
### 4. Alternate Video Frame Sampling Methods
129+
If you do not want to use sparse temporal sampling and instead want to sample a single N-frame continuous
130+
clip from a video, this is possible. Set `NUM_SEGMENTS=1` and `FRAMES_PER_SEGMENT=N`. Because VideoFrameDataset
131+
will chose a random start index per segment and take `NUM_SEGMENTS` continuous frames from each sampled start
132+
index, this will result in a single N-frame continuous clip per video. An example of this is in `demo.py`.
127133

128-
### 4. Using VideoFrameDataset for training
134+
### 5. Using VideoFrameDataset for training
129135
As demonstrated in `demo.py`, we can use PyTorch's `torch.utils.data.DataLoader` class with VideoFrameDataset to take care of shuffling, batching, and more.
130-
To turn the lists of PIL images returned by VideoFrameDataset into tensors, the transform `video_dataset.imglist_totensor()` can be supplied
136+
To turn the lists of PIL images returned by VideoFrameDataset into tensors, the transform `video_dataset.ImglistToTensor()` can be supplied
131137
as the `transform` parameter to VideoFrameDataset. This turns a list of N PIL images into a batch of images/frames of shape `N x CHANNELS x HEIGHT x WIDTH`.
132-
We can further chain preprocessing and augmentation functions that act on batches of images onto the end of `imglist_totensor()`.
138+
We can further chain preprocessing and augmentation functions that act on batches of images onto the end of `ImglistToTensor()`.
133139

134140
As of `torchvision 0.8.0`, all torchvision transforms can now also operate on batches of images, and they apply deterministic or random transformations
135141
on the batch identically on all images of the batch. Therefore, any torchvision transform can be used here to apply video-uniform preprocessing and augmentation.
136142

137143
REMEMBER:
138-
Pytorch transforms are applied to individual dataset samples (in this case a video frame PIL list, or a frame tensor after `imglist_totensor()`) before
139-
batching. So, any transforms used here must expect its input to be a frame tensor of shape `FRAMES x CHANNELS x HEIGHT x WIDTH` or a list of PIL images if `imglist_totensor()` is not used.
140-
### 5. Conclusion
144+
Pytorch transforms are applied to individual dataset samples (in this case a video frame PIL list, or a frame tensor after `ImglistToTensor()`) before
145+
batching. So, any transforms used here must expect its input to be a frame tensor of shape `FRAMES x CHANNELS x HEIGHT x WIDTH` or a list of PIL images if `ImglistToTensor()` is not used.
146+
### 6. Conclusion
141147
A proper code-based explanation on how to use VideoFrameDataset for training is provided in `demo.py`
142148

143-
### 6. Upcoming Features
144-
- [ ] Add demo for sampling a single continous-frame clip from videos.
149+
### 7. Upcoming Features
150+
- [x] Add demo for sampling a single continous-frame clip from videos.
145151
- [ ] Add support for arbitrary labels that are more than just a single integer.
146152
- [ ] Add support for specifying START_FRAME and END_FRAME for a video instead of NUM_FRAMES.
147153

148-
### 7. Acknowledgements
154+
### 8. Acknowledgements
149155
We thank the authors of TSN for their [codebase](https://github.com/yjxiong/tsn-pytorch), from which we took VideoFrameDataset and adapted it
150156
for general use and compatibility.
151157
```

demo.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,43 @@
1-
from video_dataset import VideoFrameDataset, imglist_totensor
1+
from video_dataset import VideoFrameDataset, ImglistToTensor
22
from torchvision import transforms
33
import torch
44
import matplotlib.pyplot as plt
5+
from mpl_toolkits.axes_grid1 import ImageGrid
56
import os
67

8+
"""
9+
Ignore this function and look at "main" below.
10+
"""
11+
def plot_video(rows, cols, frame_list, plot_width, plot_height):
12+
fig = plt.figure(figsize=(plot_width, plot_height))
13+
grid = ImageGrid(fig, 111, # similar to subplot(111)
14+
nrows_ncols=(rows, cols), # creates 2x2 grid of axes
15+
axes_pad=0.3, # pad between axes in inch.
16+
)
17+
18+
for index, (ax, im) in enumerate(zip(grid, frame_list)):
19+
# Iterating over the grid returns the Axes.
20+
ax.imshow(im)
21+
ax.set_title(index)
22+
plt.show()
723

824
if __name__ == '__main__':
925
"""
1026
This demo uses the dummy dataset inside of the folder "demo_dataset".
1127
It is structured just like a real dataset would need to be structured.
28+
29+
TABLE OF CODE CONTENTS:
30+
1. Minimal demo without image transforms
31+
2. Minimal demo without sparse temporal sampling for single continuous frame clips, without image transforms
32+
3. Demo with image transforms
33+
4. Demo 3 continued with PyTorch dataloader
34+
1235
"""
1336
videos_root = os.path.join(os.getcwd(), 'demo_dataset')
1437
annotation_file = os.path.join(videos_root, 'annotations.txt')
1538

39+
40+
1641
""" DEMO 1 WITHOUT IMAGE TRANSFORMS """
1742
dataset = VideoFrameDataset(
1843
root_path=videos_root,
@@ -29,20 +54,42 @@
2954
frames = sample[0] # list of PIL images
3055
label = sample[1] # integer label
3156

32-
for image in frames:
33-
plt.imshow(image)
34-
plt.title(label)
35-
plt.show()
36-
plt.pause(1)
57+
plot_video(rows=1, cols=5, frame_list=frames, plot_width=15., plot_height=3.)
58+
59+
60+
61+
""" DEMO 2 SINGLE CONTINUOUS FRAME CLIP INSTEAD OF SAMPLED FRAMES, WITHOUT TRANSFORMS """
62+
# If you do not want to use sparse temporal sampling, and instead
63+
# want to just load N consecutive frames starting from a random
64+
# start index, this is easy. Simply set NUM_SEGMENTS=1 and
65+
# FRAMES_PER_SEGMENT=N. Each time a sample is loaded, N
66+
# frames will be loaded from a new random start index.
67+
dataset = VideoFrameDataset(
68+
root_path=videos_root,
69+
annotationfile_path=annotation_file,
70+
num_segments=1,
71+
frames_per_segment=9,
72+
imagefile_template='img_{:05d}.jpg',
73+
transform=None,
74+
random_shift=True,
75+
test_mode=False
76+
)
77+
78+
sample = dataset[3]
79+
frames = sample[0] # list of PIL images
80+
label = sample[1] # integer label
81+
82+
plot_video(rows=3, cols=3, frame_list=frames, plot_width=10., plot_height=5.)
83+
3784

3885

39-
""" DEMO 2 WITH TRANSFORMS """
86+
""" DEMO 3 WITH TRANSFORMS """
4087
# As of torchvision 0.8.0, torchvision transforms support batches of images
4188
# of size (BATCH x CHANNELS x HEIGHT x WIDTH) and apply deterministic or random
4289
# transformations on the batch identically on all images of the batch. Any torchvision
4390
# transform for image augmentation can thus also be used for video augmentation.
4491
preprocess = transforms.Compose([
45-
transforms.Lambda(imglist_totensor), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
92+
ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
4693
transforms.Resize(299), # image batch, resize smaller edge to 299
4794
transforms.CenterCrop(299), # image batch, center crop to square 299x299
4895
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
@@ -60,12 +107,10 @@
60107
)
61108

62109
sample = dataset[2]
63-
# tensor of shape (NUM_SEGMENTS*FRAMES_PER_SEGMENT) x CHANNELS x HEIGHT x WIDTH
64-
frame_tensor = sample[0]
65-
print('Video Tensor Size:', frame_tensor.size())
66-
# integer label
67-
label = sample[1]
110+
frame_tensor = sample[0] # tensor of shape (NUM_SEGMENTS*FRAMES_PER_SEGMENT) x CHANNELS x HEIGHT x WIDTH
111+
label = sample[1] # integer label
68112

113+
print('Video Tensor Size:', frame_tensor.size())
69114

70115
def denormalize(video_tensor):
71116
"""
@@ -82,14 +127,11 @@ def denormalize(video_tensor):
82127

83128

84129
frame_tensor = denormalize(frame_tensor)
85-
for image in frame_tensor:
86-
plt.imshow(image)
87-
plt.title(label)
88-
plt.show()
89-
plt.pause(1)
130+
plot_video(rows=1, cols=5, frame_list=frames, plot_width=15., plot_height=3.)
131+
90132

91133

92-
""" DEMO 2 CONTINUED: DATALOADER """
134+
""" DEMO 3 CONTINUED: DATALOADER """
93135
dataloader = torch.utils.data.DataLoader(
94136
dataset=dataset,
95137
batch_size=2,

video_dataset.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class VideoFrameDataset(torch.utils.data.Dataset):
4343
loads x RGB frames of a video (sparse temporal sampling) and evenly
4444
chooses those frames from start to end of the video, returning
4545
a list of x PIL images or ``FRAMES x CHANNELS x HEIGHT x WIDTH``
46-
tensors where FRAMES=x if the ``imglist_totensor()``
46+
tensors where FRAMES=x if the ``ImglistToTensor()``
4747
transform is used.
4848
4949
More specifically, the frame range [0,N] is divided into NUM_SEGMENTS
@@ -235,19 +235,21 @@ def _get(self, record, indices):
235235
def __len__(self):
236236
return len(self.video_list)
237237

238-
def imglist_totensor(img_list):
238+
class ImglistToTensor(torch.nn.Module):
239239
"""
240-
Converts each PIL image in a list to
241-
a torch Tensor and stacks them into
242-
a single tensor. Can be used as first transform
243-
for ``VideoFrameDataset``.
244-
To use this with torchvision.transforms.Compose, wrap this
245-
function in a torchvision lambda like
246-
this ``torchvision.transforms.Lambda(imglist_totensor)``.
247-
248-
Args:
249-
img_list: list of PIL images.
250-
Returns:
251-
tensor of size ``NUM_IMAGES x CHANNELS x HEIGHT x WIDTH``
240+
Converts a list of PIL images in the range [0,255] to a torch.FloatTensor
241+
of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1].
242+
Can be used as first transform for ``VideoFrameDataset``.
252243
"""
253-
return torch.stack([transforms.functional.to_tensor(pic) for pic in img_list])
244+
def forward(self, img_list):
245+
"""
246+
Converts each PIL image in a list to
247+
a torch Tensor and stacks them into
248+
a single tensor.
249+
250+
Args:
251+
img_list: list of PIL images.
252+
Returns:
253+
tensor of size ``NUM_IMAGES x CHANNELS x HEIGHT x WIDTH``
254+
"""
255+
return torch.stack([transforms.functional.to_tensor(pic) for pic in img_list])

0 commit comments

Comments
 (0)