Skip to content

Conversation

@rootonchair
Copy link
Contributor

@rootonchair rootonchair commented Jan 9, 2026

What does this PR do?

Fixes #12925

Test script t2i

import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image

pipe = LTX2Pipeline.from_pretrained("rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16)
upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained("rootonchair/LTX-2-19b-distilled", subfolder="upsample_pipeline", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=768,
    height=512,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="latent",
    return_dict=False,
)

upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_sample.mp4",
)

Before submitting

Who can review?

@sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@rootonchair rootonchair marked this pull request as ready for review January 11, 2026 15:39
@rootonchair
Copy link
Contributor Author

Running results
Original

output.mp4

Converted ckpt

ltx2_sample.mp4

@sayakpaul sayakpaul requested a review from dg845 January 12, 2026 03:41
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for shipping this so quickly! Left some comments, LMK if they make sense.

"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"timestep_conditioning": timestep_conditioning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is timestep_conditioning used in the corresponding Video VAE class?

Copy link
Contributor Author

@rootonchair rootonchair Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timestep_conditioning is used here. It will depend on the metadata stored in the original checkpoint and distilled checkpoint is configured to use timestep_conditioning. You can confirm it using below script:

from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
model_loader = SafetensorsModelStateDictLoader()
model_loader.metadata("weights/ltx-2-19b-distilled.safetensors")["vae"]["timestep_conditioning"] # True

Hence, I got some unexpected keys if don't turn timestep_conditioning to True when converting distilled weights

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rootonchair, there was a recent update on the official Lightricks/LTX-2 repo which updated the VAE for the distilled checkpoint: https://huggingface.co/Lightricks/LTX-2/commit/1931de987c8e265eb64a9123227b903754e3cc68. So I think rootonchair/LTX-2-19b-distilled needs to be updated with the new converted VAE as well.

It looks like timestep_conditioning should be set to False for the new video VAE, according to the config metadata on ltx-2-19b-distilled.safetensors. So we should probably remove the unexpected keys (last_time_embedder, last_scale_shift_table) rather than processing them in the VAE conversion script.

import json
import safetensors
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download("Lightricks/LTX-2", "ltx-2-19b-distilled.safetensors")
with safetensors.safe_open(ckpt_path, framework="pt") as f:
    config = json.loads(f.metadata()["config"])
config["vae"]["timestep_conditioning"]  # Should be False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will convert the new weight. For the convert script, I thinnk we should keep it there until the original repo decide to delete it

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left some comments about the distilled sigmas schedule.

If I print out the timesteps for the Stage 1 distilled pipeline, I get (for commit faeccc5):

Distilled timesteps: tensor([1000.0000,  999.6502,  999.2961,  998.9380,  998.5754,  994.4882,
         979.3755,  929.6974,  100.0000], device='cuda:0')

Here the sigmas (and thus the timesteps) are shifted toward a terminal value of 0.1, and use_dynamic_shifting is applied as well. However, I believe the distilled sigmas are used as-is in the original LTX 2.0 code:

https://github.com/Lightricks/LTX-2/blob/391c0a2462f8bc2238fa0a2b7717bc1914969230/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py#L132

So I think when creating the distilled scheduler we need to disable use_dynamic_shifting and shift_terminal so that the distilled sigmas are used without changes.

Can you check whether the final distilled sigmas match up with those of the original implementation?

@dg845
Copy link
Collaborator

dg845 commented Jan 13, 2026

The original test script didn't work for me, but I was able to get a working version as follows:

import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda:0"
width = 768
height = 512

pipe = LTX2Pipeline.from_pretrained(
    "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload(device=device)

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "rootonchair/LTX-2-19b-distilled",
    subfolder="upsample_pipeline/latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample.mp4",
)

The necessary changes were to create LTX2LatentUpsamplePipeline directly from the models, as from_pretrained doesn't work because there are two pipelines (the distilled pipeline and the latent upsampling pipeline) in the same repo (a known limitation), and to set Stage 2 inference to use width * 2 and height * 2, as the upsampling pipeline upsamples the video by 2x in both the width and height.

rootonchair and others added 2 commits January 13, 2026 12:01
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
@sayakpaul
Copy link
Member

Not jeopardizing this PR at all but while we're at the two-stage pipeline stuff, it could also be cool to verify it with the distilled LoRA that we have in place (PR already merged: #12933).

So, what we would do is:

  1. Obtain video and audio latents with the regular LTX2 pipeline
  2. Upsample the video latents
  3. Load the distilled LoRA into the regular LTX2 pipeline
  4. Run the pipeline with 4 steps and these sigmas and with a guidance scale of 1.

Once we're close to merging the PR, we could document all of these to inform the community.

@rootonchair
Copy link
Contributor Author

@dg845 thank you for your detail reviews. Let's me take a closer look on that

@rootonchair
Copy link
Contributor Author

Not jeopardizing this PR at all but while we're at the two-stage pipeline stuff, it could also be cool to verify it with the distilled LoRA that we have in place (PR already merged: #12933).

So, what we would do is:

  1. Obtain video and audio latents with the regular LTX2 pipeline
  2. Upsample the video latents
  3. Load the distilled LoRA into the regular LTX2 pipeline
  4. Run the pipeline with 4 steps and these sigmas and with a guidance scale of 1.

Once we're close to merging the PR, we could document all of these to inform the community.

That sounds interesting. Let's have a quick test on two stage distilled LoRA too

@dg845
Copy link
Collaborator

dg845 commented Jan 13, 2026

For two-stage inference with the Stage 2 distilled LoRA, I think this script should work:

import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda:0"
width = 768
height = 512

pipe = LTX2Pipeline.from_pretrained(
    "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16
)
# This scheduler should use distilled sigmas without any changes
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
    pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler
pipe.enable_model_cpu_offload(device=device)

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "rootonchair/LTX-2-19b-distilled",
    subfolder="upsample_pipeline/latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

# Load Stage 2 distilled LoRA
pipe.load_lora_weights(
    "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
)
pipe.set_adapters("stage_2_distilled", 1.0)
# VAE tiling seems necessary to avoid OOM error when VAE decoding
pipe.vae.enable_tiling()
video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample.mp4",
)

Sample with LoRA and scheduler fix:

ltx2_distilled_sample_lora_fix.mp4

@rootonchair
Copy link
Contributor Author

For two-stage inference with the Stage 2 distilled LoRA, I think this script should work:

import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda:0"
width = 768
height = 512

pipe = LTX2Pipeline.from_pretrained(
    "rootonchair/LTX-2-19b-distilled", torch_dtype=torch.bfloat16
)
# This scheduler should use distilled sigmas without any changes
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
    pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler
pipe.enable_model_cpu_offload(device=device)

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "rootonchair/LTX-2-19b-distilled",
    subfolder="upsample_pipeline/latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

# Load Stage 2 distilled LoRA
pipe.load_lora_weights(
    "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
)
pipe.set_adapters("stage_2_distilled", 1.0)
# VAE tiling seems necessary to avoid OOM error when VAE decoding
pipe.vae.enable_tiling()
video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample.mp4",
)

Sample with LoRA and scheduler fix:

ltx2_distilled_sample_lora_fix.mp4

I think we should run with the original LTX2 weight and not the distilled checkpoint. WDYT?

@sayakpaul
Copy link
Member

Yes, the first stage, in this case, should use the non-distilled ckpt.

@dg845
Copy link
Collaborator

dg845 commented Jan 13, 2026

Fixed script (I think):

import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda:0"
width = 768
height = 512

pipe = LTX2Pipeline.from_pretrained(
    "Lightricks/LTX-2", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload(device=device)

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

# Stage 1 default (non-distilled) inference
frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=40,
    sigmas=None,
    guidance_scale=4.0,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "Lightricks/LTX-2",
    subfolder="latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

# Load Stage 2 distilled LoRA
pipe.load_lora_weights(
    "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors"
)
pipe.set_adapters("stage_2_distilled", 1.0)
# VAE tiling seems necessary to avoid OOM error when VAE decoding
pipe.vae.enable_tiling()
# Change scheduler to use Stage 2 distilled sigmas as is
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
    pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler
# Stage 2 inference with distilled LoRA and sigmas
video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample.mp4",
)

@dg845
Copy link
Collaborator

dg845 commented Jan 13, 2026

If I test the distilled pipeline with the prompt "a dog dancing to energetic electronic dance music", I get the following sample:

ltx2_distilled_sample_dog_edm.mp4

I would expect the audio to be music for this prompt, but instead the audio is only noise, so I think there might be something wrong with the way audio is currently being handled in the distilled pipeline. (The video also doesn't follow the prompt closely; I'm not sure if this is a symptom of the audio being messed up or if there are also bugs for video processing.)

@sayakpaul
Copy link
Member

@dg845 should the second stage inference with LoRA be run with 4 num_inference_steps? Also, is the following necessary?

width=width * 2,
height=height * 2,

@dg845
Copy link
Collaborator

dg845 commented Jan 14, 2026

should the second stage inference with LoRA be run with 4 num_inference_steps?

I believe it should be run with 3, as STAGE_2_DISTILLED_SIGMA_VALUES has 3 values excluding the trailing 0.0. Note that if sigmas is set to STAGE_2_DISTILLED_SIGMA_VALUES, this will currently override num_inference_steps.

Also, is the following necessary?

It is currently necessary as otherwise we'd get a shape error, but we could modify the code to infer the latent_height and latent_width from supplied latents, by changing this part:

latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width

@dg845
Copy link
Collaborator

dg845 commented Jan 14, 2026

Wrote a sample commit f4d47b9 which infers the latent dimensions (e.g. latent_height, latent_width, etc.) if latents or audio_latents is supplied. This should make it easier to run two-stage inference without having to manage the width and height.

@rootonchair feel free to cherry-pick the changes if they work for you.

@rootonchair
Copy link
Contributor Author

Thanks for the PR! Left some comments about the distilled sigmas schedule.

If I print out the timesteps for the Stage 1 distilled pipeline, I get (for commit faeccc5):

Distilled timesteps: tensor([1000.0000,  999.6502,  999.2961,  998.9380,  998.5754,  994.4882,
         979.3755,  929.6974,  100.0000], device='cuda:0')

Here the sigmas (and thus the timesteps) are shifted toward a terminal value of 0.1, and use_dynamic_shifting is applied as well. However, I believe the distilled sigmas are used as-is in the original LTX 2.0 code:

https://github.com/Lightricks/LTX-2/blob/391c0a2462f8bc2238fa0a2b7717bc1914969230/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py#L132

So I think when creating the distilled scheduler we need to disable use_dynamic_shifting and shift_terminal so that the distilled sigmas are used without changes.

Can you check whether the final distilled sigmas match up with those of the original implementation?

I have checked with the original implementation and the timesteps is match with each other. As the sigmas later scaled up to 1000 here

@rootonchair
Copy link
Contributor Author

If I test the distilled pipeline with the prompt "a dog dancing to energetic electronic dance music", I get the following sample:

ltx2_distilled_sample_dog_edm.mp4

I would expect the audio to be music for this prompt, but instead the audio is only noise, so I think there might be something wrong with the way audio is currently being handled in the distilled pipeline. (The video also doesn't follow the prompt closely; I'm not sure if this is a symptom of the audio being messed up or if there are also bugs for video processing.)

This is really weird, let's investigate more on this before we ship the model

@rootonchair
Copy link
Contributor Author

Wrote a sample commit f4d47b9 which infers the latent dimensions (e.g. latent_height, latent_width, etc.) if latents or audio_latents is supplied. This should make it easier to run two-stage inference without having to manage the width and height.

@rootonchair feel free to cherry-pick the changes if they work for you.

Thanks

f4d47b9

@dg845 I think your commit is good. Just check a few more before applying

if latents is not None:
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
Copy link
Collaborator

@dg845 dg845 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the output_type is set to "latent", LTX2Pipeline will return the denormalized audio latents:

audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)

Since the DiT expects normalized latents, I think we need to normalize the audio latents here:

                latents = self._pack_audio_latents(latents)
                latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)

where _normalize_audio_latents is something like

    @staticmethod
    def _normalize_audio_latents(
        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
    ) -> torch.Tensor:
        # Normalize latents across the combined channel and mel bin dimension [B, L, C * M]
        latents_mean = latents_mean.to(latents.device, latents.dtype)
        latents_std = latents_std.to(latents.device, latents.dtype)
        latents = (latents - latents_mean) / latents_std
        return latents

This should make the Stage 2 audio latents have the expected distribution.

Copy link
Collaborator

@dg845 dg845 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that although LTX2Pipeline also returns the denormalized video latents when output_type="latent", LTX2LatentUpsamplePipeline returns the normalized latents, so the Stage 2 video latents are not affected by this issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(However, prepare_latents-type methods usually expect supplied latents to be normalized, since we generally return them as-is without normalizing them, so we might need to think through the design on where we expect latents to be normalized or denormalized. CC @sayakpaul.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is a bit of a spiraling situation and I wonder if exposing a flag like normalize_latents (or leverage the latents_normalized flag as used in the upsampling pipeline) could make sense here (we keep it to a reasonable default) and we educate the users about when to use what for normalize_latents. I think this also comes with the uniqueness of the LTX2 pipeline a bit.

The situation is getting confusing likely because the video latents returned from the upsampling pipeline is always normalized. If the users wants them unnormalized (with a flag normalize_latents=False, then I guess both upsampled_video_latent and audio_latent could be passed as is to the stage 2 pipeline?

(However, prepare_latents-type methods usually expect supplied latents to be normalized, since we generally return them as-is without normalizing them, so we might need to think through the design on where we expect latents to be normalized or denormalized.

I think this is both true and false. Inside Flux Kontext, the image latents are normalized inside prepare_latents():

image_latents = self._encode_vae_image(image=image, generator=generator)

This is what we also follow for LTX2 I2V:

init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One perspective we should consider to keep the returned latents denormalized as default is that there could be a pipeline where each module has different methods for normalizing. Only passing denormalized latents and each module need to normalize the latents themselves will help them maintain the plug-and-play ability. Moreover, the computation required for normalizing is acceptable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree to that perspective. Let's see what @dg845 has to say about this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good idea as well.

@rootonchair
Copy link
Contributor Author

I just updated the model repository. Move latent_upsampler outside and remove upsampler_pipeline

@rootonchair
Copy link
Contributor Author

The new distilled weight have just been updated!

@rootonchair rootonchair requested a review from sayakpaul January 18, 2026 17:25
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rootonchair!

Could you also provide some end to end examples of using it in two-staged way and also with LoRA?

Comment on lines -712 to +711
return latents, latent_length
return latents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this being changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This param is computed inside the prepare_audio_latents then being returned for using outside however later we move the computation outside the __call__ function already

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I was supporting inferring the latent dimensions from latents/audio_latents in eb01780 I moved the calculation into the __call__ method as I thought it made more sense there with the new logic. This also makes the audio latents logic more parallel to the video latents logic in both __call__ and prepare_latents/prepare_audio_latents.


# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we convert the sigmas into np.array or no need?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see it being implement on other pipelines so I think it's no need: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2.py#L932

@JoeGaffney
Copy link

JoeGaffney commented Jan 20, 2026

If I test the distilled pipeline with the prompt "a dog dancing to energetic electronic dance music", I get the following sample:

ltx2_distilled_sample_dog_edm.mp4
I would expect the audio to be music for this prompt, but instead the audio is only noise, so I think there might be something wrong with the way audio is currently being handled in the distilled pipeline. (The video also doesn't follow the prompt closely; I'm not sure if this is a symptom of the audio being messed up or if there are also bugs for video processing.)

Hey @dg845 I have been testing LTX-2 in comfy and also diffusers have not tried distilled in diffusers. But am getting really poor results in diffusers. With comfy i was doing one pass just using distilled skipping the upscale and was pretty decent possibly better than with the two pass.

Have you been able to get good results with diffusers?

"a dog dancing to energetic electronic dance music" tested with your prompt adherence is not great but quality and audio is better.

image
LTX-2_00001_.mp4

@sayakpaul sayakpaul requested a review from dg845 January 20, 2026 12:41
@rootonchair
Copy link
Contributor Author

It seems like removing the last element in distilled sigmas cause huge degrade in performance. @JoeGaffney could you try again? In the meantime, I will dig on why

@rootonchair
Copy link
Contributor Author

Here is end-to-end example for distillation checkpoint

import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda"
width = 768
height = 512
random_seed = 42
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "rootonchair/LTX-2-19b-distilled"

pipe = LTX2Pipeline.from_pretrained(
    model_path, torch_dtype=torch.bfloat16
)
pipe.enable_sequential_cpu_offload(device=device)

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    generator=generator,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    model_path,
    subfolder="latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    generator=generator,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample.mp4",
)

@rootonchair
Copy link
Contributor Author

After debug, it seems like the audio issue coming from 2nd stage denoising, as the audio from first stage is still normal

Comment on lines -664 to +670
default="Lightricks/LTX-2",
default=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to keep the default as "Lightricks/LTX-2", as I think it is more convenient than having it be None and the script user having to specify a repo id each time.

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! I think the sigmas are still being calculated incorrectly. If I print out the scheduler sigmas after

self._num_timesteps = len(timesteps)

on commit 62acd4c, I get the following output:

Timesteps: tensor([1000.0000,  999.6502,  999.2961,  998.9380,  998.5754,  994.4882,
         979.3755,  929.6974,  100.0000], device='cuda:0')
Num timesteps: 9
Sigmas: tensor([1.0000, 0.9997, 0.9993, 0.9989, 0.9986, 0.9945, 0.9794, 0.9297, 0.1000,
        0.0000], device='cuda:0')
Num sigmas: 10
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:19<00:00,  2.16s/it]
Timesteps: tensor([999.9924, 999.9713, 999.8963, 100.0000], device='cuda:0')
Num timesteps: 4
Sigmas: tensor([1.0000, 1.0000, 0.9999, 0.1000, 0.0000], device='cuda:0')
Num sigmas: 5
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:34<00:00,  8.50s/it]

whereas if I print out the sigmas from the original LTX 2 code (before this line), I get

Sigmas: tensor([1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250, 0.4219, 0.0000],
       device='cuda:0')
Num sigmas: 9
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:10<00:00,  1.34s/it]
Sigmas: tensor([0.9094, 0.7250, 0.4219, 0.0000], device='cuda:0')
Num sigmas: 4
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:21<00:00,  7.07s/it]

So the original code is using the distilled sigmas as-is. I believe the scheduler in rootonchair/LTX-2-19b-distilled sets use_dynamic_shifting=True and shift_terminal=0.1 (the settings for the non-distilled pipeline), which explains the discrepancy; these should be set to use_dynamic_shifting=False and shift_terminal=None, respectively.

Furthermore, because FlowMatchEulerDiscreteScheduler will add a trailing 0 to the sigma schedule:

# 6. Append the terminal sigma value.
# If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
# `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

I think removing the trailing 0.0 from DISTILLED_SIGMA_VALUES/STAGE_2_DISTILLED_SIGMA_VALUES is correct.

@dg845
Copy link
Collaborator

dg845 commented Jan 21, 2026

I think one thing we're also not currently doing is renoising the latents for Stage 2. The original LTX 2 distilled pipeline code calls denoise_audio_video for the Stage 2 denoising loop, which internally calls noise_video_state/noise_audio_state before denoising. These functions themselves call create_noised_state.

create_noised_state calls a GaussianNoiser instance, which will sample standard Gaussian noise and then blend it with the initial latents:

        scaled_mask = latent_state.denoise_mask * noise_scale
        latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)

Here denoise_mask is usually an all-ones tensor, and for Stage 2 noise_scale is set to STAGE_2_DISTILLED_SIGMA_VALUES[0] (see this line), which is 0.909375. (This is another reason why I think the current sigmas calculation is incorrect: the current first sigma is very close to 1 for Stage 2, which means applying renoising would completely overwrite the Stage 1 latents with noise, which doesn't make sense.)

Here is a script that implements renoising (without the scheduler and sigmas fixes suggested in #12934 (review)):

Distilled Script with Renoising
import torch
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video


device = "cuda:0"
width = 768
height = 512
random_seed = 42
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "rootonchair/LTX-2-19b-distilled"

pipe = LTX2Pipeline.from_pretrained(
    model_path, torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload(device=device)
pipe.vae.enable_tiling()

prompt = "A beautiful sunset over the ocean"
negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."

frame_rate = 24.0
video_latent, audio_latent = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    generator=generator,
    output_type="latent",
    return_dict=False,
)

with torch.no_grad():
    stage_1_video = video_latent.to(pipe.vae.dtype)
    stage_1_video = pipe.vae.decode(stage_1_video, None, return_dict=False)[0]
    stage_1_video = pipe.video_processor.postprocess_video(stage_1_video, output_type="np")
    stage_1_video = (stage_1_video * 255).round().astype("uint8")
    stage_1_video = torch.from_numpy(stage_1_video)

    stage_1_audio = audio_latent.to(pipe.audio_vae.dtype)
    stage_1_audio = pipe.audio_vae.decode(stage_1_audio, return_dict=False)[0]
    stage_1_audio = pipe.vocoder(stage_1_audio)

encode_video(
    stage_1_video[0],
    fps=frame_rate,
    audio=stage_1_audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample_stage_1.mp4",
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    model_path,
    subfolder="latent_upsampler",
    torch_dtype=torch.bfloat16,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

# Renoise latents for Stage 2
noise_scale = STAGE_2_DISTILLED_SIGMA_VALUES[0]  # 0.909375 < 1

video_noise = torch.randn(
    upscaled_video_latent.size(),
    generator=generator,
    dtype=upscaled_video_latent.dtype,
    device=upscaled_video_latent.device,
)
upscaled_video_latent = noise_scale * video_noise + (1 - noise_scale) * upscaled_video_latent

audio_noise = torch.randn(
    audio_latent.size(),
    generator=generator,
    dtype=audio_latent.dtype,
    device=audio_latent.device,
)
audio_latent = noise_scale * audio_noise + (1 - noise_scale) * audio_latent

video, audio = pipe(
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_inference_steps=3,
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    generator=generator,
    guidance_scale=1.0,
    output_type="np",
    return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_distilled_sample_stage_2.mp4",
)

The script implements renoising outside the pipeline, but ideally it would be implemented inside the pipeline, perhaps here:

if latents is not None:
if latents.ndim == 5:
# latents are of shape [B, C, F, H, W], need to be packed
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)

Unfortunately, this doesn't seem to totally fix the audio for the Stage 2 sample:

ltx2_distilled_sample_stage_2.mp4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LTX-2 distilled checkpoint support

5 participants