Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit aa30ed4

Browse files
committed
add watermark inpainter
1 parent 6ee7ea3 commit aa30ed4

File tree

4 files changed

+379
-4
lines changed

4 files changed

+379
-4
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
output/
2+
models/lama.ckpt
3+
.vscode/
4+
models/model_scope_diffusers/
15
text-to-video-ms-1.7b/
26

37
# Byte-compiled / optimized / DLL files

inference.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1+
import os
12
import argparse
3+
import warnings
24
from uuid import uuid4
35

46
import torch
57
from diffusers import DPMSolverMultistepScheduler, TextToVideoSDPipeline
8+
from einops import rearrange
69

710
from train import export_to_video, handle_memory_attention, load_primary_models
11+
from utils.lama import inpaint_watermark
812

913

1014
def initialize_pipeline(model, device="cuda", xformers=False, sdp=False):
11-
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
15+
with warnings.catch_warnings():
16+
warnings.simplefilter("ignore")
17+
18+
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
19+
1220
pipeline = TextToVideoSDPipeline.from_pretrained(
1321
pretrained_model_name_or_path=model,
1422
scheduler=scheduler,
@@ -69,15 +77,28 @@ def inference(
6977
parser.add_argument("-d", "--device", type=str, default="cuda")
7078
parser.add_argument("-x", "--xformers", action="store_true")
7179
parser.add_argument("-S", "--sdp", action="store_true")
80+
parser.add_argument("-rw", "--remove-watermark", action="store_true")
7281
args = vars(parser.parse_args())
7382

7483
output_dir = args.pop("output_dir")
7584
prompt = args.get("prompt")
7685
fps = args.pop("fps")
86+
remove_watermark = args.pop("remove_watermark")
7787

7888
videos = inference(**args)
7989

90+
os.makedirs(output_dir, exist_ok=True)
91+
8092
for video in videos:
81-
video = video.permute(1, 2, 3, 0).clamp(-1, 1).add(1).mul(127.5).byte().cpu().numpy()
82-
out_file = f"{output_dir}/{prompt} {str(uuid4())[:8]}.mp4"
83-
export_to_video(video, out_file, fps)
93+
94+
if remove_watermark:
95+
video = rearrange(video, "c f h w -> f c h w").add(1).div(2)
96+
video = inpaint_watermark(video)
97+
video = rearrange(video, "f c h w -> f h w c").clamp(0, 1).mul(255)
98+
99+
else:
100+
video = rearrange(video, "c f h w -> f h w c").clamp(-1, 1).add(1).mul(127.5)
101+
102+
video = video.byte().cpu().numpy()
103+
104+
export_to_video(video, f"{output_dir}/{prompt} {str(uuid4())[:8]}.mp4", fps)

0 commit comments

Comments
 (0)