Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 7d384d6

Browse files
Merge pull request #39 from JCBrouwer/inference/initial-script
Add inference script
2 parents daddaf1 + d663381 commit 7d384d6

File tree

4 files changed

+560
-0
lines changed

4 files changed

+560
-0
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
output/
2+
models/lama.ckpt
3+
.vscode/
4+
models/model_scope_diffusers/
15
text-to-video-ms-1.7b/
26

37
# Byte-compiled / optimized / DLL files

inference.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import argparse
2+
import os
3+
import warnings
4+
from pathlib import Path
5+
from uuid import uuid4
6+
7+
import torch
8+
from diffusers import DPMSolverMultistepScheduler, TextToVideoSDPipeline
9+
from einops import rearrange
10+
from torch.nn.functional import interpolate
11+
12+
from train import export_to_video, handle_memory_attention, load_primary_models
13+
from utils.lama import inpaint_watermark
14+
15+
16+
def initialize_pipeline(model, device="cuda", xformers=False, sdp=False):
17+
with warnings.catch_warnings():
18+
warnings.simplefilter("ignore")
19+
20+
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
21+
22+
pipeline = TextToVideoSDPipeline.from_pretrained(
23+
pretrained_model_name_or_path=model,
24+
scheduler=scheduler,
25+
tokenizer=tokenizer,
26+
text_encoder=text_encoder.to(device=device, dtype=torch.half),
27+
vae=vae.to(device=device, dtype=torch.half),
28+
unet=unet.to(device=device, dtype=torch.half),
29+
)
30+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
31+
unet._set_gradient_checkpointing(value=False)
32+
handle_memory_attention(xformers, sdp, unet)
33+
vae.enable_slicing()
34+
return pipeline
35+
36+
37+
def vid2vid(
38+
pipeline, init_video, init_weight, prompt, negative_prompt, height, width, num_inference_steps, guidance_scale
39+
):
40+
num_frames = init_video.shape[2]
41+
init_video = rearrange(init_video, "b c f h w -> (b f) c h w")
42+
latents = pipeline.vae.encode(init_video).latent_dist.sample()
43+
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=num_frames)
44+
latents = pipeline.scheduler.add_noise(
45+
original_samples=latents * 0.18215,
46+
noise=torch.randn_like(latents),
47+
timesteps=(torch.ones(latents.shape[0]) * pipeline.scheduler.num_train_timesteps * (1 - init_weight)).long(),
48+
)
49+
if latents.shape[0] != len(prompt):
50+
latents = latents.repeat(len(prompt), 1, 1, 1, 1)
51+
52+
do_classifier_free_guidance = guidance_scale > 1.0
53+
54+
prompt_embeds = pipeline._encode_prompt(
55+
prompt=prompt,
56+
negative_prompt=negative_prompt,
57+
device=latents.device,
58+
num_images_per_prompt=1,
59+
do_classifier_free_guidance=do_classifier_free_guidance,
60+
)
61+
62+
pipeline.scheduler.set_timesteps(num_inference_steps, device=latents.device)
63+
timesteps = pipeline.scheduler.timesteps
64+
timesteps = timesteps[round(init_weight * len(timesteps)) :]
65+
66+
with pipeline.progress_bar(total=len(timesteps)) as progress_bar:
67+
for t in timesteps:
68+
# expand the latents if we are doing classifier free guidance
69+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
70+
latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
71+
72+
# predict the noise residual
73+
noise_pred = pipeline.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
74+
75+
# perform guidance
76+
if do_classifier_free_guidance:
77+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
78+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
79+
80+
# reshape latents
81+
bsz, channel, frames, width, height = latents.shape
82+
latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
83+
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
84+
85+
# compute the previous noisy sample x_t -> x_t-1
86+
latents = pipeline.scheduler.step(noise_pred, t, latents).prev_sample
87+
88+
# reshape latents back
89+
latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
90+
91+
progress_bar.update()
92+
93+
video_tensor = pipeline.decode_latents(latents)
94+
95+
return video_tensor
96+
97+
98+
@torch.inference_mode()
99+
def inference(
100+
model,
101+
prompt,
102+
negative_prompt=None,
103+
batch_size=1,
104+
num_frames=16,
105+
width=256,
106+
height=256,
107+
num_steps=50,
108+
guidance_scale=9,
109+
init_video=None,
110+
init_weight=0.5,
111+
device="cuda",
112+
xformers=False,
113+
sdp=False,
114+
):
115+
with torch.autocast(device, dtype=torch.half):
116+
pipeline = initialize_pipeline(model, device, xformers, sdp)
117+
118+
prompt = [prompt] * batch_size
119+
negative_prompt = ([negative_prompt] * batch_size) if negative_prompt is not None else None
120+
121+
if init_video is not None:
122+
videos = vid2vid(
123+
pipeline=pipeline,
124+
init_video=init_video.to(device=device, dtype=torch.half),
125+
init_weight=init_weight,
126+
prompt=prompt,
127+
negative_prompt=negative_prompt,
128+
height=height,
129+
width=width,
130+
num_inference_steps=num_steps,
131+
guidance_scale=guidance_scale,
132+
)
133+
134+
else:
135+
videos = pipeline(
136+
prompt=prompt,
137+
negative_prompt=negative_prompt,
138+
num_frames=num_frames,
139+
height=height,
140+
width=width,
141+
num_inference_steps=num_steps,
142+
guidance_scale=guidance_scale,
143+
output_type="pt",
144+
).frames
145+
146+
return videos
147+
148+
149+
if __name__ == "__main__":
150+
import decord
151+
152+
decord.bridge.set_bridge("torch")
153+
154+
parser = argparse.ArgumentParser()
155+
parser.add_argument("-m", "--model", type=str, required=True)
156+
parser.add_argument("-p", "--prompt", type=str, required=True)
157+
parser.add_argument("-n", "--negative-prompt", type=str, default=None)
158+
parser.add_argument("-o", "--output-dir", type=str, default="./output")
159+
parser.add_argument("-B", "--batch-size", type=int, default=1)
160+
parser.add_argument("-T", "--num-frames", type=int, default=16)
161+
parser.add_argument("-W", "--width", type=int, default=256)
162+
parser.add_argument("-H", "--height", type=int, default=256)
163+
parser.add_argument("-s", "--num-steps", type=int, default=50)
164+
parser.add_argument("-g", "--guidance-scale", type=float, default=9)
165+
parser.add_argument("-i", "--init-video", type=str, default=None)
166+
parser.add_argument("-iw", "--init-weight", type=float, default=0.5)
167+
parser.add_argument("-f", "--fps", type=int, default=8)
168+
parser.add_argument("-d", "--device", type=str, default="cuda")
169+
parser.add_argument("-x", "--xformers", action="store_true")
170+
parser.add_argument("-S", "--sdp", action="store_true")
171+
parser.add_argument("-rw", "--remove-watermark", action="store_true")
172+
args = vars(parser.parse_args())
173+
174+
output_dir = args.pop("output_dir")
175+
prompt = args.get("prompt")
176+
fps = args.pop("fps")
177+
remove_watermark = args.pop("remove_watermark")
178+
init_video = args.pop("init_video")
179+
180+
if init_video is not None:
181+
vr = decord.VideoReader(init_video)
182+
init = rearrange(vr[:], "f h w c -> c f h w").div(127.5).sub(1).unsqueeze(0)
183+
init = interpolate(init, size=(args["num_frames"], args["height"], args["width"]), mode="trilinear")
184+
args["init_video"] = init
185+
186+
videos = inference(**args)
187+
188+
os.makedirs(output_dir, exist_ok=True)
189+
out_stem = f"{output_dir}/"
190+
if init_video is not None:
191+
out_stem += f"({Path(init_video).stem}) * {args['init_weight']} | "
192+
out_stem += f"{prompt}"
193+
194+
for video in videos:
195+
196+
if remove_watermark:
197+
video = rearrange(video, "c f h w -> f c h w").add(1).div(2)
198+
video = inpaint_watermark(video)
199+
video = rearrange(video, "f c h w -> f h w c").clamp(0, 1).mul(255)
200+
201+
else:
202+
video = rearrange(video, "c f h w -> f h w c").clamp(-1, 1).add(1).mul(127.5)
203+
204+
video = video.byte().cpu().numpy()
205+
206+
export_to_video(video, f"{out_stem} {str(uuid4())[:8]}.mp4", fps)

0 commit comments

Comments
 (0)