diff --git a/fast_llm_external_models/apriel2/examples/pure_gdn_step1.yaml b/fast_llm_external_models/apriel2/examples/pure_gdn_step1.yaml new file mode 100644 index 000000000..4719062f2 --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/pure_gdn_step1.yaml @@ -0,0 +1,19 @@ +# Step 1: Convert fixed -> pattern with all GDN blocks +# +# Sets main_mixer_name to gdn for all layers +# Run before pure_gdn_step2.yaml +# +# Usage: +# python convert.py /tmp/apriel2-0.5b-dev /tmp/apriel2-0.5b-pure-gdn \ +# -s examples/pure_gdn_step1.yaml \ +# -s examples/pure_gdn_step2.yaml + +decoder: + type: pattern + # Single block type - all layers use GDN + pattern: [gdn_block] + + blocks: + gdn_block: + mixer: + main_mixer_name: gdn diff --git a/fast_llm_external_models/apriel2/examples/pure_gdn_step2.yaml b/fast_llm_external_models/apriel2/examples/pure_gdn_step2.yaml new file mode 100644 index 000000000..fd5994f77 --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/pure_gdn_step2.yaml @@ -0,0 +1,18 @@ +# Step 2: Unwrap stochastic -> pure GDN +# +# Converts stochastic mixer to non-stochastic GDN for all layers +# Run after pure_gdn_step1.yaml +# +# Usage: +# python convert.py /tmp/apriel2-0.5b-dev /tmp/apriel2-0.5b-pure-gdn \ +# -s examples/pure_gdn_step1.yaml \ +# -s examples/pure_gdn_step2.yaml + +decoder: + blocks: + gdn_block: + mixer: + type: gdn + init: transfer + convolution_layer: + kernel_size: 4 diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 076e7f4b8..10f80338a 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -19,27 +19,49 @@ from transformers.models.mistral.modeling_mistral import MistralMLP, MistralRMSNorm, apply_rotary_pos_emb from transformers.processing_utils import Unpack from transformers.utils import logging -from transformers.utils.import_utils import ( - is_causal_conv1d_available, - is_mamba_ssm_available, - is_torch_flex_attn_available, -) from .configuration_apriel2 import Apriel2Config, Apriel2TextConfig -# GDN implementation - matches Fast-LLM's gdn.py exactly +# ============================================================================= +# Kernel implementation flags (for debugging vLLM vs FLA/mamba_ssm differences) +# ============================================================================= +USE_VLLM_CONV = False +USE_VLLM_GDN_OPS = False +USE_VLLM_GATED_NORM = False +USE_VLLM_MAMBA_OPS = False # Not yet implemented in vLLM wrapper + +# Causal conv1d try: - from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule + if USE_VLLM_CONV: + from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn + else: + from causal_conv1d import causal_conv1d_fn + # causal_conv1d_update always from causal_conv1d (vLLM's has different signature) + from causal_conv1d import causal_conv1d_update +except ImportError: + causal_conv1d_fn = None + causal_conv1d_update = None + +# GDN ops (chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) +try: + if USE_VLLM_GDN_OPS: + from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule + else: + from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule except ImportError: chunk_gated_delta_rule = None fused_recurrent_gated_delta_rule = None +# Gated RMSNorm try: - from fla.modules.fused_norm_gate import rms_norm_gated + if USE_VLLM_GATED_NORM: + from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn as rms_norm_gated + else: + from fla.modules.fused_norm_gate import rms_norm_gated except ImportError: rms_norm_gated = None -# KDA implementation - matches Fast-LLM's kda.py +# KDA ops try: from fla.ops.kda import chunk_kda, fused_recurrent_kda from fla.ops.kda.gate import fused_kda_gate @@ -48,26 +70,17 @@ fused_recurrent_kda = None fused_kda_gate = None - +# Mamba/SSM ops try: - from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn - from causal_conv1d import causal_conv1d_update as _causal_conv1d_update - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn - from mamba_ssm.ops.triton.selective_state_update import selective_state_update + if USE_VLLM_MAMBA_OPS: + raise ImportError("vLLM mamba ops not yet wrapped") + else: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update except ImportError: - _causal_conv1d_fn = None - _causal_conv1d_update = None selective_scan_fn = None selective_state_update = None - -is_fast_path_available = is_mamba_ssm_available() and is_causal_conv1d_available() - -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask -else: - BlockMask = torch.Tensor - logger = logging.get_logger(__name__) @@ -520,10 +533,10 @@ def __init__( activation: str = "silu", **kwargs, ): - if not is_fast_path_available: + if causal_conv1d_fn is None: raise ImportError( - "CausalConv1d requires CUDA kernels from causal_conv1d and mamba_ssm. " - "Install with: pip install causal-conv1d mamba-ssm" + "CausalConv1d requires CUDA kernels from causal_conv1d. " + "Install with: pip install causal-conv1d" ) # Remove padding from kwargs since we handle it ourselves kwargs.pop("padding", None) @@ -564,6 +577,55 @@ def forward( batch_size, dim, seq_len = x.shape state_len = self.kernel_size[0] - 1 + if USE_VLLM_CONV: + # vLLM expects x as [dim, total_tokens] + # x shape: [batch, dim, seq] + # x_flat[:, t] should equal x[batch_for_t, :, seq_for_t] + # permute to [dim, batch, seq], then reshape to [dim, batch*seq] + x_flat = x.permute(1, 0, 2).reshape(dim, batch_size * seq_len).contiguous() + + # Create conv_states buffer: [batch, dim, state_len] + # vLLM requires stride(1) == 1 (dim dimension contiguous) + # Create as [batch, state_len, dim] contiguous, then transpose to get right strides + conv_states = x.new_zeros(batch_size, state_len, dim).transpose(1, 2) + + # Create query_start_loc: cumulative sequence lengths + # For batch_size sequences each of length seq_len + query_start_loc = torch.arange( + 0, batch_size * seq_len + 1, seq_len, + dtype=torch.int32, device=x.device + ) + + # has_initial_state: all False (no prior state) + has_initial_state = torch.zeros(batch_size, dtype=torch.bool, device=x.device) + + # cache_indices: identity mapping + cache_indices = torch.arange(batch_size, dtype=torch.int32, device=x.device) + + # Call vLLM's causal_conv1d_fn + out_flat = causal_conv1d_fn( + x_flat, + self._weight, + self.bias, + conv_states, + query_start_loc, + cache_indices=cache_indices, + has_initial_state=has_initial_state, + activation=self._activation, + ) + + # Convert back: [dim, total_tokens] -> [batch, dim, seq] + # out_flat shape: [dim, batch*seq] + # reshape to [dim, batch, seq], then permute to [batch, dim, seq] + out = out_flat.reshape(dim, batch_size, seq_len).permute(1, 0, 2) + + if return_final_state: + # conv_states was updated in-place by vLLM's implementation + # Return it in the expected format: [batch, dim, state_len] + return out, conv_states + return out + + # FLA/causal_conv1d path below # Edge case: seq_len==1 with return_final_state # CUDA kernel limitation: return_final_states requires channel-last layout, # which is impossible when seq_len==1. Handle via update() with zero-init state. @@ -573,7 +635,7 @@ def forward( # Create channel-last state: stride(1) == 1 conv_state = x.new_zeros(batch_size, state_len, dim).transpose(1, 2) # Use update() which handles single tokens efficiently - out = _causal_conv1d_update( + out = causal_conv1d_update( x.squeeze(2), # [batch, dim, 1] -> [batch, dim] conv_state, self._weight, @@ -596,7 +658,7 @@ def forward( else: final_state = None - out = _causal_conv1d_fn( + out = causal_conv1d_fn( x, self._weight, bias=self.bias, @@ -633,7 +695,7 @@ def update( Returns: Output tensor [batch, dim] """ - return _causal_conv1d_update( + return causal_conv1d_update( x, conv_state, self._weight, @@ -1089,12 +1151,6 @@ def forward( **kwargs, ): """Forward pass for Mamba.""" - # Check for CUDA when using fast path - if is_fast_path_available and "cuda" not in self.in_proj.weight.device.type: - raise RuntimeError( - "Mamba with CUDA kernels requires CUDA device. Current device: " + str(self.in_proj.weight.device) - ) - cache_position = kwargs.get("cache_position", None) batch, seqlen, dim = hidden_states.shape @@ -1281,15 +1337,10 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states return ssm_state, conv_state -def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: - """L2 normalization matching Fast-LLM's implementation.""" - return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) - - class GatedRMSNormalization(nn.Module): """ Gated RMS normalization layer matching Fast-LLM's implementation. - Uses fla.modules.fused_norm_gate.rms_norm_gated (required). + Uses fla.modules.fused_norm_gate.rms_norm_gated or vLLM's rmsnorm_fn. Args: hidden_size: Size of the hidden dimension @@ -1301,24 +1352,38 @@ def __init__(self, hidden_size: int, eps: float = 1e-5, activation: str = "silu" super().__init__() if rms_norm_gated is None: raise ImportError( - "GatedRMSNormalization requires rms_norm_gated from fla library. " "Install with: pip install fla-core" + "GatedRMSNormalization requires rms_norm_gated. " + "Install fla-core or ensure vLLM is available." ) self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps self.activation = activation def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - return rms_norm_gated( - input_, - gate, - self.weight, - None, - activation=self.activation, - eps=self.eps, - residual=None, - prenorm=False, - residual_in_fp32=False, - ) + if USE_VLLM_GATED_NORM: + # vLLM's rmsnorm_fn signature: (x, weight, bias, z, eps, group_size, norm_before_gate) + return rms_norm_gated( + input_, + self.weight, + None, # bias + z=gate, + eps=self.eps, + group_size=None, + norm_before_gate=True, + ) + else: + # FLA's rms_norm_gated signature + return rms_norm_gated( + input_, + gate, + self.weight, + None, + activation=self.activation, + eps=self.eps, + residual=None, + prenorm=False, + residual_in_fp32=False, + ) class Apriel2GatedDeltaNet(nn.Module): @@ -1391,6 +1456,45 @@ def __init__( "GatedDeltaNet requires the fla library for optimized kernels. " "Install with: pip install fla-core" ) + _debug_enabled = False # Set to True for debugging + _debug_layer = False # num_tokens <= 10 + _debug_state = False # Debug recurrent state + _debug_output = False # Debug output hidden states during decode + + def _debug_tensor(self, name: str, t: torch.Tensor): + if not self._debug_enabled: + return + if t is None: + print(f"[TF-GDN layer={self.layer_idx}] {name}: None") + return + try: + flat = t.flatten()[:8] + vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) + print(f"[TF-GDN layer={self.layer_idx}] {name}: shape={t.shape}, dtype={t.dtype}, " + f"mean={t.float().mean().item():.6f}, std={t.float().std().item():.6f}, " + f"first8=[{vals}]") + except Exception as e: + print(f"[TF-GDN layer={self.layer_idx}] {name}: ERROR accessing tensor: {e}") + + def _debug_print(self, msg: str): + if not self._debug_enabled: + return + print(f"[TF-GDN layer={self.layer_idx}] {msg}") + + def _debug_state_stats(self, name: str, state: torch.Tensor, seq_len: int): + """Debug recurrent state with statistics.""" + if not self._debug_state or state is None: + return + try: + flat = state.flatten() + first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist()) + print(f"[TF-GDN L{self.layer_idx}] {name} (seq_len={seq_len}): shape={state.shape}, " + f"mean={state.float().mean().item():.6f}, std={state.float().std().item():.6f}, " + f"min={state.float().min().item():.6f}, max={state.float().max().item():.6f}, " + f"first8=[{first8}]") + except Exception as e: + print(f"[TF-GDN L{self.layer_idx}] {name}: ERROR accessing state: {e}") + def _fix_query_key_value_ordering(self, mixed_qkvz: torch.Tensor, mixed_ba: torch.Tensor): """ Split QKVZ and BA tensors using Fast-LLM's flat layout. @@ -1436,6 +1540,9 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m cache_position = kwargs.get("cache_position", None) batch_size, seq_len, _ = hidden_states.shape + self._debug_print(f"===== FORWARD START (batch={batch_size}, seq={seq_len}) =====") + self._debug_tensor("hidden_states", hidden_states) + # Get conv and recurrent state from cache if available conv_state = None recurrent_state = None @@ -1448,13 +1555,22 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m use_precomputed_states = ( past_key_values is not None and conv_state is not None and seq_len == 1 and cache_position is not None ) + self._debug_print(f"use_precomputed_states={use_precomputed_states}") # Project to QKVZ and BA mixed_qkvz = self.in_proj_qkvz(hidden_states) mixed_ba = self.in_proj_ba(hidden_states) + self._debug_tensor("mixed_qkvz", mixed_qkvz) + self._debug_tensor("mixed_ba", mixed_ba) # Split into components using Fast-LLM's flat layout query, key, value, z, beta, alpha = self._fix_query_key_value_ordering(mixed_qkvz, mixed_ba) + self._debug_tensor("query (after split)", query) + self._debug_tensor("key (after split)", key) + self._debug_tensor("value (after split)", value) + self._debug_tensor("z (after split)", z) + self._debug_tensor("beta (after split)", beta) + self._debug_tensor("alpha (after split)", alpha) # Flatten QKV for convolution (no Z in conv) query_flat = query.reshape(batch_size, seq_len, -1) @@ -1462,10 +1578,15 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m value_flat = value.reshape(batch_size, seq_len, -1) mixed_qkv = torch.cat([query_flat, key_flat, value_flat], dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) # [batch, conv_dim, seq] + mixed_qkv_before_conv = mixed_qkv # Save for debug + self._debug_tensor("mixed_qkv (before conv)", mixed_qkv) + self._debug_tensor("conv_weight", self.convolution.weight) + self._debug_tensor("conv_bias", self.convolution.bias) # Apply causal convolution if use_precomputed_states: # Single token decode - use cached conv state + self._debug_print("Using conv.update (decode path)") mixed_qkv = self.convolution.update( mixed_qkv.squeeze(2), # [batch, conv_dim, 1] -> [batch, conv_dim] conv_state, @@ -1474,6 +1595,7 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m ) # [batch, conv_dim] -> [batch, conv_dim, 1] else: # Prefill mode + self._debug_print("Using conv.forward (prefill path)") use_cache = past_key_values is not None if use_cache: mixed_qkv, final_state = self.convolution(mixed_qkv, return_final_state=True) @@ -1482,25 +1604,49 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m mixed_qkv = self.convolution(mixed_qkv) mixed_qkv = mixed_qkv.transpose(1, 2) # [batch, seq, conv_dim] + self._debug_tensor("mixed_qkv (after conv)", mixed_qkv) # Split back after convolution query_flat, key_flat, value_flat = torch.split(mixed_qkv, (self.key_dim, self.key_dim, self.value_dim), dim=-1) query = query_flat.reshape(batch_size, seq_len, self.key_heads, self.key_head_dim) key = key_flat.reshape(batch_size, seq_len, self.key_heads, self.key_head_dim) value = value_flat.reshape(batch_size, seq_len, self.value_heads, self.value_head_dim) + self._debug_tensor("query (after conv)", query) + self._debug_tensor("key (after conv)", key) + self._debug_tensor("value (after conv)", value) # Compute gating - match Fast-LLM exactly beta_gate = beta.sigmoid() g = -self.A_log.float().exp() * F.softplus(alpha.float() + self.dt_bias) + self._debug_tensor("beta_gate", beta_gate) + self._debug_tensor("g", g) + self._debug_tensor("A_log", self.A_log) + self._debug_tensor("dt_bias", self.dt_bias) # Expand K heads to V heads if grouped query attention if self.value_heads_per_key > 1: query = query.repeat_interleave(self.value_heads_per_key, dim=2) key = key.repeat_interleave(self.value_heads_per_key, dim=2) + self._debug_print(f"Expanded q/k heads: {self.key_heads} -> {self.value_heads}") + self._debug_tensor("query (after expand)", query) + self._debug_tensor("key (after expand)", key) # Run gated delta rule (FLA kernels required) + self._debug_tensor("recurrent_state (initial)", recurrent_state) if not use_precomputed_states: # Chunked mode for prefill + self._debug_print("Using chunk_gated_delta_rule (prefill)") + # Debug PREFILL INPUTS before kernel call + if self._debug_state: + print(f"[TF-GDN L{self.layer_idx}] PREFILL INPUTS:") + print(f" hidden_states: shape={hidden_states.shape}, first8={hidden_states.flatten()[:8].tolist()}") + print(f" mixed_qkv_before_conv: shape={mixed_qkv_before_conv.shape}, first8={mixed_qkv_before_conv.flatten()[:8].tolist()}") + print(f" q: shape={query.shape}, first8={query.flatten()[:8].tolist()}") + print(f" k: shape={key.shape}, first8={key.flatten()[:8].tolist()}") + print(f" v: shape={value.shape}, first8={value.flatten()[:8].tolist()}") + print(f" g: shape={g.shape}, first8={g.flatten()[:8].tolist()}") + print(f" beta: shape={beta_gate.shape}, first8={beta_gate.flatten()[:8].tolist()}") + print(f" initial_state: {recurrent_state}") output, last_recurrent_state = chunk_gated_delta_rule( query, key, @@ -1514,18 +1660,42 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m # Ensure state is in same dtype as hidden_states (fla kernel may return float32) if last_recurrent_state is not None: last_recurrent_state = last_recurrent_state.to(hidden_states.dtype) + self._debug_state_stats("PREFILL out_state", last_recurrent_state, seq_len) else: # Recurrent mode for single token decode - output, last_recurrent_state = fused_recurrent_gated_delta_rule( - query, - key, - value, - g=g, - beta=beta_gate, - initial_state=recurrent_state, - output_final_state=past_key_values is not None, - use_qk_l2norm_in_kernel=True, - ) + self._debug_print("Using fused_recurrent_gated_delta_rule (decode)") + self._debug_state_stats("DECODE in_state", recurrent_state, seq_len) + # Debug decode inputs + if self._debug_state: + print(f"[TF-GDN L{self.layer_idx}] DECODE inputs: q={query.flatten()[:4].tolist()}, k={key.flatten()[:4].tolist()}, v={value.flatten()[:4].tolist()}, g={g.flatten()[:4].tolist()}, beta={beta_gate.flatten()[:4].tolist()}") + # vLLM and FLA have different signatures: + # - vLLM: inplace_final_state (default True, set False to avoid ssm_state_indices requirement) + # - FLA: output_final_state + if USE_VLLM_GDN_OPS: + output, last_recurrent_state = fused_recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta_gate, + initial_state=recurrent_state, + inplace_final_state=False, + use_qk_l2norm_in_kernel=True, + ) + else: + output, last_recurrent_state = fused_recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta_gate, + initial_state=recurrent_state, + output_final_state=past_key_values is not None, + use_qk_l2norm_in_kernel=True, + ) + self._debug_state_stats("DECODE out_state", last_recurrent_state, seq_len) + + self._debug_tensor("output (after FLA)", output) # Update recurrent state in cache if past_key_values is not None: @@ -1535,12 +1705,44 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m z_shape_og = z.shape output = output.reshape(-1, output.shape[-1]) z_flat = z.reshape(-1, z.shape[-1]) + self._debug_tensor("output (before norm)", output) + self._debug_tensor("z_flat (for norm)", z_flat) + # Debug last token before norm (reshaped has tokens * heads rows) + batch_size, num_tokens = hidden_states.shape[:2] + if self._debug_layer and num_tokens > 0: + num_heads = self.value_heads + last_token_start = (num_tokens - 1) * num_heads + last_out = output[last_token_start:last_token_start+1, :8] + last_z = z_flat[last_token_start:last_token_start+1, :8] + print(f"[TF-GDN layer={self.layer_idx}] output before norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_out.flatten().float().tolist())}]") + print(f"[TF-GDN layer={self.layer_idx}] z before norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_z.flatten().float().tolist())}]") + self._debug_tensor("norm.weight", self.norm.weight) + self._debug_print(f"norm.eps={self.norm.eps}, norm.activation={self.norm.activation}") output = self.norm(output, z_flat) + self._debug_tensor("output (after norm)", output) + # Debug last token after norm + if self._debug_layer and num_tokens > 0: + last_out_after = output[last_token_start:last_token_start+1, :8] + print(f"[TF-GDN layer={self.layer_idx}] output after norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_out_after.flatten().float().tolist())}]") output = output.reshape(z_shape_og) output = output.reshape(output.shape[0], output.shape[1], -1) # Output projection output = self.out_proj(output) + self._debug_tensor("output (final)", output) + # Show last token specifically + if self._debug_layer and output.dim() == 3: + last_token = output[0, -1, :8] + vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist()) + print(f"[TF-GDN layer={self.layer_idx}] output (last token): last_token_first8=[{vals}]") + # Debug output hidden states during decode + # Get decode step from cache + decode_step = past_key_values.get_seq_length() if past_key_values is not None else 0 + if self._debug_output and use_precomputed_states and output.dim() == 3: + flat = output.flatten() + first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist()) + print(f"[TF-GDN L{self.layer_idx}] STEP={decode_step} OUTPUT hs: mean={output.float().mean().item():.6f}, std={output.float().std().item():.6f}, first8=[{first8}]") + self._debug_print("===== FORWARD END =====") return (output,) @@ -2115,6 +2317,21 @@ def _create_norm(self, norm_config: dict, hidden_size: int, rms_norm_eps: float) else: raise ValueError(f"Unknown normalization type: {norm_type}") + _debug_layer = False # Set to True to debug layer outputs + + def _debug_tensor(self, name: str, t: torch.Tensor, show_last=False): + if not self._debug_layer or t is None: + return + if show_last: + # Show last token + last = t[0, -1, :8] + vals = ", ".join(f"{v:.6f}" for v in last.float().tolist()) + print(f"[TF Layer {self.layer_idx}] {name}: shape={t.shape}, last_token_first8=[{vals}]") + else: + flat = t.flatten()[:8] + vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) + print(f"[TF Layer {self.layer_idx}] {name}: shape={t.shape}, first8=[{vals}]") + def forward( self, hidden_states: torch.Tensor, @@ -2126,8 +2343,14 @@ def forward( position_embeddings=None, **kwargs, ) -> tuple: + num_tokens = hidden_states.size(1) + self._debug_layer = False # Disabled for testing + + self._debug_tensor("input hidden_states", hidden_states) + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + self._debug_tensor("after input_layernorm", hidden_states) mixer_outputs = self.mixer( hidden_states, @@ -2140,13 +2363,23 @@ def forward( **kwargs, ) hidden_states = mixer_outputs[0] + self._debug_tensor("mixer output", hidden_states) + hidden_states = residual + hidden_states + self._debug_tensor("after residual add 1", hidden_states) # MLP residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) + self._debug_tensor("after post_attention_layernorm", hidden_states) + hidden_states = self.mlp(hidden_states) + self._debug_tensor("after mlp", hidden_states) + hidden_states = residual + hidden_states + self._debug_tensor("after residual add 2 (final)", hidden_states) + # Also show last token for final layer comparison + self._debug_tensor("after residual add 2 (last token)", hidden_states, show_last=True) outputs = (hidden_states,) if output_attentions: @@ -2411,8 +2644,24 @@ def forward( ) # Apply final normalization + # Debug final norm + batch_size, seq_len = hidden_states.shape[:2] + _debug_final = False # seq_len <= 10 + if _debug_final: + # Show LAST token (to match vLLM) + last_token = hidden_states[0, -1, :8] + vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist()) + print(f"[TF Final] hidden_states (before norm): shape={hidden_states.shape}, last_token_first8=[{vals}]") + print(f"[TF Final] norm.weight: first8=[{', '.join(f'{v:.6f}' for v in self.norm.weight.flatten()[:8].float().tolist())}]") + print(f"[TF Final] norm.variance_epsilon={self.norm.variance_epsilon}") + hidden_states = self.norm(hidden_states) + if _debug_final: + last_token = hidden_states[0, -1, :8] + vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist()) + print(f"[TF Final] hidden_states (after norm): shape={hidden_states.shape}, last_token_first8=[{vals}]") + # Add final hidden state if requested if output_hidden_states: all_hidden_states += (hidden_states,) @@ -2494,10 +2743,27 @@ def forward( hidden_states = outputs.last_hidden_state + # Debug LM head input + batch_size, seq_len = hidden_states.shape[:2] + _debug_lm_head = False # seq_len <= 10 + if _debug_lm_head: + # Show LAST token's first 8 features (to match vLLM which only passes last token) + last_token_hs = hidden_states[0, -1, :8] + vals = ", ".join(f"{v:.6f}" for v in last_token_hs.float().tolist()) + print(f"[TF LM Head] input hidden_states: shape={hidden_states.shape}, last_token_first8=[{vals}]") + print(f"[TF LM Head] lm_head.weight: shape={self.lm_head.weight.shape}, first8=[{', '.join(f'{v:.6f}' for v in self.lm_head.weight.flatten()[:8].float().tolist())}]") + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) + if _debug_lm_head: + # Get last token logits + last_logits = logits[0, -1] + top_vals, top_idx = last_logits.topk(5) + print(f"[TF LM Head] logits shape={logits.shape}") + print(f"[TF LM Head] last token top-5 logits: {[(idx.item(), val.item()) for idx, val in zip(top_idx, top_vals)]}") + loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues diff --git a/fast_llm_external_models/apriel2/vllm/README.md b/fast_llm_external_models/apriel2/vllm/README.md new file mode 100644 index 000000000..402dff642 --- /dev/null +++ b/fast_llm_external_models/apriel2/vllm/README.md @@ -0,0 +1,23 @@ +# vLLM Support for Apriel2 + +## Usage + +Register Apriel2 before creating the LLM: + +```python +from fast_llm_external_models.apriel2.vllm import register +from vllm import LLM + +register() + +llm = LLM(model="path/to/apriel2/checkpoint") +``` + +## Entry Point (Alternative) + +Add to your `pyproject.toml` to auto-register on vLLM import: + +```toml +[project.entry-points."vllm.plugins"] +apriel2 = "fast_llm_external_models.apriel2.vllm:register" +``` diff --git a/fast_llm_external_models/apriel2/vllm/__init__.py b/fast_llm_external_models/apriel2/vllm/__init__.py new file mode 100644 index 000000000..d258911a5 --- /dev/null +++ b/fast_llm_external_models/apriel2/vllm/__init__.py @@ -0,0 +1,17 @@ +"""vLLM model implementation for Apriel2. + +This module provides vLLM-optimized implementations of Apriel2 models. +See README.md for usage instructions. + +Placement switching (for stochastic mixer models): + placements = llm.collective_rpc("get_layer_placements") + llm.collective_rpc("set_layer_placements", args=(new_placement,)) + +Plugin usage (for vLLM subprocess registration): + Set VLLM_PLUGINS=fast_llm_external_models.apriel2.vllm.config_convertor + before starting vLLM to ensure registration in subprocesses. +""" + +from fast_llm_external_models.apriel2.vllm.modeling_apriel2 import Apriel2ForCausalLM + +__all__ = ["Apriel2ForCausalLM"] diff --git a/fast_llm_external_models/apriel2/vllm/config_convertor.py b/fast_llm_external_models/apriel2/vllm/config_convertor.py new file mode 100644 index 000000000..0b15733f5 --- /dev/null +++ b/fast_llm_external_models/apriel2/vllm/config_convertor.py @@ -0,0 +1,101 @@ +"""Config convertor and registration for Apriel2 models. + +This module provides: +1. A custom ModelArchConfigConvertor for Apriel2's nested decoder config format +2. A register() function for vLLM's plugin system (entry_points) + +Registration is automatic when fast-llm is installed. vLLM discovers the +entry point defined in setup.cfg and calls register() in all processes. +""" + +from vllm import ModelRegistry +from vllm.transformers_utils.model_arch_config_convertor import ( + MODEL_ARCH_CONFIG_CONVERTORS, + ModelArchConfigConvertorBase, +) + + +class Apriel2TextModelArchConfigConvertor(ModelArchConfigConvertorBase): + """Config convertor for Apriel2TextConfig with nested decoder structure. + + Apriel2 configs use a nested decoder format instead of standard HuggingFace + attributes like num_hidden_layers. This convertor extracts the required + values from the nested structure. + """ + + def _get_first_attention_block(self): + """Find the first attention block config. + + Handles both regular and stochastic mixer types. For stochastic mixers, + looks up the main_mixer_name to find the attention config. + """ + decoder = getattr(self.hf_text_config, 'decoder', {}) + decoder_type = decoder.get('type', 'fixed') + + if decoder_type == 'fixed': + block = decoder.get('block', {}) + mixer = block.get('mixer', {}) + mixer_type = mixer.get('type', 'attention') + if mixer_type == 'stochastic': + main_mixer_name = mixer.get('main_mixer_name', 'attention') + return mixer.get('mixers', {}).get(main_mixer_name, {}) + elif mixer_type == 'attention': + return mixer + elif decoder_type == 'pattern': + blocks = decoder.get('blocks', {}) + pattern = decoder.get('pattern', []) + for block_name in pattern: + block = blocks.get(block_name, {}) + mixer = block.get('mixer', {}) + if mixer.get('type') == 'attention': + return mixer + return {} + + def get_num_hidden_layers(self) -> int: + decoder = getattr(self.hf_text_config, 'decoder', {}) + return decoder.get('num_blocks', 0) + + def get_total_num_attention_heads(self) -> int: + return self._get_first_attention_block().get('heads', 0) + + def get_total_num_kv_heads(self) -> int: + return self._get_first_attention_block().get( + 'head_groups', self.get_total_num_attention_heads() + ) + + def get_head_size(self) -> int: + return self._get_first_attention_block().get('head_size', 0) + + +def register(): + """Register Apriel2 models and config convertors with vLLM. + + This function is called automatically by vLLM's plugin system via Python's + entry_points mechanism. The entry point is defined in fast-llm's setup.cfg: + + [options.entry_points] + vllm.general_plugins = + apriel2 = fast_llm_external_models.apriel2.vllm.config_convertor:register + + vLLM discovers all entry points in the 'vllm.general_plugins' group using + importlib.metadata and calls each plugin's register function during startup. + This happens in every process (parent and subprocesses spawned by AsyncLLM), + ensuring model registration is available everywhere. + + The VLLM_PLUGINS environment variable can optionally filter which plugins + are loaded (comma-separated list of plugin names to enable). + + Safe to call multiple times - skips registration if already done. + """ + # Skip if already registered + if 'apriel2_text' in MODEL_ARCH_CONFIG_CONVERTORS: + return + + # Register config convertor (only apriel2_text, not apriel2 with vision encoder) + MODEL_ARCH_CONFIG_CONVERTORS['apriel2_text'] = Apriel2TextModelArchConfigConvertor + + # Register model class + ModelRegistry.register_model( + "Apriel2ForCausalLM", + "fast_llm_external_models.apriel2.vllm:Apriel2ForCausalLM", + ) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py new file mode 100644 index 000000000..100014a60 --- /dev/null +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -0,0 +1,2863 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Apriel2 model compatible with HuggingFace weights. + +Apriel2 is a hybrid model that supports multiple mixer types (attention, mamba, +GatedDeltaNet, KDA) with flexible block patterns. This implementation is +optimized for vLLM inference. +""" + +import logging +import math +from collections.abc import Iterable +from itertools import islice + +import torch + +logger = logging.getLogger(__name__) +import triton +from einops import rearrange +from torch import nn +from transformers import PretrainedConfig +from transformers.activations import ACT2FN + +from vllm.v1.attention.backend import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CacheConfig, + ModelConfig, + SpeculativeConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + divide, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fla.ops import ( + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, +) +from vllm.model_executor.models.qwen3_next import ( + fused_gdn_gating as qwen3_fused_gdn_gating, +) +from vllm.model_executor.layers.fla.ops.kda import ( + FusedRMSNormGated, + chunk_kda, + fused_kda_gate, + fused_recurrent_kda, +) + +# Import to register kda_attention custom op +import vllm.model_executor.layers.kda # noqa: F401 +from vllm.model_executor.layers.layernorm import RMSNorm, RMSNormGated +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.v1.attention.backend import AttentionBackend +from vllm.v1.attention.selector import get_mamba_attn_backend +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + sharded_weight_loader, +) +from vllm.model_executor.models.interfaces import HasInnerState, SupportsPP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheSpec, + MambaSpec, + SlidingWindowSpec, +) +from vllm.logger import init_logger + +apriel2_logger = init_logger(__name__) + +# ============================================================================= +# Debug Flags +# ============================================================================= +# Top-level debug flags that control all debug output in the module. +# Set these to True to enable debugging for specific components. + +DEBUG_GDN_LAYER = False # Debug GDN layer forward pass (tensors, shapes) +DEBUG_GDN_STATE = False # Debug GDN recurrent state during decode +DEBUG_GDN_OUTPUT = False # Debug GDN output hidden states during decode +DEBUG_KDA_LAYER = False # Debug KDA layer outputs +DEBUG_DECODER_LAYER = False # Debug decoder layer outputs (residual, norm) +DEBUG_FINAL_NORM = False # Debug final norm before LM head +DEBUG_LM_HEAD = False # Debug LM head input/output + + +# ============================================================================= +# KV Cache Spec Computation +# ============================================================================= + +from dataclasses import dataclass +from typing import Literal + + +# Valid mamba_type values for MambaSpec +MambaType = Literal["gdn_attention", "kda_attention", "mamba"] + +# Cache for unified page size computation (keyed by object ids). +# Avoids recomputing the same value for every layer during model init. +_unified_page_size_cache: dict[tuple[int, int], tuple[int, int]] = {} + + +def _get_dtype_size(dtype: torch.dtype) -> int: + """Get size in bytes for a torch dtype.""" + if isinstance(dtype, str): + # Handle string dtype names (e.g., "auto", "bfloat16") + if dtype == "auto": + dtype = torch.bfloat16 # Default to bfloat16 + else: + dtype = getattr(torch, dtype, torch.bfloat16) + return torch.tensor([], dtype=dtype).element_size() + + +@dataclass +class AttentionBlockParams: + """Parameters for an attention block's KV cache.""" + num_kv_heads: int + head_size: int + window_size: int | None + dtype: torch.dtype + + @property + def page_size_per_token(self) -> int: + """Bytes per token for K + V.""" + return 2 * self.num_kv_heads * self.head_size * _get_dtype_size(self.dtype) + + +@dataclass +class MambaBlockParams: + """Parameters for a mamba-like block's state cache.""" + shapes: tuple + dtypes: tuple + mamba_type: MambaType + + @property + def natural_page_size(self) -> int: + """Natural page size based on state shapes.""" + return sum( + _get_dtype_size(dtype) * math.prod(shape) + for shape, dtype in zip(self.shapes, self.dtypes) + ) + + +BlockParams = AttentionBlockParams | MambaBlockParams + + +def _create_mixer_params( + mixer_config: dict, + mixer_type: str, + vllm_config: VllmConfig, +) -> BlockParams | None: + """Create BlockParams for a single mixer config. + + This is the single source of truth for converting a mixer config dict + into typed BlockParams. Used by both top-level and stochastic mixer handling. + + Args: + mixer_config: The mixer configuration dict. + mixer_type: The mixer type string (attention, gdn, kda, mamba). + vllm_config: The vLLM config for cache/parallel settings. + + Returns: + BlockParams for this mixer, or None if the mixer type doesn't need cache. + """ + cache_config = vllm_config.cache_config + model_dtype = vllm_config.model_config.dtype + tp_size = vllm_config.parallel_config.tensor_parallel_size + + if mixer_type == "attention" or mixer_type == "sliding_window": + cache_dtype = cache_config.cache_dtype + if cache_dtype is None or cache_dtype == "auto": + kv_cache_dtype = model_dtype + elif isinstance(cache_dtype, str): + kv_cache_dtype = getattr(torch, cache_dtype, model_dtype) + else: + kv_cache_dtype = cache_dtype + + return AttentionBlockParams( + num_kv_heads=mixer_config["head_groups"], + head_size=mixer_config["head_size"], + window_size=mixer_config.get("window_size"), + dtype=kv_cache_dtype, + ) + + elif mixer_type == "gdn": + shapes = MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_world_size=tp_size, + num_k_heads=mixer_config["key_heads"], + num_v_heads=mixer_config["value_heads"], + head_k_dim=mixer_config["key_head_dim"], + head_v_dim=mixer_config["value_head_dim"], + conv_kernel_size=mixer_config["convolution_layer"]["kernel_size"], + num_spec=0, + ) + dtypes = MambaStateDtypeCalculator.gated_delta_net_state_dtype( + model_dtype, + cache_config.mamba_cache_dtype, + ) + return MambaBlockParams( + shapes=shapes, + dtypes=dtypes, + mamba_type="gdn_attention", + ) + + elif mixer_type == "kda": + shapes = MambaStateShapeCalculator.kda_state_shape( + tp_world_size=tp_size, + num_heads=mixer_config["heads"], + head_dim=mixer_config["head_dim"], + conv_kernel_size=mixer_config["convolution_layer"]["kernel_size"], + ) + dtypes = MambaStateDtypeCalculator.kda_state_dtype( + model_dtype, + cache_config.mamba_cache_dtype, + ) + return MambaBlockParams( + shapes=shapes, + dtypes=dtypes, + mamba_type="kda_attention", + ) + + elif mixer_type == "mamba": + d_state = mixer_config["state_size"] + d_conv = mixer_config["d_conv"] + d_inner = mixer_config.get("d_inner") + if d_inner is None: + raise ValueError("Mamba mixer must specify 'd_inner'") + shapes = MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=tp_size, + intermediate_size=d_inner, + state_size=d_state, + conv_kernel=d_conv, + ) + dtypes = MambaStateDtypeCalculator.mamba1_state_dtype( + model_dtype, + cache_config.mamba_cache_dtype, + cache_config.mamba_ssm_cache_dtype, + ) + return MambaBlockParams( + shapes=shapes, + dtypes=dtypes, + mamba_type="mamba", + ) + + # Unknown mixer type - may not need cache + return None + + +def get_block_params( + blocks_config: dict[str, dict], + vllm_config: VllmConfig, +) -> dict[str, BlockParams]: + """Parse block configs and compute cache parameters ONCE. + + This is the single source of truth for shapes, dtypes, and page sizes. + Downstream functions use these precomputed params. + + Args: + blocks_config: Dict mapping block names to their configs. + vllm_config: The vLLM config for cache/parallel settings. + + Returns: + Dict mapping block names to their BlockParams. + """ + params: dict[str, BlockParams] = {} + + for block_name, block_config in blocks_config.items(): + mixer_config = block_config.get("mixer", {}) + mixer_type = mixer_config.get("type", "attention") + + if mixer_type == "stochastic": + # For stochastic mixers, compute params for ALL sub-mixers + # This creates the "convex hull" of cache requirements so the unified + # page size is large enough for any mixer type + mixers = mixer_config.get("mixers", {}) + for sub_mixer_name, sub_mixer_config in mixers.items(): + sub_mixer_type = sub_mixer_config.get("type", "attention") + sub_block_name = f"{block_name}.{sub_mixer_name}" + sub_params = _create_mixer_params(sub_mixer_config, sub_mixer_type, vllm_config) + if sub_params is not None: + params[sub_block_name] = sub_params + else: + # Regular (non-stochastic) mixer + mixer_params = _create_mixer_params(mixer_config, mixer_type, vllm_config) + if mixer_params is not None: + params[block_name] = mixer_params + else: + raise ValueError(f"Block '{block_name}': unknown mixer type '{mixer_type}'") + + return params + + +def get_block_page_sizes( + block_params: dict[str, BlockParams], +) -> tuple[int | None, dict[str, int]]: + """Extract page sizes from precomputed block params. + + Args: + block_params: Dict mapping block names to their BlockParams. + + Returns: + Tuple of: + - attn_page_per_token: Bytes per token for attention (None if no attention). + - mamba_page_sizes: Dict mapping mamba block names to natural page sizes. + """ + attn_page_per_token: int | None = None + mamba_page_sizes: dict[str, int] = {} + + for block_name, params in block_params.items(): + if isinstance(params, AttentionBlockParams): + # All attention blocks should have same head config + attn_page_per_token = params.page_size_per_token + elif isinstance(params, MambaBlockParams): + mamba_page_sizes[block_name] = params.natural_page_size + + return attn_page_per_token, mamba_page_sizes + + +def unify_block_page_sizes( + attn_page_per_token: int | None, + mamba_page_sizes: dict[str, int], + default_block_size: int = 16, + alignment: int = 16, +) -> tuple[int, int]: + """Compute unified (block_size, page_size) for all block types. + + The unified page_size must work for both attention (which scales with + block_size) and mamba-like blocks (fixed state sizes). We achieve this by: + 1. Finding max mamba page size + 2. Computing block_size so attention page >= max mamba page + 3. Padding mamba pages to match attention page + + Args: + attn_page_per_token: Bytes per token for attention (None if no attention). + mamba_page_sizes: Dict of mamba-like block names to natural page sizes. + default_block_size: Minimum block size for attention. + alignment: Block size alignment (FlashAttention needs 16). + + Returns: + Tuple of (block_size, unified_page_size). + """ + # Pure attention model + if not mamba_page_sizes: + block_size = max(default_block_size, alignment) + if attn_page_per_token is None: + return block_size, 0 + return block_size, block_size * attn_page_per_token + + # Pure mamba model + if attn_page_per_token is None: + max_mamba_page = max(mamba_page_sizes.values()) + return default_block_size, max_mamba_page + + # Hybrid model: need to align attention and mamba page sizes + max_mamba_page = max(mamba_page_sizes.values()) + + # Compute minimum block_size so attention page >= max mamba page + # attn_page = block_size * attn_page_per_token >= max_mamba_page + min_block_size = -(-max_mamba_page // attn_page_per_token) # ceiling division + + # Align to kernel requirements + aligned_block_size = alignment * -(-min_block_size // alignment) + + # Use larger of default and computed + block_size = max(default_block_size, aligned_block_size) + + # Unified page size (attention page, mamba will be padded to match) + unified_page_size = block_size * attn_page_per_token + + apriel2_logger.info( + "Page size unification: max_mamba=%d, attn_per_token=%d, " + "block_size=%d, unified_page=%d", + max_mamba_page, attn_page_per_token, block_size, unified_page_size + ) + + return block_size, unified_page_size + + +def get_blocks_config(decoder_config: dict) -> dict[str, dict]: + """Extract the blocks config dict from a decoder config. + + Handles both 'fixed' (single block) and 'pattern' (multiple blocks) modes. + + Args: + decoder_config: The decoder config dict from model config. + + Returns: + Dict mapping block names to their configs. + """ + seq_type = decoder_config.get("type", "fixed") + + if seq_type == "fixed": + # Single block type - synthesize a name + block_config = decoder_config.get("block", {}) + return {"block": block_config} + elif seq_type == "pattern": + return decoder_config.get("blocks", {}) + else: + return {} + + +def get_unified_page_size_for_config( + config: PretrainedConfig, + vllm_config: VllmConfig, +) -> tuple[int, int]: + """Compute unified (block_size, page_size) for the model config. + + This is used by layer-level get_kv_cache_spec() methods to ensure + all layers return specs with matching page_size_bytes, even when + vLLM iterates over layers individually (e.g., TransformersForCausalLM). + + Results are cached by object identity to avoid redundant computation + when called from each layer's get_kv_cache_spec(). + + Args: + config: The HuggingFace model config. + vllm_config: The vLLM config. + + Returns: + Tuple of (block_size, unified_page_size). + """ + cache_key = (id(config), id(vllm_config)) + if cache_key in _unified_page_size_cache: + return _unified_page_size_cache[cache_key] + + decoder_config = getattr(config, "decoder", {}) or {} + blocks_config = get_blocks_config(decoder_config) + block_params = get_block_params(blocks_config, vllm_config) + attn_page_per_token, mamba_page_sizes = get_block_page_sizes(block_params) + result = unify_block_page_sizes(attn_page_per_token, mamba_page_sizes) + + _unified_page_size_cache[cache_key] = result + return result + + +class Apriel2Config(PretrainedConfig): + """Configuration for Apriel2 models. + + This config supports both text-only and multimodal variants with + flexible decoder block patterns (attention, mamba, GDN, KDA). + """ + + model_type = "apriel2" + + def __init__( + self, + vocab_size: int = 131072, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + head_dim: int = 128, + hidden_act: str = "silu", + max_position_embeddings: int = 131072, + rms_norm_eps: float = 1e-5, + tie_word_embeddings: bool = True, + rope_theta: float = 500000.0, + rope_scaling: dict | None = None, + # Apriel2 specific + decoder: dict | None = None, + embeddings: dict | None = None, + head: dict | None = None, + vision_encoder: dict | None = None, + image_token_index: int | None = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + # Apriel2 specific configs + self.decoder = decoder or {} + self.embeddings = embeddings or { + "max_position_embeddings": max_position_embeddings + } + self.head = head or {"normalization": {"epsilon": rms_norm_eps}} + self.vision_encoder = vision_encoder + self.image_token_index = image_token_index + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + @property + def layers_block_type(self) -> list[str]: + """Return block types for each layer (for hybrid model detection).""" + decoder_config = self.decoder + seq_type = decoder_config.get("type", "fixed") + num_blocks = decoder_config.get("num_blocks", self.num_hidden_layers) + + if seq_type == "fixed": + block_config = decoder_config.get("block", {}) + mixer_type = block_config.get("mixer", {}).get("type", "attention") + return [mixer_type] * num_blocks + elif seq_type == "pattern": + pattern = decoder_config.get("pattern", ["attention"]) + blocks_config = decoder_config.get("blocks", {}) + result = [] + for i in range(num_blocks): + block_name = pattern[i % len(pattern)] + mixer_type = ( + blocks_config.get(block_name, {}) + .get("mixer", {}) + .get("type", "attention") + ) + result.append(mixer_type) + return result + return ["attention"] * num_blocks + + +class Apriel2MLP(nn.Module): + """Apriel2 MLP with gated activation (SwiGLU style).""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported." + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loaded = set() + for name, weight in weights: + if name == "gate_proj.weight": + self.gate_up_proj.weight_loader(self.gate_up_proj.weight, weight, 0) + loaded.add("gate_up_proj.weight") + elif name == "up_proj.weight": + self.gate_up_proj.weight_loader(self.gate_up_proj.weight, weight, 1) + loaded.add("gate_up_proj.weight") + elif name == "down_proj.weight": + self.down_proj.weight_loader(self.down_proj.weight, weight) + loaded.add("down_proj.weight") + return loaded + + +class Apriel2Attention(nn.Module): + """Apriel2 attention layer with rotary embeddings and GQA support.""" + + def __init__( + self, + config: Apriel2Config, + mixer_config: dict, + layer_idx: int, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + # Extract from mixer config (required) + self.total_num_heads = mixer_config["heads"] + self.total_num_kv_heads = mixer_config["head_groups"] + self.head_dim = mixer_config["head_size"] + + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + # Bias configuration - supports per-layer overrides + default_bias = mixer_config.get("add_linear_biases", False) + + def get_layer_bias(layer_name: str) -> bool: + layer_cfg = mixer_config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + return default_bias if enabled is None else enabled + + q_bias = get_layer_bias("query_layer") + k_bias = get_layer_bias("key_layer") + v_bias = get_layer_bias("value_layer") + o_bias = get_layer_bias("dense_layer") + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=q_bias or k_bias or v_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=o_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # Rotary embeddings + rotary_config = mixer_config.get("rotary", {}) + rope_theta = rotary_config["theta"] + max_pos = config.embeddings["max_position_embeddings"] + + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_pos, + rope_parameters={"rope_theta": rope_theta}, + ) + + # Sliding window support + self.window_size = mixer_config.get("window_size", None) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.window_size, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor | None = None, + **kwargs, + ) -> None: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output[:], _ = self.o_proj(attn_output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loaded = set() + for name, weight in weights: + if name == "q_proj.weight": + self.qkv_proj.weight_loader(self.qkv_proj.weight, weight, "q") + loaded.add("qkv_proj.weight") + elif name == "k_proj.weight": + self.qkv_proj.weight_loader(self.qkv_proj.weight, weight, "k") + loaded.add("qkv_proj.weight") + elif name == "v_proj.weight": + self.qkv_proj.weight_loader(self.qkv_proj.weight, weight, "v") + loaded.add("qkv_proj.weight") + elif name == "q_proj.bias": + self.qkv_proj.weight_loader(self.qkv_proj.bias, weight, "q") + loaded.add("qkv_proj.bias") + elif name == "k_proj.bias": + self.qkv_proj.weight_loader(self.qkv_proj.bias, weight, "k") + loaded.add("qkv_proj.bias") + elif name == "v_proj.bias": + self.qkv_proj.weight_loader(self.qkv_proj.bias, weight, "v") + loaded.add("qkv_proj.bias") + elif name == "o_proj.weight": + self.o_proj.weight_loader(self.o_proj.weight, weight) + loaded.add("o_proj.weight") + elif name == "o_proj.bias": + self.o_proj.weight_loader(self.o_proj.bias, weight) + loaded.add("o_proj.bias") + return loaded + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + """Return cache spec for attention with unified page size for hybrid models.""" + config = vllm_config.model_config.hf_config + block_size, _ = get_unified_page_size_for_config(config, vllm_config) + + # Get dtype from cache config + cache_dtype = vllm_config.cache_config.cache_dtype + if cache_dtype is None or cache_dtype == "auto": + kv_cache_dtype = vllm_config.model_config.dtype + elif isinstance(cache_dtype, str): + kv_cache_dtype = getattr(torch, cache_dtype, vllm_config.model_config.dtype) + else: + kv_cache_dtype = cache_dtype + + if self.window_size is not None: + return SlidingWindowSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_dim, + dtype=kv_cache_dtype, + sliding_window=self.window_size, + ) + else: + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_dim, + dtype=kv_cache_dtype, + ) + + +class Apriel2MambaMixer(nn.Module): + """Apriel2 Mamba mixer layer wrapping vLLM's MambaMixer.""" + + def __init__( + self, + config: Apriel2Config, + mixer_config: dict, + layer_idx: int, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + + # Extract mamba params from config - architecture values required + d_state = mixer_config["state_size"] + d_conv = mixer_config["d_conv"] + expand = mixer_config.get("expand", None) + d_inner = mixer_config.get("d_inner", None) + if d_inner is None: + if expand is None: + raise ValueError("mixer_config must specify either 'd_inner' or 'expand'") + d_inner = int(expand * config.hidden_size) + dt_rank = mixer_config.get("dt_rank", "auto") + if dt_rank == "auto": + dt_rank = math.ceil(config.hidden_size / 16) + + conv_bias = mixer_config.get("conv_bias", True) + bias = mixer_config.get("add_linear_biases", False) + + self.mamba = MambaMixer( + hidden_size=config.hidden_size, + ssm_state_size=d_state, + conv_kernel_size=d_conv, + intermediate_size=d_inner, + time_step_rank=dt_rank, + use_conv_bias=conv_bias, + use_bias=bias, + use_rms_norm=False, + activation=mixer_config.get("activation", "silu"), + model_config=model_config, + cache_config=cache_config, + prefix=prefix, + ) + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor | None = None, + **kwargs, + ) -> None: + self.mamba(hidden_states, output) + + +# ============================================================================ +# GDN custom op registration +# ============================================================================ + + +def apriel2_gdn_attention_core( + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward_core( + mixed_qkv=mixed_qkv, + b=b, + a=a, + core_attn_out=core_attn_out, + ) + + +def apriel2_gdn_attention_core_fake( + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="apriel2_gdn_attention_core", + op_func=apriel2_gdn_attention_core, + mutates_args=["core_attn_out"], + fake_impl=apriel2_gdn_attention_core_fake, +) + + +@triton.jit +def fused_gdn_gating_kernel( + A_log_ptr, + a_ptr, + b_ptr, + dt_bias_ptr, + g_ptr, + beta_ptr, + num_heads: tl.constexpr, + total_elements: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + SOFTPLUS_THRESHOLD: tl.constexpr, +): + """Fused kernel for GDN gating computation.""" + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < total_elements + + # Load and convert to fp32 for math operations (exp/log require fp32/fp64) + A_log = tl.load(A_log_ptr + offset % num_heads, mask=mask).to(tl.float32) + dt_bias = tl.load(dt_bias_ptr + offset % num_heads, mask=mask).to(tl.float32) + a = tl.load(a_ptr + offset, mask=mask).to(tl.float32) + b = tl.load(b_ptr + offset, mask=mask).to(tl.float32) + + # g = -exp(A_log) * softplus(a + dt_bias) + # Use numerically stable softplus: for large x, softplus(x) ≈ x + A = tl.exp(A_log) + x = a + dt_bias + softplus_val = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x) + g = -A * softplus_val + + # beta = sigmoid(b) + beta = tl.sigmoid(b) + + tl.store(g_ptr + offset, g, mask=mask) + tl.store(beta_ptr + offset, beta, mask=mask) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + softplus_threshold: float = 20.0, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute GDN gating: g = -exp(A_log) * softplus(a + dt_bias), beta = sigmoid(b)""" + batch_size = a.shape[0] + num_heads = a.shape[-1] + + g = torch.empty_like(a, dtype=torch.float32) + beta = torch.empty_like(b) + + total_elements = batch_size * num_heads + BLOCK_SIZE = 256 + grid = ((total_elements + BLOCK_SIZE - 1) // BLOCK_SIZE,) + + fused_gdn_gating_kernel[grid]( + A_log, + a.reshape(-1), + b.reshape(-1), + dt_bias, + g.reshape(-1), + beta.reshape(-1), + num_heads, + total_elements, + BLOCK_SIZE, + softplus_threshold, + ) + + g = g.unsqueeze(0) # Add batch dim for chunk_gated_delta_rule + beta = beta.unsqueeze(0) + + return g, beta + + +class Apriel2GatedDeltaNet(nn.Module, AttentionLayerBase): + """Gated Delta Net mixer for Apriel2 using vLLM infrastructure. + + Inherits from AttentionLayerBase directly (not MambaBase) to avoid + the global mamba_page_size_padded assumption that breaks heterogeneous + block models like Apriel2. + """ + + # State cache set by vLLM's bind_kv_cache + kv_cache: tuple[torch.Tensor, ...] + + @property + def mamba_type(self) -> str: + return "gdn_attention" + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + if self.model_config is None or self.cache_config is None: + raise ValueError("model_config and cache_config must be set") + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + """Return cache spec with unified page size for hybrid models. + + The unified page size ensures all layers (attention, mamba, GDN, KDA) + have matching page_size_bytes, which is required by vLLM's KV cache + management. + """ + config = vllm_config.model_config.hf_config + _, unified_page_size = get_unified_page_size_for_config(config, vllm_config) + + block_size = ( + vllm_config.cache_config.mamba_block_size + or vllm_config.model_config.max_model_len + ) + return MambaSpec( + shapes=self.get_state_shape(), + dtypes=self.get_state_dtype(), + block_size=block_size, + page_size_padded=unified_page_size, + mamba_type=self.mamba_type, + num_speculative_blocks=self.num_spec, + ) + + def get_attn_backend(self) -> type[AttentionBackend]: + """Get the attention backend for GDN.""" + return get_mamba_attn_backend(self.mamba_type) + + def __init__( + self, + config: Apriel2Config, + mixer_config: dict, + layer_idx: int, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + + # Config params - required architecture values + self.num_v_heads = mixer_config["value_heads"] + self.num_k_heads = mixer_config["key_heads"] + self.head_k_dim = mixer_config["key_head_dim"] + self.head_v_dim = mixer_config["value_head_dim"] + conv_config = mixer_config["convolution_layer"] + self.conv_kernel_size = conv_config["kernel_size"] + # Internal defaults for implementation details + self.layer_norm_epsilon = mixer_config.get("norm_eps", 1e-5) + self.activation = conv_config.get("activation", "silu") + self.act = ACT2FN[self.activation] + + self.layer_idx = layer_idx + self.prefix = prefix + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ) + + # Derived dimensions + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_dim = self.key_dim * 2 + self.value_dim + self.value_heads_per_key = self.num_v_heads // self.num_k_heads + + # Convolution layer using vLLM's ColumnParallelLinear pattern + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # Input projections + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + + self.in_proj_qkvz = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_qkvz, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkvz", + ) + self.in_proj_ba = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_ba, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_ba", + ) + + # Set up weight loaders for conv1d + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [query_key_settings, query_key_settings, value_settings], + self.tp_size, + self.tp_rank, + ) + }, + ) + + # Time step and decay parameters + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), + ) + self.A_log = nn.Parameter( + torch.empty(divide(self.num_v_heads, self.tp_size)), + ) + + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + # Output normalization and projection + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=current_platform.current_device(), + dtype=config.torch_dtype if hasattr(config, 'torch_dtype') else None, + ) + + self.out_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + # Register with compilation context + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def fix_query_key_value_ordering( + self, + mixed_qkvz: torch.Tensor, + mixed_ba: torch.Tensor, + ): + """Derives query, key, value, z, b, a tensors from projections. + + Uses flat layout matching Fast-LLM and transformers reference: + - QKVZ: [Q_all | K_all | V_all | Z_all] + - BA: [b_all | a_all] + """ + num_tokens = mixed_qkvz.size(0) + + # Split QKVZ using flat layout (matching Fast-LLM/transformers reference) + qkvz_sizes = ( + self.key_dim // self.tp_size, # Q: key_heads * key_head_dim + self.key_dim // self.tp_size, # K: key_heads * key_head_dim + self.value_dim // self.tp_size, # V: value_heads * value_head_dim + self.value_dim // self.tp_size, # Z: value_heads * value_head_dim + ) + query, key, value, z = torch.split(mixed_qkvz, qkvz_sizes, dim=-1) + + # Reshape to head format: [tokens, heads, head_dim] + query = query.reshape(num_tokens, self.num_k_heads // self.tp_size, self.head_k_dim) + key = key.reshape(num_tokens, self.num_k_heads // self.tp_size, self.head_k_dim) + value = value.reshape(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim) + z = z.reshape(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim) + + # Split BA using flat layout: [b_all | a_all] + ba_sizes = ( + self.num_v_heads // self.tp_size, # b (beta) + self.num_v_heads // self.tp_size, # a (alpha) + ) + b, a = torch.split(mixed_ba, ba_sizes, dim=-1) + + return query, key, value, z, b, a + + def rearrange_mixed_qkv(self, mixed_qkv: torch.Tensor | None): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + query, key = map( + lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), + (query, key), + ) + value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) + return query.contiguous(), key.contiguous(), value.contiguous() + + def _debug_state_stats(self, name: str, state: torch.Tensor, seq_len: int): + """Debug recurrent state with statistics.""" + if not DEBUG_GDN_STATE or state is None: + return + flat = state.flatten() + first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist()) + print(f"[vLLM-GDN {self.prefix}] {name} (seq_len={seq_len}): shape={state.shape}, " + f"mean={state.float().mean().item():.6f}, std={state.float().std().item():.6f}, " + f"min={state.float().min().item():.6f}, max={state.float().max().item():.6f}, " + f"first8=[{first8}]") + + def _debug_tensor(self, name: str, t: torch.Tensor): + if not DEBUG_GDN_LAYER: + return + if t is None: + print(f"[GDN {self.prefix}] {name}: None") + return + flat = t.flatten()[:8] + vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) + print(f"[GDN {self.prefix}] {name}: shape={t.shape}, dtype={t.dtype}, " + f"mean={t.float().mean().item():.6f}, std={t.float().std().item():.6f}, " + f"first8=[{vals}]") + + def _debug_print(self, msg: str): + if not DEBUG_GDN_LAYER: + return + print(f"[GDN {self.prefix}] {msg}") + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor | None = None, + **kwargs, + ) -> None: + """Forward pass with custom op for core attention.""" + num_tokens = hidden_states.size(0) + + # self._cached_hidden_states = hidden_states # Cache for debug in _forward_core + # self._debug_print(f"===== FORWARD START (num_tokens={num_tokens}) =====") + # self._debug_tensor("hidden_states", hidden_states) + + # Part 1: Input Projection + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + projected_states_ba, _ = self.in_proj_ba(hidden_states) + # self._debug_tensor("projected_states_qkvz", projected_states_qkvz) + # self._debug_tensor("projected_states_ba", projected_states_ba) + + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba + ) + # self._debug_tensor("query (after fix_ordering)", query) + # self._debug_tensor("key (after fix_ordering)", key) + # self._debug_tensor("value (after fix_ordering)", value) + # self._debug_tensor("z (after fix_ordering)", z) + # self._debug_tensor("b (after fix_ordering)", b) + # self._debug_tensor("a (after fix_ordering)", a) + + # Flatten heads: [tokens, heads, head_dim] -> [tokens, heads * head_dim] + query = query.reshape(query.size(0), -1) + key = key.reshape(key.size(0), -1) + value = value.reshape(value.size(0), -1) + mixed_qkv = torch.cat((query, key, value), dim=-1) + # self._debug_tensor("mixed_qkv (flattened)", mixed_qkv) + + # Part 2: Core Attention (Custom Op) + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + torch.ops.vllm.apriel2_gdn_attention_core( + mixed_qkv, + b, + a, + core_attn_out, + self.prefix, + ) + # self._debug_tensor("core_attn_out (after custom op)", core_attn_out) + + # Part 3: Output Projection + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + # self._debug_tensor("core_attn_out (before norm)", core_attn_out) + # self._debug_tensor("z (before norm)", z) + # Debug last token before norm (reshaped has tokens * heads rows) + if DEBUG_GDN_LAYER and num_tokens > 0: + num_heads = self.num_v_heads // self.tp_size + last_token_start = (num_tokens - 1) * num_heads + last_attn = core_attn_out[last_token_start:last_token_start+1, :8] + last_z = z[last_token_start:last_token_start+1, :8] + print(f"[GDN {self.prefix}] core_attn_out before norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_attn.flatten().float().tolist())}]") + print(f"[GDN {self.prefix}] z before norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_z.flatten().float().tolist())}]") + # self._debug_tensor("norm.weight", self.norm.weight) + # self._debug_print(f"norm.norm_before_gate={self.norm.norm_before_gate}, norm.eps={self.norm.eps}") + core_attn_out = self.norm(core_attn_out, z) + # self._debug_tensor("core_attn_out (after norm)", core_attn_out) + # Debug last token after norm + if DEBUG_GDN_LAYER and num_tokens > 0: + last_attn_after = core_attn_out[last_token_start:last_token_start+1, :8] + print(f"[GDN {self.prefix}] core_attn_out after norm (last token, head 0): [{', '.join(f'{v:.6f}' for v in last_attn_after.flatten().float().tolist())}]") + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + # self._debug_tensor("core_attn_out (before out_proj)", core_attn_out) + output[:num_tokens], _ = self.out_proj(core_attn_out) + # self._debug_tensor("output (final)", output[:num_tokens]) + # Show last token specifically + if DEBUG_GDN_LAYER: + last_token = output[num_tokens-1, :8] + vals = ", ".join(f"{v:.6f}" for v in last_token.float().tolist()) + print(f"[GDN {self.prefix}] output (last token): last_token_first8=[{vals}]") + # Debug output hidden states during decode (num_tokens == 1) + if DEBUG_GDN_OUTPUT and num_tokens == 1: + flat = output[:num_tokens].flatten() + first8 = ", ".join(f"{v:.6f}" for v in flat[:8].float().tolist()) + print(f"[vLLM-GDN {self.prefix}] OUTPUT hs: shape={output[:num_tokens].shape}, mean={output[:num_tokens].float().mean().item():.6f}, std={output[:num_tokens].float().std().item():.6f}, first8=[{first8}]") + # self._debug_print("===== FORWARD END =====") + + def _forward_core( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + ): + """Core attention computation (called by custom op).""" + # self._debug_print("===== _forward_core START =====") + # self._debug_tensor("mixed_qkv (input to core)", mixed_qkv) + # self._debug_tensor("b (input to core)", b) + # self._debug_tensor("a (input to core)", a) + + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # self._debug_print("attn_metadata is None, returning early") + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + + has_initial_state = attn_metadata.has_initial_state + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor + num_actual_tokens = attn_metadata.num_actual_tokens + + # self._debug_print(f"num_actual_tokens={num_actual_tokens}, num_prefills={attn_metadata.num_prefills}, num_decodes={attn_metadata.num_decodes}") + # self._debug_print(f"has_initial_state={has_initial_state}") + # self._debug_print(f"non_spec_query_start_loc={non_spec_query_start_loc}") + + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + + # self._debug_tensor("conv_state (from cache)", conv_state) + # self._debug_tensor("ssm_state (from cache)", ssm_state) + + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + # self._debug_tensor("mixed_qkv (truncated)", mixed_qkv) + # self._debug_tensor("b (truncated)", b) + # self._debug_tensor("a (truncated)", a) + + # Convolution + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + # self._debug_tensor("conv_weights", conv_weights) + # self._debug_tensor("conv1d.bias", self.conv1d.bias) + # self._debug_print(f"activation={self.activation}") + + if attn_metadata.num_prefills > 0: + # self._debug_print("Using causal_conv1d_fn (prefill path)") + mixed_qkv_T = mixed_qkv.transpose(0, 1) + # self._debug_tensor("mixed_qkv_T (before conv)", mixed_qkv_T) + mixed_qkv = causal_conv1d_fn( + mixed_qkv_T, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + else: + # self._debug_print("Using causal_conv1d_update (decode path)") + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], + validate_data=True, + ) + + # self._debug_tensor("mixed_qkv (after conv)", mixed_qkv) + + query, key, value = self.rearrange_mixed_qkv(mixed_qkv) + # self._debug_tensor("query (after rearrange)", query) + # self._debug_tensor("key (after rearrange)", key) + # self._debug_tensor("value (after rearrange)", value) + + # Expand K heads to V heads for grouped query attention + # (matches Fast-LLM and transformers reference implementations) + # Always call repeat_interleave (no-op when value_heads_per_key == 1) to avoid + # conditional branches that confuse torch.compile + # self._debug_print(f"Expanding K heads to V heads (value_heads_per_key={self.value_heads_per_key})") + query = query.repeat_interleave(self.value_heads_per_key, dim=2) + key = key.repeat_interleave(self.value_heads_per_key, dim=2) + # self._debug_tensor("query (after expand)", query) + # self._debug_tensor("key (after expand)", key) + + # self._debug_tensor("A_log", self.A_log) + # self._debug_tensor("dt_bias", self.dt_bias) + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) + # self._debug_tensor("g (from gating)", g) + # self._debug_tensor("beta (from gating)", beta) + + # Recurrent attention + if attn_metadata.num_prefills > 0: + # self._debug_print("Using chunk_gated_delta_rule (prefill)") + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + # self._debug_tensor("initial_state", initial_state) + # Debug PREFILL INPUTS before kernel call + if DEBUG_GDN_STATE: + print(f"[vLLM-GDN {self.prefix}] PREFILL INPUTS:") + print(f" hidden_states: shape={self._cached_hidden_states.shape}, first8={self._cached_hidden_states.flatten()[:8].tolist()}") + print(f" mixed_qkv (input): shape={mixed_qkv.shape}, first8={mixed_qkv.flatten()[:8].tolist()}") + print(f" q: shape={query.shape}, first8={query.flatten()[:8].tolist()}") + print(f" k: shape={key.shape}, first8={key.flatten()[:8].tolist()}") + print(f" v: shape={value.shape}, first8={value.flatten()[:8].tolist()}") + print(f" g: shape={g.shape}, first8={g.flatten()[:8].tolist()}") + print(f" beta: shape={beta.shape}, first8={beta.flatten()[:8].tolist()}") + print(f" initial_state: {initial_state}") + print(f" cu_seqlens: {non_spec_query_start_loc}") + core_out, last_state = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # self._debug_tensor("core_out (from chunk_gated_delta_rule)", core_out) + # self._debug_tensor("last_state", last_state) + # # Debug prefill state - get seq_len from query_start_loc + # if non_spec_query_start_loc is not None and len(non_spec_query_start_loc) >= 2: + # prefill_seq_len = int(non_spec_query_start_loc[1] - non_spec_query_start_loc[0]) + # else: + # prefill_seq_len = num_actual_tokens + # self._debug_state_stats("PREFILL out_state", last_state, prefill_seq_len) + ssm_state[non_spec_state_indices_tensor] = last_state.to(ssm_state.dtype) + else: + # self._debug_print("Using fused_recurrent_gated_delta_rule (decode)") + # # For decode, access the correct slot using state indices + # if non_spec_state_indices_tensor is not None and len(non_spec_state_indices_tensor) > 0: + # slot_idx = int(non_spec_state_indices_tensor[0]) + # actual_state = ssm_state[slot_idx:slot_idx+1] + # # self._debug_state_stats("DECODE in_state", actual_state, num_actual_tokens) + # Debug decode inputs + if DEBUG_GDN_STATE: + print(f"[vLLM-GDN {self.prefix}] DECODE inputs: q={query.flatten()[:4].tolist()}, k={key.flatten()[:4].tolist()}, v={value.flatten()[:4].tolist()}, g={g.flatten()[:4].tolist()}, beta={beta.flatten()[:4].tolist()}") + core_out, _ = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata.num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) + # self._debug_tensor("core_out (from fused_recurrent)", core_out) + # if non_spec_state_indices_tensor is not None and len(non_spec_state_indices_tensor) > 0: + # actual_state = ssm_state[slot_idx:slot_idx+1] + # # self._debug_state_stats("DECODE out_state", actual_state, num_actual_tokens) + + core_attn_out[:num_actual_tokens] = core_out.squeeze(0)[:num_actual_tokens] + # self._debug_tensor("core_attn_out (final output)", core_attn_out) + # self._debug_print("===== _forward_core END =====") + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Checkpoint uses "convolution", model uses "conv1d" + loaded = set() + for name, weight in weights: + if name == "convolution.weight": + self.conv1d.weight_loader(self.conv1d.weight, weight) + loaded.add("conv1d.weight") + elif name == "in_proj_qkvz.weight": + self.in_proj_qkvz.weight_loader(self.in_proj_qkvz.weight, weight) + loaded.add("in_proj_qkvz.weight") + elif name == "in_proj_ba.weight": + self.in_proj_ba.weight_loader(self.in_proj_ba.weight, weight) + loaded.add("in_proj_ba.weight") + elif name == "out_proj.weight": + self.out_proj.weight_loader(self.out_proj.weight, weight) + loaded.add("out_proj.weight") + elif name == "norm.weight": + self.norm.weight.data.copy_(weight) + loaded.add("norm.weight") + elif name == "A_log": + self.A_log.data.copy_(weight) + loaded.add("A_log") + elif name == "dt_bias": + self.dt_bias.data.copy_(weight) + loaded.add("dt_bias") + return loaded + + +class Apriel2KDAMixer(nn.Module, AttentionLayerBase): + """Kimi Delta Attention mixer for Apriel2 using vLLM's KDA infrastructure. + + Inherits from AttentionLayerBase directly (not MambaBase) to avoid + the global mamba_page_size_padded assumption that breaks heterogeneous + block models like Apriel2. + """ + + # State cache set by vLLM's bind_kv_cache + kv_cache: tuple[torch.Tensor, ...] + + @property + def mamba_type(self) -> str: + # Use "gdn_attention" to match vLLM's KDA backend registration + return "gdn_attention" + + def get_state_dtype( + self, + ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: + if self.model_config is None or self.cache_config is None: + raise ValueError("model_config and cache_config must be set") + return MambaStateDtypeCalculator.kda_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype + ) + + def get_state_shape( + self, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.kda_state_shape( + self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size + ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + """Return cache spec with unified page size for hybrid models. + + The unified page size ensures all layers (attention, mamba, GDN, KDA) + have matching page_size_bytes, which is required by vLLM's KV cache + management. + """ + config = vllm_config.model_config.hf_config + _, unified_page_size = get_unified_page_size_for_config(config, vllm_config) + + block_size = ( + vllm_config.cache_config.mamba_block_size + or vllm_config.model_config.max_model_len + ) + return MambaSpec( + shapes=self.get_state_shape(), + dtypes=self.get_state_dtype(), + block_size=block_size, + page_size_padded=unified_page_size, + mamba_type=self.mamba_type, + ) + + def get_attn_backend(self) -> type[AttentionBackend]: + """Get the attention backend for KDA.""" + return get_mamba_attn_backend(self.mamba_type) + + def __init__( + self, + config: Apriel2Config, + mixer_config: dict, + layer_idx: int, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.model_config = model_config + self.cache_config = cache_config + + # Extract KDA config params - architecture values required + self.num_heads = mixer_config["heads"] + self.head_dim = mixer_config["head_dim"] + conv_config = mixer_config["convolution_layer"] + self.conv_size = conv_config["kernel_size"] + # Internal defaults for implementation details + norm_config = mixer_config.get("normalization", {}) + rms_norm_eps = norm_config.get("epsilon", 1e-6) + norm_activation = norm_config.get("activation", "silu") + + self.layer_idx = layer_idx + self.prefix = prefix + + assert self.num_heads % self.tp_size == 0 + self.local_num_heads = divide(self.num_heads, self.tp_size) + projection_size = self.head_dim * self.num_heads + + # Use vLLM's parallel layers + self.q_proj = ColumnParallelLinear( + self.hidden_size, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.k_proj = ColumnParallelLinear( + self.hidden_size, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.k_proj", + ) + self.v_proj = ColumnParallelLinear( + self.hidden_size, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.v_proj", + ) + + self.f_a_proj = ReplicatedLinear( + self.hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.f_a_proj", + ) + self.f_b_proj = ColumnParallelLinear( + self.head_dim, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.f_b_proj", + ) + self.dt_bias = nn.Parameter( + torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32) + ) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + self.b_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.b_proj", + ) + + # Convolutions as parallel linears + self.q_conv1d = ColumnParallelLinear( + input_size=self.conv_size, + output_size=projection_size, + bias=False, + params_dtype=torch.float32, + prefix=f"{prefix}.q_conv1d", + ) + self.k_conv1d = ColumnParallelLinear( + input_size=self.conv_size, + output_size=projection_size, + bias=False, + params_dtype=torch.float32, + prefix=f"{prefix}.k_conv1d", + ) + self.v_conv1d = ColumnParallelLinear( + input_size=self.conv_size, + output_size=projection_size, + bias=False, + params_dtype=torch.float32, + prefix=f"{prefix}.v_conv1d", + ) + # Shape conv weights correctly + self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1) + self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1) + self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1) + + # Store A_log as 1D to match checkpoint format - fused_kda_gate accepts [H] or [1,1,H,1] + self.A_log = nn.Parameter( + torch.empty(self.local_num_heads, dtype=torch.float32) + ) + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + + self.g_a_proj = ReplicatedLinear( + self.hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.g_a_proj", + ) + self.g_b_proj = ColumnParallelLinear( + self.head_dim, + projection_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.g_b_proj", + ) + self.o_norm = FusedRMSNormGated( + self.head_dim, eps=rms_norm_eps, activation=norm_activation + ) + self.o_proj = RowParallelLinear( + projection_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # Register with compilation context + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor | None = None, + **kwargs, + ) -> None: + num_tokens = hidden_states.size(0) + q = self.q_proj(hidden_states)[0] + k = self.k_proj(hidden_states)[0] + v = self.v_proj(hidden_states)[0] + + beta = self.b_proj(hidden_states)[0].float().sigmoid() + g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0] + g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias) + beta = beta.unsqueeze(0) + g1 = g1.unsqueeze(0) + + g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0] + g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim) + + core_attn_out = torch.zeros( + (1, num_tokens, self.local_num_heads, self.head_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + torch.ops.vllm.kda_attention( + q, + k, + v, + g1, + beta, + core_attn_out, + self.prefix, + ) + core_attn_out = self.o_norm(core_attn_out, g2) + core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") + output[:] = self.o_proj(core_attn_out)[0] + + def _forward( + self, + q_proj_states: torch.Tensor, + k_proj_states: torch.Tensor, + v_proj_states: torch.Tensor, + g1: torch.Tensor, + beta: torch.Tensor, + core_attn_out: torch.Tensor, + ) -> None: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor + num_actual_tokens = attn_metadata.num_actual_tokens + constant_caches = self.kv_cache[forward_context.virtual_engine] + + q_proj_states = q_proj_states[:num_actual_tokens] + k_proj_states = k_proj_states[:num_actual_tokens] + v_proj_states = v_proj_states[:num_actual_tokens] + g1 = g1[:num_actual_tokens] + beta = beta[:num_actual_tokens] + + (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches + conv_state_q = conv_state_q.transpose(-1, -2) + conv_state_k = conv_state_k.transpose(-1, -2) + conv_state_v = conv_state_v.transpose(-1, -2) + + q_conv_weights = self.q_conv1d.weight.view( + self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2) + ) + k_conv_weights = self.k_conv1d.weight.view( + self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2) + ) + v_conv_weights = self.v_conv1d.weight.view( + self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2) + ) + + if attn_metadata.num_prefills > 0: + q_proj_states = q_proj_states.transpose(0, 1) + k_proj_states = k_proj_states.transpose(0, 1) + v_proj_states = v_proj_states.transpose(0, 1) + q = causal_conv1d_fn( + q_proj_states, + q_conv_weights, + self.q_conv1d.bias, + activation="silu", + conv_states=conv_state_q, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + k = causal_conv1d_fn( + k_proj_states, + k_conv_weights, + self.k_conv1d.bias, + activation="silu", + conv_states=conv_state_k, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + v = causal_conv1d_fn( + v_proj_states, + v_conv_weights, + self.v_conv1d.bias, + activation="silu", + conv_states=conv_state_v, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + else: + decode_conv_indices = non_spec_state_indices_tensor[:num_actual_tokens] + q = causal_conv1d_update( + q_proj_states, + conv_state_q, + q_conv_weights, + self.q_conv1d.bias, + activation="silu", + conv_state_indices=decode_conv_indices, + validate_data=True, + ) + k = causal_conv1d_update( + k_proj_states, + conv_state_k, + k_conv_weights, + self.k_conv1d.bias, + activation="silu", + conv_state_indices=decode_conv_indices, + validate_data=True, + ) + v = causal_conv1d_update( + v_proj_states, + conv_state_v, + v_conv_weights, + self.v_conv1d.bias, + activation="silu", + conv_state_indices=decode_conv_indices, + validate_data=True, + ) + + q, k, v = map( + lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v) + ) + + if attn_metadata.num_prefills > 0: + zero_idx = non_spec_state_indices_tensor[~has_initial_state] + recurrent_state[zero_idx] = 0 + initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous() + core_attn_out_non_spec, last_recurrent_state = chunk_kda( + q=q, + k=k, + v=v, + g=g1, + beta=beta, + initial_state=initial_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=non_spec_query_start_loc, + ) + recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state + else: + core_attn_out_non_spec, _ = fused_recurrent_kda( + q=q, + k=k, + v=v, + g=g1, + beta=beta, + initial_state=recurrent_state, + use_qk_l2norm_in_kernel=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata.num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + ) + core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[0, :num_actual_tokens] + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Checkpoint to model name translations: + # beta_proj → b_proj, q_conv → q_conv1d, k_conv → k_conv1d, v_conv → v_conv1d, norm → o_norm + loaded = set() + for name, weight in weights: + if name == "beta_proj.weight": + self.b_proj.weight_loader(self.b_proj.weight, weight) + loaded.add("b_proj.weight") + elif name == "q_conv.weight": + self.q_conv1d.weight_loader(self.q_conv1d.weight, weight) + loaded.add("q_conv1d.weight") + elif name == "k_conv.weight": + self.k_conv1d.weight_loader(self.k_conv1d.weight, weight) + loaded.add("k_conv1d.weight") + elif name == "v_conv.weight": + self.v_conv1d.weight_loader(self.v_conv1d.weight, weight) + loaded.add("v_conv1d.weight") + elif name == "norm.weight": + self.o_norm.weight.data.copy_(weight) + loaded.add("o_norm.weight") + elif name == "q_proj.weight": + self.q_proj.weight_loader(self.q_proj.weight, weight) + loaded.add("q_proj.weight") + elif name == "k_proj.weight": + self.k_proj.weight_loader(self.k_proj.weight, weight) + loaded.add("k_proj.weight") + elif name == "v_proj.weight": + self.v_proj.weight_loader(self.v_proj.weight, weight) + loaded.add("v_proj.weight") + elif name == "o_proj.weight": + self.o_proj.weight_loader(self.o_proj.weight, weight) + loaded.add("o_proj.weight") + elif name == "f_a_proj.weight": + self.f_a_proj.weight_loader(self.f_a_proj.weight, weight) + loaded.add("f_a_proj.weight") + elif name == "f_b_proj.weight": + self.f_b_proj.weight_loader(self.f_b_proj.weight, weight) + loaded.add("f_b_proj.weight") + elif name == "g_a_proj.weight": + self.g_a_proj.weight_loader(self.g_a_proj.weight, weight) + loaded.add("g_a_proj.weight") + elif name == "g_b_proj.weight": + self.g_b_proj.weight_loader(self.g_b_proj.weight, weight) + loaded.add("g_b_proj.weight") + elif name == "A_log": + self.A_log.data.copy_(weight) + loaded.add("A_log") + elif name == "dt_bias": + self.dt_bias.data.copy_(weight) + loaded.add("dt_bias") + return loaded + + +class Apriel2AttentionDecoderLayer(nn.Module): + """Attention-based decoder layer for Apriel2.""" + + def __init__( + self, + config: Apriel2Config, + layer_idx: int, + block_config: dict, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + mixer_config = block_config.get("mixer", {}) + mlp_config = block_config.get("mlp", {}) + norm_config = block_config.get("normalization", {}) + + self.mixer = Apriel2Attention( + config=config, + mixer_config=mixer_config, + layer_idx=layer_idx, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) + + intermediate_size = mlp_config["intermediate_size"] + mlp_bias = mlp_config.get("add_linear_biases", False) + hidden_act = mlp_config.get("activation", "silu") + rms_norm_eps = norm_config["epsilon"] + + self.mlp = Apriel2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + quant_config=quant_config, + bias=mlp_bias, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + positions: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, positions=positions) + hidden_states, residual = self.post_attention_layernorm(output, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Apriel2MambaDecoderLayer(nn.Module): + """Mamba-based decoder layer for Apriel2.""" + + def __init__( + self, + config: Apriel2Config, + layer_idx: int, + block_config: dict, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + mixer_config = block_config.get("mixer", {}) + mlp_config = block_config.get("mlp", {}) + norm_config = block_config.get("normalization", {}) + + self.mixer = Apriel2MambaMixer( + config=config, + mixer_config=mixer_config, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) + + intermediate_size = mlp_config["intermediate_size"] + mlp_bias = mlp_config.get("add_linear_biases", False) + hidden_act = mlp_config.get("activation", "silu") + rms_norm_eps = norm_config["epsilon"] + + self.mlp = Apriel2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + quant_config=quant_config, + bias=mlp_bias, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output) + hidden_states, residual = self.post_attention_layernorm(output, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Apriel2GDNDecoderLayer(nn.Module): + """GatedDeltaNet-based decoder layer for Apriel2.""" + + def __init__( + self, + config: Apriel2Config, + layer_idx: int, + block_config: dict, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + mixer_config = block_config.get("mixer", {}) + mlp_config = block_config.get("mlp", {}) + norm_config = block_config.get("normalization", {}) + + self.mixer = Apriel2GatedDeltaNet( + config=config, + mixer_config=mixer_config, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f"{prefix}.mixer", + ) + + intermediate_size = mlp_config["intermediate_size"] + mlp_bias = mlp_config.get("add_linear_biases", False) + hidden_act = mlp_config.get("activation", "silu") + rms_norm_eps = norm_config["epsilon"] + + self.mlp = Apriel2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + quant_config=quant_config, + bias=mlp_bias, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=rms_norm_eps + ) + + def _debug_tensor(self, name: str, t: torch.Tensor, show_last=False): + if not DEBUG_DECODER_LAYER or t is None: + return + if show_last: + # Show last token + last = t[-1, :8] if t.dim() == 2 else t[0, -1, :8] + vals = ", ".join(f"{v:.6f}" for v in last.float().tolist()) + print(f"[vLLM Layer] {name}: shape={t.shape}, last_token_first8=[{vals}]") + else: + flat = t.flatten()[:8] + vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) + print(f"[vLLM Layer] {name}: shape={t.shape}, first8=[{vals}]") + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + # self._debug_tensor("input hidden_states", hidden_states) + # self._debug_tensor("input residual", residual) + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # self._debug_tensor("after input_layernorm", hidden_states) + # self._debug_tensor("residual after input_layernorm", residual) + + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output) + # self._debug_tensor("mixer output", output) + + hidden_states, residual = self.post_attention_layernorm(output, residual) + # self._debug_tensor("after post_attention_layernorm", hidden_states) + # self._debug_tensor("residual after post_attention_layernorm", residual) + + hidden_states = self.mlp(hidden_states) + # self._debug_tensor("after mlp", hidden_states) + # Also show last token for final layer comparison + # self._debug_tensor("after mlp (last token)", hidden_states, show_last=True) + # self._debug_tensor("residual (last token)", residual, show_last=True) + + return hidden_states, residual + + +class Apriel2KDADecoderLayer(nn.Module): + """KDA-based decoder layer for Apriel2.""" + + def __init__( + self, + config: Apriel2Config, + layer_idx: int, + block_config: dict, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + mixer_config = block_config.get("mixer", {}) + mlp_config = block_config.get("mlp", {}) + norm_config = block_config.get("normalization", {}) + + self.mixer = Apriel2KDAMixer( + config=config, + mixer_config=mixer_config, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) + + intermediate_size = mlp_config["intermediate_size"] + mlp_bias = mlp_config.get("add_linear_biases", False) + hidden_act = mlp_config.get("activation", "silu") + rms_norm_eps = norm_config["epsilon"] + + self.mlp = Apriel2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + quant_config=quant_config, + bias=mlp_bias, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output) + hidden_states, residual = self.post_attention_layernorm(output, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Apriel2StochasticMixer(nn.Module): + """Stochastic mixer that contains multiple sub-mixers. + + At inference time, routes inputs to the active mixer (configurable). + All sub-mixer weights are loaded and available for runtime switching. + + Each sub-mixer gets a unique virtual layer index for cache registration, + similar to Falcon H1's approach. This allows each mixer type to have its + own cache allocation without conflicts. + """ + + # Map mixer type to (mixer_class, needs_model_config, needs_speculative_config) + MIXER_REGISTRY: dict[str, tuple[type, bool, bool]] = { + "attention": (Apriel2Attention, False, False), + "mamba": (Apriel2MambaMixer, True, False), + "gdn": (Apriel2GatedDeltaNet, True, True), + "kda": (Apriel2KDAMixer, True, False), + } + + def __init__( + self, + config: Apriel2Config, + mixer_config: dict, + layer_idx: int, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + + # Get sub-mixer configs + mixers_config = mixer_config.get("mixers", {}) + self.active_mixer_name = mixer_config.get("main_mixer_name", list(mixers_config.keys())[0]) + + # Get total number of layers for computing virtual layer indices + decoder_config = getattr(config, "decoder", {}) or {} + num_layers = decoder_config["num_blocks"] + + # Parse the prefix to extract base path (e.g., "model.layers" from "model.layers.0.mixer") + # prefix format: "model.layers.{layer_idx}.mixer" + prefix_parts = prefix.rsplit(".", 2) # ["model.layers", "0", "mixer"] + if len(prefix_parts) >= 3: + layers_base = prefix_parts[0] # "model.layers" + else: + layers_base = "model.layers" + + # Create all sub-mixers with unique virtual layer indices + # Each sub-mixer gets a unique offset based on its position (not type) + # to avoid collisions when multiple mixers have the same type + self.mixers = nn.ModuleDict() + for mixer_index, (name, sub_mixer_config) in enumerate(mixers_config.items()): + sub_mixer_type = sub_mixer_config.get("type", "attention") + + if sub_mixer_type not in self.MIXER_REGISTRY: + raise ValueError(f"Unknown sub-mixer type '{sub_mixer_type}' in stochastic mixer") + + mixer_class, needs_model_config, needs_spec_config = self.MIXER_REGISTRY[sub_mixer_type] + + # Compute virtual layer index using mixer's position index (Falcon H1 style) + # Each sub-mixer gets its own "virtual layer" range: layer_idx + (index+1) * num_layers + # This ensures unique indices even when multiple mixers have the same type + virtual_layer_idx = layer_idx + (mixer_index + 1) * num_layers + + # Build prefix with virtual layer index for cache registration + # This only affects static_forward_context registration, not weight loading + virtual_prefix = f"{layers_base}.{virtual_layer_idx}.stochastic_{name}" + + # Build kwargs based on what each mixer type needs + kwargs = { + "config": config, + "mixer_config": sub_mixer_config, + "layer_idx": layer_idx, # Keep real layer_idx for any internal use + "cache_config": cache_config, + "quant_config": quant_config, + "prefix": virtual_prefix, + } + if needs_model_config: + kwargs["model_config"] = model_config + if needs_spec_config: + kwargs["speculative_config"] = speculative_config + + self.mixers[name] = mixer_class(**kwargs) + logger.debug( + f"Created sub-mixer '{name}' (type={sub_mixer_type}) at virtual layer {virtual_layer_idx} " + f"(real layer {layer_idx}, prefix={virtual_prefix})" + ) + + self._mixer_names = list(self.mixers.keys()) + logger.info( + f"Initialized Apriel2StochasticMixer at layer {layer_idx} with {len(self.mixers)} mixers: " + f"{', '.join(self._mixer_names)} (active={self.active_mixer_name})" + ) + + def set_active_mixer(self, name: str) -> None: + """Set the active mixer by name.""" + if name not in self.mixers: + raise ValueError(f"Unknown mixer '{name}'. Available: {self._mixer_names}") + self.active_mixer_name = name + + def get_active_mixer(self) -> str: + """Get the name of the currently active mixer.""" + return self.active_mixer_name + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor | None = None, + **kwargs, + ) -> None: + """Forward through the active mixer.""" + mixer = self.mixers[self.active_mixer_name] + mixer(hidden_states, output, positions=positions, **kwargs) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights for all sub-mixers.""" + loaded = set() + # Group weights by sub-mixer name + weights_by_mixer: dict[str, list[tuple[str, torch.Tensor]]] = {name: [] for name in self.mixers} + + for name, weight in weights: + # Weight names are like "mixers.attention.q_proj.weight" + if name.startswith("mixers."): + parts = name.split(".", 2) # ["mixers", "attention", "q_proj.weight"] + if len(parts) >= 3: + mixer_name = parts[1] + param_name = parts[2] + if mixer_name in weights_by_mixer: + weights_by_mixer[mixer_name].append((param_name, weight)) + + # Load weights for each sub-mixer + for mixer_name, mixer_weights in weights_by_mixer.items(): + if mixer_weights: + sub_loaded = self.mixers[mixer_name].load_weights(mixer_weights) + # Prefix the loaded names with the mixer path + loaded.update(f"mixers.{mixer_name}.{n}" for n in sub_loaded) + + return loaded + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + """Return cache spec for the active mixer. + + Delegates to the active sub-mixer's get_kv_cache_spec method. + """ + active_mixer = self.mixers[self.active_mixer_name] + return active_mixer.get_kv_cache_spec(vllm_config) + + +class Apriel2StochasticDecoderLayer(nn.Module): + """Stochastic decoder layer that can switch between multiple mixer types.""" + + def __init__( + self, + config: Apriel2Config, + layer_idx: int, + block_config: dict, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + mixer_config = block_config.get("mixer", {}) + mlp_config = block_config.get("mlp", {}) + norm_config = block_config.get("normalization", {}) + + self.mixer = Apriel2StochasticMixer( + config=config, + mixer_config=mixer_config, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f"{prefix}.mixer", + ) + + intermediate_size = mlp_config["intermediate_size"] + mlp_bias = mlp_config.get("add_linear_biases", False) + hidden_act = mlp_config.get("activation", "silu") + rms_norm_eps = norm_config["epsilon"] + + self.mlp = Apriel2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + quant_config=quant_config, + bias=mlp_bias, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=rms_norm_eps + ) + + def set_active_mixer(self, name: str) -> None: + """Set the active mixer for this layer.""" + self.mixer.set_active_mixer(name) + + def get_active_mixer(self) -> str: + """Get the name of the currently active mixer.""" + return self.mixer.get_active_mixer() + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + positions: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, positions=positions, **kwargs) + hidden_states, residual = self.post_attention_layernorm(output, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": Apriel2AttentionDecoderLayer, + "mamba": Apriel2MambaDecoderLayer, + "gdn": Apriel2GDNDecoderLayer, + "kda": Apriel2KDADecoderLayer, + "stochastic": Apriel2StochasticDecoderLayer, +} + + +def get_block_config_for_layer( + config: Apriel2Config, layer_idx: int +) -> tuple[str, dict]: + """Get mixer type and block config for a specific layer.""" + decoder_config = config.decoder + seq_type = decoder_config.get("type", "fixed") + + if seq_type == "fixed": + block_config = decoder_config.get("block", {}) + mixer_type = block_config.get("mixer", {}).get("type", "attention") + return mixer_type, block_config + elif seq_type == "pattern": + pattern = decoder_config.get("pattern", ["attention"]) + blocks_config = decoder_config.get("blocks", {}) + block_name = pattern[layer_idx % len(pattern)] + block_config = blocks_config.get(block_name, {}) + mixer_type = block_config.get("mixer", {}).get("type", "attention") + return mixer_type, block_config + else: + return "attention", {} + + +def apriel2_model_invariants( + input_ids, positions, intermediate_tensors=None, inputs_embeds=None +): + """Shape invariants for Apriel2 model compilation. + + These are translated to runtime assertions for unbacked dynamic shapes + and are compiled away for backed shapes. + """ + if input_ids is not None: + torch._check(positions.size()[0] == input_ids.size()[0]) + + +@support_torch_compile(shape_invariants=apriel2_model_invariants) +class Apriel2Model(nn.Module): + """Apriel2 base model (decoder stack).""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = None + + def get_layer(*, prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + mixer_type, block_config = get_block_config_for_layer(config, layer_idx) + layer_class = ALL_DECODER_LAYER_TYPES.get(mixer_type) + + if layer_class is None: + raise ValueError(f"Unknown mixer type: {mixer_type}") + + if mixer_type == "attention": + return layer_class( + config=config, + layer_idx=layer_idx, + block_config=block_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + elif mixer_type == "mamba": + return layer_class( + config=config, + layer_idx=layer_idx, + block_config=block_config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + elif mixer_type == "gdn": + return layer_class( + config=config, + layer_idx=layer_idx, + block_config=block_config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=vllm_config.speculative_config, + prefix=prefix, + ) + elif mixer_type == "stochastic": + return layer_class( + config=config, + layer_idx=layer_idx, + block_config=block_config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=vllm_config.speculative_config, + prefix=prefix, + ) + else: # kda + return layer_class( + config=config, + layer_idx=layer_idx, + block_config=block_config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + num_layers = config.decoder["num_blocks"] + self.start_layer, self.end_layer, self.layers = make_layers( + num_layers, + get_layer, + prefix=f"{prefix}.layers" if prefix else "layers", + ) + + if get_pp_group().is_last_rank: + head_norm_eps = config.head["normalization"]["epsilon"] + self.norm = RMSNorm(config.hidden_size, eps=head_norm_eps) + else: + self.norm = None + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + hidden_states=hidden_states, + residual=residual, + positions=positions, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + # Debug final norm + if DEBUG_FINAL_NORM: + # Show LAST token (to match TF) + last_hs = hidden_states[-1, :8] + last_res = residual[-1, :8] if residual is not None else None + hs_vals = ", ".join(f"{v:.6f}" for v in last_hs.float().tolist()) + res_vals = ", ".join(f"{v:.6f}" for v in last_res.float().tolist()) if last_res is not None else "None" + print(f"[vLLM Final] hidden_states (before norm): shape={hidden_states.shape}, last_token_first8=[{hs_vals}]") + print(f"[vLLM Final] residual (before norm): shape={residual.shape if residual is not None else None}, last_token_first8=[{res_vals}]") + print(f"[vLLM Final] norm.weight: first8=[{', '.join(f'{v:.6f}' for v in self.norm.weight.flatten()[:8].float().tolist())}]") + print(f"[vLLM Final] norm.variance_epsilon={self.norm.variance_epsilon}") + + hidden_states, _ = self.norm(hidden_states, residual) + + if DEBUG_FINAL_NORM: + last_out = hidden_states[-1, :8] + out_vals = ", ".join(f"{v:.6f}" for v in last_out.float().tolist()) + print(f"[vLLM Final] hidden_states (after norm): shape={hidden_states.shape}, last_token_first8=[{out_vals}]") + + return hidden_states + + +class Apriel2ForCausalLM(nn.Module, HasInnerState, SupportsPP): + """Apriel2 model for causal language modeling. + + Supports hybrid architectures with attention, mamba, GDN, and KDA mixers. + """ + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "model.decoder.blocks.": "model.layers.", + }, + ) + + # For hybrid models + has_inner_state = True + # Don't use is_hybrid=True - it triggers HybridAttentionMambaModelConfig + # which assumes all mamba-like layers have the same shape. + # Apriel2 has heterogeneous blocks, each with its own get_kv_cache_spec(). + is_hybrid = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + self.config = config + self.vllm_config = vllm_config + + self.model = Apriel2Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + self.logits_processor = LogitsProcessor(config.vocab_size) + else: + self.lm_head = None + + self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + # Debug LM head input + if DEBUG_LM_HEAD: + flat = hidden_states.flatten()[:8] + vals = ", ".join(f"{v:.6f}" for v in flat.float().tolist()) + print(f"[vLLM LM Head] input hidden_states: shape={hidden_states.shape}, first8=[{vals}]") + if self.lm_head is not None: + lm_weight = self.lm_head.weight + print(f"[vLLM LM Head] lm_head.weight: shape={lm_weight.shape}, first8=[{', '.join(f'{v:.6f}' for v in lm_weight.flatten()[:8].float().tolist())}]") + + logits = self.logits_processor(self.lm_head, hidden_states) + + if DEBUG_LM_HEAD and logits is not None: + # Get last token logits + last_logits = logits[-1] if logits.dim() == 2 else logits[0, -1] + top_vals, top_idx = last_logits.topk(5) + print(f"[vLLM LM Head] logits shape={logits.shape}") + print(f"[vLLM LM Head] last token top-5 logits: {[(idx.item(), val.item()) for idx, val in zip(top_idx, top_vals)]}") + + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def set_layer_placements(self, placement: list[str]) -> dict[int, str]: + """Set the active mixer for each stochastic layer. + + This method is designed to be used with vLLM's apply_model: + llm.apply_model("set_layer_placements", placement) + + Args: + placement: List of mixer names, one per layer. For non-stochastic + layers, the value is ignored. Example: ["attention", "gdn", ...] + + Returns: + Dict mapping layer index to the mixer that was set (only for + stochastic layers that were actually changed). + """ + changed = {} + layers = self.model.layers + for layer_idx, mixer_name in enumerate(placement): + if layer_idx >= len(layers): + break + layer = layers[layer_idx] + if isinstance(layer, Apriel2StochasticDecoderLayer): + layer.set_active_mixer(mixer_name) + changed[layer_idx] = mixer_name + return changed + + def get_layer_placements(self) -> dict[int, str]: + """Get the current active mixer for each stochastic layer. + + This method is designed to be used with vLLM's apply_model: + placements = llm.apply_model("get_layer_placements") + + Returns: + Dict mapping layer index to the currently active mixer name + (only for stochastic layers). + """ + placements = {} + layers = self.model.layers + for layer_idx, layer in enumerate(layers): + if isinstance(layer, Apriel2StochasticDecoderLayer): + placements[layer_idx] = layer.get_active_mixer() + return placements + + +# ----------------------------------------------------------------------------- +# Worker monkey-patching for placement switching via collective_rpc +# ----------------------------------------------------------------------------- +# This allows calling placement methods by string name without cloudpickle: +# placements = llm.collective_rpc("get_layer_placements") +# llm.collective_rpc("set_layer_placements", args=(new_placement,)) + + +def _patch_worker_for_placement_switching(): + """Add placement switching methods to the vLLM GPU worker.""" + try: + from vllm.v1.worker.gpu_worker import Worker + except ImportError: + return # vLLM not available or different version + + if hasattr(Worker, "get_layer_placements"): + return # Already patched + + def _get_layer_placements(self) -> dict[int, str]: + return self.get_model().get_layer_placements() + + def _set_layer_placements(self, placement: list[str]) -> dict[int, str]: + return self.get_model().set_layer_placements(placement) + + Worker.get_layer_placements = _get_layer_placements + Worker.set_layer_placements = _set_layer_placements + + +_patch_worker_for_placement_switching() diff --git a/fast_llm_external_models/apriel2/vllm/test_apriel2.py b/fast_llm_external_models/apriel2/vllm/test_apriel2.py new file mode 100644 index 000000000..33014876a --- /dev/null +++ b/fast_llm_external_models/apriel2/vllm/test_apriel2.py @@ -0,0 +1,1232 @@ +#!/usr/bin/env python3 +"""Test script for Apriel2 vLLM implementation. + +This script tests coherence and numerical correctness of Apriel2 models +by comparing vLLM outputs with the reference Transformers implementation. + +Usage: + # Test coherence (generation quality) + python test_apriel2.py coherence /path/to/model + python test_apriel2.py coherence /path/to/model --placement every2nd-gdn + + # Compare logits between vLLM and Transformers + python test_apriel2.py logits /path/to/model + python test_apriel2.py logits /path/to/model --placement all-gdn + + # Statistical comparison with many prompts (for rigorous testing) + python test_apriel2.py stats /path/to/model --num-prompts 128 + python test_apriel2.py stats /path/to/model --placement every3rd-gdn + + # Run both tests + python test_apriel2.py all /path/to/model + +Placement patterns: + --placement all-attention All layers use attention + --placement all-gdn All layers use GDN + --placement every2nd-gdn Every 2nd layer is GDN (1=attn, 2=gdn, 3=attn, ...) + --placement every3rd-gdn Every 3rd layer is GDN + --placement every4th-gdn Every 4th layer is GDN + --placement attn,gdn,attn,... Explicit comma-separated list +""" + +import argparse +import gc +from pathlib import Path + +import numpy as np +import torch +import triton + +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig +from vllm.config.compilation import CompilationMode + +# Apriel2 model registration is handled automatically via vLLM's plugin system +# (see fast-llm setup.cfg entry_points for vllm.general_plugins) + + +# Set a triton allocator to avoid "no allocator was set" errors +def _triton_allocator(size, align, stream): + return torch.empty(size, dtype=torch.int8, device='cuda').data_ptr() + + +triton.set_allocator(_triton_allocator) + + +def parse_placement(placement_str: str, num_layers: int) -> list[str]: + """Parse placement string into a list of mixer names. + + Args: + placement_str: Either a pattern name or comma-separated mixer names. + Patterns: all-attention, all-gdn, every2nd-gdn, every3rd-gdn, every4th-gdn + Explicit: attention,gdn,attention,gdn,... + num_layers: Number of layers in the model. + + Returns: + List of mixer names, one per layer. + """ + placement_str = placement_str.strip().lower() + + if placement_str == "all-attention": + return ["attention"] * num_layers + elif placement_str == "all-gdn": + return ["gdn"] * num_layers + elif placement_str.startswith("every") and placement_str.endswith("-gdn"): + # Parse "every2nd-gdn", "every3rd-gdn", etc. + n_str = placement_str[5:-4] # Extract "2nd", "3rd", etc. + n_str = n_str.rstrip("ndrdth") # Remove ordinal suffix + n = int(n_str) + placement = [] + for i in range(num_layers): + if (i + 1) % n == 0: # Every nth layer is GDN + placement.append("gdn") + else: + placement.append("attention") + return placement + elif "," in placement_str: + # Explicit comma-separated list + placement = [m.strip() for m in placement_str.split(",")] + if len(placement) != num_layers: + raise ValueError(f"Placement has {len(placement)} entries but model has {num_layers} layers") + return placement + else: + raise ValueError(f"Unknown placement pattern: {placement_str}") + + +def apply_placement(llm: "LLM", placement_str: str | None) -> None: + """Apply placement to a vLLM model if specified. + + Args: + llm: vLLM LLM instance. + placement_str: Placement string or None to skip. + """ + if placement_str is None: + return + + # Get current placements to determine num_layers + placements = llm.collective_rpc("get_layer_placements") + if not placements or not placements[0]: + print(f" Model does not support placement switching, ignoring --placement") + return + + num_layers = len(placements[0]) + current = list(placements[0].values()) + print(f" Current placement: {current[0]} (all {num_layers} layers)") + + new_placement = parse_placement(placement_str, num_layers) + llm.collective_rpc("set_layer_placements", args=(new_placement,)) + + # Verify + placements_after = llm.collective_rpc("get_layer_placements") + attn_count = sum(1 for v in placements_after[0].values() if v == "attention") + gdn_count = sum(1 for v in placements_after[0].values() if v == "gdn") + print(f" Applied placement '{placement_str}': {attn_count} attention, {gdn_count} gdn") + + +def setup_transformers(): + """Register Apriel2 model with Transformers.""" + from transformers import AutoConfig, AutoModelForCausalLM + + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + + AutoConfig.register("apriel2_text", Apriel2TextConfig) + AutoModelForCausalLM.register(Apriel2TextConfig, Apriel2ForCausalLM) + + +def test_coherence_vllm(model_paths: list[str], prompts: list[str], max_tokens: int = 50, placement: str | None = None): + """Test generation coherence with vLLM.""" + sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0) + + results = {} + for model_path in model_paths: + model_name = Path(model_path).name + print(f"\n{'#'*70}") + print(f"# vLLM: {model_name}") + print(f"{'#'*70}") + + llm = LLM( + model=model_path, + trust_remote_code=True, + gpu_memory_utilization=0.4, + max_model_len=2048, + ) + + apply_placement(llm, placement) + + outputs = llm.generate(prompts, sampling_params) + results[model_name] = {} + + for output in outputs: + prompt = output.prompt + generated = output.outputs[0].text + results[model_name][prompt] = generated + print(f"\nPrompt: {prompt!r}") + print(f"Output: {prompt + generated!r}") + + del llm + gc.collect() + torch.cuda.empty_cache() + + return results + + +def test_coherence_transformers(model_paths: list[str], prompts: list[str], max_tokens: int = 50): + """Test generation coherence with Transformers.""" + from transformers import AutoModelForCausalLM, AutoTokenizer + + setup_transformers() + + results = {} + for model_path in model_paths: + model_name = Path(model_path).name + print(f"\n{'#'*70}") + print(f"# Transformers: {model_name}") + print(f"{'#'*70}") + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="cuda", + trust_remote_code=True, + ) + model.eval() + + results[model_name] = {} + for prompt in prompts: + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") + + with torch.no_grad(): + output_ids = model.generate( + input_ids, + max_new_tokens=max_tokens, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + + generated = tokenizer.decode(output_ids[0], skip_special_tokens=True) + # Extract just the generated part (remove prompt) + generated_only = generated[len(prompt):] + results[model_name][prompt] = generated_only + print(f"\nPrompt: {prompt!r}") + print(f"Output: {generated!r}") + + del model + torch.cuda.empty_cache() + + return results + + +def compare_logits(model_path: str, prompt: str, max_tokens: int = 1, dtype: str = "bfloat16", no_compile: bool = False, revision: str | None = None, debug_gdn: bool = False, placement: str | None = None): + """Compare logits between vLLM and Transformers.""" + from transformers import AutoModelForCausalLM, AutoTokenizer + + setup_transformers() + + # Enable GDN debug if requested + if debug_gdn: + # vLLM GDN class + from fast_llm_external_models.apriel2.vllm.modeling_apriel2 import Apriel2GatedDeltaNet as VLLMGatedDeltaNet + VLLMGatedDeltaNet._debug_global_enable = True + print("GDN debug mode enabled for vLLM") + + torch_dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32 + + print(f"\n{'='*70}") + print(f"Model: {model_path}") + print(f"Revision: {revision}") + print(f"Prompt: {prompt!r}") + print(f"Dtype: {dtype}") + print(f"No compile: {no_compile}") + print(f"Debug GDN: {debug_gdn}") + print(f"Placement: {placement}") + print(f"{'='*70}\n") + + # Tokenize + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision, trust_remote_code=True) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + print(f"Input tokens: {input_ids.shape[1]}") + print(f"Token IDs: {input_ids[0].tolist()}") + + # --- vLLM --- + compile_label = "no-compile" if no_compile else "compiled" + print(f"\n--- vLLM ({dtype}, {compile_label}) ---") + compilation_config = CompilationConfig(mode=CompilationMode.NONE) if no_compile else None + llm = LLM( + model=model_path, + revision=revision, + trust_remote_code=True, + gpu_memory_utilization=0.4, + max_model_len=2048, + dtype=dtype, + compilation_config=compilation_config, + ) + + apply_placement(llm, placement) + + sampling_params = SamplingParams( + max_tokens=max_tokens, + temperature=0, + logprobs=20, + ) + + outputs = llm.generate([prompt], sampling_params) + output = outputs[0] + + vllm_text = output.outputs[0].text + vllm_token_ids = output.outputs[0].token_ids + vllm_logprobs = output.outputs[0].logprobs + + print(f"Generated text: {vllm_text!r}") + print(f"Generated token IDs: {vllm_token_ids}") + + vllm_first_token_id = None + if vllm_token_ids: + vllm_first_token_id = vllm_token_ids[0] + vllm_first_token = tokenizer.decode([vllm_first_token_id]) + print(f"First generated token: {vllm_first_token!r} (id={vllm_first_token_id})") + + if vllm_logprobs and len(vllm_logprobs) > 0: + print("Top-5 by logprob:") + first_logprobs = vllm_logprobs[0] + sorted_logprobs = sorted(first_logprobs.items(), key=lambda x: x[1].logprob, reverse=True)[:5] + for tid, lp in sorted_logprobs: + token = tokenizer.decode([tid]) + print(f" {token!r} (id={tid}): logprob={lp.logprob:.4f}") + + del llm + gc.collect() + torch.cuda.empty_cache() + + # --- Transformers --- + # Use flash_attention_2 to match vLLM's attention backend (bf16 only) + attn_impl = "flash_attention_2" if dtype == "bfloat16" else "eager" + print(f"\n--- Transformers ({dtype}, {attn_impl}) ---") + model = AutoModelForCausalLM.from_pretrained( + model_path, + revision=revision, + torch_dtype=torch_dtype, + device_map="cuda", + trust_remote_code=True, + attn_implementation=attn_impl, + ) + model.eval() + + # Enable debug on transformers GDN layers if requested + if debug_gdn: + for name, module in model.named_modules(): + if module.__class__.__name__ == "Apriel2GatedDeltaNet": + module._debug_enabled = True # Enable at instance level (TF doesn't have warmup filtering) + print(f"Enabled debug on {name}") + + with torch.no_grad(): + tf_outputs = model(input_ids.to("cuda")) + tf_logits = tf_outputs.logits.cpu() + + print(f"Logits shape: {tf_logits.shape}") + + tf_next_token_logits = tf_logits[0, -1, :] + tf_next_token_id = tf_next_token_logits.argmax().item() + tf_next_token = tokenizer.decode([tf_next_token_id]) + print(f"Predicted next token: {tf_next_token!r} (id={tf_next_token_id})") + + tf_logprobs = torch.log_softmax(tf_next_token_logits.float(), dim=-1) + print("Top-5 by logprob:") + tf_top5 = tf_logprobs.topk(5) + for i in range(5): + tid = tf_top5.indices[i].item() + lp = tf_top5.values[i].item() + token = tokenizer.decode([tid]) + print(f" {token!r} (id={tid}): logprob={lp:.4f}") + + del model + torch.cuda.empty_cache() + + # --- Comparison --- + print("\n--- Comparison ---") + match = False + if vllm_first_token_id is not None: + if vllm_first_token_id == tf_next_token_id: + print("MATCH: Both models predict the same next token!") + match = True + else: + vllm_first_token = tokenizer.decode([vllm_first_token_id]) + print(f"MISMATCH: Transformers predicts {tf_next_token!r}, vLLM predicts {vllm_first_token!r}") + + tf_topk = tf_logprobs.topk(10) + if vllm_first_token_id in tf_topk.indices.tolist(): + rank = tf_topk.indices.tolist().index(vllm_first_token_id) + print(f" vLLM's token is rank {rank+1} in transformers' predictions") + else: + print(f" vLLM's token is NOT in transformers' top-10") + + # Compare logprobs + if vllm_logprobs and len(vllm_logprobs) > 0: + print("\n--- Logprob Comparison ---") + first_logprobs = vllm_logprobs[0] + + common_tokens = set(first_logprobs.keys()) & set(range(len(tf_logprobs))) + if common_tokens: + diffs = [] + for tid in list(common_tokens)[:10]: + vllm_lp = first_logprobs[tid].logprob + tf_lp = tf_logprobs[tid].item() + diff = abs(vllm_lp - tf_lp) + diffs.append(diff) + if diff > 0.1: + token = tokenizer.decode([tid]) + print(f" {token!r}: vLLM={vllm_lp:.4f}, TF={tf_lp:.4f}, diff={diff:.4f}") + + avg_diff = sum(diffs) / len(diffs) if diffs else 0 + max_diff = max(diffs) if diffs else 0 + print(f"\n Average logprob diff: {avg_diff:.6f}") + print(f" Max logprob diff: {max_diff:.6f}") + + return match, vllm_text, tf_next_token + + +def compare_comprehensive( + model_path: str, + prompt_sizes: list[int], + decode_lengths: list[int], + batch_sizes: list[int], + dtype: str = "bfloat16", + no_compile: bool = True, + revision: str | None = None, + placement: str | None = None, +): + """Compare vLLM and Transformers across various configurations. + + Returns a list of result dicts with keys: + prompt_size, decode_length, batch_size, avg_diff, max_diff, all_match + """ + from transformers import AutoModelForCausalLM, AutoTokenizer + + setup_transformers() + + torch_dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32 + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision, trust_remote_code=True) + + # Generate prompts of different sizes using a base text + base_text = ( + "The study of artificial intelligence has evolved significantly over the past decades. " + "Machine learning, a subset of AI, focuses on developing algorithms that can learn from data. " + "Deep learning, in turn, uses neural networks with many layers to model complex patterns. " + "Natural language processing enables computers to understand and generate human language. " + "Computer vision allows machines to interpret and analyze visual information from the world. " + "Reinforcement learning trains agents to make decisions by rewarding desired behaviors. " + "The field continues to advance rapidly, with new breakthroughs occurring frequently. " + "Applications range from autonomous vehicles to medical diagnosis and beyond. " + "Ethical considerations around AI development have become increasingly important. " + "Researchers work to ensure AI systems are fair, transparent, and beneficial to society. " + ) + + # Repeat base text to get enough tokens + long_text = (base_text * 20)[:8000] # Plenty of text + + def get_prompt_with_tokens(target_tokens: int) -> str: + """Get a prompt with approximately target_tokens tokens.""" + # Binary search for right length + low, high = 1, len(long_text) + while low < high: + mid = (low + high) // 2 + test_prompt = long_text[:mid] + num_tokens = len(tokenizer.encode(test_prompt)) + if num_tokens < target_tokens: + low = mid + 1 + else: + high = mid + return long_text[:low] + + results = [] + + # Load vLLM once + print(f"\n{'='*70}") + print(f"Loading vLLM model: {model_path}") + print(f"{'='*70}") + + compilation_config = CompilationConfig(mode=CompilationMode.NONE) if no_compile else None + llm = LLM( + model=model_path, + revision=revision, + trust_remote_code=True, + gpu_memory_utilization=0.4, + max_model_len=2048, + dtype=dtype, + compilation_config=compilation_config, + ) + + apply_placement(llm, placement) + + # Load Transformers once + print(f"\nLoading Transformers model...") + attn_impl = "flash_attention_2" if dtype == "bfloat16" else "eager" + tf_model = AutoModelForCausalLM.from_pretrained( + model_path, + revision=revision, + torch_dtype=torch_dtype, + device_map="cuda", + trust_remote_code=True, + attn_implementation=attn_impl, + ) + tf_model.eval() + + print(f"\n{'='*70}") + print(f"Running comparisons...") + print(f"{'='*70}\n") + + # Header + print(f"{'Prompt':<8} {'Decode':<8} {'Batch':<8} {'Avg Diff':<12} {'Max Diff':<12} {'Match':<8}") + print("-" * 60) + + for prompt_size in prompt_sizes: + prompt = get_prompt_with_tokens(prompt_size) + actual_tokens = len(tokenizer.encode(prompt)) + + for decode_length in decode_lengths: + for batch_size in batch_sizes: + # Create batch of prompts + prompts = [prompt] * batch_size + + # vLLM inference + sampling_params = SamplingParams( + max_tokens=decode_length, + temperature=0, + logprobs=20, + ) + vllm_outputs = llm.generate(prompts, sampling_params) + + # Transformers inference + input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.to("cuda") + + with torch.no_grad(): + if decode_length == 1: + # Just get logits for next token prediction + tf_outputs = tf_model(input_ids) + tf_logits = tf_outputs.logits + else: + # Generate multiple tokens + tf_output_ids = tf_model.generate( + input_ids, + max_new_tokens=decode_length, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + return_dict_in_generate=True, + output_logits=True, + ) + # Stack logits from generation steps + tf_logits = torch.stack(tf_output_ids.logits, dim=1) + + # Compare logprobs for each position and batch element + all_diffs = [] + all_match = True + + for b in range(batch_size): + vllm_out = vllm_outputs[b] + vllm_logprobs_list = vllm_out.outputs[0].logprobs or [] + vllm_token_ids = vllm_out.outputs[0].token_ids + + for pos in range(min(decode_length, len(vllm_logprobs_list))): + vllm_logprobs = vllm_logprobs_list[pos] + + if decode_length == 1: + # For prefill, use last position + tf_pos_logits = tf_logits[b, -1, :] + else: + # For generation, use corresponding position + tf_pos_logits = tf_logits[b, pos, :] + + tf_pos_logprobs = torch.log_softmax(tf_pos_logits.float(), dim=-1) + + # Get TF's predicted token + tf_pred_token = tf_pos_logprobs.argmax().item() + vllm_pred_token = vllm_token_ids[pos] if pos < len(vllm_token_ids) else None + + if vllm_pred_token != tf_pred_token: + if all_match: # First mismatch + vllm_lp_for_tok = vllm_logprobs.get(vllm_pred_token, None) + vllm_lp_val = vllm_lp_for_tok.logprob if vllm_lp_for_tok else "N/A" + tf_lp_vllm_tok = tf_pos_logprobs[vllm_pred_token].item() if vllm_pred_token and vllm_pred_token < len(tf_pos_logprobs) else "N/A" + tf_lp_tf_tok = tf_pos_logprobs[tf_pred_token].item() + print(f" FIRST MISMATCH at pos {pos}: vLLM tok={vllm_pred_token} (lp={vllm_lp_val}), TF tok={tf_pred_token} (lp={tf_lp_tf_tok:.4f})") + print(f" TF logprob for vLLM's token: {tf_lp_vllm_tok}") + all_match = False + + # Compare logprobs for common tokens + for tid, lp_info in vllm_logprobs.items(): + if tid < len(tf_pos_logprobs): + vllm_lp = lp_info.logprob + tf_lp = tf_pos_logprobs[tid].item() + diff = abs(vllm_lp - tf_lp) + all_diffs.append(diff) + + avg_diff = sum(all_diffs) / len(all_diffs) if all_diffs else 0 + max_diff = max(all_diffs) if all_diffs else 0 + match_str = "YES" if all_match else "NO" + + result = { + "prompt_size": actual_tokens, + "decode_length": decode_length, + "batch_size": batch_size, + "avg_diff": avg_diff, + "max_diff": max_diff, + "all_match": all_match, + } + results.append(result) + + print(f"{actual_tokens:<8} {decode_length:<8} {batch_size:<8} {avg_diff:<12.6f} {max_diff:<12.6f} {match_str:<8}") + + # Cleanup + del llm + del tf_model + gc.collect() + torch.cuda.empty_cache() + + # Summary + print(f"\n{'='*60}") + print("SUMMARY") + print(f"{'='*60}") + all_avg = sum(r["avg_diff"] for r in results) / len(results) + all_max = max(r["max_diff"] for r in results) + all_matched = all(r["all_match"] for r in results) + print(f"Overall average diff: {all_avg:.6f}") + print(f"Overall max diff: {all_max:.6f}") + print(f"All predictions match: {'YES' if all_matched else 'NO'}") + + return results + + +def cmd_compare(args): + """Run comprehensive comparison across configurations.""" + prompt_sizes = [int(x) for x in args.prompt_sizes.split(",")] + decode_lengths = [int(x) for x in args.decode_lengths.split(",")] + batch_sizes = [int(x) for x in args.batch_sizes.split(",")] + + for model_path in args.model_paths: + compare_comprehensive( + model_path, + prompt_sizes=prompt_sizes, + decode_lengths=decode_lengths, + batch_sizes=batch_sizes, + dtype=args.dtype, + no_compile=args.no_compile, + revision=getattr(args, 'revision', None), + placement=getattr(args, 'placement', None), + ) + + +def cmd_coherence(args): + """Run coherence test.""" + prompts = [ + "The capital of France is", + "To solve this math problem, I need to", + "Once upon a time, there was a", + ] + + placement = getattr(args, 'placement', None) + + print("\n" + "="*70) + print("COHERENCE TEST: vLLM") + print("="*70) + vllm_results = test_coherence_vllm(args.model_paths, prompts, args.max_tokens, placement=placement) + + print("\n" + "="*70) + print("COHERENCE TEST: Transformers") + print("="*70) + tf_results = test_coherence_transformers(args.model_paths, prompts, args.max_tokens) + + # Compare results + print("\n" + "="*70) + print("COMPARISON SUMMARY") + print("="*70) + for model_name in vllm_results: + print(f"\n{model_name}:") + for prompt in prompts: + vllm_out = vllm_results[model_name].get(prompt, "") + tf_out = tf_results[model_name].get(prompt, "") + # Compare first 20 chars + vllm_start = vllm_out[:20].strip() + tf_start = tf_out[:20].strip() + match = "MATCH" if vllm_start == tf_start else "DIFF" + print(f" [{match}] {prompt[:30]!r}...") + if match == "DIFF": + print(f" vLLM: {vllm_start!r}...") + print(f" TF: {tf_start!r}...") + + +# ============================================================================ +# Statistical Testing Infrastructure (v2) +# ============================================================================ +# +# Design goals: +# 1. Dataset-based prompts (C4) for reproducibility +# 2. Controlled tokenization - same token IDs to both backends +# 3. Per-position statistics (prefill + each decode step) +# 4. Configurable Transformers kernel selection +# 5. Full parameter space: prompts, prompt_length, decode_length, batch_size, compile, kernels + +from dataclasses import dataclass, field +from itertools import islice + + +@dataclass +class TokenComparison: + """Comparison data for a single token position.""" + prompt_idx: int + position: int # 0 = prefill (last token), 1+ = decode steps + vllm_token_id: int + tf_token_id: int + token_match: bool + avg_logprob_diff: float + max_logprob_diff: float + top_k_diffs: list[float] = field(default_factory=list) + + +def load_and_tokenize_prompts( + num_prompts: int, + prompt_length: int, + tokenizer, + seed: int = 42, +) -> list[list[int]]: + """Load prompts from C4 dataset and tokenize to exact length. + + Streams through shuffled dataset until we find exactly num_prompts + that have at least prompt_length tokens. + + Args: + num_prompts: Number of prompts to collect + prompt_length: Exact number of tokens per prompt + tokenizer: Tokenizer to use + seed: Random seed for shuffling + + Returns: + List of token ID lists, all exactly prompt_length long + """ + from datasets import load_dataset + + print(f"Loading C4 dataset (streaming, seed={seed})...") + dataset = load_dataset('allenai/c4', 'en', split='train', streaming=True) + + # Shuffle with seed for reproducibility + dataset = dataset.shuffle(seed=seed, buffer_size=10000) + + token_ids_list = [] + samples_checked = 0 + + for sample in dataset: + samples_checked += 1 + text = sample['text'] + + # Tokenize and check length + tokens = tokenizer.encode(text, add_special_tokens=False) + if len(tokens) >= prompt_length: + token_ids_list.append(tokens[:prompt_length]) + + if len(token_ids_list) >= num_prompts: + break + + # Progress every 100 samples + if samples_checked % 100 == 0: + print(f" Checked {samples_checked} samples, found {len(token_ids_list)}/{num_prompts} valid prompts", end="\r") + + print(f" Checked {samples_checked} samples, found {len(token_ids_list)}/{num_prompts} valid prompts") + + if len(token_ids_list) < num_prompts: + print(f" Warning: Only found {len(token_ids_list)} prompts with >= {prompt_length} tokens") + + return token_ids_list + + +def set_transformers_kernels(model_path: str, kernel_config: str) -> None: + """Set kernel configuration in Transformers modeling file. + + Args: + model_path: Path to model (to find modeling file) + kernel_config: 'upstream' or 'vllm' + """ + import importlib + import sys + + # The Transformers model uses the local modeling_apriel2.py in the checkpoint + # We need to modify its flags before loading + modeling_path = Path(model_path) / "modeling_apriel2.py" + if not modeling_path.exists(): + print(f" Warning: No modeling_apriel2.py found at {model_path}") + return + + # Read the file + content = modeling_path.read_text() + + # Set the flags based on kernel_config + if kernel_config == "upstream": + new_content = content.replace("USE_VLLM_CONV = True", "USE_VLLM_CONV = False") + new_content = new_content.replace("USE_VLLM_GDN_OPS = True", "USE_VLLM_GDN_OPS = False") + new_content = new_content.replace("USE_VLLM_GATED_NORM = True", "USE_VLLM_GATED_NORM = False") + elif kernel_config == "vllm": + new_content = content.replace("USE_VLLM_CONV = False", "USE_VLLM_CONV = True") + new_content = new_content.replace("USE_VLLM_GDN_OPS = False", "USE_VLLM_GDN_OPS = True") + new_content = new_content.replace("USE_VLLM_GATED_NORM = False", "USE_VLLM_GATED_NORM = True") + else: + raise ValueError(f"Unknown kernel_config: {kernel_config}") + + if new_content != content: + modeling_path.write_text(new_content) + print(f" Set Transformers kernels to: {kernel_config}") + + # Clear any cached imports + modules_to_remove = [k for k in sys.modules if 'apriel2' in k.lower()] + for mod in modules_to_remove: + del sys.modules[mod] + + +def run_vllm_inference( + model_path: str, + token_ids_list: list[list[int]], + decode_length: int, + batch_size: int, + dtype: str, + no_compile: bool, + revision: str | None, + placement: str | None = None, +) -> tuple[list[list[int]], list[list[dict]]]: + """Run vLLM inference and return generated tokens and logprobs. + + Returns: + - generated_tokens: list of list of token IDs (one list per prompt) + - logprobs_per_position: list of list of logprob dicts (one list per prompt) + """ + from vllm import TokensPrompt + + compile_label = "no-compile" if no_compile else "compiled" + print(f"\nLoading vLLM model ({compile_label}, batch_size={batch_size})...") + compilation_config = CompilationConfig(mode=CompilationMode.NONE) if no_compile else None + + llm = LLM( + model=model_path, + revision=revision, + trust_remote_code=True, + gpu_memory_utilization=0.4, + max_model_len=2048, + dtype=dtype, + compilation_config=compilation_config, + max_num_seqs=batch_size, # Control max concurrent sequences + enable_prefix_caching=False, # Disable for hybrid models + ) + + apply_placement(llm, placement) + + # Create TokensPrompt for each prompt + vllm_prompts = [TokensPrompt(prompt_token_ids=ids) for ids in token_ids_list] + + sampling_params = SamplingParams( + max_tokens=decode_length, + temperature=0, + logprobs=20, + ) + + print(f"Running vLLM inference on {len(vllm_prompts)} prompts (decode_length={decode_length})...") + outputs = llm.generate(vllm_prompts, sampling_params) + + # Extract results + generated_tokens = [] + logprobs_per_position = [] + + for output in outputs: + tokens = list(output.outputs[0].token_ids) if output.outputs[0].token_ids else [] + generated_tokens.append(tokens) + + # Logprobs for each position + lps = [] + if output.outputs[0].logprobs: + for pos_lps in output.outputs[0].logprobs: + lps.append(pos_lps if pos_lps else {}) + logprobs_per_position.append(lps) + + del llm + gc.collect() + torch.cuda.empty_cache() + + return generated_tokens, logprobs_per_position + + +def run_transformers_inference( + model_path: str, + token_ids_list: list[list[int]], + decode_length: int, + batch_size: int, + dtype: str, + revision: str | None, +) -> tuple[list[list[int]], list[list[torch.Tensor]]]: + """Run Transformers inference and return generated tokens and logprobs. + + Args: + batch_size: Number of prompts to process together. For generation, + each prompt still decodes sequentially within the batch. + + Returns: + - generated_tokens: list of list of token IDs (one list per prompt) + - logprobs_per_position: list of list of logprob tensors (one list per prompt) + """ + from transformers import AutoModelForCausalLM + + torch_dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32 + attn_impl = "flash_attention_2" if dtype == "bfloat16" else "eager" + + print(f"\nLoading Transformers model ({attn_impl}, batch_size={batch_size})...") + model = AutoModelForCausalLM.from_pretrained( + model_path, + revision=revision, + torch_dtype=torch_dtype, + device_map="cuda", + trust_remote_code=True, + attn_implementation=attn_impl, + ) + model.eval() + + generated_tokens = [] + logprobs_per_position = [] + + print(f"Running Transformers inference on {len(token_ids_list)} prompts...") + + # Process in batches + for batch_start in range(0, len(token_ids_list), batch_size): + batch_end = min(batch_start + batch_size, len(token_ids_list)) + batch_token_ids = token_ids_list[batch_start:batch_end] + actual_batch_size = len(batch_token_ids) + + # For batch_size > 1, we need to handle each prompt separately for generation + # because sequences grow at different rates and we need per-token logprobs + for i, token_ids in enumerate(batch_token_ids): + input_ids = torch.tensor([token_ids], device="cuda") + prompt_tokens = [] + prompt_logprobs = [] + + with torch.no_grad(): + # Generate tokens one at a time to get logprobs at each step + for step in range(decode_length): + outputs = model(input_ids) + logits = outputs.logits[:, -1, :] # Last position + logprobs = torch.log_softmax(logits.float(), dim=-1).cpu() + + next_token = logits.argmax(dim=-1).item() + prompt_tokens.append(next_token) + prompt_logprobs.append(logprobs[0]) + + # Append to input for next step + input_ids = torch.cat([input_ids, torch.tensor([[next_token]], device="cuda")], dim=1) + + generated_tokens.append(prompt_tokens) + logprobs_per_position.append(prompt_logprobs) + + processed = batch_end + if processed % 10 == 0 or processed == len(token_ids_list): + print(f" Processed {processed}/{len(token_ids_list)} prompts", end="\r") + + print() + + del model + torch.cuda.empty_cache() + + return generated_tokens, logprobs_per_position + + +def compute_comparisons( + vllm_tokens: list[list[int]], + vllm_logprobs: list[list[dict]], + tf_tokens: list[list[int]], + tf_logprobs: list[list[torch.Tensor]], +) -> list[TokenComparison]: + """Compute per-position comparisons between vLLM and Transformers.""" + comparisons = [] + + for prompt_idx, (vt, vl, tt, tl) in enumerate(zip(vllm_tokens, vllm_logprobs, tf_tokens, tf_logprobs)): + # Compare each position + for pos in range(min(len(vt), len(tt), len(vl), len(tl))): + vllm_token = vt[pos] + tf_token = tt[pos] + vllm_lps = vl[pos] + tf_lps = tl[pos] + + # Compute logprob differences for top-K tokens + diffs = [] + if vllm_lps: + for tid, lp_info in list(vllm_lps.items())[:20]: + vllm_lp = lp_info.logprob + tf_lp = tf_lps[tid].item() + diffs.append(abs(vllm_lp - tf_lp)) + + avg_diff = sum(diffs) / len(diffs) if diffs else 0.0 + max_diff = max(diffs) if diffs else 0.0 + + comparisons.append(TokenComparison( + prompt_idx=prompt_idx, + position=pos, + vllm_token_id=vllm_token, + tf_token_id=tf_token, + token_match=(vllm_token == tf_token), + avg_logprob_diff=avg_diff, + max_logprob_diff=max_diff, + top_k_diffs=diffs, + )) + + return comparisons + + +def print_stats_report(comparisons: list[TokenComparison], title: str = "Statistics"): + """Print comprehensive statistics from comparisons.""" + print(f"\n{'='*70}") + print(f" {title}") + print(f"{'='*70}") + + if not comparisons: + print("No comparisons to report.") + return {} + + # Group by position + by_position: dict[int, list[TokenComparison]] = {} + for c in comparisons: + by_position.setdefault(c.position, []).append(c) + + # Overall stats + all_avg_diffs = np.array([c.avg_logprob_diff for c in comparisons]) + all_max_diffs = np.array([c.max_logprob_diff for c in comparisons]) + all_matches = np.array([c.token_match for c in comparisons]) + + n_total = len(comparisons) + n_prompts = len(set(c.prompt_idx for c in comparisons)) + n_positions = len(by_position) + + print(f"\nTotal comparisons: {n_total} ({n_prompts} prompts x {n_positions} positions)") + print(f"Token match rate: {all_matches.sum()}/{n_total} ({100*all_matches.mean():.1f}%)") + + # Per-position stats + print(f"\n--- Per-Position Statistics ---") + print(f"{'Pos':>4} {'N':>6} {'Match%':>8} {'AvgDiff':>10} {'p50':>8} {'p95':>8} {'Max':>8}") + print("-" * 60) + + position_stats = {} + for pos in sorted(by_position.keys()): + pos_comparisons = by_position[pos] + pos_diffs = np.array([c.avg_logprob_diff for c in pos_comparisons]) + pos_matches = np.array([c.token_match for c in pos_comparisons]) + + stats = { + "n": len(pos_comparisons), + "match_rate": pos_matches.mean(), + "avg_diff_mean": pos_diffs.mean(), + "avg_diff_p50": np.percentile(pos_diffs, 50), + "avg_diff_p95": np.percentile(pos_diffs, 95), + "avg_diff_max": pos_diffs.max(), + } + position_stats[pos] = stats + + pos_label = "prefill" if pos == 0 else f"decode{pos}" + print(f"{pos_label:>4} {stats['n']:>6} {100*stats['match_rate']:>7.1f}% " + f"{stats['avg_diff_mean']:>10.4f} {stats['avg_diff_p50']:>8.4f} " + f"{stats['avg_diff_p95']:>8.4f} {stats['avg_diff_max']:>8.4f}") + + # Overall distribution + print(f"\n--- Overall Avg Logprob Diff Distribution ---") + print(f" Mean: {all_avg_diffs.mean():.6f}") + print(f" Std: {all_avg_diffs.std():.6f}") + print(f" p10: {np.percentile(all_avg_diffs, 10):.6f}") + print(f" p50: {np.percentile(all_avg_diffs, 50):.6f}") + print(f" p90: {np.percentile(all_avg_diffs, 90):.6f}") + print(f" p95: {np.percentile(all_avg_diffs, 95):.6f}") + print(f" p99: {np.percentile(all_avg_diffs, 99):.6f}") + print(f" Max: {all_avg_diffs.max():.6f}") + + # Outliers + outlier_threshold = 1.0 + outliers = [c for c in comparisons if c.avg_logprob_diff > outlier_threshold] + if outliers: + print(f"\n--- Outliers (avg diff > {outlier_threshold}) ---") + print(f" Count: {len(outliers)} ({100*len(outliers)/n_total:.1f}%)") + # Show by position + outlier_positions = {} + for o in outliers: + outlier_positions.setdefault(o.position, []).append(o) + for pos, pos_outliers in sorted(outlier_positions.items()): + pos_label = "prefill" if pos == 0 else f"decode{pos}" + print(f" Position {pos_label}: {len(pos_outliers)} outliers") + + return { + "n_total": n_total, + "n_prompts": n_prompts, + "n_positions": n_positions, + "match_rate": all_matches.mean(), + "avg_diff_mean": all_avg_diffs.mean(), + "avg_diff_p50": np.percentile(all_avg_diffs, 50), + "avg_diff_p95": np.percentile(all_avg_diffs, 95), + "avg_diff_max": all_avg_diffs.max(), + "n_outliers": len(outliers), + "position_stats": position_stats, + } + + +def cmd_stats(args): + """Run statistical comparison with many prompts.""" + from transformers import AutoTokenizer + + setup_transformers() + + for model_path in args.model_paths: + print(f"\n{'#'*70}") + print(f"# Statistical Comparison: {Path(model_path).name}") + print(f"# Prompts: {args.num_prompts}, prompt_length: {args.prompt_length}, decode_length: {args.decode_length}") + print(f"# Mode: {'no-compile' if args.no_compile else 'compiled'}, TF kernels: {args.tf_kernels}") + print(f"{'#'*70}") + + revision = getattr(args, 'revision', None) + + # Set Transformers kernel configuration + set_transformers_kernels(model_path, args.tf_kernels) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision, trust_remote_code=True) + + # Load and tokenize prompts from dataset + print(f"\nLoading {args.num_prompts} prompts from C4 (exactly {args.prompt_length} tokens each)...") + token_ids_list = load_and_tokenize_prompts( + args.num_prompts, + args.prompt_length, + tokenizer, + seed=args.seed, + ) + print(f" Prepared {len(token_ids_list)} token sequences") + + # Run vLLM inference + placement = getattr(args, 'placement', None) + vllm_tokens, vllm_logprobs = run_vllm_inference( + model_path, token_ids_list, args.decode_length, + args.batch_size, args.dtype, args.no_compile, revision, placement + ) + + # Run Transformers inference + tf_tokens, tf_logprobs = run_transformers_inference( + model_path, token_ids_list, args.decode_length, + args.batch_size, args.dtype, revision + ) + + # Compute comparisons + comparisons = compute_comparisons(vllm_tokens, vllm_logprobs, tf_tokens, tf_logprobs) + + # Print statistics + mode_label = "no-compile" if args.no_compile else "compiled" + stats = print_stats_report( + comparisons, + f"Results ({mode_label}, TF={args.tf_kernels})" + ) + + print(f"\n{'='*70}") + print(f" SUMMARY: {Path(model_path).name}") + print(f"{'='*70}") + print(f" Mode: {mode_label}") + print(f" TF kernels: {args.tf_kernels}") + print(f" Batch size: {args.batch_size}") + print(f" Dtype: {args.dtype}") + if revision: + print(f" Revision: {revision}") + print(f" Token match rate: {100*stats['match_rate']:.1f}%") + print(f" Avg diff (mean): {stats['avg_diff_mean']:.4f}") + print(f" Avg diff (p95): {stats['avg_diff_p95']:.4f}") + print(f" Avg diff (max): {stats['avg_diff_max']:.4f}") + if stats['n_outliers'] > 0: + print(f" WARNING: {stats['n_outliers']} outliers detected (avg diff > 1.0)") + print() + + +def cmd_logits(args): + """Run logits comparison test.""" + revision = getattr(args, 'revision', None) + debug_gdn = getattr(args, 'debug_gdn', False) + placement = getattr(args, 'placement', None) + for model_path in args.model_paths: + compare_logits(model_path, args.prompt, args.max_tokens, args.dtype, args.no_compile, revision, debug_gdn, placement) + + +def cmd_all(args): + """Run all tests.""" + cmd_coherence(args) + print("\n\n") + cmd_logits(args) + + +def main(): + parser = argparse.ArgumentParser( + description="Test Apriel2 vLLM implementation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + # Placement help text (used by multiple subcommands) + placement_help = ( + "Mixer placement: 'all-attention', 'all-gdn', 'every2nd-gdn', 'every3rd-gdn', " + "or comma-separated list like 'attention,gdn,attention,gdn,...'" + ) + + # Coherence test + p_coherence = subparsers.add_parser("coherence", help="Test generation coherence") + p_coherence.add_argument("model_paths", nargs="+", help="Path(s) to model checkpoint(s)") + p_coherence.add_argument("--max-tokens", type=int, default=50, help="Max tokens to generate") + p_coherence.add_argument("--placement", default=None, help=placement_help) + p_coherence.set_defaults(func=cmd_coherence) + + # Logits comparison + p_logits = subparsers.add_parser("logits", help="Compare logits between vLLM and Transformers") + p_logits.add_argument("model_paths", nargs="+", help="Path(s) to model checkpoint(s)") + p_logits.add_argument("--prompt", default="The capital of France is", help="Input prompt") + p_logits.add_argument("--max-tokens", type=int, default=1, help="Max tokens to generate") + p_logits.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16", help="Data type") + p_logits.add_argument("--no-compile", action="store_true", help="Disable torch.compile") + p_logits.add_argument("--revision", default=None, help="Model revision") + p_logits.add_argument("--debug-gdn", action="store_true", help="Enable GDN debug output") + p_logits.add_argument("--placement", default=None, help=placement_help) + p_logits.set_defaults(func=cmd_logits) + + # Comprehensive comparison + p_compare = subparsers.add_parser("compare", help="Compare across prompt sizes, decode lengths, and batch sizes") + p_compare.add_argument("model_paths", nargs="+", help="Path(s) to model checkpoint(s)") + p_compare.add_argument("--prompt-sizes", default="5,50,200", help="Comma-separated prompt sizes in tokens") + p_compare.add_argument("--decode-lengths", default="1,5,10", help="Comma-separated decode lengths") + p_compare.add_argument("--batch-sizes", default="1,2,4", help="Comma-separated batch sizes") + p_compare.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16", help="Data type") + p_compare.add_argument("--no-compile", action="store_true", help="Disable torch.compile (default: compile enabled)") + p_compare.add_argument("--revision", default=None, help="Model revision") + p_compare.add_argument("--placement", default=None, help=placement_help) + p_compare.set_defaults(func=cmd_compare) + + # Statistical comparison + p_stats = subparsers.add_parser("stats", help="Statistical comparison with many prompts (per-position analysis)") + p_stats.add_argument("model_paths", nargs="+", help="Path(s) to model checkpoint(s)") + p_stats.add_argument("--num-prompts", type=int, default=64, help="Number of prompts to test") + p_stats.add_argument("--prompt-length", type=int, default=256, help="Number of tokens to prefill") + p_stats.add_argument("--decode-length", type=int, default=10, help="Number of tokens to decode") + p_stats.add_argument("--batch-size", type=int, default=1, help="Batch size for inference") + p_stats.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16", help="Data type") + p_stats.add_argument("--no-compile", action="store_true", help="Disable torch.compile (default: compile enabled)") + p_stats.add_argument("--tf-kernels", choices=["upstream", "vllm"], default="upstream", + help="Transformers kernel config: upstream FLA or vLLM forks") + p_stats.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") + p_stats.add_argument("--revision", default=None, help="Model revision") + p_stats.add_argument("--placement", default=None, help=placement_help) + p_stats.set_defaults(func=cmd_stats) + + # All tests + p_all = subparsers.add_parser("all", help="Run all tests") + p_all.add_argument("model_paths", nargs="+", help="Path(s) to model checkpoint(s)") + p_all.add_argument("--prompt", default="The capital of France is", help="Input prompt for logits test") + p_all.add_argument("--max-tokens", type=int, default=50, help="Max tokens for coherence test") + p_all.add_argument("--placement", default=None, help=placement_help) + p_all.set_defaults(func=cmd_all) + + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index 005ae5a8a..e8b3f5b99 100644 --- a/setup.cfg +++ b/setup.cfg @@ -95,3 +95,7 @@ DOCS = [options.entry_points] console_scripts = fast-llm = fast_llm.cli:fast_llm_main + +# vLLM plugin for Apriel2 model registration +vllm.general_plugins = + apriel2 = fast_llm_external_models.apriel2.vllm.config_convertor:register diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 5e7526377..cdb7f5189 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -951,6 +951,7 @@ def update_and_add_testing_config( # TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead). skip_tests=("sdp", "ms", GRAD_ACC, TP_NO_STP), requires_cuda=True, + auto_model_class=transformers.AutoModelForImageTextToText, )