1- import os
21import argparse
2+ import os
33import warnings
4+ from pathlib import Path
45from uuid import uuid4
56
67import torch
78from diffusers import DPMSolverMultistepScheduler , TextToVideoSDPipeline
89from einops import rearrange
10+ from torch .nn .functional import interpolate
911
1012from train import export_to_video , handle_memory_attention , load_primary_models
1113from utils .lama import inpaint_watermark
@@ -32,47 +34,136 @@ def initialize_pipeline(model, device="cuda", xformers=False, sdp=False):
3234 return pipeline
3335
3436
37+ def vid2vid (
38+ pipeline , init_video , init_weight , prompt , negative_prompt , height , width , num_inference_steps , guidance_scale
39+ ):
40+ num_frames = init_video .shape [2 ]
41+ init_video = rearrange (init_video , "b c f h w -> (b f) c h w" )
42+ latents = pipeline .vae .encode (init_video ).latent_dist .sample ()
43+ latents = rearrange (latents , "(b f) c h w -> b c f h w" , f = num_frames )
44+ latents = pipeline .scheduler .add_noise (
45+ original_samples = latents * 0.18215 ,
46+ noise = torch .randn_like (latents ),
47+ timesteps = (torch .ones (latents .shape [0 ]) * pipeline .scheduler .num_train_timesteps * (1 - init_weight )).long (),
48+ )
49+ if latents .shape [0 ] != len (prompt ):
50+ latents = latents .repeat (len (prompt ), 1 , 1 , 1 , 1 )
51+
52+ do_classifier_free_guidance = guidance_scale > 1.0
53+
54+ prompt_embeds = pipeline ._encode_prompt (
55+ prompt = prompt ,
56+ negative_prompt = negative_prompt ,
57+ device = latents .device ,
58+ num_images_per_prompt = 1 ,
59+ do_classifier_free_guidance = do_classifier_free_guidance ,
60+ )
61+
62+ pipeline .scheduler .set_timesteps (num_inference_steps , device = latents .device )
63+ timesteps = pipeline .scheduler .timesteps
64+ timesteps = timesteps [round (init_weight * len (timesteps )) :]
65+
66+ with pipeline .progress_bar (total = len (timesteps )) as progress_bar :
67+ for t in timesteps :
68+ # expand the latents if we are doing classifier free guidance
69+ latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
70+ latent_model_input = pipeline .scheduler .scale_model_input (latent_model_input , t )
71+
72+ # predict the noise residual
73+ noise_pred = pipeline .unet (latent_model_input , t , encoder_hidden_states = prompt_embeds ).sample
74+
75+ # perform guidance
76+ if do_classifier_free_guidance :
77+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
78+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
79+
80+ # reshape latents
81+ bsz , channel , frames , width , height = latents .shape
82+ latents = latents .permute (0 , 2 , 1 , 3 , 4 ).reshape (bsz * frames , channel , width , height )
83+ noise_pred = noise_pred .permute (0 , 2 , 1 , 3 , 4 ).reshape (bsz * frames , channel , width , height )
84+
85+ # compute the previous noisy sample x_t -> x_t-1
86+ latents = pipeline .scheduler .step (noise_pred , t , latents ).prev_sample
87+
88+ # reshape latents back
89+ latents = latents [None , :].reshape (bsz , frames , channel , width , height ).permute (0 , 2 , 1 , 3 , 4 )
90+
91+ progress_bar .update ()
92+
93+ video_tensor = pipeline .decode_latents (latents )
94+
95+ return video_tensor
96+
97+
3598@torch .inference_mode ()
3699def inference (
37100 model ,
38101 prompt ,
102+ negative_prompt = None ,
39103 batch_size = 1 ,
40104 num_frames = 16 ,
41105 width = 256 ,
42106 height = 256 ,
43107 num_steps = 50 ,
44108 guidance_scale = 9 ,
109+ init_video = None ,
110+ init_weight = 0.5 ,
45111 device = "cuda" ,
46112 xformers = False ,
47113 sdp = False ,
48114):
49115 with torch .autocast (device , dtype = torch .half ):
50116 pipeline = initialize_pipeline (model , device , xformers , sdp )
51117
52- videos = pipeline (
53- prompt = [prompt ] * batch_size ,
54- width = width ,
55- height = height ,
56- num_frames = num_frames ,
57- num_inference_steps = num_steps ,
58- guidance_scale = guidance_scale ,
59- output_type = "pt" ,
60- ).frames
118+ prompt = [prompt ] * batch_size
119+ negative_prompt = ([negative_prompt ] * batch_size ) if negative_prompt is not None else None
120+
121+ if init_video is not None :
122+ videos = vid2vid (
123+ pipeline = pipeline ,
124+ init_video = init_video .to (device = device , dtype = torch .half ),
125+ init_weight = init_weight ,
126+ prompt = prompt ,
127+ negative_prompt = negative_prompt ,
128+ height = height ,
129+ width = width ,
130+ num_inference_steps = num_steps ,
131+ guidance_scale = guidance_scale ,
132+ )
133+
134+ else :
135+ videos = pipeline (
136+ prompt = prompt ,
137+ negative_prompt = negative_prompt ,
138+ num_frames = num_frames ,
139+ height = height ,
140+ width = width ,
141+ num_inference_steps = num_steps ,
142+ guidance_scale = guidance_scale ,
143+ output_type = "pt" ,
144+ ).frames
61145
62146 return videos
63147
64148
65149if __name__ == "__main__" :
150+ import decord
151+
152+ decord .bridge .set_bridge ("torch" )
153+
66154 parser = argparse .ArgumentParser ()
67155 parser .add_argument ("-m" , "--model" , type = str , required = True )
68156 parser .add_argument ("-p" , "--prompt" , type = str , required = True )
157+ parser .add_argument ("-n" , "--negative-prompt" , type = str , default = None )
69158 parser .add_argument ("-o" , "--output-dir" , type = str , default = "./output" )
70159 parser .add_argument ("-B" , "--batch-size" , type = int , default = 1 )
71160 parser .add_argument ("-T" , "--num-frames" , type = int , default = 16 )
72161 parser .add_argument ("-W" , "--width" , type = int , default = 256 )
73162 parser .add_argument ("-H" , "--height" , type = int , default = 256 )
74163 parser .add_argument ("-s" , "--num-steps" , type = int , default = 50 )
75164 parser .add_argument ("-g" , "--guidance-scale" , type = float , default = 9 )
165+ parser .add_argument ("-i" , "--init-video" , type = str , default = None )
166+ parser .add_argument ("-iw" , "--init-weight" , type = float , default = 0.5 )
76167 parser .add_argument ("-f" , "--fps" , type = int , default = 8 )
77168 parser .add_argument ("-d" , "--device" , type = str , default = "cuda" )
78169 parser .add_argument ("-x" , "--xformers" , action = "store_true" )
@@ -84,10 +175,21 @@ def inference(
84175 prompt = args .get ("prompt" )
85176 fps = args .pop ("fps" )
86177 remove_watermark = args .pop ("remove_watermark" )
178+ init_video = args .pop ("init_video" )
179+
180+ if init_video is not None :
181+ vr = decord .VideoReader (init_video )
182+ init = rearrange (vr [:], "f h w c -> c f h w" ).div (127.5 ).sub (1 ).unsqueeze (0 )
183+ init = interpolate (init , size = (args ["num_frames" ], args ["height" ], args ["width" ]), mode = "trilinear" )
184+ args ["init_video" ] = init
87185
88186 videos = inference (** args )
89187
90188 os .makedirs (output_dir , exist_ok = True )
189+ out_stem = f"{ output_dir } /"
190+ if init_video is not None :
191+ out_stem += f"({ Path (init_video ).stem } ) * { args ['init_weight' ]} | "
192+ out_stem += f"{ prompt } "
91193
92194 for video in videos :
93195
@@ -101,4 +203,4 @@ def inference(
101203
102204 video = video .byte ().cpu ().numpy ()
103205
104- export_to_video (video , f"{ output_dir } / { prompt } { str (uuid4 ())[:8 ]} .mp4" , fps )
206+ export_to_video (video , f"{ out_stem } { str (uuid4 ())[:8 ]} .mp4" , fps )
0 commit comments