Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit fb30834

Browse files
Merge pull request #25 from JCBrouwer/video-folder-dataset
Add video folder dataset
2 parents 2a48e2d + f66580f commit fb30834

File tree

4 files changed

+120
-4
lines changed

4 files changed

+120
-4
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
text-to-video-ms-1.7b/
2+
13
# Byte-compiled / optimized / DLL files
24
__pycache__/
35
*.py[cod]

configs/video_folder.yaml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
pretrained_model_path: "./text-to-video-ms-1.7b/"
2+
output_dir: "./output"
3+
train_text_encoder: False
4+
5+
train_data:
6+
type: folder
7+
path: "path/to/folder/of/videos/"
8+
n_sample_frames: 16
9+
width: 256
10+
height: 256
11+
fps: 24
12+
fallback_prompt: "" # used when a video doesn't have a corresponding .txt file with a prompt
13+
14+
validation_data:
15+
prompt: ""
16+
sample_preview: True
17+
num_frames: 48
18+
width: 256
19+
height: 256
20+
num_inference_steps: 50
21+
guidance_scale: 9
22+
23+
learning_rate: 1e-5
24+
adam_weight_decay: 1e-2
25+
train_batch_size: 1
26+
max_train_steps: 50000
27+
checkpointing_steps: 5000
28+
validation_steps: 500
29+
trainable_modules:
30+
- "attn1"
31+
- "attn2"
32+
- "attn3"
33+
seed: 1234
34+
mixed_precision: "fp16"
35+
use_8bit_adam: False # This seems to be incompatible at the moment.
36+
gradient_checkpointing: True
37+
enable_xformers_memory_efficient_attention: False
38+
enable_torch_2_attn: True

train.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Dict, Optional, Tuple
1111
from omegaconf import OmegaConf
1212

13+
import cv2
1314
import torch
1415
import torch.nn.functional as F
1516
import torch.utils.checkpoint
@@ -24,7 +25,7 @@
2425
from accelerate.logging import get_logger
2526
from accelerate.utils import set_seed
2627

27-
from .models.unet_3d_condition import UNet3DConditionModel
28+
from models.unet_3d_condition import UNet3DConditionModel
2829
from diffusers.models import AutoencoderKL
2930
from diffusers import DPMSolverMultistepScheduler, DDPMScheduler, TextToVideoSDPipeline
3031
from diffusers.optimization import get_scheduler
@@ -34,7 +35,7 @@
3435
from diffusers.models.attention import BasicTransformerBlock
3536

3637
from transformers import CLIPTextModel, CLIPTokenizer
37-
from utils.dataset import VideoDataset
38+
from utils.dataset import VideoDataset, VideoFolderDataset
3839
from einops import rearrange, repeat
3940

4041
already_printed_unet = False
@@ -60,6 +61,14 @@ def accelerate_set_verbose(accelerator):
6061
transformers.utils.logging.set_verbosity_error()
6162
diffusers.utils.logging.set_verbosity_error()
6263

64+
def export_to_video(video_frames, output_video_path, fps):
65+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
66+
h, w, _ = video_frames[0].shape
67+
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
68+
for i in range(len(video_frames)):
69+
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
70+
video_writer.write(img)
71+
6372
def create_output_folders(output_dir, config):
6473
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
6574
out_dir = os.path.join(output_dir, f"train_{now}")
@@ -306,7 +315,10 @@ def main(
306315
)
307316

308317
# Get the training dataset
309-
train_dataset = VideoDataset(**train_data, tokenizer=tokenizer)
318+
if train_data.pop("type", "regular") == "folder":
319+
train_dataset = VideoFolderDataset(**train_data, tokenizer=tokenizer)
320+
else:
321+
train_dataset = VideoDataset(**train_data, tokenizer=tokenizer)
310322

311323
# DataLoaders creation:
312324
train_dataloader = torch.utils.data.DataLoader(
@@ -513,7 +525,7 @@ def finetune_unet(batch, train_encoder=False):
513525
num_inference_steps=validation_data.num_inference_steps,
514526
guidance_scale=validation_data.guidance_scale
515527
).frames
516-
video_path = export_to_video(video_frames, out_file)
528+
export_to_video(video_frames, out_file, train_data.get('fps', 8))
517529

518530
del pipeline
519531
gc.collect()

utils/dataset.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torch.utils.data import Dataset
1111
from einops import rearrange
12+
from glob import glob
1213

1314
class VideoDataset(Dataset):
1415
def __init__(
@@ -199,3 +200,66 @@ def __getitem__(self, index):
199200
}
200201

201202
return example
203+
204+
class VideoFolderDataset(Dataset):
205+
def __init__(
206+
self,
207+
tokenizer=None,
208+
width: int = 256,
209+
height: int = 256,
210+
n_sample_frames: int = 16,
211+
fps: int = 8,
212+
path: str = "./data",
213+
fallback_prompt: str = "",
214+
**kwargs
215+
):
216+
self.tokenizer = tokenizer
217+
218+
self.fallback_prompt = fallback_prompt
219+
220+
self.video_files = glob(f"{path}/*.mp4")
221+
222+
self.width = width
223+
self.height = height
224+
225+
self.n_sample_frames = n_sample_frames
226+
self.fps = fps
227+
228+
def get_prompt_ids(self, prompt):
229+
return self.tokenizer(
230+
prompt,
231+
truncation=True,
232+
padding="max_length",
233+
max_length=self.tokenizer.model_max_length,
234+
return_tensors="pt",
235+
).input_ids
236+
237+
def __len__(self):
238+
return len(self.video_files)
239+
240+
def __getitem__(self, index):
241+
vr = decord.VideoReader(self.video_files[index], width=self.width, height=self.height)
242+
native_fps = vr.get_avg_fps()
243+
every_nth_frame = round(native_fps / self.fps)
244+
245+
effective_length = len(vr) // every_nth_frame
246+
247+
if effective_length < self.n_sample_frames:
248+
return self.__getitem__(random.randint(0, len(self.video_files) - 1))
249+
250+
effective_idx = random.randint(0, effective_length - self.n_sample_frames)
251+
252+
idxs = every_nth_frame * np.arange(effective_idx, effective_idx + self.n_sample_frames)
253+
254+
video = vr.get_batch(idxs)
255+
video = rearrange(video, "f h w c -> f c h w")
256+
257+
if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):
258+
with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:
259+
prompt = f.read()
260+
else:
261+
prompt = self.fallback_prompt
262+
263+
prompt_ids = self.get_prompt_ids(prompt)
264+
265+
return {"pixel_values": (video / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt}

0 commit comments

Comments
 (0)