Skip to content

Commit 94204be

Browse files
authored
Add normalization after vae encoding to data_processing_pipeline_for_wan. (#321)
1 parent 1e1058a commit 94204be

File tree

2 files changed

+116
-4
lines changed

2 files changed

+116
-4
lines changed

src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
import jax.numpy as jnp
3030
from jax.sharding import Mesh
3131
from maxdiffusion import pyconfig, max_utils
32-
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
3332
from maxdiffusion.video_processor import VideoProcessor
33+
from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1
3434

3535
import tensorflow as tf
3636

@@ -80,7 +80,13 @@ def text_encode(pipeline, prompt: Union[str, List[str]]):
8080
def vae_encode(video, rng, vae, vae_cache):
8181
latent = vae.encode(video, feat_cache=vae_cache)
8282
latent = latent.latent_dist.sample(rng)
83-
return latent
83+
latents = jnp.transpose(latent, (0, 4, 1, 2, 3))
84+
latents_mean = jnp.array(vae.latents_mean).reshape(1, vae.z_dim, 1, 1, 1)
85+
latents_std = jnp.array(vae.latents_std).reshape(1, vae.z_dim, 1, 1, 1)
86+
87+
# Apply normalization: (x - mean) / std
88+
latents = (latents - latents_mean) / latents_std
89+
return latents
8490

8591

8692
def generate_dataset(config, pipeline):
@@ -121,7 +127,6 @@ def generate_dataset(config, pipeline):
121127
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
122128
with mesh:
123129
latents = p_vae_encode(video=video, rng=new_rng)
124-
latents = jnp.transpose(latents, (0, 4, 1, 2, 3))
125130
encoder_hidden_states = text_encode(pipeline, text)
126131
for latent, encoder_hidden_state in zip(latents, encoder_hidden_states):
127132
writer.write(create_example(latent, encoder_hidden_state))
@@ -138,8 +143,10 @@ def generate_dataset(config, pipeline):
138143

139144

140145
def run(config):
141-
pipeline = WanPipeline.from_pretrained(config, load_transformer=False)
146+
checkpoint_loader = WanCheckpointer2_1(config=config)
147+
pipeline, _, _ = checkpoint_loader.load_checkpoint()
142148
# Don't need the transformer for preprocessing.
149+
del pipeline.transformer
143150
generate_dataset(config, pipeline)
144151

145152

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import pytest
19+
import functools
20+
import jax
21+
import jax.numpy as jnp
22+
from flax.linen import partitioning as nn_partitioning
23+
from jax.sharding import Mesh
24+
from .. import pyconfig
25+
from ..max_utils import (
26+
create_device_mesh,
27+
)
28+
import numpy as np
29+
import unittest
30+
from ..data_preprocessing.wan_txt2vid_data_preprocessing import vae_encode
31+
from ..checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1
32+
from ..utils import load_video
33+
from ..video_processor import VideoProcessor
34+
import flax
35+
36+
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
37+
38+
CACHE_T = 2
39+
40+
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
41+
42+
flax.config.update("flax_always_shard_variable", False)
43+
44+
45+
class DataProcessingTest(unittest.TestCase):
46+
47+
def setUp(self):
48+
DataProcessingTest.dummy_data = {}
49+
pyconfig.initialize(
50+
[
51+
None,
52+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
53+
],
54+
unittest=True,
55+
)
56+
config = pyconfig.config
57+
self.config = config
58+
devices_array = create_device_mesh(config)
59+
self.mesh = Mesh(devices_array, config.mesh_axes)
60+
61+
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
62+
def test_wan_vae_encode_normalization(self):
63+
"""Test wan vae encode function normalization"""
64+
pyconfig.initialize(
65+
[
66+
None,
67+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
68+
],
69+
unittest=True,
70+
)
71+
config = pyconfig.config
72+
devices_array = create_device_mesh(config)
73+
mesh = Mesh(devices_array, config.mesh_axes)
74+
checkpoint_loader = WanCheckpointer2_1(config=config)
75+
pipeline, _, _ = checkpoint_loader.load_checkpoint()
76+
77+
vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
78+
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)
79+
80+
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
81+
video = load_video(video_path)
82+
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width)]
83+
videos = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
84+
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))
85+
86+
rng = jax.random.key(config.seed)
87+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
88+
latents = p_vae_encode(videos, rng=rng)
89+
# 1. Verify Channel Count (Wan 2.1 requires 16)
90+
self.assertEqual(latents.shape[1], 16, f"Expected 16 channels, got {latents.shape[1]}")
91+
92+
# 2. Verify Global Stats
93+
# We expect mean near 0 and variance near 1.
94+
# We use a threshold (e.g., 0.15) since this is just one video.
95+
global_mean = jnp.mean(latents)
96+
global_var = jnp.var(latents)
97+
98+
self.assertLess(abs(global_mean), 0.2, f"Global mean {global_mean} is too far from 0")
99+
self.assertAlmostEqual(global_var, 1.0, delta=0.2, msg=f"Global variance {global_var} is too far from 1.0")
100+
101+
# 3. Verify Channel-wise Range
102+
# Ensure no channel is completely "dead" or "exploding"
103+
channel_vars = jnp.var(latents, axis=(0, 2, 3, 4))
104+
self.assertTrue(jnp.all(channel_vars > 0.1), "One or more channels have near-zero variance")
105+
self.assertTrue(jnp.all(channel_vars < 5.0), "One or more channels have exploding variance")

0 commit comments

Comments
 (0)