Skip to content

Commit 70176d5

Browse files
committed
New demo for just loading single N-frame clips per video
1 parent d2b1040 commit 70176d5

File tree

3 files changed

+93
-42
lines changed

3 files changed

+93
-42
lines changed

README.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ for image in frames:
5454
- [1. Requirements](#1-requirements)
5555
- [2. Custom Dataset](#2-custom-dataset)
5656
- [3. Video Frame Sampling Method](#3-video-frame-sampling-method)
57-
- [4. Using VideoFrameDataset for Training](#4-using-videoframedataset-for-training)
58-
- [5. Conclusion](#5-conclusion)
59-
- [6. Acknowledgements](#6-acknowledgements)
57+
- [4. Alternate Video Frame Sampling Methods](#4-alternate-vide-frame-sampling-methods)
58+
- [5. Using VideoFrameDataset for Training](#5-using-videoframedataset-for-training)
59+
- [6. Conclusion](#6-conclusion)
60+
- [7. Acknowledgements](#7-acknowledgements)
6061

6162
### 1. Requirements
6263
```
@@ -118,20 +119,26 @@ When loading a video, only a number of its frames are loaded. They are chosen in
118119
1. The frame indices [1,N] are divided into NUM_SEGMENTS even segments. From each segment, FRAMES_PER_SEGMENT consecutive indices are chosen at random.
119120
This results in NUM_SEGMENTS*FRAMES_PER_SEGMENT chosen indices, whose frames are loaded as PIL images and put into a list and returned when calling
120121
`dataset[i]`.
122+
123+
### 4. Alternate Video Frame Sampling Methods
124+
If you do not want to use sparse temporal sampling and instead want to sample a single N-frame continuous
125+
clip from a video, this is possible. Set `NUM_SEGMENTS=1` and `FRAMES_PER_SEGMENT=N`. Because VideoFrameDataset
126+
will chose a random start index per segment and take `NUM_SEGMENTS` continuous frames from each sampled start
127+
index, this will result in a single N-frame continuous clip per video. An example of this is in `demo.py`.
121128

122-
### 4. Using VideoFrameDataset for training
129+
### 5. Using VideoFrameDataset for training
123130
As demonstrated in `demo.py`, we can use PyTorch's `torch.utils.data.DataLoader` class with VideoFrameDataset to take care of shuffling, batching, and more.
124-
To turn the lists of PIL images returned by VideoFrameDataset into tensors, the transform `video_dataset.imglist_totensor()` can be supplied
131+
To turn the lists of PIL images returned by VideoFrameDataset into tensors, the transform `video_dataset.ImglistToTensor()` can be supplied
125132
as the `transform` parameter to VideoFrameDataset. This turns a list of N PIL images into a batch of images/frames of shape `N x CHANNELS x HEIGHT x WIDTH`.
126-
We can further chain preprocessing and augmentation functions that act on batches of images onto the end of `imglist_totensor()`.
133+
We can further chain preprocessing and augmentation functions that act on batches of images onto the end of `ImglistToTensor()`.
127134

128135
As of `torchvision 0.8.0`, all torchvision transforms can now also operate on batches of images, and they apply deterministic or random transformations
129136
on the batch identically on all images of the batch. Therefore, any torchvision transform can be used here to apply video-uniform preprocessing and augmentation.
130137

131-
### 5. Conclusion
138+
### 6. Conclusion
132139
A proper code-based explanation on how to use VideoFrameDataset for training is provided in `demo.py`
133140

134-
### 6. Acknowledgements
141+
### 7. Acknowledgements
135142
We thank the authors of TSN for their [codebase](https://github.com/yjxiong/tsn-pytorch), from which we took VideoFrameDataset and adapted it.
136143
```
137144
@InProceedings{wang2016_TemporalSegmentNetworks,

demo.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,43 @@
1-
from video_dataset import VideoFrameDataset, imglist_totensor
1+
from video_dataset import VideoFrameDataset, ImglistToTensor
22
from torchvision import transforms
33
import torch
44
import matplotlib.pyplot as plt
5+
from mpl_toolkits.axes_grid1 import ImageGrid
56
import os
67

8+
"""
9+
Ignore this function and look at "main" below.
10+
"""
11+
def plot_video(rows, cols, frame_list, plot_width, plot_height):
12+
fig = plt.figure(figsize=(plot_width, plot_height))
13+
grid = ImageGrid(fig, 111, # similar to subplot(111)
14+
nrows_ncols=(rows, cols), # creates 2x2 grid of axes
15+
axes_pad=0.3, # pad between axes in inch.
16+
)
17+
18+
for index, (ax, im) in enumerate(zip(grid, frame_list)):
19+
# Iterating over the grid returns the Axes.
20+
ax.imshow(im)
21+
ax.set_title(index)
22+
plt.show()
723

824
if __name__ == '__main__':
925
"""
1026
This demo uses the dummy dataset inside of the folder "demo_dataset".
1127
It is structured just like a real dataset would need to be structured.
28+
29+
TABLE OF CODE CONTENTS:
30+
1. Minimal demo without image transforms
31+
2. Minimal demo without sparse temporal sampling for single continuous frame clips, without image transforms
32+
3. Demo with image transforms
33+
4. Demo 3 continued with PyTorch dataloader
34+
1235
"""
1336
videos_root = os.path.join(os.getcwd(), 'demo_dataset')
1437
annotation_file = os.path.join(videos_root, 'annotations.txt')
1538

39+
40+
1641
""" DEMO 1 WITHOUT IMAGE TRANSFORMS """
1742
dataset = VideoFrameDataset(
1843
root_path=videos_root,
@@ -29,20 +54,42 @@
2954
frames = sample[0] # list of PIL images
3055
label = sample[1] # integer label
3156

32-
for image in frames:
33-
plt.imshow(image)
34-
plt.title(label)
35-
plt.show()
36-
plt.pause(1)
57+
plot_video(rows=1, cols=5, frame_list=frames, plot_width=15., plot_height=3.)
58+
59+
60+
61+
""" DEMO 2 SINGLE CONTINUOUS FRAME CLIP INSTEAD OF SAMPLED FRAMES, WITHOUT TRANSFORMS """
62+
# If you do not want to use sparse temporal sampling, and instead
63+
# want to just load N consecutive frames starting from a random
64+
# start index, this is easy. Simply set NUM_SEGMENTS=1 and
65+
# FRAMES_PER_SEGMENT=N. Each time a sample is loaded, N
66+
# frames will be loaded from a new random start index.
67+
dataset = VideoFrameDataset(
68+
root_path=videos_root,
69+
annotationfile_path=annotation_file,
70+
num_segments=1,
71+
frames_per_segment=9,
72+
imagefile_template='img_{:05d}.jpg',
73+
transform=None,
74+
random_shift=True,
75+
test_mode=False
76+
)
77+
78+
sample = dataset[3]
79+
frames = sample[0] # list of PIL images
80+
label = sample[1] # integer label
81+
82+
plot_video(rows=3, cols=3, frame_list=frames, plot_width=10., plot_height=5.)
83+
3784

3885

39-
""" DEMO 2 WITH TRANSFORMS """
86+
""" DEMO 3 WITH TRANSFORMS """
4087
# As of torchvision 0.8.0, torchvision transforms support batches of images
4188
# of size (BATCH x CHANNELS x HEIGHT x WIDTH) and apply deterministic or random
4289
# transformations on the batch identically on all images of the batch. Any torchvision
4390
# transform for image augmentation can thus also be used for video augmentation.
4491
preprocess = transforms.Compose([
45-
transforms.Lambda(imglist_totensor), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
92+
ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
4693
transforms.Resize(299), # image batch, resize smaller edge to 299
4794
transforms.CenterCrop(299), # image batch, center crop to square 299x299
4895
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
@@ -60,12 +107,10 @@
60107
)
61108

62109
sample = dataset[2]
63-
# tensor of shape (NUM_SEGMENTS*FRAMES_PER_SEGMENT) x CHANNELS x HEIGHT x WIDTH
64-
frame_tensor = sample[0]
65-
print('Video Tensor Size:', frame_tensor.size())
66-
# integer label
67-
label = sample[1]
110+
frame_tensor = sample[0] # tensor of shape (NUM_SEGMENTS*FRAMES_PER_SEGMENT) x CHANNELS x HEIGHT x WIDTH
111+
label = sample[1] # integer label
68112

113+
print('Video Tensor Size:', frame_tensor.size())
69114

70115
def denormalize(video_tensor):
71116
"""
@@ -82,14 +127,11 @@ def denormalize(video_tensor):
82127

83128

84129
frame_tensor = denormalize(frame_tensor)
85-
for image in frame_tensor:
86-
plt.imshow(image)
87-
plt.title(label)
88-
plt.show()
89-
plt.pause(1)
130+
plot_video(rows=1, cols=5, frame_list=frames, plot_width=15., plot_height=3.)
131+
90132

91133

92-
""" DEMO 2 CONTINUED: DATALOADER """
134+
""" DEMO 3 CONTINUED: DATALOADER """
93135
dataloader = torch.utils.data.DataLoader(
94136
dataset=dataset,
95137
batch_size=2,

video_dataset.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class VideoFrameDataset(torch.utils.data.Dataset):
4343
loads x RGB frames of a video (sparse temporal sampling) and evenly
4444
chooses those frames from start to end of the video, returning
4545
a list of x PIL images or ``FRAMES x CHANNELS x HEIGHT x WIDTH``
46-
tensors where FRAMES=x if the ``imglist_totensor()``
46+
tensors where FRAMES=x if the ``ImglistToTensor()``
4747
transform is used.
4848
4949
More specifically, the frame range [0,N] is divided into NUM_SEGMENTS
@@ -235,19 +235,21 @@ def _get(self, record, indices):
235235
def __len__(self):
236236
return len(self.video_list)
237237

238-
def imglist_totensor(img_list):
238+
class ImglistToTensor(torch.nn.Module):
239239
"""
240-
Converts each PIL image in a list to
241-
a torch Tensor and stacks them into
242-
a single tensor. Can be used as first transform
243-
for ``VideoFrameDataset``.
244-
To use this with torchvision.transforms.Compose, wrap this
245-
function in a torchvision lambda like
246-
this ``torchvision.transforms.Lambda(imglist_totensor)``.
247-
248-
Args:
249-
img_list: list of PIL images.
250-
Returns:
251-
tensor of size ``NUM_IMAGES x CHANNELS x HEIGHT x WIDTH``
240+
Converts a list of PIL images in the range [0,255] to a torch.FloatTensor
241+
of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1].
242+
Can be used as first transform for ``VideoFrameDataset``.
252243
"""
253-
return torch.stack([transforms.functional.to_tensor(pic) for pic in img_list])
244+
def forward(self, img_list):
245+
"""
246+
Converts each PIL image in a list to
247+
a torch Tensor and stacks them into
248+
a single tensor.
249+
250+
Args:
251+
img_list: list of PIL images.
252+
Returns:
253+
tensor of size ``NUM_IMAGES x CHANNELS x HEIGHT x WIDTH``
254+
"""
255+
return torch.stack([transforms.functional.to_tensor(pic) for pic in img_list])

0 commit comments

Comments
 (0)