-
Notifications
You must be signed in to change notification settings - Fork 42
[Prototype] Add vLLM Apriel2 model with plugin-based registration #447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
tscholak
wants to merge
32
commits into
main
Choose a base branch
from
feature/vllm-apriel2-models
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- Add README.md documenting the algebraic structure of the conversion system (surgery monoid, action law, plan composition, total vs partial operations) - Add prune_supernet_step1.yaml and prune_supernet_step2.yaml examples demonstrating the two-step workflow for pruning a homogeneous supernet to a heterogeneous network with different mixer types per layer Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add modeling_apriel2.py with full vLLM-optimized implementation supporting attention, mamba, GDN, and KDA mixer types - Add register() function for runtime model registration via vLLM's ModelRegistry (no patching required) - Based on Nanda's vllm_diff.patch, adapted for external package use Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Refactor weight loading: each mixer module (Attention, MLP, GDN, KDA) now handles its own weight structure via load_weights() methods - Fix KDA mamba_type to use "gdn_attention" for vLLM backend registration - Add KDA op registration import for custom op support - Remove unused positions parameter from KDA forward - Add config_convertor.py for Apriel2TextConfig to vLLM config mapping - Add test_apriel2.py for coherence and logit comparison testing between vLLM and Transformers implementations Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Remove all PyTorch fallback implementations to ensure fast CUDA kernels are always used. The module now fails loudly at import/instantiation if required kernels are missing. Changes: - Remove torch_causal_conv1d_fn and torch_causal_conv1d_update fallbacks - Remove torch_selective_scan_fn and torch_selective_state_update stubs - Remove torch_chunk_gated_delta_rule function - Remove _recurrent_gated_delta_rule method from Apriel2GatedDeltaNet - Remove _forward_local method from GatedRMSNormalization - Remove TestFastVsSlowPath test class (no longer needed) - Handle CausalConv1d seq_len==1 edge case via update() instead of fallback - Add ImportError at module load for missing causal_conv1d/mamba_ssm - Add ImportError at class init for missing FLA kernels Required packages: - causal_conv1d (for CausalConv1d) - mamba_ssm (for Mamba/SSM operations) - fla (for GDN, KDA, GatedRMSNormalization) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The chunk_gated_delta_rule call was always passing initial_state=None, ignoring any existing recurrent state from previous decode cycles. This broke continued generation scenarios (prefill -> decode -> prefill). Changed initial_state=None to initial_state=recurrent_state to match the correct behavior already present in KDA's chunk_kda call. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add test_vs_qwen3next_with_cache and test_vs_fla_with_cache tests that verify mixer implementations through all inference phases: - Phase 1: Initial prefill with cache population - Phase 2: Single-token decode using cached states - Phase 3: Prefill again (decode→prefill transition) Tests compare outputs and recurrent states at each phase. Convolution states are not compared due to different storage formats between implementations (Apriel2 stores kernel_size-1, references store kernel_size). For GDN, Phase 3 documents expected divergence from Qwen3Next due to its bug where chunk mode ignores initial_state. For KDA, all phases should match since FLA correctly passes initial_state in chunk mode. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Merge test_vs_qwen3next and test_vs_qwen3next_with_cache into single parameterized test with use_cache fixture - Merge test_vs_fla and test_vs_fla_with_cache similarly - Add use_cache (False/True) and decode_steps (4) fixtures - Use proper Apriel2Cache from cache.py instead of ad-hoc SimpleCache - Use same total sequence length for both cache and non-cache modes - Skip cache tests when seq_len < decode_steps + 2 (too small for 3 phases) - Split sequence as: prefill=2/3, decode=4, prefill2=1/3 of remaining Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix KDA mode selection to match FLA: use fused_recurrent only when seq_len <= 64 AND not training (single expression instead of override) - Replace use_cache fixture with explicit phase fixtures (prefill_len, decode_steps, prefill2_len) for clearer test parameterization - Update test_chunked_vs_recurrent to use Apriel2Cache and fixtures - Rename config_dict to mixer_config for consistency across all tests - Remove unused qwen3_config fixture (recreated inline where needed) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
CausalConv1d is now tested through KDA equivalence tests which use CausalConv1d for q_conv, k_conv, v_conv. The isolated tests were also obsolete since CPU fallback was removed. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Move all cache classes (_AttentionCache, _SSMCache, _DummyCacheLayer, Apriel2Cache, _LayerListAccessor) into modeling_apriel2.py for better tooling compatibility - modeling code is expected to be together. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Enable "fast" mode (bf16/sdpa) tests that were previously skipped - Add test_dtype fixture parameter to all tests that create models - Convert models to correct dtype with .to(device="cuda", dtype=test_dtype) - Create input tensors with explicit dtype parameter - Fix assert_close to cast tensors to same dtype before comparison All 1718 mixer equivalence tests now pass in both fp32 and bf16 modes. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add gdn_mixer_config and kda_mixer_config fixtures to centralize mixer config dict construction (eliminates 6 duplicate dicts) - Add kda_hidden_size fixture for derived hidden_size calculation - Add make_apriel2_config() helper for minimal Apriel2TextConfig construction (eliminates 4 duplicate config blocks) - Update all GDN and KDA tests to use new fixtures - Consolidate duplicate imports within test methods Net reduction: 47 lines (-125/+78) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix rope_theta parameter: use 'rope_theta' key instead of 'base' in get_rope() call. This fixes attention alignment (0.002 fp32 / 0.05 bf16) - Switch GDN from qwen3_fused_gdn_gating to fused_gdn_gating - Add commented-out GQA head expansion code for GDN (WIP) - Add dtype parameter to test_apriel2.py for bf16/fp32 comparison - Use flash_attention_2 for bf16 transformers to match vLLM backend Current alignment status: - attn-swa: ✅ MATCH (0.002 fp32 / 0.05 bf16) - KDA: ✅ MATCH (0.003 fp32 / 0.07 bf16) - GDN: ❌ MISMATCH (14.6 - investigation ongoing) Co-Authored-By: Claude <noreply@anthropic.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Base automatically changed from
fix/require-cuda-kernels-no-fallbacks
to
oo/apriel_modeling_bug
January 19, 2026 14:42
…sigmoid The vLLM KDA implementation was hardcoding activation="sigmoid" for the output normalization, while the transformers implementation defaults to "silu" when not specified in config. This caused significant logprob differences (avg 1.1) between vLLM and transformers. Now reads norm_activation from mixer_config.normalization.activation with default "silu" to match transformers behavior. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Changes to transformers model (modeling_apriel2.py): - Add USE_VLLM_CONV, USE_VLLM_GDN_OPS, USE_VLLM_GATED_NORM flags - Restructure kernel imports to use vLLM ops when flags enabled - Add _debug_enabled, _debug_layer, _debug_final flags for debugging - Handle vLLM vs FLA signature differences for fused_recurrent_gated_delta_rule Changes to vLLM model (vllm/modeling_apriel2.py): - Add _debug_enabled, _debug_layer flags for GDN mixer - Add _debug_final, _debug_lm_head flags for final norm and LM head - Gate debug prints with boolean flags instead of num_tokens checks Changes to test script (vllm/test_apriel2.py): - Add comprehensive comparison command for vLLM vs TF logprob testing - Test across prompt sizes, decode lengths, and batch sizes Results: Prefill logprobs now match perfectly between vLLM and TF when using vLLM kernels (USE_VLLM_GDN_OPS=True, USE_VLLM_GATED_NORM=True). Some divergence remains during multi-token decode for certain prompt lengths. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add _debug_state flag and _debug_state_stats() method to both TF and vLLM GDN mixer classes to track recurrent state evolution during prefill and decode phases. Key additions: - TF: Debug state after prefill and during decode for layer 1 - vLLM: Debug state with correct slot indexing for decode phase - Print state statistics (mean, std, min, max, first8 values) This helps investigate the decode divergence at specific prompt lengths (50, 51, 59, 60, 70 tokens) where vLLM and TF produce different results. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add pure_gdn_step1.yaml: converts fixed -> pattern with all GDN blocks - Add pure_gdn_step2.yaml: unwraps stochastic -> pure GDN mixer - Improve TF GDN debug logging with try/except for tensor access - Add vLLM GDN debug output logging during decode phase - Add first mismatch details in test_apriel2.py compare output Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Replace scattered class-level and function-local debug flags with top-level DEBUG_* constants for easier control: - DEBUG_GDN_LAYER: GDN layer forward pass (tensors, shapes) - DEBUG_GDN_STATE: GDN recurrent state during decode - DEBUG_GDN_OUTPUT: GDN output hidden states during decode - DEBUG_KDA_LAYER: KDA layer outputs - DEBUG_DECODER_LAYER: Decoder layer outputs (residual, norm) - DEBUG_FINAL_NORM: Final norm before LM head - DEBUG_LM_HEAD: LM head input/output Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Always call repeat_interleave for K→V head expansion (no-op when value_heads_per_key == 1) to avoid conditional branches that confuse torch.compile's shape inference. Also temporarily comment out compilation_config in test script while investigating hybrid model compilation issues. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Also keep USE_VLLM_* flags at False for upstream kernel testing. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Change AttentionDecoderLayer.forward signature: move positions to optional kwarg - All layers now accept (hidden_states, residual, positions=None, **kwargs) - Remove isinstance dispatch in Apriel2Model.forward loop - Call all layer types uniformly with same arguments Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Match Llama's approach: use torch._check to assert relationship between positions and input_ids sizes without hardcoding values. This helps the compiler understand dynamic shapes during chunked prefill warmup. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Adds vLLM-optimized Apriel2 model implementation to
fast_llm_external_modelsModelRegistry.register_model()for runtime registration (no vLLM patching required)Attribution
Model implementation based on work by @nandahkrishna from the
apriel2-vllmbranch. This PR adapts that implementation for plugin-based registration as an alternative to patching vLLM directly.Goal
Evaluate whether vLLM's plugin/registration mechanism can work for us as a short-term solution, avoiding the need to maintain a patched vLLM fork.
Usage
vLLM vs Transformers Alignment Verification
Statistical comparison using
test_apriel2.py statscommand with:Models Tested
pure-gdnattn-swaevery5th-kdaResults Summary
Per-Position Token Match Rate (no-compile mode)
Key Findings
1. Divergence is NOT mixer-specific
All models (GDN, SWA, KDA) show similar divergence patterns between vLLM and Transformers. This indicates the issue is in shared model code (RMSNorm, MLP, embeddings) rather than mixer implementations.
2. torch.compile has minimal impact
Compile vs no-compile produces nearly identical results:
Previous reports of GDN torch.compile issues appear to have been measurement artifacts.
3. Divergence accumulates over decode steps
Small numerical differences compound during autoregressive generation, causing progressive divergence.
4. Prefill is well-aligned
All models show excellent prefill alignment (95-98% match, avg diff ~0.04), making them reliable for likelihood-based evaluation (MMLU, etc.).
Implications
For likelihood-based evaluation (MMLU)
✅ All models reliable - prefill-only evaluation shows 95-98% alignment
For generative evaluation (GSM8K)
Root Cause Investigation Needed
The divergence affects all model types equally, suggesting the issue is in:
Test Configuration
Test plan
🤖 Generated with Claude Code