Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 6ee7ea3

Browse files
committed
add initial inference script
1 parent 25697f9 commit 6ee7ea3

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

inference.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import argparse
2+
from uuid import uuid4
3+
4+
import torch
5+
from diffusers import DPMSolverMultistepScheduler, TextToVideoSDPipeline
6+
7+
from train import export_to_video, handle_memory_attention, load_primary_models
8+
9+
10+
def initialize_pipeline(model, device="cuda", xformers=False, sdp=False):
11+
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
12+
pipeline = TextToVideoSDPipeline.from_pretrained(
13+
pretrained_model_name_or_path=model,
14+
scheduler=scheduler,
15+
tokenizer=tokenizer,
16+
text_encoder=text_encoder.to(device=device, dtype=torch.half),
17+
vae=vae.to(device=device, dtype=torch.half),
18+
unet=unet.to(device=device, dtype=torch.half),
19+
)
20+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
21+
unet._set_gradient_checkpointing(value=False)
22+
handle_memory_attention(xformers, sdp, unet)
23+
vae.enable_slicing()
24+
return pipeline
25+
26+
27+
@torch.inference_mode()
28+
def inference(
29+
model,
30+
prompt,
31+
batch_size=1,
32+
num_frames=16,
33+
width=256,
34+
height=256,
35+
num_steps=50,
36+
guidance_scale=9,
37+
device="cuda",
38+
xformers=False,
39+
sdp=False,
40+
):
41+
with torch.autocast(device, dtype=torch.half):
42+
pipeline = initialize_pipeline(model, device, xformers, sdp)
43+
44+
videos = pipeline(
45+
prompt=[prompt] * batch_size,
46+
width=width,
47+
height=height,
48+
num_frames=num_frames,
49+
num_inference_steps=num_steps,
50+
guidance_scale=guidance_scale,
51+
output_type="pt",
52+
).frames
53+
54+
return videos
55+
56+
57+
if __name__ == "__main__":
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument("-m", "--model", type=str, required=True)
60+
parser.add_argument("-p", "--prompt", type=str, required=True)
61+
parser.add_argument("-o", "--output-dir", type=str, default="./output")
62+
parser.add_argument("-B", "--batch-size", type=int, default=1)
63+
parser.add_argument("-T", "--num-frames", type=int, default=16)
64+
parser.add_argument("-W", "--width", type=int, default=256)
65+
parser.add_argument("-H", "--height", type=int, default=256)
66+
parser.add_argument("-s", "--num-steps", type=int, default=50)
67+
parser.add_argument("-g", "--guidance-scale", type=float, default=9)
68+
parser.add_argument("-f", "--fps", type=int, default=8)
69+
parser.add_argument("-d", "--device", type=str, default="cuda")
70+
parser.add_argument("-x", "--xformers", action="store_true")
71+
parser.add_argument("-S", "--sdp", action="store_true")
72+
args = vars(parser.parse_args())
73+
74+
output_dir = args.pop("output_dir")
75+
prompt = args.get("prompt")
76+
fps = args.pop("fps")
77+
78+
videos = inference(**args)
79+
80+
for video in videos:
81+
video = video.permute(1, 2, 3, 0).clamp(-1, 1).add(1).mul(127.5).byte().cpu().numpy()
82+
out_file = f"{output_dir}/{prompt} {str(uuid4())[:8]}.mp4"
83+
export_to_video(video, out_file, fps)

0 commit comments

Comments
 (0)