Skip to content

Commit baaa8d0

Browse files
dg845sayakpaul
andauthored
LTX 2 Improve encode_video by Accepting More Input Types (#13057)
* Support different pipeline outputs for LTX 2 encode_video * Update examples to use improved encode_video function * Fix comment * Address review comments * make style and make quality * Have non-iterator video inputs respect video_chunks_number * make style and make quality * Add warning when encode_video receives a non-denormalized np.ndarray * make style and make quality --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 44f4dc0 commit baaa8d0

File tree

5 files changed

+68
-20
lines changed

5 files changed

+68
-20
lines changed

docs/source/en/api/pipelines/ltx2.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ video, audio = pipe(
106106
output_type="np",
107107
return_dict=False,
108108
)
109-
video = (video * 255).round().astype("uint8")
110-
video = torch.from_numpy(video)
111109

112110
encode_video(
113111
video[0],
@@ -185,8 +183,6 @@ video, audio = pipe(
185183
output_type="np",
186184
return_dict=False,
187185
)
188-
video = (video * 255).round().astype("uint8")
189-
video = torch.from_numpy(video)
190186

191187
encode_video(
192188
video[0],

src/diffusers/pipelines/ltx2/export_utils.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,20 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from collections.abc import Iterator
1617
from fractions import Fraction
17-
from typing import Optional
18+
from itertools import chain
19+
from typing import List, Optional, Union
1820

21+
import numpy as np
22+
import PIL.Image
1923
import torch
24+
from tqdm import tqdm
2025

21-
from ...utils import is_av_available
26+
from ...utils import get_logger, is_av_available
27+
28+
29+
logger = get_logger(__name__) # pylint: disable=invalid-name
2230

2331

2432
_CAN_USE_AV = is_av_available()
@@ -101,11 +109,59 @@ def _write_audio(
101109

102110

103111
def encode_video(
104-
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
112+
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
113+
fps: int,
114+
audio: Optional[torch.Tensor],
115+
audio_sample_rate: Optional[int],
116+
output_path: str,
117+
video_chunks_number: int = 1,
105118
) -> None:
106-
video_np = video.cpu().numpy()
107-
108-
_, height, width, _ = video_np.shape
119+
"""
120+
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
121+
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
122+
123+
Args:
124+
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
125+
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
126+
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
127+
usually return with `output_type="np"`).
128+
fps (`int`)
129+
The frames per second (FPS) of the encoded video.
130+
audio (`torch.Tensor`, *optional*):
131+
An audio waveform of shape [audio_channels, samples].
132+
audio_sample_rate: (`int`, *optional*):
133+
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
134+
output_path (`str`):
135+
The path to save the encoded video to.
136+
video_chunks_number (`int`, *optional*, defaults to `1`):
137+
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
138+
number of chunks to use often depends on the tiling config for the video VAE.
139+
"""
140+
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
141+
# Pipeline output_type="pil"; assumes each image is in "RGB" mode
142+
video_frames = [np.array(frame) for frame in video]
143+
video = np.stack(video_frames, axis=0)
144+
video = torch.from_numpy(video)
145+
elif isinstance(video, np.ndarray):
146+
# Pipeline output_type="np"
147+
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
148+
if np.all(is_denormalized):
149+
video = (video * 255).round().astype("uint8")
150+
else:
151+
logger.warning(
152+
"Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel "
153+
"values in [0, ..., 255] and will be used as is."
154+
)
155+
video = torch.from_numpy(video)
156+
157+
if isinstance(video, torch.Tensor):
158+
# Split into video_chunks_number along the frame dimension
159+
video = torch.tensor_split(video, video_chunks_number, dim=0)
160+
video = iter(video)
161+
162+
first_chunk = next(video)
163+
164+
_, height, width, _ = first_chunk.shape
109165

110166
container = av.open(output_path, mode="w")
111167
stream = container.add_stream("libx264", rate=int(fps))
@@ -119,10 +175,12 @@ def encode_video(
119175

120176
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
121177

122-
for frame_array in video_np:
123-
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
124-
for packet in stream.encode(frame):
125-
container.mux(packet)
178+
for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"):
179+
video_chunk_cpu = video_chunk.to("cpu").numpy()
180+
for frame_array in video_chunk_cpu:
181+
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
182+
for packet in stream.encode(frame):
183+
container.mux(packet)
126184

127185
# Flush encoder
128186
for packet in stream.encode():

src/diffusers/pipelines/ltx2/pipeline_ltx2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@
6969
... output_type="np",
7070
... return_dict=False,
7171
... )
72-
>>> video = (video * 255).round().astype("uint8")
73-
>>> video = torch.from_numpy(video)
7472
7573
>>> encode_video(
7674
... video[0],

src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@
7575
... output_type="np",
7676
... return_dict=False,
7777
... )
78-
>>> video = (video * 255).round().astype("uint8")
79-
>>> video = torch.from_numpy(video)
8078
8179
>>> encode_video(
8280
... video[0],

src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@
7676
... output_type="np",
7777
... return_dict=False,
7878
... )[0]
79-
>>> video = (video * 255).round().astype("uint8")
80-
>>> video = torch.from_numpy(video)
8179
8280
>>> encode_video(
8381
... video[0],

0 commit comments

Comments
 (0)