Skip to content

[FEATURE] [WAN2.2] Add support to train LoRA for WanVacePipeline #12683

@SlimRG

Description

@SlimRG

I can't train LoRA (aka PEFT) model, using only diffusers.
I know about DiffSync, but it has different code and some broken copabilities.

I have dataset:
In each folder 4 files:
orig.mp4 - original video (1280*720)
cropped.mp4 - cropped video with gray frames
mask.png - mask for outpaint
prompt.txt - positive prompt for outpainting

For easy navigation - I created two CSV files (train.csv and validate.csv) as in DiffSync:

video,vace_video,vace_video_mask,prompt
video-000015-1285-1-1/orig.mp4,video-000015-1285-1-1/cropped.mp4,video-000015-1285-1-1/mask.png,"Outpaint. Dim, warm ambient lighting from out-of-focus lanterns illuminates a dark interior space with indistinct wooden structures; on the left, a figure partially obscures the view, while on the right, another person moves closer before exiting frame."
...

So, I want to create pipeline:

def init_pipeline() -> WanVACEPipeline:
    """
    Init WanVACEPipeline + VAE.
    """
    vae = AutoencoderKLWan.from_pretrained(
        MODEL_ID,
        subfolder="vae",
        torch_dtype=torch.float32,
    )

    pipe: WanVACEPipeline = WanVACEPipeline.from_pretrained(
        MODEL_ID,  # "linoyts/Wan2.2-VACE-Fun-14B-diffusers"
        vae=vae,
        torch_dtype=BASE_DTYPE, # bfloat16
    )

    pipe.scheduler = UniPCMultistepScheduler.from_config(
        pipe.scheduler.config, flow_shift=FLOW_SHIFT # 5.0
    )

    # flash-attention
    if hasattr(pipe.transformer, "set_attention_backend"):
        pipe.transformer.set_attention_backend("flash")

    extra_models_to_quantize = []

    if hasattr(pipe, "transformer_2") and pipe.transformer_2 is not None:
        if hasattr(pipe.transformer_2, "set_attention_backend"):
            pipe.transformer_2.set_attention_backend("flash")
        # We train first transformer, so can quantinize second
        extra_models_to_quantize.append("transformer_2")

    # Freeze all
    if hasattr(pipe, "components"):
        for name, module in pipe.components.items():
            if isinstance(module, nn.Module):
                for p in module.parameters():
                    p.requires_grad = False
    else:
        # fallback, if no components 
        for attr in ["vae", "transformer", "transformer_2", "text_encoder"]:
            if hasattr(pipe, attr):
                m = getattr(pipe, attr)
                if isinstance(m, nn.Module):
                    for p in m.parameters():
                        p.requires_grad = False

    pipe.to(DEVICE)

    # mmgp (offload + quant)
    if ENABLE_MMGP and HAS_MMGP:
        offload.profile(
            pipe,
            MMGP_PROFILE,
            quantizeTransformer=False,        # no first transformer quant
            extraModelsToQuantize=extra_models_to_quantize,
        )
    else:
        print("[mmgp] Не используется (либо не установлен, либо отключен).")

    # Patch __call__, for remove no_grad
    _patch_pipeline_call_for_grad(pipe)

    return pipe

LoRA config:

def inject_lora_into_transformer(
    transformer: nn.Module,
    r: int,
    alpha: int,
    dropout: float,
) -> nn.Module:
    replaced = 0

    def _recursive(module: nn.Module) -> None:
        nonlocal replaced
        for name, child in list(module.named_children()):
            if isinstance(child, nn.Linear):
                setattr(module, name, LoRALinear(child, r=r, alpha=alpha, dropout=dropout))
                replaced += 1
            else:
                _recursive(child)

    _recursive(transformer)

    if replaced == 0:
        print("[LoRA] WARNING: не найдено ни одного nn.Linear в transformer для LoRA.")
    else:
        print(f"[LoRA] Обёрнуто Linear-слоёв: {replaced}")

    transformer.to(DEVICE)

    return transformer


def get_lora_parameters(transformer: nn.Module) -> List[torch.nn.Parameter]:
    params: List[torch.nn.Parameter] = []
    for name, p in transformer.named_parameters():
        if "lora_A" in name or "lora_B" in name:
            params.append(p)

    if not params:
        raise RuntimeError("Не найдено trainable LoRA-параметров (lora_A/lora_B).")

    total = sum(p.numel() for p in params)
    print(f"[LoRA] Всего LoRA-параметров: {total / 1e6:.2f}M")

    return params

When I try to train, I have error in pipe (WanVacePipeline)

Traceback (most recent call last):
  File "d:\Experiments\Train_IN2OUT\8__trainLoRA.py", line 717, in <module>
    train()
    ~~~~~^^
  File "d:\Experiments\Train_IN2OUT\8__trainLoRA.py", line 629, in train
    pipe.train()
    ^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python313\Lib\site-packages\diffusers\configuration_utils.py", line 144, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'WanVACEPipeline' object has no attribute 'train'

So, can you add train() methods as in PyTorch or give an example - how to train LoRA?




My full code:

# -*- coding: utf-8 -*-
from __future__ import annotations

import os
import csv
import math
import random
import time
from pathlib import Path
from typing import Dict, List, Any, Tuple

import numpy as np
import av
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from diffusers import AutoencoderKLWan, WanVACEPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler

# mmgp (опционально)
try:
    from mmgp import offload, profile_type

    HAS_MMGP = True
except Exception:
    HAS_MMGP = False


# ==========================
#        КОНФИГ
# ==========================

# Датасет (как в твоём скрипте CSV)
DATA_ROOT = Path(r"D:/Experiments/Train_IN2OUT/4__dataset")
TRAIN_CSV_PATH = DATA_ROOT / "metadata_vace_train.csv"
VAL_CSV_PATH = DATA_ROOT / "metadata_vace_val.csv"

# Модель Wan2.2 VACE Fun
MODEL_ID = "linoyts/Wan2.2-VACE-Fun-14B-diffusers"

# Видео параметры (у тебя уже всё 1280x720x17)
FRAME_H = 720
FRAME_W = 1280
NUM_FRAMES = 17

# Обучение
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BASE_DTYPE = torch.bfloat16  # Wan 2.x любят bfloat16
USE_AMP = True               # autocast
BATCH_SIZE = 1               # 14B, так что лучше 1
NUM_EPOCHS = 2               # подбирай под себя
MAX_TRAIN_STEPS = None       # можно оставить None, тогда считается автоматически
NUM_INFERENCE_STEPS_TRAIN = 12  # шагов диффузии в обучении (меньше ради скорости)
NUM_INFERENCE_STEPS_VAL = 16    # валидация можно чуть медленнее

LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01

# LoRA гиперы
LORA_RANK = 32
LORA_ALPHA = 64
LORA_DROPOUT = 0.05

# Логирование и сохранение
LOG_DIR = Path(r"D:/Experiments/Train_IN2OUT/logs_wan_vace_lora")
OUTPUT_LORA_DIR = Path(r"D:/Experiments/Train_IN2OUT/wan_vace_lora_out")
SAVE_EVERY_STEPS = 2000
VAL_EVERY_STEPS = 1000
MAX_VAL_BATCHES = 64  # сколько валидационных сэмплов гонять каждый раз

# mmgp
ENABLE_MMGP = True  # ВКЛЮЧИ, если у тебя стоит mmgp и хочешь offload
MMGP_PROFILE = profile_type.HighRAM_HighVRAM if HAS_MMGP else None

# Разное
SEED = 42
FLOW_SHIFT = 5.0  # для UniPCMultistepScheduler
GUIDANCE_SCALE_TRAIN = 5.0
GUIDANCE_SCALE_VAL = 5.0

os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(OUTPUT_LORA_DIR, exist_ok=True)


# ==========================
#   Фиксация сидов
# ==========================

def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(SEED)


# ==========================
#   Dataset / DataLoader
# ==========================

class VACELoraSample:
    """
    Одна запись датасета:
        gt_path       -> video (orig.mp4)
        cond_path     -> vace_video (cropped.mp4)
        mask_path     -> vace_video_mask (mask.png)
        prompt        -> positive prompt
    """

    __slots__ = ("gt_path", "cond_path", "mask_path", "prompt")

    def __init__(self, gt_path: str, cond_path: str, mask_path: str, prompt: str) -> None:
        self.gt_path = gt_path
        self.cond_path = cond_path
        self.mask_path = mask_path
        self.prompt = prompt


class VACELoraDataset(Dataset):
    def __init__(self, csv_path: Path, root_dir: Path) -> None:
        super().__init__()
        self.root_dir = root_dir
        self.samples: List[VACELoraSample] = []

        with csv_path.open("r", encoding="utf-8", newline="") as f:
            reader = csv.DictReader(f)
            for row in reader:
                # колонки: video, vace_video, vace_video_mask, prompt
                video_rel = row["video"]
                cond_rel = row["vace_video"]
                mask_rel = row["vace_video_mask"]
                prompt = row["prompt"]

                gt_path = (root_dir / video_rel).as_posix()
                cond_path = (root_dir / cond_rel).as_posix()
                mask_path = (root_dir / mask_rel).as_posix()

                self.samples.append(VACELoraSample(gt_path, cond_path, mask_path, prompt))

        if not self.samples:
            raise RuntimeError(f"Пустой датасет: {csv_path}")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        s = self.samples[idx]
        # Для batch_size=1 можно не заморачиваться с коллатом
        return {
            "gt_path": s.gt_path,
            "cond_path": s.cond_path,
            "mask_path": s.mask_path,
            "prompt": s.prompt,
        }


def make_dataloaders() -> tuple[DataLoader, DataLoader]:
    train_ds = VACELoraDataset(TRAIN_CSV_PATH, DATA_ROOT)
    val_ds = VACELoraDataset(VAL_CSV_PATH, DATA_ROOT)

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        drop_last=False,
    )

    return train_loader, val_loader


# ==========================
#   Видео загрузка (PyAV)
# ==========================

def load_video_pyav(
    path: str,
    num_frames: int = NUM_FRAMES,
) -> Tuple[List[Image.Image], torch.Tensor]:
    """
    Считывает видео через PyAV.
    Возвращает:
        - список PIL.Image (RGB) длиной num_frames
        - тензор torch.Tensor shape [1, num_frames, 3, H, W] в диапазоне [0,1]
    """
    container = av.open(path)
    stream = container.streams.video[0]

    frames: List[Image.Image] = []
    for frame in container.decode(stream):
        img = frame.to_rgb().to_ndarray()  # H, W, 3, uint8
        pil = Image.fromarray(img)
        frames.append(pil)
    container.close()

    if len(frames) == 0:
        raise RuntimeError(f"Видео без кадров: {path}")

    # У нас 17 кадров, но на всякий случай нормализуем количество
    if len(frames) >= num_frames:
        # Центрируем
        start = (len(frames) - num_frames) // 2
        frames = frames[start: start + num_frames]
    else:
        # Дополняем последним кадром
        last = frames[-1]
        for _ in range(num_frames - len(frames)):
            frames.append(last.copy())

    # Собираем тензор [1, T, 3, H, W] в [0,1]
    arr = np.stack(
        [np.array(f, dtype=np.float32) / 255.0 for f in frames],
        axis=0,
    )  # [T, H, W, 3]
    arr = torch.from_numpy(arr).permute(0, 3, 1, 2).unsqueeze(0)  # [1, T, 3, H, W]

    return frames, arr


# ==========================
#   LoRA-обёртка для Linear
# ==========================

class LoRALinear(nn.Module):
    """
    Простая LoRA-надстройка для nn.Linear:
        y = W x + scale * B(A x)

    Базовый Linear замораживаем, тренируем только A/B.
    """

    def __init__(
        self,
        linear: nn.Linear,
        r: int = 32,
        alpha: int = 64,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        if r <= 0:
            raise ValueError("LoRA rank r должен быть > 0")

        self.linear = linear
        self.r = r
        self.lora_A = nn.Parameter(torch.zeros(r, linear.in_features))
        self.lora_B = nn.Parameter(torch.zeros(linear.out_features, r))
        self.scaling = alpha / r
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()

        # Инициализация
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

        # Заморозка базового Linear
        self.linear.weight.requires_grad = False
        if self.linear.bias is not None:
            self.linear.bias.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Основной путь
        result = self.linear(x)

        # LoRA-добавка
        lora_x = self.dropout(x)
        # [*, in] -> [*, r]
        lora_x = F.linear(lora_x, self.lora_A)
        # [*, r] -> [*, out]
        lora_x = F.linear(lora_x, self.lora_B) * self.scaling

        return result + lora_x


def inject_lora_into_transformer(
    transformer: nn.Module,
    r: int,
    alpha: int,
    dropout: float,
) -> nn.Module:
    """
    Рекурсивно оборачиваем ВСЕ nn.Linear в transformer в LoRALinear.
    """
    replaced = 0

    def _recursive(module: nn.Module) -> None:
        nonlocal replaced
        for name, child in list(module.named_children()):
            if isinstance(child, nn.Linear):
                setattr(module, name, LoRALinear(child, r=r, alpha=alpha, dropout=dropout))
                replaced += 1
            else:
                _recursive(child)

    _recursive(transformer)

    if replaced == 0:
        print("[LoRA] WARNING: не найдено ни одного nn.Linear в transformer для LoRA.")
    else:
        print(f"[LoRA] Обёрнуто Linear-слоёв: {replaced}")

    # Переносим transformer (включая LoRA-параметры) на нужное устройство
    transformer.to(DEVICE)

    return transformer


def get_lora_parameters(transformer: nn.Module) -> List[torch.nn.Parameter]:
    """
    Собираем только LoRA-параметры (для оптимизатора).
    """
    params: List[torch.nn.Parameter] = []
    for name, p in transformer.named_parameters():
        if "lora_A" in name or "lora_B" in name:
            params.append(p)

    if not params:
        raise RuntimeError("Не найдено trainable LoRA-параметров (lora_A/lora_B).")

    total = sum(p.numel() for p in params)
    print(f"[LoRA] Всего LoRA-параметров: {total / 1e6:.2f}M")

    return params


def save_lora_checkpoint(
    pipe: WanVACEPipeline,
    step: int,
    output_dir: Path,
) -> None:
    """
    Сохранение LoRA-адаптера:
      - сохраняем только state_dict LoRALinear-слоёв transformer.
    Позже можно будет написать маленький скрипт-конвертер в любой удобный формат.
    """
    ckpt_dir = output_dir / f"step_{step:07d}"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    # Собираем только LoRA-параметры transformer'а
    lora_state = {
        k: v.cpu()
        for k, v in pipe.transformer.state_dict().items()
        if "lora_A" in k or "lora_B" in k
    }

    ckpt_path = ckpt_dir / "lora_transformer.pt"
    print(f"[SAVE] Сохраняем LoRA в {ckpt_path} ...")
    torch.save(lora_state, ckpt_path)
    print("[SAVE] Готово.")


# ==========================
#   Инициализация пайплайна
# ==========================

def _patch_pipeline_call_for_grad(pipe: WanVACEPipeline) -> None:
    """
    Diffusers-пайплайны часто декорированы @torch.no_grad().
    Здесь аккуратно вытаскиваем оригинальный __call__ без no_grad, если возможно.
    """
    call_fn = pipe.__call__
    orig_call = getattr(call_fn, "__wrapped__", None)

    if orig_call is not None:
        # привязываем как bound-method к объекту pipe
        pipe._call_with_grad = orig_call.__get__(pipe, type(pipe))
        print("[PIPE] Найден __wrapped__; будем использовать pipe._call_with_grad() без no_grad.")
    else:
        # fallback — будет использоваться как есть (если вдруг без декоратора)
        pipe._call_with_grad = pipe.__call__
        print("[PIPE] __wrapped__ не найден; используем pipe.__call__ как есть.")


def init_pipeline() -> WanVACEPipeline:
    """
    Инициализация WanVACEPipeline + VAE.
    """
    print("Загрузка VAE (bfloat16)...")
    vae = AutoencoderKLWan.from_pretrained(
        MODEL_ID,
        subfolder="vae",
        torch_dtype=torch.bfloat16,
    )

    print("Загрузка WanVACEPipeline (bfloat16)...")
    pipe: WanVACEPipeline = WanVACEPipeline.from_pretrained(
        MODEL_ID,
        vae=vae,
        torch_dtype=BASE_DTYPE,
    )

    # UniPC с flow_shift как ты обычно делаешь
    pipe.scheduler = UniPCMultistepScheduler.from_config(
        pipe.scheduler.config, flow_shift=FLOW_SHIFT
    )

    # Включаем flash-attention, если поддерживается
    if hasattr(pipe.transformer, "set_attention_backend"):
        pipe.transformer.set_attention_backend("flash")

    extra_models_to_quantize = []

    if hasattr(pipe, "transformer_2") and pipe.transformer_2 is not None:
        if hasattr(pipe.transformer_2, "set_attention_backend"):
            pipe.transformer_2.set_attention_backend("flash")
        # второму трансформеру LoRA не навешиваем → его можно квантизовать/offload'ить
        extra_models_to_quantize.append("transformer_2")

    # Замораживаем ВСЕ модули пайплайна (vae, transformer, transformer_2, text_encoder, ...)
    if hasattr(pipe, "components"):
        for name, module in pipe.components.items():
            if isinstance(module, nn.Module):
                for p in module.parameters():
                    p.requires_grad = False
    else:
        # fallback, если components нет (на всякий случай)
        for attr in ["vae", "transformer", "transformer_2", "text_encoder"]:
            if hasattr(pipe, attr):
                m = getattr(pipe, attr)
                if isinstance(m, nn.Module):
                    for p in m.parameters():
                        p.requires_grad = False

    # Перенос на устройство (базовые веса)
    pipe.to(DEVICE)

    # mmgp (offload + quant только для второго трансформера)
    if ENABLE_MMGP and HAS_MMGP:
        print("[mmgp] Включаем offload.profile (quantize только transformer_2).")
        offload.profile(
            pipe,
            MMGP_PROFILE,
            quantizeTransformer=False,        # первый transformer НЕ квантуем
            extraModelsToQuantize=extra_models_to_quantize,
        )
    else:
        print("[mmgp] Не используется (либо не установлен, либо отключен).")

    # Патчим __call__, чтобы убрать no_grad
    _patch_pipeline_call_for_grad(pipe)

    return pipe


# ==========================
#   Вспомогательная генерация
# ==========================

def run_vace_forward(
    pipe: WanVACEPipeline,
    cond_frames: List[Image.Image],
    mask_img: Image.Image,
    prompt: str,
    num_inference_steps: int,
    guidance_scale: float,
) -> torch.Tensor:
    """
    Один прогон WanVACEPipeline.
    Возврат: тензор видео [1, T, 3, H, W] в диапазоне [0,1] (output_type="pt").
    ВАЖНО: используем pipe._call_with_grad, чтобы autograd работал.
    """
    # Маска одна на всё видео → размножаем по кадрам
    mask_list = [mask_img] * len(cond_frames)

    out = pipe._call_with_grad(
        video=cond_frames,
        mask=mask_list,
        prompt=prompt,
        height=FRAME_H,
        width=FRAME_W,
        num_frames=len(cond_frames),
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        output_type="pt",   # критично для градиента
        return_dict=True,
    )

    frames = out.frames  # обычно torch.Tensor или np.ndarray

    if isinstance(frames, torch.Tensor):
        video = frames  # ожидаем [B, T, C, H, W] или [T, C, H, W]
    else:
        # np.ndarray или список → в тензор
        frames_np = np.array(frames)
        video = torch.from_numpy(frames_np)

    if video.dim() == 4:
        # [T, C, H, W] → [1, T, C, H, W]
        video = video.unsqueeze(0)

    # Приводим к [1, T, 3, H, W]
    if video.shape[2] != 3:
        raise RuntimeError(f"Ожидали 3 канала, получили: {video.shape}")

    # Диапазон diffusers обычно [0,1] для pt-выхода; оставляем как есть.
    return video


# ==========================
#     Train / Val циклы
# ==========================

def validate(
    pipe: WanVACEPipeline,
    val_loader: DataLoader,
    writer: SummaryWriter,
    global_step: int,
) -> float:
    pipe.eval()
    total_loss = 0.0
    count = 0

    with torch.no_grad():  # валидация без градиента
        for batch_idx, batch in enumerate(val_loader):
            if batch_idx >= MAX_VAL_BATCHES:
                break

            gt_path = batch["gt_path"]
            cond_path = batch["cond_path"]
            mask_path = batch["mask_path"]
            prompt = batch["prompt"]

            # Для batch_size=1 превращаем в строки
            if isinstance(gt_path, list):
                gt_path = gt_path[0]
                cond_path = cond_path[0]
                mask_path = mask_path[0]
                prompt = prompt[0]

            # Загружаем GT и condition
            cond_frames, _ = load_video_pyav(cond_path, num_frames=NUM_FRAMES)
            _, gt_tensor = load_video_pyav(gt_path, num_frames=NUM_FRAMES)
            gt_tensor = gt_tensor.to(DEVICE, dtype=torch.float32)

            mask_img = Image.open(mask_path).convert("L")

            with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=BASE_DTYPE):
                pred_video = run_vace_forward(
                    pipe,
                    cond_frames,
                    mask_img,
                    prompt,
                    num_inference_steps=NUM_INFERENCE_STEPS_VAL,
                    guidance_scale=GUIDANCE_SCALE_VAL,
                )
                pred_video = pred_video.to(DEVICE, dtype=torch.float32)

                # Выравниваем T
                min_frames = min(pred_video.shape[1], gt_tensor.shape[1])
                pred_use = pred_video[:, :min_frames]
                gt_use = gt_tensor[:, :min_frames]

                loss = F.l1_loss(pred_use, gt_use)

            total_loss += loss.item()
            count += 1

            # Для первых пар трекаем картинку в TensorBoard
            if batch_idx == 0:
                # frame 0 (gt vs pred)
                gt_frame = gt_use[0, 0]  # [3,H,W]
                pred_frame = pred_use[0, 0]
                # Собираем грид 1x2
                grid = torch.stack([gt_frame, pred_frame], dim=0)
                writer.add_images("val/gt_pred_frame0", grid, global_step)

    avg_loss = total_loss / max(1, count)
    writer.add_scalar("val/loss", avg_loss, global_step)
    print(f"[VAL] step={global_step} loss={avg_loss:.6f} (по {count} батчам)")
    return avg_loss


def train() -> None:
    train_loader, val_loader = make_dataloaders()

    pipe = init_pipeline()

    # Вешаем LoRA на первый transformer
    print("[LoRA] Инжектим LoRA в pipe.transformer ...")
    pipe.transformer = inject_lora_into_transformer(
        pipe.transformer,
        r=LORA_RANK,
        alpha=LORA_ALPHA,
        dropout=LORA_DROPOUT,
    )

    # Собираем trainable LoRA-параметры
    lora_params = get_lora_parameters(pipe.transformer)

    optimizer = torch.optim.AdamW(
        lora_params,
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        betas=(0.9, 0.999),
        eps=1e-8,
    )

    writer = SummaryWriter(LOG_DIR)

    # Подсчёт максимального количества шагов, если не задано
    if MAX_TRAIN_STEPS is not None:
        max_steps = MAX_TRAIN_STEPS
    else:
        max_steps = NUM_EPOCHS * math.ceil(len(train_loader.dataset) / BATCH_SIZE)

    print(f"Старт обучения: max_steps={max_steps}, epochs={NUM_EPOCHS}")

    global_step = 0
    start_time = time.time()

    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

    for epoch in range(NUM_EPOCHS):
        pipe.train()
        epoch_loss = 0.0
        for batch_idx, batch in enumerate(train_loader):
            if global_step >= max_steps:
                break

            gt_path = batch["gt_path"]
            cond_path = batch["cond_path"]
            mask_path = batch["mask_path"]
            prompt = batch["prompt"]

            # batch_size=1 → распаковываем
            if isinstance(gt_path, list):
                gt_path = gt_path[0]
                cond_path = cond_path[0]
                mask_path = mask_path[0]
                prompt = prompt[0]

            cond_frames, _ = load_video_pyav(cond_path, num_frames=NUM_FRAMES)
            _, gt_tensor = load_video_pyav(gt_path, num_frames=NUM_FRAMES)
            gt_tensor = gt_tensor.to(DEVICE, dtype=torch.float32)

            mask_img = Image.open(mask_path).convert("L")

            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=BASE_DTYPE):
                pred_video = run_vace_forward(
                    pipe,
                    cond_frames,
                    mask_img,
                    prompt,
                    num_inference_steps=NUM_INFERENCE_STEPS_TRAIN,
                    guidance_scale=GUIDANCE_SCALE_TRAIN,
                )
                pred_video = pred_video.to(DEVICE, dtype=torch.float32)

                # Выравниваем по числу кадров
                min_frames = min(pred_video.shape[1], gt_tensor.shape[1])
                pred_use = pred_video[:, :min_frames]
                gt_use = gt_tensor[:, :min_frames]

                loss = F.l1_loss(pred_use, gt_use)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            global_step += 1
            epoch_loss += loss.item()

            if global_step % 10 == 0:
                avg_loss = epoch_loss / (batch_idx + 1)
                elapsed = time.time() - start_time
                print(
                    f"[TRAIN] epoch={epoch+1} step={global_step}/{max_steps} "
                    f"batch_loss={loss.item():.6f} avg_epoch_loss={avg_loss:.6f} "
                    f"time={elapsed/60:.1f}min"
                )
                writer.add_scalar("train/loss", loss.item(), global_step)

            # Периодическая валидация
            if global_step % VAL_EVERY_STEPS == 0:
                _ = validate(pipe, val_loader, writer, global_step)
                pipe.train()

            # Периодическое сохранение LoRA
            if global_step % SAVE_EVERY_STEPS == 0:
                save_lora_checkpoint(pipe, global_step, OUTPUT_LORA_DIR)

            if global_step >= max_steps:
                break

        # конец эпохи
        avg_epoch_loss = epoch_loss / max(1, (batch_idx + 1))
        writer.add_scalar("train/epoch_loss", avg_epoch_loss, epoch + 1)
        print(f"=== Epoch {epoch+1}/{NUM_EPOCHS} завершена. avg_loss={avg_epoch_loss:.6f} ===")

        # Валидация в конце эпохи
        _ = validate(pipe, val_loader, writer, global_step)

    # Финальное сохранение
    save_lora_checkpoint(pipe, global_step, OUTPUT_LORA_DIR)
    writer.close()
    print("Обучение завершено.")


if __name__ == "__main__":
    train()

@DN6 @a-r-r-o-w

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions