From 3d78f9d17d26b348755a342a2547e90d49842afb Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Fri, 9 Jan 2026 19:45:20 +0700 Subject: [PATCH 01/12] add constants for distill sigmas values and allow ltx pipeline to pass in sigmas --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 7 ++++++- src/diffusers/pipelines/ltx2/utils.py | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/ltx2/utils.py diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 99d6b71ec3d7..7d1cd3ce656c 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -750,6 +750,7 @@ def __call__( num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, @@ -788,6 +789,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -958,7 +963,7 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py new file mode 100644 index 000000000000..99c82c82c1af --- /dev/null +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -0,0 +1,4 @@ +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] + +# Reduced schedule for super-resolution stage 2 (subset of distilled values) +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] \ No newline at end of file From 9c754a46aa768d30216a5580a91a2923e25cbf8a Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sun, 11 Jan 2026 22:13:58 +0700 Subject: [PATCH 02/12] add time conditioning conversion and token packing for latents --- scripts/convert_ltx2_to_diffusers.py | 15 +++++++++------ src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 5367113365a2..2794feffed6f 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -63,6 +63,8 @@ "up_blocks.4": "up_blocks.1", "up_blocks.5": "up_blocks.2.upsamplers.0", "up_blocks.6": "up_blocks.2", + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", # Common # For all 3D ResNets "res_blocks": "resnets", @@ -372,7 +374,7 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) - return connectors -def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { "model_id": "diffusers-internal-dev/dummy-ltx2", @@ -396,7 +398,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -433,7 +435,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), - "timestep_conditioning": False, + "timestep_conditioning": timestep_conditioning, "patch_size": 4, "patch_size_t": 1, "resnet_norm_eps": 1e-6, @@ -450,8 +452,8 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A return config, rename_dict, special_keys_remap -def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: - config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version) +def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning) diffusers_config = config["diffusers_config"] with init_empty_weights(): @@ -717,6 +719,7 @@ def get_args(): help="Latent upsampler filename", ) + parser.add_argument("--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model") parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") @@ -786,7 +789,7 @@ def main(args): original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) elif combined_ckpt is not None: original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) - vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version) + vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning) if not args.full_pipeline and not args.upsample_pipeline: vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 7d1cd3ce656c..c662d9c16745 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -653,6 +653,11 @@ def prepare_latents( latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: + if latents.ndim == 5: + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) return latents.to(device=device, dtype=dtype) height = height // self.vae_spatial_compression_ratio @@ -694,6 +699,9 @@ def prepare_audio_latents( latent_length = round(duration_s * latents_per_second) if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) return latents.to(device=device, dtype=dtype), latent_length # TODO: confirm whether this logic is correct @@ -1097,6 +1105,8 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) + prenorm_latents = latents + prenorm_audio_latents = audio_latents latents = self._denormalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) From 6fbeacf53bcc4b3c6281eb5e52e1bd81cf152555 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sun, 11 Jan 2026 23:00:02 +0700 Subject: [PATCH 03/12] make style & quality --- scripts/convert_ltx2_to_diffusers.py | 16 ++++++++++++---- src/diffusers/pipelines/ltx2/utils.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 2794feffed6f..72b334b71e71 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -374,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) - return connectors -def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +def get_ltx2_video_vae_config( + version: str, timestep_conditioning: bool = False +) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { "model_id": "diffusers-internal-dev/dummy-ltx2", @@ -452,7 +454,9 @@ def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) return config, rename_dict, special_keys_remap -def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool) -> Dict[str, Any]: +def convert_ltx2_video_vae( + original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool +) -> Dict[str, Any]: config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning) diffusers_config = config["diffusers_config"] @@ -719,7 +723,9 @@ def get_args(): help="Latent upsampler filename", ) - parser.add_argument("--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model") + parser.add_argument( + "--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model" + ) parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") @@ -789,7 +795,9 @@ def main(args): original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) elif combined_ckpt is not None: original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) - vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning) + vae = convert_ltx2_video_vae( + original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning + ) if not args.full_pipeline and not args.upsample_pipeline: vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index 99c82c82c1af..bd0ae08c1073 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,4 +1,4 @@ DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] # Reduced schedule for super-resolution stage 2 (subset of distilled values) -STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] \ No newline at end of file +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] From 82c2e7f068692ae689eaf872030503f5b8024eaf Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Sun, 11 Jan 2026 23:01:34 +0700 Subject: [PATCH 04/12] remove prenorm --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index c662d9c16745..b26ccca55c0a 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -1105,8 +1105,6 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) - prenorm_latents = latents - prenorm_audio_latents = audio_latents latents = self._denormalize_latents( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) From 837fd85c76148b1703636a8c321ce3ff163d8ab9 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 12 Jan 2026 11:26:31 +0700 Subject: [PATCH 05/12] add sigma param to ltx2 i2v --- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index b1711e283191..a33462f70c50 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -811,6 +811,7 @@ def __call__( num_frames: int = 121, frame_rate: float = 24.0, num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, @@ -851,6 +852,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 40): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is @@ -1028,7 +1033,7 @@ def __call__( latent_width = width // self.vae_spatial_compression_ratio video_sequence_length = latent_num_frames * latent_height * latent_width - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, self.scheduler.config.get("base_image_seq_len", 1024), From 96fbcd8301a81deea58773a74ca620f12d70cebc Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Mon, 12 Jan 2026 11:30:31 +0700 Subject: [PATCH 06/12] fix copies and add pack latents to i2v --- src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index a33462f70c50..92206cee4e9b 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -689,6 +689,11 @@ def prepare_latents( conditioning_mask = self._pack_latents( conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ).squeeze(-1) + if latents.ndim == 5: + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: raise ValueError( f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." @@ -754,6 +759,9 @@ def prepare_audio_latents( latent_length = round(duration_s * latents_per_second) if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) return latents.to(device=device, dtype=dtype), latent_length # TODO: confirm whether this logic is correct From 9575e0632afb5f02833b33e4d4f1da37274779e7 Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Tue, 13 Jan 2026 12:01:03 +0700 Subject: [PATCH 07/12] Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/ltx2/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index bd0ae08c1073..7e143edc9bb6 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,4 +1,4 @@ -DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] # Reduced schedule for super-resolution stage 2 (subset of distilled values) -STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] From eb01780ada0f9c2f9aa3c2a6b9c48329295985c3 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 14 Jan 2026 03:09:53 +0100 Subject: [PATCH 08/12] Infer latent dims if latents/audio_latents is supplied --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 47 ++++++++------- .../ltx2/pipeline_ltx2_image2video.py | 57 +++++++++++-------- 2 files changed, 61 insertions(+), 43 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 588e05737e88..54f1061da5c9 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -682,32 +682,23 @@ def prepare_audio_latents( self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) - return latents.to(device=device, dtype=dtype), latent_length + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -717,7 +708,7 @@ def prepare_audio_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -935,6 +926,14 @@ def __call__( latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + else: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels @@ -950,20 +949,30 @@ def __call__( latents, ) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + else: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` is correct." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 92206cee4e9b..460ff8eec7de 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -742,32 +742,23 @@ def prepare_audio_latents( self, batch_size: int = 1, num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value num_mel_bins: int = 64, - num_frames: int = 121, - frame_rate: float = 25.0, - sampling_rate: int = 16000, - hop_length: int = 160, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: - duration_s = num_frames / frame_rate - latents_per_second = ( - float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) - ) - latent_length = round(duration_s * latents_per_second) - if latents is not None: if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) - return latents.to(device=device, dtype=dtype), latent_length + return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -777,7 +768,7 @@ def prepare_audio_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_audio_latents(latents) - return latents, latent_length + return latents @property def guidance_scale(self): @@ -995,6 +986,19 @@ def __call__( ) # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + else: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) image = image.to(device=device, dtype=prompt_embeds.dtype) @@ -1015,20 +1019,30 @@ def __call__( if self.do_classifier_free_guidance: conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + else: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` is correct." + ) + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_channels_latents_audio = ( self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 ) - audio_latents, audio_num_frames = self.prepare_audio_latents( + audio_latents = self.prepare_audio_latents( batch_size * num_videos_per_prompt, num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, num_mel_bins=num_mel_bins, - num_frames=num_frames, # Video frames, audio frames will be calculated from this - frame_rate=frame_rate, - sampling_rate=self.audio_sampling_rate, - hop_length=self.audio_hop_length, dtype=torch.float32, device=device, generator=generator, @@ -1036,11 +1050,6 @@ def __call__( ) # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - video_sequence_length = latent_num_frames * latent_height * latent_width - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( video_sequence_length, From 7574bf991132f4fc2aab8def4b7855ff2f8a38b7 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Wed, 14 Jan 2026 22:50:38 +0700 Subject: [PATCH 09/12] add note for predefined sigmas --- src/diffusers/pipelines/ltx2/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index 7e143edc9bb6..8b790a4df0cb 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,3 +1,5 @@ +# Pre-trained sigma values for distilled model are taken from +# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] # Reduced schedule for super-resolution stage 2 (subset of distilled values) From c22eed5a8ef7e471e3607fc20417eee8dc39826e Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Wed, 14 Jan 2026 22:51:35 +0700 Subject: [PATCH 10/12] run make style and quality --- src/diffusers/pipelines/ltx2/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index 8b790a4df0cb..f80469817fe6 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,4 +1,4 @@ -# Pre-trained sigma values for distilled model are taken from +# Pre-trained sigma values for distilled model are taken from # https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] From c282485d0b03f1052b00aee795c95658564e5a2d Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Tue, 20 Jan 2026 23:04:57 +0700 Subject: [PATCH 11/12] revert distill timesteps & set original_state_dict_repo_idd to default None --- scripts/convert_ltx2_to_diffusers.py | 2 +- src/diffusers/pipelines/ltx2/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 72b334b71e71..e6a9aea4e46c 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -667,7 +667,7 @@ def get_args(): parser.add_argument( "--original_state_dict_repo_id", - default="Lightricks/LTX-2", + default=None, type=str, help="HF Hub repo id with LTX 2.0 checkpoint", ) diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py index f80469817fe6..77a0e3a883a3 100644 --- a/src/diffusers/pipelines/ltx2/utils.py +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -1,6 +1,6 @@ # Pre-trained sigma values for distilled model are taken from # https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py -DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] # Reduced schedule for super-resolution stage 2 (subset of distilled values) -STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] From 62acd4cf85f49ae1c6b65cae38c3be3bef6699c5 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Wed, 21 Jan 2026 00:01:43 +0700 Subject: [PATCH 12/12] add latent normalize --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 7 +++++++ src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 54f1061da5c9..b72a1079f92f 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -594,6 +594,12 @@ def _denormalize_latents( latents = latents * latents_std / scaling_factor + latents_mean return latents + @staticmethod + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + @staticmethod def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): latents_mean = latents_mean.to(latents.device, latents.dtype) @@ -693,6 +699,7 @@ def prepare_audio_latents( if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 460ff8eec7de..a316f5307130 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -656,6 +656,13 @@ def _unpack_audio_latents( latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) return latents + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): @@ -753,6 +760,7 @@ def prepare_audio_latents( if latents.ndim == 4: # latents are of shape [B, C, L, M], need to be packed latents = self._pack_audio_latents(latents) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) return latents.to(device=device, dtype=dtype) # TODO: confirm whether this logic is correct