diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index a856849f07..4ca2683064 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -21,6 +21,7 @@ import jax import jax.nn +from jax import lax from jax.ad_checkpoint import checkpoint_name from jax.sharding import Mesh import jax.numpy as jnp @@ -28,6 +29,9 @@ from flax import linen as nn from flax import nnx +from jax.sharding import PartitionSpec as P +from jax.experimental.shard_map import shard_map + from MaxText.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN from MaxText.layers import attentions from MaxText.layers import initializers as max_initializers @@ -43,243 +47,360 @@ from MaxText.layers.initializers import nd_dense_init, variable_to_logically_partitioned from maxtext.inference import page_manager from maxtext.utils import max_utils +from maxtext.scratch_code import gdn_pallas # ----------------------------------------- # Qwen3-Next Layer Implementations # ----------------------------------------- -def jax_chunk_gated_delta_rule( - query: Array, - key: Array, - value: Array, - g: Array, - beta: Array, +import jax +import jax.numpy as jnp +from jax import lax + +def pallas_chunk_gated_delta_rule( + query: jax.Array, + key: jax.Array, + value: jax.Array, + g: jax.Array, + beta: jax.Array, chunk_size: int = 64, - initial_state: None | Array = None, + initial_state: None | jax.Array = None, use_qk_norm_in_gdn: bool = False, -) -> tuple[Array, None | Array]: + compute_dtype: jnp.dtype = jnp.bfloat16, + mesh: Mesh | None = None, +) -> tuple[jax.Array, None | jax.Array]: """ - A JAX implementation of the chunked Gated Delta Rule, a parallel scan algorithm. - This function implements the core recurrent logic of the Gated Delta Network in - a hardware-efficient way by splitting the sequence into chunks and using - jax.lax.scan for the recurrent part. - - Tensor Shape Abbreviations: - B: batch_size, S: sequence_length, H: num_heads, - D_k: key/query_head_dim, D_v: value_head_dim, - N: num_chunks, C: chunk_size - - Args: - query: Query tensor. Shape (B, S, H, D_k) - key: Key tensor. Shape (B, S, H, D_k) - value: Value tensor. Shape (B, S, H, D_v) - g: Log decay tensor. Shape (B, S, H) - beta: Gate tensor. Shape (B, S, H) - chunk_size: The size of each chunk for processing. - initial_state: Optional initial state for the recurrence. Shape (B, H, D_k, D_v) - use_qk_norm_in_gdn: Whether to apply L2 normalization to query and key. - - Returns: - Output tensor. Shape (B, S, H, D_v) - Final recurrent state. Shape (B, H, D_k, D_v) or None + Pallas-accelerated version of Gated Delta Rule. """ - # ========================================================================= # STAGE 1: PREPARATION & PADDING # ========================================================================= initial_dtype = query.dtype if use_qk_norm_in_gdn: + from MaxText.layers.normalizations import l2norm query = l2norm(query, dim=-1, eps=1e-6) key = l2norm(key, dim=-1, eps=1e-6) - # Transpose (B, S, H, D) -> (B, H, S, D) - query = jnp.transpose(query, (0, 2, 1, 3)).astype(jnp.float32) - key = jnp.transpose(key, (0, 2, 1, 3)).astype(jnp.float32) - value = jnp.transpose(value, (0, 2, 1, 3)).astype(jnp.float32) - # Transpose (B, S, H) -> (B, H, S) - beta = jnp.transpose(beta, (0, 2, 1)).astype(jnp.float32) - g = jnp.transpose(g, (0, 2, 1)).astype(jnp.float32) - - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] - pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size - - # Padding to make sequence_length divisible by chunk_size - if pad_size > 0: - query = jnp.pad(query, ((0, 0), (0, 0), (0, pad_size), (0, 0))) # (B, H, S_padded, D_k) - key = jnp.pad(key, ((0, 0), (0, 0), (0, pad_size), (0, 0))) # (B, H, S_padded, D_k) - value = jnp.pad(value, ((0, 0), (0, 0), (0, pad_size), (0, 0))) # (B, H, S_padded, D_v) - beta = jnp.pad(beta, ((0, 0), (0, 0), (0, pad_size))) # (B, H, S_padded) - g = jnp.pad(g, ((0, 0), (0, 0), (0, pad_size))) # (B, H, S_padded) - - total_sequence_length = sequence_length + pad_size - # query shape: (B, H, S_padded, D_k) - scale = jax.lax.rsqrt(jnp.array(query.shape[-1]).astype(jnp.float32)) - query = query * scale + g = g.astype(jnp.float32) + query = query.astype(compute_dtype) + key = key.astype(compute_dtype) + value = value.astype(compute_dtype) + beta = beta.astype(compute_dtype) - v_beta = value * jnp.expand_dims(beta, -1) # (B, H, S_padded, D_v) - k_beta = key * jnp.expand_dims(beta, -1) # (B, H, S_padded, D_k) - - # Reshape to chunks - num_chunks = total_sequence_length // chunk_size - # query_c shape: (B, H, N, C, D_k) - query_c = query.reshape(batch_size, num_heads, num_chunks, chunk_size, k_head_dim) - key_c = key.reshape(batch_size, num_heads, num_chunks, chunk_size, k_head_dim) - k_beta_c = k_beta.reshape(batch_size, num_heads, num_chunks, chunk_size, k_head_dim) - v_beta_c = v_beta.reshape(batch_size, num_heads, num_chunks, chunk_size, v_head_dim) - g_c = g.reshape(batch_size, num_heads, num_chunks, chunk_size) # (B, H, N, C) + scale = jax.lax.rsqrt(jnp.array(query.shape[-1], dtype=jnp.float32)).astype(compute_dtype) + query = query * scale - mask = jnp.triu(jnp.ones((chunk_size, chunk_size), dtype=bool), k=0) # (C, C) + B, S, H, K_dim = key.shape + V_dim = value.shape[-1] + + pad_len = (chunk_size - (S % chunk_size)) % chunk_size + if pad_len > 0: + pad_fn = lambda x, val=0.0: jnp.pad(x, ((0,0), (0, pad_len)) + ((0,0),)*(x.ndim-2), constant_values=val) + query = pad_fn(query) + key = pad_fn(key) + value = pad_fn(value) + g = pad_fn(g) + beta = pad_fn(beta) + + num_chunks = query.shape[1] // chunk_size + + def to_chunk(x): + return x.reshape(B, num_chunks, chunk_size, H, -1).transpose(0, 1, 3, 2, 4) + def to_chunk_scalar(x): + return x.reshape(B, num_chunks, chunk_size, H).transpose(0, 1, 3, 2) + + q_c = to_chunk(query) + k_c = to_chunk(key) + v_c = to_chunk(value) + g_c = to_chunk_scalar(g) + beta_c = to_chunk_scalar(beta) # ========================================================================= - # STAGE 2: INTRA-CHUNK CALCULATION (PARALLEL) + # STAGE 2: INTRA-CHUNK PRE-COMPUTATION # ========================================================================= - # g_cumsum shape: (B, H, N, C) g_cumsum = jnp.cumsum(g_c, axis=-1) - # g_diff shape: (B, H, N, C, C) - g_diff = jnp.expand_dims(g_cumsum, -1) - jnp.expand_dims(g_cumsum, -2) - - # Apply tril to zero out the upper triangle of g_diff. This is crucial because - # the upper triangle contains large positive values that would cause exp() to overflow. - g_diff_tril = jnp.tril(g_diff) - - # Exponentiate the lower triangular g_diff. Since these values are non-positive, - # exp() will not overflow and will produce values between 0 and 1. - g_diff_exp = jnp.exp(g_diff_tril).astype(jnp.float32) - - # The result g_diff_exp is already lower triangular and serves as the decay_mask. - # decay_mask shape: (B, H, N, C, C) - decay_mask = g_diff_exp - - # --- Precompute within-chunk attention --- - # NOTE: Precision set to HIGHEST for numerical accuracy. - prec = jax.lax.Precision.HIGHEST - # attn shape: (B, H, N, C, C) - attn = -jnp.matmul(k_beta_c, jnp.swapaxes(key_c, -1, -2), precision=prec) * decay_mask - attn = jnp.where(mask, 0.0, attn) - - # Iterative refinement of the intra-chunk attention. - # This loop is equivalent to inverting (I - A) where A is the lower triangular part of attn. - def inner_attn_body(i, attn_val): - # indices: (C,) - indices = jnp.arange(chunk_size) - # col_mask: (C,) - col_mask = indices < i - # row: (B, H, N, C) - row = attn_val[..., i, :] * col_mask - # sub_mask: (C, C) - sub_mask = jnp.expand_dims(indices < i, -1) & (indices < i) - # sub: (B, H, N, C, C) - sub = attn_val * sub_mask - # row_exp: (B, H, N, C, 1) - row_exp = jnp.expand_dims(row, -1) - # term: (B, H, N, C, C) - term = row_exp * sub - # summed: (B, H, N, C) - summed = jnp.sum(term, axis=-2) - # update_val: (B, H, N, C) - update_val = row + summed - # original_row: (B, H, N, C) - original_row = attn_val[..., i, :] - # new_row: (B, H, N, C) - new_row = jnp.where(col_mask, update_val, original_row) - return attn_val.at[..., i, :].set(new_row) - - attn = jax.lax.fori_loop(1, chunk_size, inner_attn_body, attn) - - attn = attn + jnp.eye(chunk_size, dtype=attn.dtype) # (B, H, N, C, C) - # value_intra shape: (B, H, N, C, D_v) - value_intra = jnp.matmul(attn, v_beta_c, precision=prec) - # k_cumdecay shape: (B, H, N, C, D_k) - k_cumdecay = jnp.matmul(attn, (k_beta_c * jnp.expand_dims(jnp.exp(g_cumsum), -1)), precision=prec) - # --- End Precompute --- - - output_final_state = initial_state is not None + k_beta = k_c * beta_c[..., None] + + S = jnp.matmul(k_c, k_beta.swapaxes(-1, -2), precision=jax.lax.Precision.HIGHEST) + S = S.astype(jnp.float32) + g_diff = g_cumsum[..., :, None] - g_cumsum[..., None, :] + mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool), k=-1) + g_diff = jnp.where(mask, g_diff, -1e30) + S = S * jnp.exp(g_diff) + S = jnp.where(mask, S, 0.0) + + identity = jnp.eye(chunk_size, dtype=jnp.float32) + identity_broadcasted = jnp.broadcast_to(identity, S.shape) + A = jax.scipy.linalg.solve_triangular(identity + S, identity_broadcasted, lower=True, unit_diagonal=True) + + v_beta = v_c * beta_c[..., None] + u_chunks = jnp.matmul(A, v_beta.astype(jnp.float32), precision=jax.lax.Precision.HIGHEST) + u_chunks = u_chunks.astype(compute_dtype) + + k_beta_g = k_beta.astype(jnp.float32) * jnp.exp(g_cumsum)[..., None] + w_chunks = jnp.matmul(A, k_beta_g, precision=jax.lax.Precision.HIGHEST) + w_chunks = w_chunks.astype(compute_dtype) + + # ========================================================================= + # STAGE 3: INTER-CHUNK RECURRENCE (Pallas Kernel + shard_map) + # ========================================================================= + # Transpose to (Batch, Heads, NumChunks, ChunkSize, Dim) for Pallas + w_p = w_chunks.transpose(0, 2, 1, 3, 4) + u_p = u_chunks.transpose(0, 2, 1, 3, 4) + q_p = q_c.transpose(0, 2, 1, 3, 4) + k_p = k_c.transpose(0, 2, 1, 3, 4) + v_p = v_c.transpose(0, 2, 1, 3, 4) + g_p = g_cumsum.transpose(0, 2, 1, 3) + beta_p = beta_c.transpose(0, 2, 1, 3) + + # Handle initial state if initial_state is None: - # last_recurrent_state shape: (B, H, D_k, D_v) - last_recurrent_state = jnp.zeros((batch_size, num_heads, k_head_dim, v_head_dim), dtype=value_intra.dtype) + h_init = jnp.zeros((B, H, K_dim, V_dim), dtype=compute_dtype) else: - last_recurrent_state = initial_state.astype(value_intra.dtype) - - # mask_inter shape: (C, C) - mask_inter = jnp.triu(jnp.ones((chunk_size, chunk_size), dtype=bool), k=1) + h_init = initial_state.astype(compute_dtype) + + # Invoke Kernel + if mesh is not None: + # Mesh Partitioning + axis_names = mesh.axis_names + batch_axes = [ax for ax in ('data', 'fsdp', 'fsdp_transpose', 'expert') if ax in axis_names] + batch_spec = tuple(batch_axes) if batch_axes else None + head_axes = [ax for ax in ('tensor', 'model') if ax in axis_names] + head_spec = tuple(head_axes) if head_axes else None + + # Specs: B, H, ... + # h_init is (B, H, K, V) + in_specs = P(batch_spec, head_spec, None, None, None) + scalar_specs = P(batch_spec, head_spec, None, None) + state_spec = P(batch_spec, head_spec, None, None) + + sharded_gdn = shard_map( + gdn_pallas.gdn_pallas_layer, + mesh=mesh, + in_specs=(in_specs, in_specs, in_specs, in_specs, in_specs, scalar_specs, scalar_specs, state_spec), + out_specs=(in_specs, state_spec), # Returns (out, final_state) + check_rep=False + ) + + o_pallas, h_final = sharded_gdn(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init) + else: + # Single Device + o_pallas, h_final = gdn_pallas.gdn_pallas_layer(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init) - # Transpose for scan: (B, H, N, C, D) -> (N, B, H, C, D) - query_scan = jnp.transpose(query_c, (2, 0, 1, 3, 4)) - key_scan = jnp.transpose(key_c, (2, 0, 1, 3, 4)) - value_scan = jnp.transpose(value_intra, (2, 0, 1, 3, 4)) - k_cumdecay_scan = jnp.transpose(k_cumdecay, (2, 0, 1, 3, 4)) - # Transpose for scan: (B, H, N, C) -> (N, B, H, C) - g_scan = jnp.transpose(g_cumsum, (2, 0, 1, 3)) - decay_mask_scan = jnp.transpose(decay_mask, (2, 0, 1, 3, 4)) + o_chunks = o_pallas.transpose(0, 2, 1, 3, 4) - xs = (query_scan, key_scan, value_scan, k_cumdecay_scan, g_scan, decay_mask_scan) + # ========================================================================= + # STAGE 4: FINALIZATION + # ========================================================================= + o = o_chunks.reshape(B, -1, H, V_dim) + + if pad_len > 0: + o = o[:, :S, :, :] + + o = o.astype(initial_dtype) + + return o, h_final +def jax_chunk_gated_delta_rule( + query: jax.Array, + key: jax.Array, + value: jax.Array, + g: jax.Array, + beta: jax.Array, + chunk_size: int = 64, + initial_state: None | jax.Array = None, + use_qk_norm_in_gdn: bool = False, + compute_dtype: jnp.dtype = jnp.bfloat16, +) -> tuple[jax.Array, None | jax.Array]: + """ + Optimized JAX implementation of Gated Delta Rule (Mixed Precision + Stability Fix). + + Precision Strategy: + - Inputs (q, k, v, beta): 'compute_dtype' (e.g. bfloat16) for tensor core usage. + - Gates (g) & State (h): float32 (forced) for numerical stability. + - Matmuls: compute_dtype inputs -> float32 accumulation/output. + """ # ========================================================================= - # STAGE 3: INTER-CHUNK RECURRENCE (SEQUENTIAL VIA SCAN) + # STAGE 1: PREPARATION & PADDING # ========================================================================= - def scan_body(prev_state, x): - q_i, k_i, v_i, k_cumdecay_i, g_i, decay_mask_i = x - # prev_state shape: (B, H, D_k, D_v) - last_recurrent_state = prev_state - prec = jax.lax.Precision.HIGHEST - - # Intra-chunk attention for the current chunk - # attn_i shape: (B, H, C, C) - attn_i = jnp.matmul(q_i, jnp.swapaxes(k_i, -1, -2), precision=prec) * decay_mask_i - attn_i = jnp.where(mask_inter, 0.0, attn_i) - - # Interaction with the recurrent state - # v_prime shape: (B, H, C, D_v) - v_prime = jnp.matmul(k_cumdecay_i, last_recurrent_state, precision=prec) - # v_new shape: (B, H, C, D_v) - v_new = v_i - v_prime - - # g_i is cumulative sum, so exp(g_i) is the decay factor - g_i_exp = jnp.exp(g_i) - # attn_inter shape: (B, H, C, D_v) - attn_inter = jnp.matmul(q_i * jnp.expand_dims(g_i_exp, -1), last_recurrent_state, precision=prec) - - # core_attn_out_i shape: (B, H, C, D_v) - core_attn_out_i = attn_inter + jnp.matmul(attn_i, v_new, precision=prec) - - # Update the recurrent state - # g_i_last_exp shape: (B, H, 1, 1) - g_i_last_exp = jnp.exp(g_i[..., -1, None, None]) - # new_last_recurrent_state shape: (B, H, D_k, D_v) - new_last_recurrent_state = last_recurrent_state * g_i_last_exp - - # g_diff_exp shape: (B, H, C, 1) - g_diff_exp = jnp.expand_dims(jnp.exp(jnp.expand_dims(g_i[..., -1], -1) - g_i), -1) - # k_i_g_diff shape: (B, H, C, D_k) - k_i_g_diff = k_i * g_diff_exp - - # Update term shape: (B, H, D_k, D_v) - update_term = jnp.matmul(jnp.swapaxes(k_i_g_diff, -1, -2), v_new, precision=prec) - new_last_recurrent_state = new_last_recurrent_state + update_term - - return new_last_recurrent_state, core_attn_out_i - - # final_state shape: (B, H, D_k, D_v) - # core_attn_out_stacked shape: (N, B, H, C, D_v) - final_state, core_attn_out_stacked = jax.lax.scan(scan_body, last_recurrent_state, xs) + initial_dtype = query.dtype + + if use_qk_norm_in_gdn: + from MaxText.layers.normalizations import l2norm + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + + # [MIXED PRECISION START] + # 1. Force Gates 'g' to float32 immediately (crucial for exp/cumsum stability) + g = g.astype(jnp.float32) + + # 2. Cast inputs to the requested compute_dtype (likely bf16) to save memory/compute + query = query.astype(compute_dtype) + key = key.astype(compute_dtype) + value = value.astype(compute_dtype) + beta = beta.astype(compute_dtype) + + # Scale Query (keep in compute_dtype) + scale = jax.lax.rsqrt(jnp.array(query.shape[-1], dtype=jnp.float32)).astype(compute_dtype) + query = query * scale + + B, S, H, K_dim = key.shape + V_dim = value.shape[-1] + + # Pad sequence + pad_len = (chunk_size - (S % chunk_size)) % chunk_size + if pad_len > 0: + pad_fn = lambda x, val=0.0: jnp.pad(x, ((0,0), (0, pad_len)) + ((0,0),)*(x.ndim-2), constant_values=val) + query = pad_fn(query) + key = pad_fn(key) + value = pad_fn(value) + g = pad_fn(g) + beta = pad_fn(beta) + + num_chunks = query.shape[1] // chunk_size + + # Helper: (B, S, H, D) -> (B, N, H, C, D) + def to_chunk(x): + return x.reshape(B, num_chunks, chunk_size, H, -1).transpose(0, 1, 3, 2, 4) + + # Helper for scalars: (B, S, H) -> (B, N, H, C) + def to_chunk_scalar(x): + return x.reshape(B, num_chunks, chunk_size, H).transpose(0, 1, 3, 2) + + q_c = to_chunk(query) # compute_dtype + k_c = to_chunk(key) # compute_dtype + v_c = to_chunk(value) # compute_dtype + g_c = to_chunk_scalar(g) # float32 + beta_c = to_chunk_scalar(beta) # compute_dtype # ========================================================================= - # STAGE 4: FINALIZATION + # STAGE 2: INTRA-CHUNK PRE-COMPUTATION (Parallel) # ========================================================================= - # core_attn_out shape: (B, H, N, C, D_v) - core_attn_out = jnp.transpose(core_attn_out_stacked, (1, 2, 0, 3, 4)) + + # 1. Cumulative decay (Must be float32) + g_cumsum = jnp.cumsum(g_c, axis=-1) + + # 2. k_beta preparation (bf16 * bf16 -> bf16) + k_beta = k_c * beta_c[..., None] + + # 3. S Matrix Calculation + # Matmul: bf16 @ bf16 -> Accumulate in float32 (via HIGHEST) + S = jnp.matmul(k_c, k_beta.swapaxes(-1, -2), precision=jax.lax.Precision.HIGHEST) + + # [CRITICAL] Promote S to float32 immediately for interaction with exp(g) + S = S.astype(jnp.float32) + + # [CRITICAL FIX] Apply mask BEFORE exp to prevent 'inf' gradients + g_diff = g_cumsum[..., :, None] - g_cumsum[..., None, :] + mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool), k=-1) + g_diff = jnp.where(mask, g_diff, -1e30) + + S = S * jnp.exp(g_diff) + S = jnp.where(mask, S, 0.0) + + # 4. Inversion (A) - Strictly float32 + identity = jnp.eye(chunk_size, dtype=jnp.float32) + identity_broadcasted = jnp.broadcast_to(identity, S.shape) + + A = jax.scipy.linalg.solve_triangular( + identity + S, + identity_broadcasted, + lower=True, + unit_diagonal=True + ) - # core_attn_out shape: (B, H, S_padded, D_v) - core_attn_out = core_attn_out.reshape(batch_size, num_heads, -1, v_head_dim) - # Trim padding: (B, H, S, D_v) - core_attn_out = core_attn_out[:, :, :sequence_length, :] + # 5. WY Factors (Keep as float32 to preserve accuracy of the Inverse) + # u: f32 @ bf16 -> f32 -> cast to compute_dtype (storage optimization) + v_beta = v_c * beta_c[..., None] + u_chunks = jnp.matmul(A, v_beta.astype(jnp.float32), precision=jax.lax.Precision.HIGHEST) + u_chunks = u_chunks.astype(compute_dtype) + + # w: f32 @ bf16 -> f32 -> cast to compute_dtype (storage optimization) + k_beta_g = k_beta.astype(jnp.float32) * jnp.exp(g_cumsum)[..., None] + w_chunks = jnp.matmul(A, k_beta_g, precision=jax.lax.Precision.HIGHEST) + w_chunks = w_chunks.astype(compute_dtype) - # Transpose back to (B, S, H, D_v) - core_attn_out = jnp.transpose(core_attn_out, (0, 2, 1, 3)).astype(initial_dtype) + # ========================================================================= + # STAGE 3: INTER-CHUNK RECURRENCE (Scan) + # ========================================================================= + scan_perm_vec = (1, 0, 2, 3, 4) + scan_perm_scl = (1, 0, 2, 3) + + w_scan = w_chunks.transpose(scan_perm_vec) # f32 + u_scan = u_chunks.transpose(scan_perm_vec) # f32 + k_scan = k_c.transpose(scan_perm_vec) # compute_dtype + q_scan = q_c.transpose(scan_perm_vec) # compute_dtype + v_scan = v_c.transpose(scan_perm_vec) # compute_dtype + g_scan = g_cumsum.transpose(scan_perm_scl) # f32 + beta_scan = beta_c.transpose(scan_perm_scl)# compute_dtype + + chunk_decay_all = jnp.exp(g_cumsum[..., -1]).transpose(1, 0, 2) # f32 + + # State MUST be float32 for linear RNNs + if initial_state is None: + h_init = jnp.zeros((B, H, K_dim, V_dim), dtype=jnp.float32) + else: + h_init = initial_state.astype(jnp.float32) + + xs = (w_scan, u_scan, q_scan, k_scan, v_scan, g_scan, beta_scan, chunk_decay_all) + + def scan_body(h, args): + w, u, q, k, v, g, beta, decay_val = args + + # --- Output Computation --- + # 1. Inter-chunk: q(bf16) * exp(g)(f32) -> f32 + # f32 @ h(f32) -> f32 + q_g = q.astype(jnp.float32) * jnp.exp(g)[..., None] + term1 = jnp.matmul(q_g, h, precision=jax.lax.Precision.HIGHEST) + + # 2. Intra-chunk: q(bf16) @ k(bf16) -> bf16/f32 + attn = jnp.matmul(q, k.swapaxes(-1, -2), precision=jax.lax.Precision.HIGHEST) + + # [CRITICAL] Promote to f32 before exp mask + attn = attn.astype(jnp.float32) + + # [CRITICAL FIX] Mask before exp + g_diff = g[..., :, None] - g[..., None, :] + mask_intra = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool)) + g_diff = jnp.where(mask_intra, g_diff, -1e30) + + attn = attn * jnp.exp(g_diff) + attn = jnp.where(mask_intra, attn, 0.0) + + # beta is compute_dtype, cast to f32 for mixing + attn = attn * beta.astype(jnp.float32)[..., None, :] + + # attn(f32) @ v(bf16) -> f32 + term2 = jnp.matmul(attn, v.astype(jnp.float32), precision=jax.lax.Precision.HIGHEST) + + o_c = term1 + term2 + + # --- State Update --- + decay_expanded = decay_val[..., None, None] + + # w(f32) @ u(f32) -> f32 + update = jnp.matmul(w.swapaxes(-1, -2), u, precision=jax.lax.Precision.HIGHEST) + update = update.astype(jnp.float32) + + h_new = h * decay_expanded + update + + return h_new, o_c + + final_h, o_chunks = lax.scan(scan_body, h_init, xs) - return core_attn_out, final_state if output_final_state else None + # ========================================================================= + # STAGE 4: FINALIZATION + # ========================================================================= + o = o_chunks.transpose(1, 0, 2, 3, 4) + o = o.reshape(B, -1, H, V_dim) + + if pad_len > 0: + o = o[:, :S, :, :] + + o = o.astype(initial_dtype) + + return o, (final_h if initial_state is not None else None) class Qwen3NextGatedDeltaNet(nnx.Module): @@ -306,13 +427,14 @@ class Qwen3NextGatedDeltaNet(nnx.Module): 2. output = Linear_out(y) """ - def __init__(self, config: Config, *, rngs: nnx.Rngs): + def __init__(self, config: Config, *, rngs: nnx.Rngs, mesh: Mesh=None): """ Args: config: MaxText configuration object. rngs: The random number generators for initialization, passed by the nnx.to_linen wrapper. """ self.config = config + self.mesh = mesh cfg = self.config in_features = cfg.emb_dim @@ -477,8 +599,8 @@ def __call__(self, hidden_states: Array) -> Array: # ========================================================================= # STEP C: Gated Delta Rule Recurrence # ========================================================================= - A_log = self.A_log.value - dt_bias = self.dt_bias.value + A_log = self.A_log[...] + dt_bias = self.dt_bias[...] # beta shape: (B, S, H_v) beta = jax.nn.sigmoid(b) # g shape: (B, S, H_v) @@ -497,8 +619,15 @@ def __call__(self, hidden_states: Array) -> Array: # TODO(parambole): Pass and update cache state for jax_chunk_gated_delta_rule # core_attn_out shape: (B, S, H_v, D_v) - core_attn_out, _ = jax_chunk_gated_delta_rule( - query, key, value, g, beta, chunk_size=cfg.gdn_chunk_size, use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn + # core_attn_out, _ = jax_chunk_gated_delta_rule( + # query, key, value, g, beta, chunk_size=cfg.gdn_chunk_size, use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn, compute_dtype=cfg.dtype + # ) + core_attn_out, _ = pallas_chunk_gated_delta_rule( + query, key, value, g, beta, + chunk_size=cfg.gdn_chunk_size, + use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn, + compute_dtype=cfg.dtype, + mesh=self.mesh ) # ========================================================================= @@ -664,7 +793,7 @@ def __init__(self, config: Config, mesh: Mesh, quant: None | Quant = None, *, rn use_bias=False, # Qwen3-Next shared_expert_gate does not have a bias dtype=cfg.dtype, kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), - kernel_axes=("embed", "vocab"), + kernel_axes=("embed", None), matmul_precision=cfg.matmul_precision, rngs=rngs, ) @@ -830,7 +959,7 @@ def __init__( rngs=rngs, ) else: - self.attention = Qwen3NextGatedDeltaNet(config=cfg, rngs=rngs) + self.attention = Qwen3NextGatedDeltaNet(config=cfg, rngs=rngs, mesh=self.mesh) # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( diff --git a/src/maxtext/scratch_code/benchmark_gdn_optimization.py b/src/maxtext/scratch_code/benchmark_gdn_optimization.py new file mode 100644 index 0000000000..2f3ddbea54 --- /dev/null +++ b/src/maxtext/scratch_code/benchmark_gdn_optimization.py @@ -0,0 +1,533 @@ +import time +import functools +import types +import jax +import jax.extend +import jax.numpy as jnp +import numpy as np +from flax import nnx +from typing import Any, cast + +# Import common dependencies +from MaxText import common_types +from MaxText.layers import normalizations +from MaxText.layers.linears import DenseGeneral +from MaxText.layers import qwen3 + +# ============================================================================== +# SECTION 1: THE MIRROR (BASELINE IMPLEMENTATION) +# ============================================================================== +def baseline_chunk_gated_delta_rule( + query, key, value, g, beta, chunk_size=64, initial_state=None, use_qk_norm_in_gdn=False +): + """The ORIGINAL implementation (Do not edit this function).""" + initial_dtype = query.dtype + if use_qk_norm_in_gdn: + query = normalizations.l2norm(query, dim=-1, eps=1e-6) + key = normalizations.l2norm(key, dim=-1, eps=1e-6) + + query = jnp.transpose(query, (0, 2, 1, 3)).astype(jnp.float32) + key = jnp.transpose(key, (0, 2, 1, 3)).astype(jnp.float32) + value = jnp.transpose(value, (0, 2, 1, 3)).astype(jnp.float32) + beta = jnp.transpose(beta, (0, 2, 1)).astype(jnp.float32) + g = jnp.transpose(g, (0, 2, 1)).astype(jnp.float32) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + + if pad_size > 0: + query = jnp.pad(query, ((0, 0), (0, 0), (0, pad_size), (0, 0))) + key = jnp.pad(key, ((0, 0), (0, 0), (0, pad_size), (0, 0))) + value = jnp.pad(value, ((0, 0), (0, 0), (0, pad_size), (0, 0))) + beta = jnp.pad(beta, ((0, 0), (0, 0), (0, pad_size))) + g = jnp.pad(g, ((0, 0), (0, 0), (0, pad_size))) + + total_sequence_length = sequence_length + pad_size + scale = jax.lax.rsqrt(jnp.array(query.shape[-1]).astype(jnp.float32)) + query = query * scale + + v_beta = value * jnp.expand_dims(beta, -1) + k_beta = key * jnp.expand_dims(beta, -1) + + num_chunks = total_sequence_length // chunk_size + query_c = query.reshape(batch_size, num_heads, num_chunks, chunk_size, k_head_dim) + key_c = key.reshape(batch_size, num_heads, num_chunks, chunk_size, k_head_dim) + k_beta_c = k_beta.reshape(batch_size, num_heads, num_chunks, chunk_size, k_head_dim) + v_beta_c = v_beta.reshape(batch_size, num_heads, num_chunks, chunk_size, v_head_dim) + g_c = g.reshape(batch_size, num_heads, num_chunks, chunk_size) + + mask = jnp.triu(jnp.ones((chunk_size, chunk_size), dtype=bool), k=0) + + g_cumsum = jnp.cumsum(g_c, axis=-1) + g_diff = jnp.expand_dims(g_cumsum, -1) - jnp.expand_dims(g_cumsum, -2) + g_diff_tril = jnp.tril(g_diff) + g_diff_exp = jnp.exp(g_diff_tril).astype(jnp.float32) + decay_mask = g_diff_exp + + prec = jax.lax.Precision.HIGHEST + attn = -jnp.matmul(k_beta_c, jnp.swapaxes(key_c, -1, -2), precision=prec) * decay_mask + attn = jnp.where(mask, 0.0, attn) + + def inner_attn_body(i, attn_val): + indices = jnp.arange(chunk_size) + col_mask = indices < i + row = attn_val[..., i, :] * col_mask + sub_mask = jnp.expand_dims(indices < i, -1) & (indices < i) + sub = attn_val * sub_mask + row_exp = jnp.expand_dims(row, -1) + term = row_exp * sub + summed = jnp.sum(term, axis=-2) + update_val = row + summed + original_row = attn_val[..., i, :] + new_row = jnp.where(col_mask, update_val, original_row) + return attn_val.at[..., i, :].set(new_row) + + attn = jax.lax.fori_loop(1, chunk_size, inner_attn_body, attn) + attn = attn + jnp.eye(chunk_size, dtype=attn.dtype) + value_intra = jnp.matmul(attn, v_beta_c, precision=prec) + k_cumdecay = jnp.matmul(attn, (k_beta_c * jnp.expand_dims(jnp.exp(g_cumsum), -1)), precision=prec) + + output_final_state = initial_state is not None + if initial_state is None: + last_recurrent_state = jnp.zeros((batch_size, num_heads, k_head_dim, v_head_dim), dtype=value_intra.dtype) + else: + last_recurrent_state = initial_state.astype(value_intra.dtype) + + mask_inter = jnp.triu(jnp.ones((chunk_size, chunk_size), dtype=bool), k=1) + + query_scan = jnp.transpose(query_c, (2, 0, 1, 3, 4)) + key_scan = jnp.transpose(key_c, (2, 0, 1, 3, 4)) + value_scan = jnp.transpose(value_intra, (2, 0, 1, 3, 4)) + k_cumdecay_scan = jnp.transpose(k_cumdecay, (2, 0, 1, 3, 4)) + g_scan = jnp.transpose(g_cumsum, (2, 0, 1, 3)) + decay_mask_scan = jnp.transpose(decay_mask, (2, 0, 1, 3, 4)) + + xs = (query_scan, key_scan, value_scan, k_cumdecay_scan, g_scan, decay_mask_scan) + + def scan_body(prev_state, x): + q_i, k_i, v_i, k_cumdecay_i, g_i, decay_mask_i = x + last_recurrent_state = prev_state + prec = jax.lax.Precision.HIGHEST + + attn_i = jnp.matmul(q_i, jnp.swapaxes(k_i, -1, -2), precision=prec) * decay_mask_i + attn_i = jnp.where(mask_inter, 0.0, attn_i) + + v_prime = jnp.matmul(k_cumdecay_i, last_recurrent_state, precision=prec) + v_new = v_i - v_prime + + g_i_exp = jnp.exp(g_i) + attn_inter = jnp.matmul(q_i * jnp.expand_dims(g_i_exp, -1), last_recurrent_state, precision=prec) + + core_attn_out_i = attn_inter + jnp.matmul(attn_i, v_new, precision=prec) + + g_i_last_exp = jnp.exp(g_i[..., -1, None, None]) + new_last_recurrent_state = last_recurrent_state * g_i_last_exp + + g_diff_exp = jnp.expand_dims(jnp.exp(jnp.expand_dims(g_i[..., -1], -1) - g_i), -1) + k_i_g_diff = k_i * g_diff_exp + + update_term = jnp.matmul(jnp.swapaxes(k_i_g_diff, -1, -2), v_new, precision=prec) + new_last_recurrent_state = new_last_recurrent_state + update_term + + return new_last_recurrent_state, core_attn_out_i + + final_state, core_attn_out_stacked = jax.lax.scan(scan_body, last_recurrent_state, xs) + + core_attn_out = jnp.transpose(core_attn_out_stacked, (1, 2, 0, 3, 4)) + core_attn_out = core_attn_out.reshape(batch_size, num_heads, -1, v_head_dim) + core_attn_out = core_attn_out[:, :, :sequence_length, :] + core_attn_out = jnp.transpose(core_attn_out, (0, 2, 1, 3)).astype(initial_dtype) + + return core_attn_out, final_state if output_final_state else None + + +class BaselineGatedDeltaNet(nnx.Module): + """The Mirror/Baseline Wrapper.""" + def __init__(self, config, *, rngs): + self.config = config + cfg = self.config + + in_features = cfg.emb_dim + self.num_v_heads = cfg.gdn_num_value_heads + self.num_k_heads = cfg.gdn_num_key_heads + self.head_k_dim = cfg.gdn_key_head_dim + self.head_v_dim = cfg.gdn_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + conv_dim = self.key_dim * 2 + self.value_dim + conv_kernel_size = cfg.gdn_conv_kernel_dim + self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads + + self.in_proj_qkvz = DenseGeneral(in_features, (self.key_dim * 2 + self.value_dim * 2), dtype=cfg.dtype, kernel_axes=("embed", "mlp"), matmul_precision=cfg.matmul_precision, rngs=rngs) + self.in_proj_ba = DenseGeneral(in_features, (self.num_v_heads * 2), dtype=cfg.dtype, kernel_axes=("embed", "mlp"), matmul_precision=cfg.matmul_precision, rngs=rngs) + + self.conv1d = nnx.Conv(conv_dim, conv_dim, kernel_size=(conv_kernel_size,), feature_group_count=conv_dim, padding="CAUSAL", use_bias=False, dtype=cfg.dtype, precision=cfg.matmul_precision, rngs=rngs) + + def a_log_init(key, shape, dtype=jnp.float32): + a_vals = jax.random.uniform(key, shape=shape, dtype=dtype, minval=1e-9, maxval=16.0) + return jnp.log(a_vals) + + self.A_log = nnx.Param(a_log_init(rngs.params(), (self.num_v_heads,))) + self.dt_bias = nnx.Param(nnx.initializers.ones(rngs.params(), (self.num_v_heads,))) + + self.norm = normalizations.Qwen3NextRMSNormGated(self.head_v_dim, eps=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs) + self.out_proj = DenseGeneral(self.value_dim, (in_features,), dtype=cfg.dtype, kernel_axes=("mlp", "embed"), matmul_precision=cfg.matmul_precision, rngs=rngs) + + def __call__(self, hidden_states): + cfg = self.config + batch, seq_len, _ = hidden_states.shape + qkvz = self.in_proj_qkvz(hidden_states) + ba = self.in_proj_ba(hidden_states) + + new_shape_qkvz = (batch, seq_len, self.num_k_heads, 2 * self.head_k_dim + 2 * self.head_v_dim * self.v_heads_per_k_head) + mixed_qkvz = qkvz.reshape(new_shape_qkvz) + split_indices_qkvz = [self.head_k_dim, 2 * self.head_k_dim, 2 * self.head_k_dim + (self.v_heads_per_k_head * self.head_v_dim)] + query, key, value_raw, z_raw = jnp.split(mixed_qkvz, split_indices_qkvz, axis=3) + value = value_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + z = z_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + + new_shape_ba = (batch, seq_len, self.num_k_heads, 2 * self.v_heads_per_k_head) + mixed_ba = ba.reshape(new_shape_ba) + b_raw, a_raw = jnp.split(mixed_ba, [self.v_heads_per_k_head], axis=3) + b = b_raw.reshape(batch, seq_len, self.num_v_heads) + a = a_raw.reshape(batch, seq_len, self.num_v_heads) + + q = query.reshape(batch, seq_len, -1) + k = key.reshape(batch, seq_len, -1) + v = value.reshape(batch, seq_len, -1) + qkv = jnp.concatenate([q, k, v], axis=-1) + + conv_out = self.conv1d(qkv) + qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype) + q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1) + + query = q_conv.reshape(batch, seq_len, self.num_k_heads, self.head_k_dim) + key = k_conv.reshape(batch, seq_len, self.num_k_heads, self.head_k_dim) + value = v_conv.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + + # FIXED .value DEPRECATION + A_log = self.A_log[...] + dt_bias = self.dt_bias[...] + beta = jax.nn.sigmoid(b) + g = -jnp.exp(A_log.astype(jnp.float32)) * jax.nn.softplus(a.astype(jnp.float32) + dt_bias.astype(jnp.float32)) + g = g.astype(cfg.dtype) + + if self.num_v_heads > self.num_k_heads and self.num_v_heads % self.num_k_heads == 0: + repeats = self.num_v_heads // self.num_k_heads + query = jnp.repeat(query, repeats, axis=2) + key = jnp.repeat(key, repeats, axis=2) + + # USING BASELINE KERNEL + core_attn_out, _ = baseline_chunk_gated_delta_rule( + query, key, value, g, beta, chunk_size=cfg.gdn_chunk_size, use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn + ) + + gated_output_reshaped = self.norm(core_attn_out, z) + gated_output = gated_output_reshaped.reshape(batch, seq_len, -1) + output = self.out_proj(gated_output) + return output + + +# ============================================================================== +# SECTION 2: BENCHMARK HARNESS +# ============================================================================== + +def run_comparison(): + backend = jax.extend.backend.get_backend().platform + print(f"\nDevice: {jax.devices()[0]} ({backend})") + + # --- CONFIGURATION --- + if backend == 'tpu': + # REAL BENCHMARK SETTINGS (Heavy) + DTYPE = jnp.bfloat16 + BATCH = 2 + SEQ_LEN = 4096 + ITERS = 20 + WARMUP = 5 + else: + # CPU DEBUG SETTINGS (Fast) + print("⚠️ Running on CPU: Using reduced dimensions for speed.") + DTYPE = jnp.float32 + BATCH = 1 + SEQ_LEN = 128 + ITERS = 5 + WARMUP = 1 + + HIDDEN_SIZE = 2048 + NUM_KEY_HEADS = 16 + NUM_VALUE_HEADS = 32 + HEAD_DIM = 128 + CONV_KERNEL_DIM = 4 + CHUNK_SIZE = 64 + PROFILE_DIR = "/tmp/maxtext_gdn_profile" + + print(f"Config: Batch={BATCH}, SeqLen={SEQ_LEN}, Dtype={DTYPE}") + print(f"Model: H={HIDDEN_SIZE}, K_Heads={NUM_KEY_HEADS}, V_Heads={NUM_VALUE_HEADS}, HeadDim={HEAD_DIM}") + + dummy_config = types.SimpleNamespace( + emb_dim=HIDDEN_SIZE, + gdn_num_value_heads=NUM_VALUE_HEADS, + gdn_num_key_heads=NUM_KEY_HEADS, + gdn_key_head_dim=HEAD_DIM, + gdn_value_head_dim=HEAD_DIM, + gdn_conv_kernel_dim=CONV_KERNEL_DIM, + dtype=DTYPE, + matmul_precision='default', + normalization_layer_epsilon=1e-6, + weight_dtype=DTYPE, + gdn_chunk_size=CHUNK_SIZE, + use_qk_norm_in_gdn=True, + load_balance_loss_weight=0.0, + scan_layers=False + ) + + # 1. INSTANTIATE MODELS + print("Initializing models...") + rngs_base = nnx.Rngs(0) + baseline_model = BaselineGatedDeltaNet(config=dummy_config, rngs=rngs_base) + + rngs_opt = nnx.Rngs(0) + optimized_model = qwen3.Qwen3NextGatedDeltaNet(config=dummy_config, rngs=rngs_opt) + + # 2. WEIGHT SYNCHRONIZATION + _, params_state = nnx.split(optimized_model) + nnx.update(baseline_model, params_state) + print("✅ Models synchronized with identical weights.") + + # 3. INPUTS + key = jax.random.PRNGKey(42) + inputs = jax.random.normal(key, (BATCH, SEQ_LEN, HIDDEN_SIZE), dtype=DTYPE) + + # ------------------------------------------------------------------------- + # Helper: Pure Functional wrappers to avoid Flax/JAX Version mismatch issues + # ------------------------------------------------------------------------- + def create_jitted_train_step(model): + graphdef, params = nnx.split(model) + + @jax.jit + def pure_train_step(params, x): + m = nnx.merge(graphdef, params) + def loss_fn(m_inner): + y = m_inner(x) + return jnp.mean(y) + loss, grads = nnx.value_and_grad(loss_fn)(m) + return loss, grads + + return pure_train_step, params + + def create_jitted_forward(model): + graphdef, params = nnx.split(model) + + @jax.jit + def pure_forward(params, x): + m = nnx.merge(graphdef, params) + return m(x) + + return pure_forward, params + + # ============================================================================== + # PART A: LOGICAL CORRECTNESS + # ============================================================================== + print("\n--- Checking Logical Correctness ---") + + # Create safe functional wrappers + jit_train_base, params_base = create_jitted_train_step(baseline_model) + jit_train_opt, params_opt = create_jitted_train_step(optimized_model) + + loss_base, grads_base = jit_train_base(params_base, inputs) + jax.block_until_ready((loss_base, grads_base)) + + loss_opt, grads_opt = jit_train_opt(params_opt, inputs) + jax.block_until_ready((loss_opt, grads_opt)) + + diff_loss = jnp.abs(loss_base - loss_opt) + print(f"Forward Pass Loss Diff: {float(diff_loss):.2e}") + + flat_grads_base, _ = jax.tree_util.tree_flatten(grads_base) + flat_grads_opt, _ = jax.tree_util.tree_flatten(grads_opt) + + max_grad_diff = 0.0 + for g1, g2 in zip(flat_grads_base, flat_grads_opt): + if hasattr(g1, 'shape'): + d = jnp.max(jnp.abs(g1 - g2)) + max_grad_diff = max(max_grad_diff, float(d)) + + print(f"Backward Pass Grad Diff: {max_grad_diff:.2e}") + + if max_grad_diff > 1e-2: + print("WARNING: Significant divergence in gradients!") + else: + print("✅ Outputs & Gradients match within tolerance.") + + # ============================================================================== + # PART B: SPEED BENCHMARKING + # ============================================================================== + print("\n--- Performance Benchmark ---") + + def benchmark_func(name, func, *args): + print(f"Benchmarking {name}...") + # Warmup + for _ in range(WARMUP): + out = func(*args) + jax.block_until_ready(out) + + # Time it + t0 = time.time() + for _ in range(ITERS): + out = func(*args) + jax.block_until_ready(out) + t_avg = (time.time() - t0) / ITERS * 1000 + print(f" -> {t_avg:.2f} ms") + return t_avg + + # Create forward-only wrappers + jit_fwd_base, _ = create_jitted_forward(baseline_model) + jit_fwd_opt, _ = create_jitted_forward(optimized_model) + + t_fwd_base = benchmark_func("Baseline Forward", jit_fwd_base, params_base, inputs) + t_fwd_opt = benchmark_func("Optimized Forward", jit_fwd_opt, params_opt, inputs) + + t_train_base = benchmark_func("Baseline Train Step", jit_train_base, params_base, inputs) + t_train_opt = benchmark_func("Optimized Train Step", jit_train_opt, params_opt, inputs) + + print(f"\n--- Results ---") + print(f"Forward Speedup: {t_fwd_base/t_fwd_opt:.2f}x ({t_fwd_base:.2f}ms -> {t_fwd_opt:.2f}ms)") + print(f"Training Step Speedup: {t_train_base/t_train_opt:.2f}x ({t_train_base:.2f}ms -> {t_train_opt:.2f}ms)") + + # ============================================================================== + # PART C: STATIC MEMORY ANALYSIS + # ============================================================================== + print("\n--- Static Memory Analysis (Compiler Estimate) ---") + + def analyze_memory(name, func, *args): + print(f"Analyzing {name}...") + try: + compiled = func.lower(*args).compile() + mem_analysis = compiled.memory_analysis() + + if mem_analysis is None: + print(" Memory analysis not supported on this backend/version.") + return 0 + + # JAX 0.8.0+: The string representation is the most reliable way to read stats + print(f" {mem_analysis}") + + # Try to grab bytes if available, otherwise return 0 (skipping reduction calc) + if hasattr(mem_analysis, 'temp_size_in_bytes'): + return mem_analysis.temp_size_in_bytes + return 0 + + except Exception as e: + print(f" Memory analysis failed: {e}") + return 0 + + mem_base = analyze_memory("Baseline Train Step", jit_train_base, params_base, inputs) + mem_opt = analyze_memory("Optimized Train Step", jit_train_opt, params_opt, inputs) + + if mem_base > 0 and mem_opt > 0: + reduction = (mem_base - mem_opt) / mem_base * 100 + # Positive reduction = Good (Saved memory) + # Negative reduction = Bad (Regression) + print(f"\nMemory Reduction: {reduction:.2f}% (Higher is better)") + + # ============================================================================== + # PART D: PROFILING + # ============================================================================== + print(f"\n--- Profiling Optimized Implementation ---") + print(f"Saving trace to: {PROFILE_DIR}") + + try: + jax.profiler.start_trace(PROFILE_DIR) + for _ in range(WARMUP): # Use warmup count just to get a few samples + out = jit_train_opt(params_opt, inputs) + jax.block_until_ready(out) + jax.profiler.stop_trace() + print("Profiling complete.") + except Exception as e: + print(f"Profiling failed (possibly already active): {e}") + + print(f"Path: {PROFILE_DIR}") + + # ============================================================================== + # PART E: STABILITY STRESS TEST (ADDITION) + # ============================================================================== + print("\n--- Stability Stress Test ---") + + # 1. Define a training step that actually updates parameters (SGD) + def create_jitted_update_step(model, learning_rate=1e-4): + graphdef, params = nnx.split(model) + + @jax.jit + def train_update(params, x): + # Reconstruct model + m = nnx.merge(graphdef, params) + + def loss_fn(m_inner): + # Use a slightly more complex loss to encourage gradient flow + y = m_inner(x) + return jnp.mean(jnp.square(y)) + + loss, grads = nnx.value_and_grad(loss_fn)(m) + + # Manual SGD Update: param = param - lr * grad + new_params = jax.tree_util.tree_map( + lambda p, g: p - learning_rate * g, + params, grads + ) + return loss, grads, new_params + + return train_update, params + + # 2. Initialize the update step + jit_update_opt, current_params = create_jitted_update_step(optimized_model) + + # 3. Run simulation loop + TEST_STEPS = 15 + print(f"Running {TEST_STEPS} simulated training steps with FRESH inputs...") + + for step in range(1, TEST_STEPS + 1): + # Generate fresh random input to mimic varying data distribution + step_key = jax.random.fold_in(key, step * 100) + step_input = jax.random.normal(step_key, (BATCH, SEQ_LEN, HIDDEN_SIZE), dtype=DTYPE) + + # Perform update + loss_val, grads_val, current_params = jit_update_opt(current_params, step_input) + + # Block to ensure we catch the error at the specific step + jax.block_until_ready(loss_val) + + # --- CHECKS --- + # 1. Check Loss + if jnp.isnan(loss_val) or jnp.isinf(loss_val): + print(f"\n❌ CRITICAL FAIL at Step {step}: Loss is {loss_val}!") + return + + # 2. Check Gradients (Aggregate check) + grad_any_nan = jax.tree_util.tree_reduce( + lambda acc, x: acc or jnp.any(jnp.isnan(x)) or jnp.any(jnp.isinf(x)), + grads_val, False + ) + if grad_any_nan: + print(f"\n❌ CRITICAL FAIL at Step {step}: Gradients contain NaN or Inf!") + # Optional: Print which specific parameter exploded + flat_grads, struct = jax.tree_util.tree_flatten(grads_val) + for i, g in enumerate(flat_grads): + if jnp.any(jnp.isnan(g)): print(f" -> Gradient index {i} is NaN") + return + + # 3. Check Parameters + param_any_nan = jax.tree_util.tree_reduce( + lambda acc, x: acc or jnp.any(jnp.isnan(x)) or jnp.any(jnp.isinf(x)), + current_params, False + ) + if param_any_nan: + print(f"\n❌ CRITICAL FAIL at Step {step}: Parameters contain NaN or Inf after update!") + return + + print(f" Step {step}: Loss = {float(loss_val):.6f} | Stability: OK") + + print(f"✅ Stability Stress Test Passed: No NaNs encountered in {TEST_STEPS} steps.") + +if __name__ == "__main__": + run_comparison() \ No newline at end of file diff --git a/src/maxtext/scratch_code/gdn_pallas.py b/src/maxtext/scratch_code/gdn_pallas.py new file mode 100644 index 0000000000..e54ae7578e --- /dev/null +++ b/src/maxtext/scratch_code/gdn_pallas.py @@ -0,0 +1,181 @@ +# src/MaxText/kernels/gdn_pallas.py +import functools +import jax +import jax.numpy as jnp +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +# ============================================================================== +# 1. Pallas Kernel Implementation (Forward Pass Logic) +# ============================================================================== +def gdn_scan_kernel_tpu( + w_ref, u_ref, q_ref, k_ref, v_ref, g_ref, beta_ref, h_init_ref, + o_ref, h_final_ref, + # Hyperparameters + num_chunks: int, chunk_size: int, key_dim: int, val_dim: int, + dtype: jnp.dtype = jnp.bfloat16 +): + # 1. Load Initial State + h = h_init_ref[0, 0].astype(jnp.float32) + + for i in range(num_chunks): + # 2. Load Inputs + w = w_ref[0, 0, i] + u = u_ref[0, 0, i] + q = q_ref[0, 0, i] + k = k_ref[0, 0, i] + v = v_ref[0, 0, i] + g = g_ref[0, 0, i] + beta = beta_ref[0, 0, i] + + # 3. Output Computation + # Inter-chunk: q * exp(g) @ h + g_exp = jnp.exp(g.astype(jnp.float32)) + q_g = q.astype(jnp.float32) * g_exp[:, None] + term1 = jnp.dot(q_g, h) + + # Intra-chunk: (q @ k.T * decay) @ v + attn = jnp.dot(q.astype(jnp.float32), k.astype(jnp.float32).T) + + # Decay: exp(g[i] - g[j]) + g_diff = g.astype(jnp.float32)[:, None] - g.astype(jnp.float32)[None, :] + + # FIX: Apply mask BEFORE exp to prevent Inf * 0 = NaN + # Use float32 arithmetic for masking to avoid boolean type issues in Mosaic + mask_val = jnp.tri(chunk_size, dtype=jnp.float32) + + # For upper triangle (mask=0), set g_diff to -1e30 so exp() becomes 0 + # g_diff_masked = g_diff * 1.0 + (1.0 - 0.0) * -1e30 = -1e30 + large_neg = -1e30 + g_diff = g_diff * mask_val + (1.0 - mask_val) * large_neg + + attn_decay = jnp.exp(g_diff) + attn = attn * attn_decay + + # Apply Beta gates + attn = attn * beta.astype(jnp.float32)[:, None] + + # V projection + term2 = jnp.dot(attn, v.astype(jnp.float32)) + + o_chunk = term1 + term2 + o_ref[0, 0, i] = o_chunk.astype(dtype) + + # 4. State Update + chunk_decay = jnp.exp(g[chunk_size - 1]) + update = jnp.dot(w.astype(jnp.float32).T, u.astype(jnp.float32)) + h = h * chunk_decay + update + + # 5. Store Final State + h_final_ref[0, 0] = h.astype(dtype) + +# ============================================================================== +# 2. JAX Reference Implementation (For Autodiff) +# ============================================================================== +def _gdn_reference(w, u, q, k, v, g, beta, h_init): + """Pure JAX equivalent for autodiff.""" + perm_vec = (2, 0, 1, 3, 4) + perm_scl = (2, 0, 1, 3) + + w_s = w.transpose(perm_vec) + u_s = u.transpose(perm_vec) + q_s = q.transpose(perm_vec) + k_s = k.transpose(perm_vec) + v_s = v.transpose(perm_vec) + g_s = g.transpose(perm_scl) + beta_s = beta.transpose(perm_scl) + + h_curr = h_init.astype(jnp.float32) + B, H, N, C, Dk = k.shape + + def scan_body(h, args): + wt, ut, qt, kt, vt, gt, betat = args + + # Inter-chunk + gt_exp = jnp.exp(gt.astype(jnp.float32)) + q_g = qt.astype(jnp.float32) * gt_exp[..., None] + term1 = jnp.matmul(q_g, h) + + # Intra-chunk + attn = jnp.matmul(qt.astype(jnp.float32), kt.astype(jnp.float32).swapaxes(-1, -2)) + + # Decay (g[i] - g[j]) + g_diff = gt[..., :, None] - gt[..., None, :] + + # Mask before exp (match Pallas logic) + mask = jnp.tril(jnp.ones((C, C), dtype=jnp.float32)) + g_diff = g_diff * mask + (1.0 - mask) * -1e30 + + attn = attn * jnp.exp(g_diff) + attn = attn * betat.astype(jnp.float32)[..., None] + term2 = jnp.matmul(attn, vt.astype(jnp.float32)) + + out = (term1 + term2).astype(v.dtype) + + # Update + chunk_decay = jnp.exp(gt[..., -1])[..., None, None] + update = jnp.matmul(wt.astype(jnp.float32).swapaxes(-1, -2), ut.astype(jnp.float32)) + h_new = h * chunk_decay + update + + return h_new, out + + h_final, o_scan = lax.scan( + scan_body, + h_curr, + (w_s, u_s, q_s, k_s, v_s, g_s, beta_s) + ) + + return o_scan.transpose(1, 2, 0, 3, 4), h_final.astype(v.dtype) + +# ============================================================================== +# 3. Custom VJP Registration +# ============================================================================== + +def _gdn_pallas_forward(w, u, q, k, v, g, beta, h_init): + B, H, N_chunks, C, Dk = k.shape + _, _, _, _, Dv = v.shape + + # Specs + in_specs = pl.BlockSpec(index_map=lambda i, j: (i, j, 0, 0, 0), block_shape=(1, 1, N_chunks, C, Dk)) + val_specs = pl.BlockSpec(index_map=lambda i, j: (i, j, 0, 0, 0), block_shape=(1, 1, N_chunks, C, Dv)) + scalar_specs = pl.BlockSpec(index_map=lambda i, j: (i, j, 0, 0), block_shape=(1, 1, N_chunks, C)) + out_spec = pl.BlockSpec(index_map=lambda i, j: (i, j, 0, 0, 0), block_shape=(1, 1, N_chunks, C, Dv)) + state_spec = pl.BlockSpec(index_map=lambda i, j: (i, j, 0, 0), block_shape=(1, 1, Dk, Dv)) + + kernel_fn = functools.partial( + gdn_scan_kernel_tpu, + num_chunks=N_chunks, chunk_size=C, key_dim=Dk, val_dim=Dv, dtype=v.dtype + ) + + out, h_final = pl.pallas_call( + kernel_fn, + out_shape=[ + jax.ShapeDtypeStruct((B, H, N_chunks, C, Dv), v.dtype), + jax.ShapeDtypeStruct((B, H, Dk, Dv), v.dtype) + ], + grid=(B, H), + in_specs=[in_specs, val_specs, in_specs, in_specs, val_specs, scalar_specs, scalar_specs, state_spec], + out_specs=[out_spec, state_spec], + compiler_params=pltpu.CompilerParams(dimension_semantics=("parallel", "parallel")) + )(w, u, q, k, v, g, beta, h_init) + + return (out, h_final), (w, u, q, k, v, g, beta, h_init) + +def _gdn_pallas_backward(residuals, grad_out_tuple): + grad_out, _ = grad_out_tuple + w, u, q, k, v, g, beta, h_init = residuals + + _, vjp_fn = jax.vjp(_gdn_reference, w, u, q, k, v, g, beta, h_init) + + grad_h_final = jnp.zeros_like(h_init) + grads = vjp_fn((grad_out, grad_h_final)) + return grads + +@functools.partial(jax.custom_vjp, nondiff_argnums=()) +def gdn_pallas_layer(w, u, q, k, v, g, beta, h_init): + # Returns (output, final_state) + res, _ = _gdn_pallas_forward(w, u, q, k, v, g, beta, h_init) + return res + +gdn_pallas_layer.defvjp(_gdn_pallas_forward, _gdn_pallas_backward) \ No newline at end of file diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index b9df72c1d9..49b9f84f21 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -528,33 +528,28 @@ def calculate_gated_delta_net_flops_per_device(config): # 2 * B * S * Channels * Kernel flops_conv = 2 * B * S * (2 * K_dim + V_dim) * K_conv - # 3. Core Gated Delta Net (Attention-like operations) - # Assumptions: - # H = H_v (broadcasting K to V heads if H_v > H_k) - # N = num_chunks & N * C ~ S - # - # Query (Q): [B, S, H_v, D_k] - # Keys (K): [B, S, H_v, D_k] - # Values (V): [B, S, H_v, D_v] - # Intra-Chunk Attention (A): [B, N, H_v, C, C] - # Recurrent State (S): [B, N, H_v, D_k, D_v] - - # - Intra-chunk terms (per chunk C): - # - attn (K*K): 2 * B * S * H_v * C * D_k - # - val_intra (A*V): 2 * B * S * H_v * C * D_v - # - k_cum (A*K): 2 * B * S * H_v * C * D_k - # - inner_attn_body loop (iterative refinement): ≈ (C - 1) * B * H * N * C^2 ≈ B * H * S * C^2 - flops_intra = 2 * B * S * H_v * C * (2 * D_k + D_v) + (B * H_v * S * C**2) - - # - Inter-chunk terms (Recurrent State D_k * D_v): - # - attn_i (Q*K): 2 * B * S * H_v * C * D_k - # - v_prime (K*S): 2 * B * S * H_v * D_k * D_v - # - attn_inter (Q*S): 2 * B * S * H_v * D_k * D_v - # - core_out (A*V): 2 * B * S * H_v * C * D_v - # - update (K*V): 2 * B * S * H_v * D_k * D_v - flops_inter = (2 * B * S * H_v * C * (D_k + D_v)) + (6 * B * S * H_v * D_k * D_v) - - flops_core = flops_intra + flops_inter + # 3. Core Gated Delta Net (Optimized WY Representation) + # The implementation broadcasts K heads to V heads if H_v > H_k + H_eff = max(H_k, H_v) + + # Per-token costs derived from jax_chunk_gated_delta_rule: + # Intra-chunk Pre-computation: + # S = K @ K.T: 2 * C * D_k + # A = (I+S)^-1: ~ C^2 (Triangular solve approximation) + # U = A @ V: 2 * C * D_v + # W = A @ K: 2 * C * D_k + # Scan / Output: + # Out_Inter (Q @ h): 2 * D_k * D_v + # Out_Intra_QK (Q @ K.T): 2 * C * D_k + # Out_Intra_AV (Attn @ V): 2 * C * D_v + # State_Update (W.T @ U): 2 * D_k * D_v + + # Summing per-token factors: + # (2*C*D_k) + C^2 + (2*C*D_v) + (2*C*D_k) + (2*D_k*D_v) + (2*C*D_k) + (2*C*D_v) + (2*D_k*D_v) + # = 6*C*D_k + 4*C*D_v + 4*D_k*D_v + C^2 + + flops_core_per_token = H_eff * (6 * C * D_k + 4 * C * D_v + 4 * D_k * D_v + C**2) + flops_core = B * S * flops_core_per_token # Weights part: Projections + Conv gdn_weight_flops = flops_projections + flops_conv diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index ba87eab068..7a0a54170a 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -22,6 +22,7 @@ import unittest import os.path from tempfile import gettempdir +import absl.flags import pytest @@ -794,3 +795,33 @@ def test_olmo3_7b(self): "max_target_length=1024", ) ) + + @pytest.mark.cpu_only + def test_qwen3_next_tokamax(self): + """AOT test for Qwen3-Next with Tokamax GMM on v5p-128""" + + absl.flags.FLAGS.mark_as_parsed() + + compiled_trainstep_file = "/tmp/test_qwen3_next_tokamax.pickle" + train_compile_main( + ( + "", + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-64", + "compile_topology_num_slices=1", + "model_name=qwen3-next-80b-a3b", + "max_target_length=4096", + "per_device_batch_size=8", + "dtype=bfloat16", + "weight_dtype=bfloat16", + "sparse_matmul=True", + "ici_fsdp_parallelism=-1", + "ici_expert_parallelism=1", + "ici_tensor_parallelism=2", + "scan_layers=True", + "dataset_type=synthetic", + "tokenizer_type=huggingface", + "tokenizer_path=Qwen/Qwen3-Next-80B-A3B-Instruct", + ) + )