diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index ae44bb68..d4318286 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -3,19 +3,47 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): + from xfuser.core.distributed import ( + get_sequence_parallel_rank, + get_sp_group, + ) + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) - timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) - timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) - - noise = torch.randn_like(inputs["input_latents"]) + input_shape = inputs["input_latents"].shape + input_dtype = inputs["input_latents"].dtype + input_device = inputs["input_latents"].device + + # Random noise and timestep IDs are generated by SP local rank 0, + # and broadcast to all other ranks in the SP group. + # Alternative implementation: + # use consistent random seeds for noise and timestep_id generation within each sp group’s ranks. + if pipe.sp_size > 1: + sp_group=get_sp_group() + if get_sequence_parallel_rank() == 0: + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) + timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) + else: + timestep = torch.zeros(1, dtype=pipe.torch_dtype, device=pipe.device) + sp_group.broadcast(timestep, src=0) + + if get_sequence_parallel_rank() == 0: + noise = torch.randn(input_shape, dtype=input_dtype, device=input_device) + else: + noise = torch.zeros(input_shape, dtype=input_dtype, device=input_device) + sp_group.broadcast(noise, src=0) + else: + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) + timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) + noise = torch.randn(input_shape, dtype=input_dtype, device=input_device) + inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) - + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) - + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = loss * pipe.scheduler.training_weight(timestep) return loss diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index b8c6c6af..860b7347 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -37,6 +37,7 @@ def add_training_config(parser: argparse.ArgumentParser): parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.") + parser.add_argument("--sp_size", type=int, default=1, help="Sequence size. sp size > 1 will init usp for sequence parallal.") return parser def add_output_config(parser: argparse.ArgumentParser): diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index f6e2263a..2aabacc8 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -3,7 +3,56 @@ from accelerate import Accelerator from .training_module import DiffusionTrainingModule from .logger import ModelLogger +import time +import logging +logger = logging.getLogger(__name__) + +def build_dataloader( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + num_workers: int = 1, + sp_size: int = 1, +): + if sp_size > 1: + # When using sequence parallel, it is necessary to ensure that when the sampler uses iter to + # fetch data from the dataloader, each rank within the same SP group obtains the same sample. + if accelerator is not None: + world_size = accelerator.num_processes + rank = accelerator.process_index + else: + raise ValueError(f"Accelerator is None.") + + dp_size = world_size // sp_size + if dp_size * sp_size != world_size: + raise ValueError( + f"world_size={world_size}, sp_size={sp_size}, world_size should be diviaible by sp_size" + ) + + dp_rank = rank // sp_size + sp_rank = rank % sp_size + logger.info(f"accelerator.processid={rank}, accelerator.num_processes={world_size}, " + f"sp_size={sp_size}, dp_size={dp_size}, dp_rank={dp_rank}") + else: + if accelerator is not None: + dp_size = accelerator.num_processes + dp_rank = accelerator.process_index + else: + raise ValueError(f"Accelerator is None.") + logger.info(f"dp_size={dp_size}, dp_rank={dp_rank}") + + sampler = torch.utils.data.DistributedSampler(dataset=dataset, num_replicas=dp_size, rank=dp_rank) + + dataloader_kwargs = dict( + dataset=dataset, + sampler=sampler, + num_workers=num_workers, + pin_memory=True, + collate_fn=lambda x: x[0], + ) + dataloader = torch.utils.data.DataLoader(**dataloader_kwargs) + + return dataloader def launch_training_task( accelerator: Accelerator, @@ -15,6 +64,7 @@ def launch_training_task( num_workers: int = 1, save_steps: int = None, num_epochs: int = 1, + sp_size: int = 1, args = None, ): if args is not None: @@ -23,29 +73,80 @@ def launch_training_task( num_workers = args.dataset_num_workers save_steps = args.save_steps num_epochs = args.num_epochs - + sp_size = args.sp_size + + train_step = 0 + optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) - - model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + dataloader = build_dataloader(accelerator, dataset, num_workers, sp_size) + model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) for epoch_id in range(num_epochs): - for data in tqdm(dataloader): + progress = tqdm( + dataloader, + disable=not accelerator.is_main_process, + desc=f"Epoch {epoch_id + 1}/{num_epochs}", + ) + + for data in progress: + logger.info(f"[train] id{accelerator.process_index}, step{train_step}, prompt: {data['prompt']}") + + iter_start = time.time() + timing = {} + if data is None: + continue + with accelerator.accumulate(model): optimizer.zero_grad() + + forward_start = time.time() if dataset.load_from_cache: loss = model({}, inputs=data) else: loss = model(data) + torch.cuda.synchronize() + timing["forward"] = time.time() - forward_start + + backward_start = time.time() accelerator.backward(loss) + torch.cuda.synchronize() + timing["backward"] = time.time() - backward_start + + optim_start = time.time() optimizer.step() + torch.cuda.synchronize() + timing["optimizer"] = time.time() - optim_start + model_logger.on_step_end(accelerator, model, save_steps, loss=loss) scheduler.step() + + torch.cuda.synchronize() + iter_end = time.time() + timing["step"] = iter_end - iter_start + train_step += 1 + + if accelerator.is_main_process: + def format_time(key: str) -> str: + value = timing.get(key, 0.0) + return f"{value:.3f}s" + + postfix_dict = { + "loss": f"{loss.item():.5f}", + "lr": f"{optimizer.param_groups[0]['lr']:.5e}", + "step/t": format_time("step"), + "fwd/t": format_time("forward"), + "bwd/t": format_time("backward"), + "opt/t": format_time("optimizer"), + } + progress.set_postfix(postfix_dict) + log_msg = f"[Step {train_step:6d}] | " + " | ".join(f"{k}: {v}" for k, v in postfix_dict.items()) + progress.write(log_msg) + if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) - model_logger.on_training_end(accelerator, model, save_steps) + model_logger.on_training_end(accelerator, model, save_steps) def launch_data_process_task( accelerator: Accelerator, diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index edd6dffc..5ae98a53 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -31,7 +31,7 @@ class WanVideoPipeline(BasePipeline): - def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16, sp_size=1): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 @@ -80,6 +80,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): WanVideoPostUnit_S2V(), ] self.model_fn = model_fn_wan_video + self.sp_size = sp_size def enable_usp(self): @@ -92,7 +93,6 @@ def enable_usp(self): for block in self.dit2.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) - self.sp_size = get_sequence_parallel_world_size() self.use_unified_sequence_parallel = True @@ -105,6 +105,7 @@ def from_pretrained( audio_processor_config: ModelConfig = None, redirect_common_files: bool = True, use_usp: bool = False, + sp_size: int = 1, vram_limit: float = None, ): # Redirect model path @@ -122,16 +123,17 @@ def from_pretrained( print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.") model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] - + if use_usp: from ..utils.xfuser import initialize_usp - initialize_usp(device) + initialize_usp(device, sp_size) import torch.distributed as dist from ..core.device.npu_compatible_device import get_device_name if dist.is_available() and dist.is_initialized(): device = get_device_name() + # Initialize pipeline - pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype, sp_size=sp_size) model_pool = pipe.download_and_load_models(model_configs, vram_limit) # Fetch models @@ -1379,7 +1381,20 @@ def custom_forward(*inputs): x = animate_adapter.after_transformer_block(block_id, x, motion_vec) if tea_cache is not None: tea_cache.store(x) - + + ''' + The all_gather interface in xDit utilizes torch’s all_gather_into_tensor interface. As of the torch 2.9 release version, this interface still does not provide a backward method and cannot support automatic autograd. The commit in the torch community (https://github.com/pytorch/pytorch/pull/168140) has not yet been merged. Therefore, a simple replacement of all_gather_into_tensor in xdit with torch.distributed.nn.functional.all_gather is applied here to enable autograd support. + + def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor: + ... + # All-gather. + # torch.distributed.all_gather_into_tensor( + # output_tensor, input_, group=self.device_group + # ) + gathered_list = torch.distributed.nn.functional.all_gather(input_, group=self.device_group) + output_tensor = torch.cat(gathered_list, dim=0) + ... + ''' x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 21dc3b33..bbdcf756 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -6,20 +6,39 @@ get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ...core.device import parse_nccl_backend, parse_device_type +import logging +logger = logging.getLogger(__name__) -def initialize_usp(device_type): +def initialize_usp(device_type, sp_size): import torch.distributed as dist - from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment - dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://") + from xfuser.core.distributed import ( + initialize_model_parallel, + init_distributed_environment, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_data_parallel_world_size, + get_data_parallel_rank, + ) + + if not dist.is_initialized(): + dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://") + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + + sp_degree = sp_size + dp_degree = int(dist.get_world_size() / sp_degree) initialize_model_parallel( - sequence_parallel_degree=dist.get_world_size(), + data_parallel_degree=dp_degree, + sequence_parallel_degree=sp_degree, ring_degree=1, - ulysses_degree=dist.get_world_size(), + ulysses_degree=sp_degree, ) - getattr(torch, device_type).set_device(dist.get_rank()) - + logger.info(f"[init usp] rank: {dist.get_rank()}, world_size: {dist.get_world_size()}, " + f"sp world size: {get_sequence_parallel_world_size()}, " + f"sp rank: {get_sequence_parallel_rank()}, " + f"dp world size: {get_data_parallel_world_size()}, " + f"dp rank: {get_data_parallel_rank()}") def sinusoidal_embedding_1d(dim, position): sinusoid = torch.outer(position.type(torch.float64), torch.pow( @@ -133,6 +152,13 @@ def usp_attn_forward(self, x, freqs): k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) + ''' + Refer to commit https://github.com/xdit-project/xDiT/pull/598 for the xfuser backward error. + xFuserRingFlashAttnFunc has 17 inputs (including ctx), but it inherits the backward() method from RingFlashAttnFunc which only returns 16 values (3 gradients + 13 Nones)! + +The Math +Parent class (RingFlashAttnFunc): 14 forward inputs → backward returns 3 gradients + 11 Nones = 14 returns xFuser class (xFuserRingFlashAttnFunc): 17 forward inputs → backward should return 3 gradients + 14 Nones = 17 returns Actual: backward only returns 14 returns (inherited from parent without override) Error: PyTorch expects 17 gradients but gets only 14 → expected 17, got 13 (13 = 14 - 1 for ctx) + ''' x = xFuserLongContextAttention()( None, query=q, @@ -143,4 +169,4 @@ def usp_attn_forward(self, x, freqs): del q, k, v getattr(torch, parse_device_type(x.device)).empty_cache() - return self.o(x) \ No newline at end of file + return self.o(x) diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 49734382..d83291c0 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -23,6 +23,7 @@ def __init__( task="sft", max_timestep_boundary=1.0, min_timestep_boundary=0.0, + sp_size = 1, ): super().__init__() # Warning @@ -34,9 +35,14 @@ def __init__( model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) tokenizer_config = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/") if tokenizer_path is None else ModelConfig(tokenizer_path) audio_processor_config = ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/") if audio_processor_path is None else ModelConfig(audio_processor_path) - self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, audio_processor_config=audio_processor_config) + use_usp = True if sp_size > 1 else False + self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, + model_configs=model_configs, + tokenizer_config=tokenizer_config, + audio_processor_config=audio_processor_config, + use_usp=use_usp, sp_size = sp_size) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) - + # Training mode self.switch_pipe_to_training_mode( self.pipe, trainable_models, @@ -127,6 +133,20 @@ def wan_parser(): gradient_accumulation_steps=args.gradient_accumulation_steps, kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], ) + + if accelerator is not None: + num_replicas = accelerator.num_processes + rank = accelerator.process_index + print(f"accelerator.processid={rank}, accelerator.num_processes={num_replicas}") + else: + raise ValueError(f"Failed to init accelerator.") + + if accelerator.is_main_process: + print("\n=== accelerator state ===") + print(accelerator.state) + print(f"DeepSpeed plugin: {accelerator.state.deepspeed_plugin}") + print(f"Parallelism config: {accelerator.state.parallelism_config}") + dataset = UnifiedDataset( base_path=args.dataset_base_path, metadata_path=args.dataset_metadata_path, @@ -148,6 +168,7 @@ def wan_parser(): "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000), } ) + model = WanTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, @@ -169,6 +190,7 @@ def wan_parser(): device="cpu" if args.initialize_model_on_cpu else accelerator.device, max_timestep_boundary=args.max_timestep_boundary, min_timestep_boundary=args.min_timestep_boundary, + sp_size = args.sp_size, ) model_logger = ModelLogger( args.output_path,