3232# ==========================================
3333
3434class PytorchLTX2RotaryPosEmbed (torch .nn .Module ):
35- """
36- Reference PyTorch implementation for LTX-2 RoPE Frequency Generation.
37- Splits dim across axes, computes freqs, concatenates, and interleaves.
38- """
3935 def __init__ (self , dim : int , theta : float = 10000.0 ):
4036 super ().__init__ ()
4137 self .dim = dim
4238 self .theta = theta
4339
4440 def forward (self , ids ):
45- # ids: [B, S, Num_Axes]
4641 num_axes = ids .shape [- 1 ]
4742 dim_per_axis = self .dim // num_axes
4843
49- # Standard RoPE frequencies: theta^(-2i/d)
5044 freq_indices = torch .arange (0 , dim_per_axis , 2 , dtype = torch .float32 )
5145 inv_freq = 1.0 / (self .theta ** (freq_indices / dim_per_axis ))
5246
5347 freqs_list = []
5448 for i in range (num_axes ):
55- axis_pos = ids [..., i ] # [B, S]
56- # Outer product: [B, S, 1] * [1, 1, D/2] -> [B, S, D/2]
49+ axis_pos = ids [..., i ]
5750 freqs = torch .einsum ('bs,d->bsd' , axis_pos , inv_freq )
5851 freqs_list .append (freqs )
5952
@@ -67,19 +60,24 @@ def forward(self, ids):
6760 cos = torch .repeat_interleave (cos , 2 , dim = - 1 )
6861 sin = torch .repeat_interleave (sin , 2 , dim = - 1 )
6962
70- # Add head dim for broadcasting: [B, S, 1, D]
71- return cos . unsqueeze ( 2 ) , sin . unsqueeze ( 2 )
63+ # CORRECT: Return [B, S, InnerDim] to match JAX/LTX-2 global RoPE
64+ return cos , sin
7265
7366
7467def apply_rotary_emb_pt (x , cos , sin ):
75- """Standard PyTorch Interleaved RoPE application."""
76- # x: [B, H, S, D] -> [B, H, S, D//2, 2]
77- b , h , s , d = x .shape
78- x_reshaped = x .view (b , h , s , d // 2 , 2 )
68+ """
69+ Standard PyTorch Interleaved RoPE application.
70+ Dimension-agnostic: Works for [B, S, D] or [B, H, S, D].
71+ """
72+ # 1. Reshape last dim to pairs: [..., D] -> [..., D//2, 2]
73+ shape = x .shape
74+ x_reshaped = x .view (* shape [:- 1 ], - 1 , 2 )
75+
76+ # 2. Rotate: [-x2, x1]
7977 x1 , x2 = x_reshaped .unbind (- 1 )
80- x_rotated = torch .stack ((- x2 , x1 ), dim = - 1 ).view (b , h , s , d )
78+ x_rotated = torch .stack ((- x2 , x1 ), dim = - 1 ).view (* shape )
8179
82- # Cast to float32 for rotation parity with JAX
80+ # 3. Apply Frequencies (Float32 for parity)
8381 orig_dtype = x .dtype
8482 x_f32 = x .to (torch .float32 )
8583 rot_f32 = x_rotated .to (torch .float32 )
@@ -91,7 +89,6 @@ def apply_rotary_emb_pt(x, cos, sin):
9189
9290
9391class PytorchLTX2Attention (torch .nn .Module ):
94- """Reference LTX-2 Attention."""
9592 def __init__ (self , query_dim , context_dim , heads , dim_head ):
9693 super ().__init__ ()
9794 inner_dim = dim_head * heads
@@ -100,11 +97,9 @@ def __init__(self, query_dim, context_dim, heads, dim_head):
10097
10198 self .q_norm = torch .nn .RMSNorm (inner_dim , eps = 1e-6 )
10299 self .k_norm = torch .nn .RMSNorm (inner_dim , eps = 1e-6 )
103-
104100 self .to_q = torch .nn .Linear (query_dim , inner_dim , bias = True )
105101 self .to_k = torch .nn .Linear (context_dim , inner_dim , bias = True )
106102 self .to_v = torch .nn .Linear (context_dim , inner_dim , bias = True )
107-
108103 self .to_out = torch .nn .Sequential (
109104 torch .nn .Linear (inner_dim , query_dim , bias = True ),
110105 torch .nn .Identity ()
@@ -116,44 +111,41 @@ def forward(self, x, context=None, q_rope=None, k_rope=None, mask=None):
116111 k = self .to_k (ctx )
117112 v = self .to_v (ctx )
118113
119- q_norm = self .q_norm (q )
120- k_norm = self .k_norm (k )
114+ q = self .q_norm (q )
115+ k = self .k_norm (k )
121116
122- b , s_q , _ = q .shape
123- _ , s_kv , _ = k .shape
124-
125- # Reshape to [B, H, S, D]
126- q_h = q_norm .view (b , s_q , self .heads , self .dim_head ).transpose (1 , 2 )
127- k_h = k_norm .view (b , s_kv , self .heads , self .dim_head ).transpose (1 , 2 )
128- v_h = v .view (b , s_kv , self .heads , self .dim_head ).transpose (1 , 2 )
129-
117+ # CORRECT: Apply RoPE globally BEFORE splitting heads
130118 if q_rope is not None :
131119 q_cos , q_sin = q_rope
132- q_h = apply_rotary_emb_pt (q_h , q_cos , q_sin )
120+ q = apply_rotary_emb_pt (q , q_cos , q_sin )
133121
134122 if k_rope is not None :
135123 k_cos , k_sin = k_rope
136- k_h = apply_rotary_emb_pt (k_h , k_cos , k_sin )
124+ k = apply_rotary_emb_pt (k , k_cos , k_sin )
125+
126+ # Split Heads for Attention
127+ b , s_q , _ = q .shape
128+ _ , s_kv , _ = k .shape
129+ q_h = q .view (b , s_q , self .heads , self .dim_head ).transpose (1 , 2 )
130+ k_h = k .view (b , s_kv , self .heads , self .dim_head ).transpose (1 , 2 )
131+ v_h = v .view (b , s_kv , self .heads , self .dim_head ).transpose (1 , 2 )
137132
138- # PyTorch Attention expects mask in [B, H, S, S] or additive
139133 out = torch .nn .functional .scaled_dot_product_attention (
140134 q_h , k_h , v_h , attn_mask = mask , dropout_p = 0.0
141135 )
142-
143136 out = out .transpose (1 , 2 ).reshape (b , s_q , - 1 )
144-
145- return self .to_out (out ), (q , k , v , q_norm , k_norm , out )
137+ return self .to_out (out ), (q , k , v , q , k , out ) # Returning normed q/k as placeholder
146138
147139# ==========================================
148- # 2. JAX Imports
140+ # 2. JAX Imports & Test Suite
149141# ==========================================
150142from maxdiffusion .models .ltx2 .attention_ltx2 import LTX2Attention , LTX2RotaryPosEmbed
151143
152144class LTX2AttentionTest (unittest .TestCase ):
153145
154146 def setUp (self ):
155- # Common Parameters
156- self .B , self .S , self .D = 1 , 16 , 64
147+ # S=128 is preferred for TPU Flash Attention block sizes
148+ self .B , self .S , self .D = 1 , 128 , 64
157149 self .heads = 4
158150 self .dim_head = 16
159151 self .context_dim = 64
@@ -163,8 +155,6 @@ def setUp(self):
163155 self .np_x = np .random .randn (self .B , self .S , self .D ).astype (np .float32 )
164156
165157 def _init_and_sync_models (self , dtype = jnp .bfloat16 ):
166- """Initializes PyTorch (CPU) and JAX (TPU) and syncs weights."""
167-
168158 pt_dtype = torch .float32 if dtype == jnp .float32 else torch .bfloat16
169159 pt_model = PytorchLTX2Attention (self .D , self .context_dim , self .heads , self .dim_head )
170160 pt_model .to (device = "cpu" , dtype = pt_dtype )
@@ -197,57 +187,41 @@ def copy_norm(jax_layer, pt_layer):
197187
198188 return pt_model , jax_model
199189
200- # ------------------------------------------
201- # 1. Output Shape Tests
202- # ------------------------------------------
203190 def test_shapes (self ):
204- """Verifies JAX model handles Video (3D) and Audio (1D) shapes."""
205191 model = LTX2Attention (64 , 4 , 16 , 64 , rngs = self .rng , attention_kernel = "dot_product" )
206192
207- # Video: [B, S, D]
208193 x_vid = jnp .zeros ((1 , 128 , 64 ))
209- out_vid = model (x_vid )
194+ out_vid = model (x_vid , deterministic = True )
210195 self .assertEqual (out_vid .shape , (1 , 128 , 64 ))
211196
212- # Audio Cross-Attn: [B, S_vid, D] -> [B, S_aud, D]
213197 x_aud = jnp .zeros ((1 , 32 , 64 ))
214- out_cross = model (x_vid , encoder_hidden_states = x_aud )
198+ out_cross = model (x_vid , encoder_hidden_states = x_aud , deterministic = True )
215199 self .assertEqual (out_cross .shape , (1 , 128 , 64 ))
216200 print ("\n [PASS] Shape Tests Passed." )
217201
218- # ------------------------------------------
219- # 2. RoPE Frequency Parity
220- # ------------------------------------------
221202 def test_rope_frequency_parity (self ):
222- """Verifies JAX RoPE Frequencies match PyTorch."""
223203 dim = 60
224204 rope_pt = PytorchLTX2RotaryPosEmbed (dim = dim )
225205 rope_jax = LTX2RotaryPosEmbed (dim = dim )
226206
227207 np_ids = np .random .randint (0 , 100 , (2 , 16 , 3 )).astype (np .float32 )
228-
229208 pt_cos , pt_sin = rope_pt (torch .from_numpy (np_ids ))
230209 jax_cos , jax_sin = rope_jax (jnp .array (np_ids ))
231210
232- # 1e-5 tolerance for freq generation math
233211 np .testing .assert_allclose (pt_cos .numpy (), np .array (jax_cos ), atol = 1e-5 )
234212 np .testing .assert_allclose (pt_sin .numpy (), np .array (jax_sin ), atol = 1e-5 )
235213 print ("[PASS] RoPE Frequency Parity Verified." )
236214
237- # ------------------------------------------
238- # 3. Strict Parity Test (Full Model, BF16)
239- # ------------------------------------------
240215 def test_parity_bf16_strict (self ):
241- """Checks if JAX matches PyTorch in BF16."""
242216 pt_model , jax_model = self ._init_and_sync_models (dtype = jnp .bfloat16 )
243217
244218 pt_in = torch .from_numpy (self .np_x ).to (device = "cpu" , dtype = torch .bfloat16 )
245219 jax_in = jnp .array (self .np_x ).astype (jnp .bfloat16 )
246220
247221 with torch .no_grad ():
248222 pt_out , _ = pt_model (pt_in )
249-
250- jax_out = jax_model (jax_in )
223+
224+ jax_out = jax_model (jax_in , deterministic = True )
251225
252226 pt_res = pt_out .float ().numpy ()
253227 jax_res = np .array (jax_out , dtype = np .float32 )
@@ -258,52 +232,35 @@ def test_parity_bf16_strict(self):
258232 )
259233 print ("\n [PASS] BF16 Strict Parity Test passed." )
260234
261- # ------------------------------------------
262- # 4. Layer-wise Diagnostics
263- # ------------------------------------------
264235 def test_layer_wise_stats (self ):
265- """Prints diagnostic stats for every layer (Bfloat16)."""
266236 pt_model , jax_model = self ._init_and_sync_models (dtype = jnp .bfloat16 )
267237
268238 pt_in = torch .from_numpy (self .np_x ).to (device = "cpu" , dtype = torch .bfloat16 )
269239 jax_in = jnp .array (self .np_x ).astype (jnp .bfloat16 )
270240
271- # 1. Run PyTorch Step-by-Step (Get intermediates)
272241 with torch .no_grad ():
273242 pt_out , (pt_q , pt_k , pt_v , pt_qn , pt_kn , pt_attn ) = pt_model (pt_in )
274243
275- # 2. Run JAX Step-by-Step
276244 jax_q = jax_model .to_q (jax_in )
277245 jax_k = jax_model .to_k (jax_in )
278246 jax_v = jax_model .to_v (jax_in )
279-
280247 jax_qn = jax_model .norm_q (jax_q )
281248 jax_kn = jax_model .norm_k (jax_k )
282249
283- # Pass 3D tensors [B, S, Inner_Dim] directly to attention op
284- # NNXAttentionOp handles the internal logic for the kernel
285250 jax_attn = jax_model .attention_op .apply_attention (jax_qn , jax_kn , jax_v )
286251 jax_out = jax_model .to_out (jax_attn )
287252
288- # 3. Build & Print Comparison Table
289253 stats = []
290254 def add_stat (name , pt_t , jax_t ):
291- # Ensure pt_t is a tensor before calling .float().numpy()
292255 if isinstance (pt_t , torch .Tensor ):
293256 pt_val = pt_t .float ().numpy ()
294257 else :
295258 pt_val = pt_t
296-
297259 jax_val = np .array (jax_t , dtype = np .float32 )
298-
299260 stats .append ({
300261 "Layer" : name ,
301- "PT Max" : f"{ pt_val .max ():.4f} " ,
302- "JAX Max" : f"{ jax_val .max ():.4f} " ,
303262 "PT Mean" : f"{ pt_val .mean ():.4f} " ,
304263 "JAX Mean" : f"{ jax_val .mean ():.4f} " ,
305- "PT Min" : f"{ pt_val .min ():.4f} " ,
306- "JAX Min" : f"{ jax_val .min ():.4f} " ,
307264 "Diff (L1)" : f"{ np .abs (pt_val - jax_val ).mean ():.6f} "
308265 })
309266
@@ -319,39 +276,28 @@ def add_stat(name, pt_t, jax_t):
319276 print ("\n [DIAGNOSTIC] Layer-wise Stats:" )
320277 print (df .to_string (index = False ))
321278
322- # ------------------------------------------
323- # 5. Cross-Attention + RoPE Integration
324- # ------------------------------------------
325279 def test_cross_attn_rope_integration (self ):
326- """Verifies Video->Audio cross-attention with RoPE (Float32)."""
327280 S_Q , S_KV = 16 , 20
328281 pt_model , jax_model = self ._init_and_sync_models (dtype = jnp .float32 )
329282
330283 np_x = np .random .randn (self .B , S_Q , self .D ).astype (np .float32 )
331284 np_ctx = np .random .randn (self .B , S_KV , self .D ).astype (np .float32 )
332285
333- rope_gen_pt = PytorchLTX2RotaryPosEmbed (dim = 64 )
286+ rope_gen_pt = PytorchLTX2RotaryPosEmbed (dim = 64 ) # Gen [B, S, InnerDim]
334287
335288 ids_q = torch .randint (0 , 100 , (self .B , S_Q , 1 ))
336289 ids_k = torch .randint (0 , 100 , (self .B , S_KV , 1 ))
337290
338291 q_cos_pt , q_sin_pt = rope_gen_pt (ids_q .float ())
339292 k_cos_pt , k_sin_pt = rope_gen_pt (ids_k .float ())
340293
341- def prep_pt (c , s ):
342- c = c .view (self .B , - 1 , self .heads , self .dim_head ).transpose (1 , 2 )
343- s = s .view (self .B , - 1 , self .heads , self .dim_head ).transpose (1 , 2 )
344- return c , s
345-
346- pt_q_rope = prep_pt (q_cos_pt , q_sin_pt )
347- pt_k_rope = prep_pt (k_cos_pt , k_sin_pt )
348-
294+ # No reshape needed! Passed directly as [B, S, InnerDim]
349295 with torch .no_grad ():
350296 pt_out , _ = pt_model (
351297 torch .from_numpy (np_x ),
352298 context = torch .from_numpy (np_ctx ),
353- q_rope = pt_q_rope ,
354- k_rope = pt_k_rope
299+ q_rope = ( q_cos_pt , q_sin_pt ) ,
300+ k_rope = ( k_cos_pt , k_sin_pt )
355301 )
356302
357303 jax_q_rope = (jnp .array (q_cos_pt .numpy ()), jnp .array (q_sin_pt .numpy ()))
@@ -361,64 +307,48 @@ def prep_pt(c, s):
361307 jnp .array (np_x ),
362308 encoder_hidden_states = jnp .array (np_ctx ),
363309 rotary_emb = jax_q_rope ,
364- k_rotary_emb = jax_k_rope
310+ k_rotary_emb = jax_k_rope ,
311+ deterministic = True
365312 )
366313
367314 diff = np .abs (pt_out .numpy () - np .array (jax_out )).max ()
368315 print (f"\n [Cross-Attn + RoPE] Max Diff: { diff :.6f} " )
369316 np .testing .assert_allclose (pt_out .numpy (), np .array (jax_out ), atol = 1e-5 )
370317 print ("[PASS] Cross-Attention with RoPE Parity Verified." )
371318
372- # ------------------------------------------
373- # 6. Attention Mask Parity
374- # ------------------------------------------
375319 def test_attention_mask_parity (self ):
376- """
377- Verifies attention masks (padding) work identically using FLASH kernel.
378- Flash kernel in attention_flax expects a multiplicative mask [B, S],
379- while PyTorch SDPA expects an additive mask broadcastable to [B,H,S,S].
380- """
320+ S_flash = 512
321+ np_x = np .random .randn (self .B , S_flash , self .D ).astype (np .float32 )
381322 pt_model , jax_model = self ._init_and_sync_models (dtype = jnp .float32 )
382323
383- # Switch JAX model to use flash attention for this test
324+ devices = jax .devices ()
325+ mesh = Mesh (np .array (devices ).reshape (1 , - 1 ), ('data' , 'context' ))
326+
384327 jax_model .attention_op .attention_kernel = "flash"
385- jax_model .attention_op .mesh = Mesh ( np . array ( jax . devices ()). reshape ( 1 , - 1 ), ( 'data' , 'context' ))
328+ jax_model .attention_op .mesh = mesh
386329 jax_model .attention_op .flash_block_sizes = splash_attention_kernel .BlockSizes (
387- block_q = 512 ,
388- block_kv_compute = 128 ,
389- block_kv = 128 ,
390- block_q_dkv = 512 ,
391- block_kv_dkv = 128 ,
392- block_kv_dkv_compute = 128 ,
393- block_q_dq = 512 ,
394- block_kv_dq = 128 ,
330+ block_q = 128 , block_kv_compute = 128 , block_kv = 128 ,
331+ block_q_dkv = 128 , block_kv_dkv = 128 , block_kv_dkv_compute = 128 ,
332+ block_q_dq = 128 , block_kv_dq = 128 ,
395333 )
396334
397- np_x = np .random .randn (self .B , self .S , self .D ).astype (np .float32 )
398-
399- # Create mask pattern: 1 = keep, 0 = mask out
400- # Shape: [B, S]
401- mask_pattern_np = np .random .randint (0 , 2 , (self .B , self .S )).astype (np .float32 )
402-
403- # PyTorch needs ADDITIVE mask: 0 for keep, -inf for mask out
404- # Broadcastable to [B, H, S_q, S_kv]: [1, 1, 1, 16] is ok for B=1,H=4,S=16
335+ mask_pattern_np = np .random .randint (0 , 2 , (self .B , S_flash )).astype (np .float32 )
405336 pt_mask_additive = torch .from_numpy ((1.0 - mask_pattern_np ) * - 1e9 )[:, None , None , :]
406-
407- # JAX Flash attention needs MULTIPLICATIVE mask: [1, 16]
408337 jax_mask_multiplicative = jnp .array (mask_pattern_np )
409338
410- # PyTorch
411339 with torch .no_grad ():
412340 pt_out , _ = pt_model (torch .from_numpy (np_x ), mask = pt_mask_additive )
413341
414- # JAX
415- with jax_model .attention_op .mesh :
416- jax_out = jax_model (
342+ with mesh :
343+ jax_out = jax_model (
417344 jnp .array (np_x ),
418- attention_mask = jax_mask_multiplicative
345+ attention_mask = jax_mask_multiplicative ,
346+ deterministic = True
419347 )
420348
421- np .testing .assert_allclose (pt_out .numpy (), np .array (jax_out ), atol = 1e-5 )
349+ diff = np .abs (pt_out .numpy () - np .array (jax_out )).max ()
350+ print (f"\n [Mask Parity] Max Diff (Flash): { diff :.6f} " )
351+ np .testing .assert_allclose (pt_out .numpy (), np .array (jax_out ), atol = 1e-4 )
422352 print ("[PASS] Attention Mask Parity Verified." )
423353
424354if __name__ == "__main__" :
0 commit comments