|
| 1 | +import os |
1 | 2 | import argparse |
| 3 | +import warnings |
2 | 4 | from uuid import uuid4 |
3 | 5 |
|
4 | 6 | import torch |
5 | 7 | from diffusers import DPMSolverMultistepScheduler, TextToVideoSDPipeline |
| 8 | +from einops import rearrange |
6 | 9 |
|
7 | 10 | from train import export_to_video, handle_memory_attention, load_primary_models |
| 11 | +from utils.lama import inpaint_watermark |
8 | 12 |
|
9 | 13 |
|
10 | 14 | def initialize_pipeline(model, device="cuda", xformers=False, sdp=False): |
11 | | - scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model) |
| 15 | + with warnings.catch_warnings(): |
| 16 | + warnings.simplefilter("ignore") |
| 17 | + |
| 18 | + scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model) |
| 19 | + |
12 | 20 | pipeline = TextToVideoSDPipeline.from_pretrained( |
13 | 21 | pretrained_model_name_or_path=model, |
14 | 22 | scheduler=scheduler, |
@@ -69,15 +77,28 @@ def inference( |
69 | 77 | parser.add_argument("-d", "--device", type=str, default="cuda") |
70 | 78 | parser.add_argument("-x", "--xformers", action="store_true") |
71 | 79 | parser.add_argument("-S", "--sdp", action="store_true") |
| 80 | + parser.add_argument("-rw", "--remove-watermark", action="store_true") |
72 | 81 | args = vars(parser.parse_args()) |
73 | 82 |
|
74 | 83 | output_dir = args.pop("output_dir") |
75 | 84 | prompt = args.get("prompt") |
76 | 85 | fps = args.pop("fps") |
| 86 | + remove_watermark = args.pop("remove_watermark") |
77 | 87 |
|
78 | 88 | videos = inference(**args) |
79 | 89 |
|
| 90 | + os.makedirs(output_dir, exist_ok=True) |
| 91 | + |
80 | 92 | for video in videos: |
81 | | - video = video.permute(1, 2, 3, 0).clamp(-1, 1).add(1).mul(127.5).byte().cpu().numpy() |
82 | | - out_file = f"{output_dir}/{prompt} {str(uuid4())[:8]}.mp4" |
83 | | - export_to_video(video, out_file, fps) |
| 93 | + |
| 94 | + if remove_watermark: |
| 95 | + video = rearrange(video, "c f h w -> f c h w").add(1).div(2) |
| 96 | + video = inpaint_watermark(video) |
| 97 | + video = rearrange(video, "f c h w -> f h w c").clamp(0, 1).mul(255) |
| 98 | + |
| 99 | + else: |
| 100 | + video = rearrange(video, "c f h w -> f h w c").clamp(-1, 1).add(1).mul(127.5) |
| 101 | + |
| 102 | + video = video.byte().cpu().numpy() |
| 103 | + |
| 104 | + export_to_video(video, f"{output_dir}/{prompt} {str(uuid4())[:8]}.mp4", fps) |
0 commit comments