Skip to content

Commit 266b00a

Browse files
committed
Test
1 parent a4766f6 commit 266b00a

File tree

1 file changed

+57
-127
lines changed

1 file changed

+57
-127
lines changed

src/maxdiffusion/tests/test_attention_ltx2.py

Lines changed: 57 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,21 @@
3232
# ==========================================
3333

3434
class PytorchLTX2RotaryPosEmbed(torch.nn.Module):
35-
"""
36-
Reference PyTorch implementation for LTX-2 RoPE Frequency Generation.
37-
Splits dim across axes, computes freqs, concatenates, and interleaves.
38-
"""
3935
def __init__(self, dim: int, theta: float = 10000.0):
4036
super().__init__()
4137
self.dim = dim
4238
self.theta = theta
4339

4440
def forward(self, ids):
45-
# ids: [B, S, Num_Axes]
4641
num_axes = ids.shape[-1]
4742
dim_per_axis = self.dim // num_axes
4843

49-
# Standard RoPE frequencies: theta^(-2i/d)
5044
freq_indices = torch.arange(0, dim_per_axis, 2, dtype=torch.float32)
5145
inv_freq = 1.0 / (self.theta ** (freq_indices / dim_per_axis))
5246

5347
freqs_list = []
5448
for i in range(num_axes):
55-
axis_pos = ids[..., i] # [B, S]
56-
# Outer product: [B, S, 1] * [1, 1, D/2] -> [B, S, D/2]
49+
axis_pos = ids[..., i]
5750
freqs = torch.einsum('bs,d->bsd', axis_pos, inv_freq)
5851
freqs_list.append(freqs)
5952

@@ -67,19 +60,24 @@ def forward(self, ids):
6760
cos = torch.repeat_interleave(cos, 2, dim=-1)
6861
sin = torch.repeat_interleave(sin, 2, dim=-1)
6962

70-
# Add head dim for broadcasting: [B, S, 1, D]
71-
return cos.unsqueeze(2), sin.unsqueeze(2)
63+
# CORRECT: Return [B, S, InnerDim] to match JAX/LTX-2 global RoPE
64+
return cos, sin
7265

7366

7467
def apply_rotary_emb_pt(x, cos, sin):
75-
"""Standard PyTorch Interleaved RoPE application."""
76-
# x: [B, H, S, D] -> [B, H, S, D//2, 2]
77-
b, h, s, d = x.shape
78-
x_reshaped = x.view(b, h, s, d // 2, 2)
68+
"""
69+
Standard PyTorch Interleaved RoPE application.
70+
Dimension-agnostic: Works for [B, S, D] or [B, H, S, D].
71+
"""
72+
# 1. Reshape last dim to pairs: [..., D] -> [..., D//2, 2]
73+
shape = x.shape
74+
x_reshaped = x.view(*shape[:-1], -1, 2)
75+
76+
# 2. Rotate: [-x2, x1]
7977
x1, x2 = x_reshaped.unbind(-1)
80-
x_rotated = torch.stack((-x2, x1), dim=-1).view(b, h, s, d)
78+
x_rotated = torch.stack((-x2, x1), dim=-1).view(*shape)
8179

82-
# Cast to float32 for rotation parity with JAX
80+
# 3. Apply Frequencies (Float32 for parity)
8381
orig_dtype = x.dtype
8482
x_f32 = x.to(torch.float32)
8583
rot_f32 = x_rotated.to(torch.float32)
@@ -91,7 +89,6 @@ def apply_rotary_emb_pt(x, cos, sin):
9189

9290

9391
class PytorchLTX2Attention(torch.nn.Module):
94-
"""Reference LTX-2 Attention."""
9592
def __init__(self, query_dim, context_dim, heads, dim_head):
9693
super().__init__()
9794
inner_dim = dim_head * heads
@@ -100,11 +97,9 @@ def __init__(self, query_dim, context_dim, heads, dim_head):
10097

10198
self.q_norm = torch.nn.RMSNorm(inner_dim, eps=1e-6)
10299
self.k_norm = torch.nn.RMSNorm(inner_dim, eps=1e-6)
103-
104100
self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)
105101
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
106102
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
107-
108103
self.to_out = torch.nn.Sequential(
109104
torch.nn.Linear(inner_dim, query_dim, bias=True),
110105
torch.nn.Identity()
@@ -116,44 +111,41 @@ def forward(self, x, context=None, q_rope=None, k_rope=None, mask=None):
116111
k = self.to_k(ctx)
117112
v = self.to_v(ctx)
118113

119-
q_norm = self.q_norm(q)
120-
k_norm = self.k_norm(k)
114+
q = self.q_norm(q)
115+
k = self.k_norm(k)
121116

122-
b, s_q, _ = q.shape
123-
_, s_kv, _ = k.shape
124-
125-
# Reshape to [B, H, S, D]
126-
q_h = q_norm.view(b, s_q, self.heads, self.dim_head).transpose(1, 2)
127-
k_h = k_norm.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2)
128-
v_h = v.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2)
129-
117+
# CORRECT: Apply RoPE globally BEFORE splitting heads
130118
if q_rope is not None:
131119
q_cos, q_sin = q_rope
132-
q_h = apply_rotary_emb_pt(q_h, q_cos, q_sin)
120+
q = apply_rotary_emb_pt(q, q_cos, q_sin)
133121

134122
if k_rope is not None:
135123
k_cos, k_sin = k_rope
136-
k_h = apply_rotary_emb_pt(k_h, k_cos, k_sin)
124+
k = apply_rotary_emb_pt(k, k_cos, k_sin)
125+
126+
# Split Heads for Attention
127+
b, s_q, _ = q.shape
128+
_, s_kv, _ = k.shape
129+
q_h = q.view(b, s_q, self.heads, self.dim_head).transpose(1, 2)
130+
k_h = k.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2)
131+
v_h = v.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2)
137132

138-
# PyTorch Attention expects mask in [B, H, S, S] or additive
139133
out = torch.nn.functional.scaled_dot_product_attention(
140134
q_h, k_h, v_h, attn_mask=mask, dropout_p=0.0
141135
)
142-
143136
out = out.transpose(1, 2).reshape(b, s_q, -1)
144-
145-
return self.to_out(out), (q, k, v, q_norm, k_norm, out)
137+
return self.to_out(out), (q, k, v, q, k, out) # Returning normed q/k as placeholder
146138

147139
# ==========================================
148-
# 2. JAX Imports
140+
# 2. JAX Imports & Test Suite
149141
# ==========================================
150142
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
151143

152144
class LTX2AttentionTest(unittest.TestCase):
153145

154146
def setUp(self):
155-
# Common Parameters
156-
self.B, self.S, self.D = 1, 16, 64
147+
# S=128 is preferred for TPU Flash Attention block sizes
148+
self.B, self.S, self.D = 1, 128, 64
157149
self.heads = 4
158150
self.dim_head = 16
159151
self.context_dim = 64
@@ -163,8 +155,6 @@ def setUp(self):
163155
self.np_x = np.random.randn(self.B, self.S, self.D).astype(np.float32)
164156

165157
def _init_and_sync_models(self, dtype=jnp.bfloat16):
166-
"""Initializes PyTorch (CPU) and JAX (TPU) and syncs weights."""
167-
168158
pt_dtype = torch.float32 if dtype == jnp.float32 else torch.bfloat16
169159
pt_model = PytorchLTX2Attention(self.D, self.context_dim, self.heads, self.dim_head)
170160
pt_model.to(device="cpu", dtype=pt_dtype)
@@ -197,57 +187,41 @@ def copy_norm(jax_layer, pt_layer):
197187

198188
return pt_model, jax_model
199189

200-
# ------------------------------------------
201-
# 1. Output Shape Tests
202-
# ------------------------------------------
203190
def test_shapes(self):
204-
"""Verifies JAX model handles Video (3D) and Audio (1D) shapes."""
205191
model = LTX2Attention(64, 4, 16, 64, rngs=self.rng, attention_kernel="dot_product")
206192

207-
# Video: [B, S, D]
208193
x_vid = jnp.zeros((1, 128, 64))
209-
out_vid = model(x_vid)
194+
out_vid = model(x_vid, deterministic=True)
210195
self.assertEqual(out_vid.shape, (1, 128, 64))
211196

212-
# Audio Cross-Attn: [B, S_vid, D] -> [B, S_aud, D]
213197
x_aud = jnp.zeros((1, 32, 64))
214-
out_cross = model(x_vid, encoder_hidden_states=x_aud)
198+
out_cross = model(x_vid, encoder_hidden_states=x_aud, deterministic=True)
215199
self.assertEqual(out_cross.shape, (1, 128, 64))
216200
print("\n[PASS] Shape Tests Passed.")
217201

218-
# ------------------------------------------
219-
# 2. RoPE Frequency Parity
220-
# ------------------------------------------
221202
def test_rope_frequency_parity(self):
222-
"""Verifies JAX RoPE Frequencies match PyTorch."""
223203
dim = 60
224204
rope_pt = PytorchLTX2RotaryPosEmbed(dim=dim)
225205
rope_jax = LTX2RotaryPosEmbed(dim=dim)
226206

227207
np_ids = np.random.randint(0, 100, (2, 16, 3)).astype(np.float32)
228-
229208
pt_cos, pt_sin = rope_pt(torch.from_numpy(np_ids))
230209
jax_cos, jax_sin = rope_jax(jnp.array(np_ids))
231210

232-
# 1e-5 tolerance for freq generation math
233211
np.testing.assert_allclose(pt_cos.numpy(), np.array(jax_cos), atol=1e-5)
234212
np.testing.assert_allclose(pt_sin.numpy(), np.array(jax_sin), atol=1e-5)
235213
print("[PASS] RoPE Frequency Parity Verified.")
236214

237-
# ------------------------------------------
238-
# 3. Strict Parity Test (Full Model, BF16)
239-
# ------------------------------------------
240215
def test_parity_bf16_strict(self):
241-
"""Checks if JAX matches PyTorch in BF16."""
242216
pt_model, jax_model = self._init_and_sync_models(dtype=jnp.bfloat16)
243217

244218
pt_in = torch.from_numpy(self.np_x).to(device="cpu", dtype=torch.bfloat16)
245219
jax_in = jnp.array(self.np_x).astype(jnp.bfloat16)
246220

247221
with torch.no_grad():
248222
pt_out, _ = pt_model(pt_in)
249-
250-
jax_out = jax_model(jax_in)
223+
224+
jax_out = jax_model(jax_in, deterministic=True)
251225

252226
pt_res = pt_out.float().numpy()
253227
jax_res = np.array(jax_out, dtype=np.float32)
@@ -258,52 +232,35 @@ def test_parity_bf16_strict(self):
258232
)
259233
print("\n[PASS] BF16 Strict Parity Test passed.")
260234

261-
# ------------------------------------------
262-
# 4. Layer-wise Diagnostics
263-
# ------------------------------------------
264235
def test_layer_wise_stats(self):
265-
"""Prints diagnostic stats for every layer (Bfloat16)."""
266236
pt_model, jax_model = self._init_and_sync_models(dtype=jnp.bfloat16)
267237

268238
pt_in = torch.from_numpy(self.np_x).to(device="cpu", dtype=torch.bfloat16)
269239
jax_in = jnp.array(self.np_x).astype(jnp.bfloat16)
270240

271-
# 1. Run PyTorch Step-by-Step (Get intermediates)
272241
with torch.no_grad():
273242
pt_out, (pt_q, pt_k, pt_v, pt_qn, pt_kn, pt_attn) = pt_model(pt_in)
274243

275-
# 2. Run JAX Step-by-Step
276244
jax_q = jax_model.to_q(jax_in)
277245
jax_k = jax_model.to_k(jax_in)
278246
jax_v = jax_model.to_v(jax_in)
279-
280247
jax_qn = jax_model.norm_q(jax_q)
281248
jax_kn = jax_model.norm_k(jax_k)
282249

283-
# Pass 3D tensors [B, S, Inner_Dim] directly to attention op
284-
# NNXAttentionOp handles the internal logic for the kernel
285250
jax_attn = jax_model.attention_op.apply_attention(jax_qn, jax_kn, jax_v)
286251
jax_out = jax_model.to_out(jax_attn)
287252

288-
# 3. Build & Print Comparison Table
289253
stats = []
290254
def add_stat(name, pt_t, jax_t):
291-
# Ensure pt_t is a tensor before calling .float().numpy()
292255
if isinstance(pt_t, torch.Tensor):
293256
pt_val = pt_t.float().numpy()
294257
else:
295258
pt_val = pt_t
296-
297259
jax_val = np.array(jax_t, dtype=np.float32)
298-
299260
stats.append({
300261
"Layer": name,
301-
"PT Max": f"{pt_val.max():.4f}",
302-
"JAX Max": f"{jax_val.max():.4f}",
303262
"PT Mean": f"{pt_val.mean():.4f}",
304263
"JAX Mean": f"{jax_val.mean():.4f}",
305-
"PT Min": f"{pt_val.min():.4f}",
306-
"JAX Min": f"{jax_val.min():.4f}",
307264
"Diff (L1)": f"{np.abs(pt_val - jax_val).mean():.6f}"
308265
})
309266

@@ -319,39 +276,28 @@ def add_stat(name, pt_t, jax_t):
319276
print("\n[DIAGNOSTIC] Layer-wise Stats:")
320277
print(df.to_string(index=False))
321278

322-
# ------------------------------------------
323-
# 5. Cross-Attention + RoPE Integration
324-
# ------------------------------------------
325279
def test_cross_attn_rope_integration(self):
326-
"""Verifies Video->Audio cross-attention with RoPE (Float32)."""
327280
S_Q, S_KV = 16, 20
328281
pt_model, jax_model = self._init_and_sync_models(dtype=jnp.float32)
329282

330283
np_x = np.random.randn(self.B, S_Q, self.D).astype(np.float32)
331284
np_ctx = np.random.randn(self.B, S_KV, self.D).astype(np.float32)
332285

333-
rope_gen_pt = PytorchLTX2RotaryPosEmbed(dim=64)
286+
rope_gen_pt = PytorchLTX2RotaryPosEmbed(dim=64) # Gen [B, S, InnerDim]
334287

335288
ids_q = torch.randint(0, 100, (self.B, S_Q, 1))
336289
ids_k = torch.randint(0, 100, (self.B, S_KV, 1))
337290

338291
q_cos_pt, q_sin_pt = rope_gen_pt(ids_q.float())
339292
k_cos_pt, k_sin_pt = rope_gen_pt(ids_k.float())
340293

341-
def prep_pt(c, s):
342-
c = c.view(self.B, -1, self.heads, self.dim_head).transpose(1, 2)
343-
s = s.view(self.B, -1, self.heads, self.dim_head).transpose(1, 2)
344-
return c, s
345-
346-
pt_q_rope = prep_pt(q_cos_pt, q_sin_pt)
347-
pt_k_rope = prep_pt(k_cos_pt, k_sin_pt)
348-
294+
# No reshape needed! Passed directly as [B, S, InnerDim]
349295
with torch.no_grad():
350296
pt_out, _ = pt_model(
351297
torch.from_numpy(np_x),
352298
context=torch.from_numpy(np_ctx),
353-
q_rope=pt_q_rope,
354-
k_rope=pt_k_rope
299+
q_rope=(q_cos_pt, q_sin_pt),
300+
k_rope=(k_cos_pt, k_sin_pt)
355301
)
356302

357303
jax_q_rope = (jnp.array(q_cos_pt.numpy()), jnp.array(q_sin_pt.numpy()))
@@ -361,64 +307,48 @@ def prep_pt(c, s):
361307
jnp.array(np_x),
362308
encoder_hidden_states=jnp.array(np_ctx),
363309
rotary_emb=jax_q_rope,
364-
k_rotary_emb=jax_k_rope
310+
k_rotary_emb=jax_k_rope,
311+
deterministic=True
365312
)
366313

367314
diff = np.abs(pt_out.numpy() - np.array(jax_out)).max()
368315
print(f"\n[Cross-Attn + RoPE] Max Diff: {diff:.6f}")
369316
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-5)
370317
print("[PASS] Cross-Attention with RoPE Parity Verified.")
371318

372-
# ------------------------------------------
373-
# 6. Attention Mask Parity
374-
# ------------------------------------------
375319
def test_attention_mask_parity(self):
376-
"""
377-
Verifies attention masks (padding) work identically using FLASH kernel.
378-
Flash kernel in attention_flax expects a multiplicative mask [B, S],
379-
while PyTorch SDPA expects an additive mask broadcastable to [B,H,S,S].
380-
"""
320+
S_flash = 512
321+
np_x = np.random.randn(self.B, S_flash, self.D).astype(np.float32)
381322
pt_model, jax_model = self._init_and_sync_models(dtype=jnp.float32)
382323

383-
# Switch JAX model to use flash attention for this test
324+
devices = jax.devices()
325+
mesh = Mesh(np.array(devices).reshape(1, -1), ('data', 'context'))
326+
384327
jax_model.attention_op.attention_kernel = "flash"
385-
jax_model.attention_op.mesh = Mesh(np.array(jax.devices()).reshape(1,-1), ('data', 'context'))
328+
jax_model.attention_op.mesh = mesh
386329
jax_model.attention_op.flash_block_sizes = splash_attention_kernel.BlockSizes(
387-
block_q=512,
388-
block_kv_compute=128,
389-
block_kv=128,
390-
block_q_dkv=512,
391-
block_kv_dkv=128,
392-
block_kv_dkv_compute=128,
393-
block_q_dq=512,
394-
block_kv_dq=128,
330+
block_q=128, block_kv_compute=128, block_kv=128,
331+
block_q_dkv=128, block_kv_dkv=128, block_kv_dkv_compute=128,
332+
block_q_dq=128, block_kv_dq=128,
395333
)
396334

397-
np_x = np.random.randn(self.B, self.S, self.D).astype(np.float32)
398-
399-
# Create mask pattern: 1 = keep, 0 = mask out
400-
# Shape: [B, S]
401-
mask_pattern_np = np.random.randint(0, 2, (self.B, self.S)).astype(np.float32)
402-
403-
# PyTorch needs ADDITIVE mask: 0 for keep, -inf for mask out
404-
# Broadcastable to [B, H, S_q, S_kv]: [1, 1, 1, 16] is ok for B=1,H=4,S=16
335+
mask_pattern_np = np.random.randint(0, 2, (self.B, S_flash)).astype(np.float32)
405336
pt_mask_additive = torch.from_numpy((1.0 - mask_pattern_np) * -1e9)[:, None, None, :]
406-
407-
# JAX Flash attention needs MULTIPLICATIVE mask: [1, 16]
408337
jax_mask_multiplicative = jnp.array(mask_pattern_np)
409338

410-
# PyTorch
411339
with torch.no_grad():
412340
pt_out, _ = pt_model(torch.from_numpy(np_x), mask=pt_mask_additive)
413341

414-
# JAX
415-
with jax_model.attention_op.mesh:
416-
jax_out = jax_model(
342+
with mesh:
343+
jax_out = jax_model(
417344
jnp.array(np_x),
418-
attention_mask=jax_mask_multiplicative
345+
attention_mask=jax_mask_multiplicative,
346+
deterministic=True
419347
)
420348

421-
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-5)
349+
diff = np.abs(pt_out.numpy() - np.array(jax_out)).max()
350+
print(f"\n[Mask Parity] Max Diff (Flash): {diff:.6f}")
351+
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-4)
422352
print("[PASS] Attention Mask Parity Verified.")
423353

424354
if __name__ == "__main__":

0 commit comments

Comments
 (0)