Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions diffsynth/diffusion/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,47 @@


def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sp_group,
)

max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))

timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)

noise = torch.randn_like(inputs["input_latents"])
input_shape = inputs["input_latents"].shape
input_dtype = inputs["input_latents"].dtype
input_device = inputs["input_latents"].device

# Random noise and timestep IDs are generated by SP local rank 0,
# and broadcast to all other ranks in the SP group.
# Alternative implementation:
# use consistent random seeds for noise and timestep_id generation within each sp group’s ranks.
if pipe.sp_size > 1:
sp_group=get_sp_group()
if get_sequence_parallel_rank() == 0:
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
else:
timestep = torch.zeros(1, dtype=pipe.torch_dtype, device=pipe.device)
sp_group.broadcast(timestep, src=0)

if get_sequence_parallel_rank() == 0:
noise = torch.randn(input_shape, dtype=input_dtype, device=input_device)
else:
noise = torch.zeros(input_shape, dtype=input_dtype, device=input_device)
sp_group.broadcast(noise, src=0)
else:
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
noise = torch.randn(input_shape, dtype=input_dtype, device=input_device)

inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)

models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)

loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
Expand Down
1 change: 1 addition & 0 deletions diffsynth/diffusion/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def add_training_config(parser: argparse.ArgumentParser):
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
parser.add_argument("--sp_size", type=int, default=1, help="Sequence size. sp size > 1 will init usp for sequence parallal.")
return parser

def add_output_config(parser: argparse.ArgumentParser):
Expand Down
113 changes: 107 additions & 6 deletions diffsynth/diffusion/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,56 @@
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
import time
import logging

logger = logging.getLogger(__name__)

def build_dataloader(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
num_workers: int = 1,
sp_size: int = 1,
):
if sp_size > 1:
# When using sequence parallel, it is necessary to ensure that when the sampler uses iter to
# fetch data from the dataloader, each rank within the same SP group obtains the same sample.
if accelerator is not None:
world_size = accelerator.num_processes
rank = accelerator.process_index
else:
raise ValueError(f"Accelerator is None.")

dp_size = world_size // sp_size
if dp_size * sp_size != world_size:
raise ValueError(
f"world_size={world_size}, sp_size={sp_size}, world_size should be diviaible by sp_size"
)

dp_rank = rank // sp_size
sp_rank = rank % sp_size
logger.info(f"accelerator.processid={rank}, accelerator.num_processes={world_size}, "
f"sp_size={sp_size}, dp_size={dp_size}, dp_rank={dp_rank}")
else:
if accelerator is not None:
dp_size = accelerator.num_processes
dp_rank = accelerator.process_index
else:
raise ValueError(f"Accelerator is None.")
logger.info(f"dp_size={dp_size}, dp_rank={dp_rank}")

sampler = torch.utils.data.DistributedSampler(dataset=dataset, num_replicas=dp_size, rank=dp_rank)

dataloader_kwargs = dict(
dataset=dataset,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
collate_fn=lambda x: x[0],
)
dataloader = torch.utils.data.DataLoader(**dataloader_kwargs)

return dataloader

def launch_training_task(
accelerator: Accelerator,
Expand All @@ -15,6 +64,7 @@ def launch_training_task(
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
sp_size: int = 1,
args = None,
):
if args is not None:
Expand All @@ -23,29 +73,80 @@ def launch_training_task(
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs

sp_size = args.sp_size

train_step = 0

optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)

model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
dataloader = build_dataloader(accelerator, dataset, num_workers, sp_size)
model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)

for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
progress = tqdm(
dataloader,
disable=not accelerator.is_main_process,
desc=f"Epoch {epoch_id + 1}/{num_epochs}",
)

for data in progress:
logger.info(f"[train] id{accelerator.process_index}, step{train_step}, prompt: {data['prompt']}")

iter_start = time.time()
timing = {}
if data is None:
continue

with accelerator.accumulate(model):
optimizer.zero_grad()

forward_start = time.time()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
torch.cuda.synchronize()
timing["forward"] = time.time() - forward_start

backward_start = time.time()
accelerator.backward(loss)
torch.cuda.synchronize()
timing["backward"] = time.time() - backward_start

optim_start = time.time()
optimizer.step()
torch.cuda.synchronize()
timing["optimizer"] = time.time() - optim_start

model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()

torch.cuda.synchronize()
iter_end = time.time()
timing["step"] = iter_end - iter_start
train_step += 1

if accelerator.is_main_process:
def format_time(key: str) -> str:
value = timing.get(key, 0.0)
return f"{value:.3f}s"

postfix_dict = {
"loss": f"{loss.item():.5f}",
"lr": f"{optimizer.param_groups[0]['lr']:.5e}",
"step/t": format_time("step"),
"fwd/t": format_time("forward"),
"bwd/t": format_time("backward"),
"opt/t": format_time("optimizer"),
}
progress.set_postfix(postfix_dict)
log_msg = f"[Step {train_step:6d}] | " + " | ".join(f"{k}: {v}" for k, v in postfix_dict.items())
progress.write(log_msg)

if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)

model_logger.on_training_end(accelerator, model, save_steps)

def launch_data_process_task(
accelerator: Accelerator,
Expand Down
27 changes: 21 additions & 6 deletions diffsynth/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class WanVideoPipeline(BasePipeline):

def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16, sp_size=1):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
Expand Down Expand Up @@ -80,6 +80,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
WanVideoPostUnit_S2V(),
]
self.model_fn = model_fn_wan_video
self.sp_size = sp_size


def enable_usp(self):
Expand All @@ -92,7 +93,6 @@ def enable_usp(self):
for block in self.dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True


Expand All @@ -105,6 +105,7 @@ def from_pretrained(
audio_processor_config: ModelConfig = None,
redirect_common_files: bool = True,
use_usp: bool = False,
sp_size: int = 1,
vram_limit: float = None,
):
# Redirect model path
Expand All @@ -122,16 +123,17 @@ def from_pretrained(
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]

if use_usp:
from ..utils.xfuser import initialize_usp
initialize_usp(device)
initialize_usp(device, sp_size)
import torch.distributed as dist
from ..core.device.npu_compatible_device import get_device_name
if dist.is_available() and dist.is_initialized():
device = get_device_name()

# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype, sp_size=sp_size)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)

# Fetch models
Expand Down Expand Up @@ -1379,7 +1381,20 @@ def custom_forward(*inputs):
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
if tea_cache is not None:
tea_cache.store(x)


'''
The all_gather interface in xDit utilizes torch’s all_gather_into_tensor interface. As of the torch 2.9 release version, this interface still does not provide a backward method and cannot support automatic autograd. The commit in the torch community (https://github.com/pytorch/pytorch/pull/168140) has not yet been merged. Therefore, a simple replacement of all_gather_into_tensor in xdit with torch.distributed.nn.functional.all_gather is applied here to enable autograd support.

def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor:
...
# All-gather.
# torch.distributed.all_gather_into_tensor(
# output_tensor, input_, group=self.device_group
# )
gathered_list = torch.distributed.nn.functional.all_gather(input_, group=self.device_group)
output_tensor = torch.cat(gathered_list, dim=0)
...
'''
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
Expand Down
42 changes: 34 additions & 8 deletions diffsynth/utils/xfuser/xdit_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,39 @@
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ...core.device import parse_nccl_backend, parse_device_type
import logging

logger = logging.getLogger(__name__)

def initialize_usp(device_type):
def initialize_usp(device_type, sp_size):
import torch.distributed as dist
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")
from xfuser.core.distributed import (
initialize_model_parallel,
init_distributed_environment,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_data_parallel_world_size,
get_data_parallel_rank,
)

if not dist.is_initialized():
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")

init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())

sp_degree = sp_size
dp_degree = int(dist.get_world_size() / sp_degree)
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
data_parallel_degree=dp_degree,
sequence_parallel_degree=sp_degree,
ring_degree=1,
ulysses_degree=dist.get_world_size(),
ulysses_degree=sp_degree,
)
getattr(torch, device_type).set_device(dist.get_rank())

logger.info(f"[init usp] rank: {dist.get_rank()}, world_size: {dist.get_world_size()}, "
f"sp world size: {get_sequence_parallel_world_size()}, "
f"sp rank: {get_sequence_parallel_rank()}, "
f"dp world size: {get_data_parallel_world_size()}, "
f"dp rank: {get_data_parallel_rank()}")

def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
Expand Down Expand Up @@ -133,6 +152,13 @@ def usp_attn_forward(self, x, freqs):
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)

'''
Refer to commit https://github.com/xdit-project/xDiT/pull/598 for the xfuser backward error.
xFuserRingFlashAttnFunc has 17 inputs (including ctx), but it inherits the backward() method from RingFlashAttnFunc which only returns 16 values (3 gradients + 13 Nones)!

The Math
Parent class (RingFlashAttnFunc): 14 forward inputs → backward returns 3 gradients + 11 Nones = 14 returns xFuser class (xFuserRingFlashAttnFunc): 17 forward inputs → backward should return 3 gradients + 14 Nones = 17 returns Actual: backward only returns 14 returns (inherited from parent without override) Error: PyTorch expects 17 gradients but gets only 14 → expected 17, got 13 (13 = 14 - 1 for ctx)
'''
x = xFuserLongContextAttention()(
None,
query=q,
Expand All @@ -143,4 +169,4 @@ def usp_attn_forward(self, x, freqs):

del q, k, v
getattr(torch, parse_device_type(x.device)).empty_cache()
return self.o(x)
return self.o(x)
Loading