1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ from collections .abc import Iterator
1617from fractions import Fraction
17- from typing import Optional
18+ from itertools import chain
19+ from typing import List , Optional , Union
1820
21+ import numpy as np
22+ import PIL .Image
1923import torch
24+ from tqdm import tqdm
2025
21- from ...utils import is_av_available
26+ from ...utils import get_logger , is_av_available
27+
28+
29+ logger = get_logger (__name__ ) # pylint: disable=invalid-name
2230
2331
2432_CAN_USE_AV = is_av_available ()
@@ -101,11 +109,59 @@ def _write_audio(
101109
102110
103111def encode_video (
104- video : torch .Tensor , fps : int , audio : Optional [torch .Tensor ], audio_sample_rate : Optional [int ], output_path : str
112+ video : Union [List [PIL .Image .Image ], np .ndarray , torch .Tensor , Iterator [torch .Tensor ]],
113+ fps : int ,
114+ audio : Optional [torch .Tensor ],
115+ audio_sample_rate : Optional [int ],
116+ output_path : str ,
117+ video_chunks_number : int = 1 ,
105118) -> None :
106- video_np = video .cpu ().numpy ()
107-
108- _ , height , width , _ = video_np .shape
119+ """
120+ Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
121+ https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
122+
123+ Args:
124+ video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
125+ A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
126+ input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
127+ usually return with `output_type="np"`).
128+ fps (`int`)
129+ The frames per second (FPS) of the encoded video.
130+ audio (`torch.Tensor`, *optional*):
131+ An audio waveform of shape [audio_channels, samples].
132+ audio_sample_rate: (`int`, *optional*):
133+ The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
134+ output_path (`str`):
135+ The path to save the encoded video to.
136+ video_chunks_number (`int`, *optional*, defaults to `1`):
137+ The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
138+ number of chunks to use often depends on the tiling config for the video VAE.
139+ """
140+ if isinstance (video , list ) and isinstance (video [0 ], PIL .Image .Image ):
141+ # Pipeline output_type="pil"; assumes each image is in "RGB" mode
142+ video_frames = [np .array (frame ) for frame in video ]
143+ video = np .stack (video_frames , axis = 0 )
144+ video = torch .from_numpy (video )
145+ elif isinstance (video , np .ndarray ):
146+ # Pipeline output_type="np"
147+ is_denormalized = np .logical_and (np .zeros_like (video ) <= video , video <= np .ones_like (video ))
148+ if np .all (is_denormalized ):
149+ video = (video * 255 ).round ().astype ("uint8" )
150+ else :
151+ logger .warning (
152+ "Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel "
153+ "values in [0, ..., 255] and will be used as is."
154+ )
155+ video = torch .from_numpy (video )
156+
157+ if isinstance (video , torch .Tensor ):
158+ # Split into video_chunks_number along the frame dimension
159+ video = torch .tensor_split (video , video_chunks_number , dim = 0 )
160+ video = iter (video )
161+
162+ first_chunk = next (video )
163+
164+ _ , height , width , _ = first_chunk .shape
109165
110166 container = av .open (output_path , mode = "w" )
111167 stream = container .add_stream ("libx264" , rate = int (fps ))
@@ -119,10 +175,12 @@ def encode_video(
119175
120176 audio_stream = _prepare_audio_stream (container , audio_sample_rate )
121177
122- for frame_array in video_np :
123- frame = av .VideoFrame .from_ndarray (frame_array , format = "rgb24" )
124- for packet in stream .encode (frame ):
125- container .mux (packet )
178+ for video_chunk in tqdm (chain ([first_chunk ], video ), total = video_chunks_number , desc = "Encoding video chunks" ):
179+ video_chunk_cpu = video_chunk .to ("cpu" ).numpy ()
180+ for frame_array in video_chunk_cpu :
181+ frame = av .VideoFrame .from_ndarray (frame_array , format = "rgb24" )
182+ for packet in stream .encode (frame ):
183+ container .mux (packet )
126184
127185 # Flush encoder
128186 for packet in stream .encode ():
0 commit comments