-
Notifications
You must be signed in to change notification settings - Fork 327
[CPU] Linearize gpt_oss model and add example to quantize it to w4a8 #2113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
dsikka
merged 6 commits into
vllm-project:main
from
isharif168:convert_to_linear_gpt_oss
Dec 19, 2025
+338
−0
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
05c808c
[CPU] Linearize gpt_oss model and add separate example to quantize it…
isharif168 3c79b94
Merge branch 'main' into convert_to_linear_gpt_oss
dsikka a35ccce
Merge branch 'main' into convert_to_linear_gpt_oss
dsikka bf54df6
Merge branch 'main' into convert_to_linear_gpt_oss
shanjiaz 0bbb0e3
Merge branch 'main' into convert_to_linear_gpt_oss
dsikka 708b6b8
Address quality checks issues
isharif168 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| import torch | ||
| from compressed_tensors.quantization import QuantizationScheme | ||
| from compressed_tensors.quantization.quant_args import ( | ||
| QuantizationArgs, | ||
| QuantizationStrategy, | ||
| QuantizationType, | ||
| ) | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modeling.gpt_oss import convert_model_for_quantization_gptoss | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
|
|
||
| def main(): | ||
| MODEL_ID = "openai/gpt-oss-20b" | ||
| BASE_NAME = MODEL_ID.rstrip("/").split("/")[-1] | ||
| OUTPUT_DIR = f"{BASE_NAME}-w4a8-channelwise" | ||
|
|
||
| print(f"[GPT-OSS] Loading model: {MODEL_ID}") | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, | ||
| torch_dtype=torch.bfloat16, | ||
| device_map="auto", | ||
| trust_remote_code=True, | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | ||
|
|
||
| # ---- GPT-OSS MoE → linear experts conversion ---- | ||
| print("[GPT-OSS] Converting fused MoE experts to LinearExperts for quantization...") | ||
| convert_model_for_quantization_gptoss(model) | ||
| print("[GPT-OSS] Conversion completed.") | ||
|
|
||
| # ---- Quantization config: W4A8 (int4 weights, int8 activations) ---- | ||
|
|
||
| # Weights: 4-bit, channelwise, symmetric, static | ||
| weights_args = QuantizationArgs( | ||
| num_bits=4, | ||
| type=QuantizationType.INT, | ||
| strategy=QuantizationStrategy.CHANNEL, | ||
| symmetric=True, | ||
| dynamic=False, | ||
| ) | ||
|
|
||
| # Activations: 8-bit, per-token, asymmetric, dynamic | ||
| activations_args = QuantizationArgs( | ||
| num_bits=8, | ||
| type=QuantizationType.INT, | ||
| strategy=QuantizationStrategy.TOKEN, | ||
| symmetric=False, | ||
| dynamic=True, | ||
| observer=None, | ||
| ) | ||
|
|
||
| # Apply to all Linear layers, excluding lm_head | ||
| scheme = QuantizationScheme( | ||
| targets=["Linear"], | ||
| weights=weights_args, | ||
| input_activations=activations_args, | ||
| ) | ||
|
|
||
| recipe = QuantizationModifier( | ||
| config_groups={"group_0": scheme}, | ||
| ignore=["lm_head"], | ||
| ) | ||
|
|
||
| print(f"[GPT-OSS] Starting oneshot quantization → {OUTPUT_DIR}") | ||
| oneshot( | ||
| model=model, | ||
| recipe=recipe, | ||
| tokenizer=tokenizer, | ||
| output_dir=OUTPUT_DIR, | ||
| trust_remote_code_model=True, | ||
| ) | ||
| print(f"[GPT-OSS] Quantization finished. Quantized model written to: {OUTPUT_DIR}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,259 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| class LinearExpert(nn.Module): | ||
| """ | ||
| One MoE expert with separate gate / up / down projections. | ||
| This mirrors the GPT-OSS expert behavior: | ||
| gate = clamp(gate_proj(x)) | ||
| up = clamp(up_proj(x)) | ||
| glu = gate * sigmoid(alpha * gate) | ||
| y = down_proj((up + 1) * glu) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, hidden_size: int, intermediate_size: int, alpha: float, limit: float | ||
| ): | ||
| super().__init__() | ||
| self.alpha = alpha | ||
| self.limit = limit | ||
|
|
||
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True) | ||
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True) | ||
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| gate = self.gate_proj(x) | ||
| up = self.up_proj(x) | ||
|
|
||
| gate = gate.clamp(max=self.limit) | ||
| up = up.clamp(min=-self.limit, max=self.limit) | ||
|
|
||
| glu = gate * torch.sigmoid(self.alpha * gate) | ||
| act = (up + 1) * glu | ||
| return self.down_proj(act) | ||
|
|
||
|
|
||
| class LinearExperts(nn.Module): | ||
| """ | ||
| Container of multiple LinearExpert modules, driven by | ||
| router_indices / routing_weights. | ||
| This is the "separate gate/up" layout. | ||
| It is meant to replace the original GPT-OSS `experts` submodule. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| num_experts: int, | ||
| alpha: float = 1.702, | ||
| limit: float = 7.0, | ||
| ): | ||
| super().__init__() | ||
| self.hidden_size = hidden_size | ||
| self.expert_dim = intermediate_size | ||
| self.num_experts = num_experts | ||
| self.alpha = alpha | ||
| self.limit = limit | ||
|
|
||
| self.experts = nn.ModuleList( | ||
| [ | ||
| LinearExpert(hidden_size, intermediate_size, alpha, limit) | ||
| for _ in range(num_experts) | ||
| ] | ||
| ) | ||
|
|
||
| @torch.no_grad() | ||
| def copy_from_fused_weights( | ||
| self, | ||
| legacy_gate_up_W: torch.Tensor, # [E, H, 2D] | ||
| legacy_gate_up_b: torch.Tensor, # [E, 2D] | ||
| legacy_down_W: torch.Tensor, # [E, D, H] | ||
| legacy_down_b: torch.Tensor, # [E, H] | ||
| ) -> None: | ||
| """ | ||
| De-interleave fused gate_up weights/bias and copy into separate gate/up experts. | ||
| """ | ||
| E, H, twoD = legacy_gate_up_W.shape | ||
| assert E == self.num_experts | ||
| D = twoD // 2 | ||
| assert D == self.expert_dim | ||
|
|
||
| for i in range(E): | ||
| Wi = legacy_gate_up_W[i] # [H, 2D] | ||
| bi = legacy_gate_up_b[i] # [2D] | ||
|
|
||
| Wg = Wi[:, 0::2].contiguous() # [H, D] | ||
| Wu = Wi[:, 1::2].contiguous() # [H, D] | ||
| bg = bi[0::2].contiguous() # [D] | ||
| bu = bi[1::2].contiguous() # [D] | ||
|
|
||
| expert = self.experts[i] | ||
| expert.gate_proj.weight.copy_(Wg.t()) | ||
| expert.gate_proj.bias.copy_(bg) | ||
| expert.up_proj.weight.copy_(Wu.t()) | ||
| expert.up_proj.bias.copy_(bu) | ||
|
|
||
| expert.down_proj.weight.copy_(legacy_down_W[i].t()) | ||
| expert.down_proj.bias.copy_(legacy_down_b[i]) | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, # [B, T, H] | ||
| router_indices: Optional[ | ||
| torch.Tensor | ||
| ] = None, # [B, T, top_k] or [tokens, top_k] | ||
| routing_weights: Optional[torch.Tensor] = None, # [B, T, E] or [tokens, E] | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Implements the MoE computation using the router outputs. | ||
| This is compatible with the GPT-OSS MoE call pattern: | ||
| experts(hidden_states, router_indices, routing_weights) | ||
| """ | ||
| assert ( | ||
| routing_weights is not None and router_indices is not None | ||
| ), "router inputs required" | ||
|
|
||
| # Normalize shapes to [tokens, H], [tokens, top_k], [tokens, E] | ||
| if hidden_states.dim() == 3: | ||
| B, T, H = hidden_states.shape | ||
| x = hidden_states.reshape(-1, H) | ||
| else: | ||
| # Already flattened | ||
| B, _ = 1, hidden_states.shape[0] | ||
| H = hidden_states.shape[-1] | ||
| x = hidden_states | ||
|
|
||
| if router_indices.dim() == 3: | ||
| router_indices = router_indices.reshape(-1, router_indices.shape[-1]) | ||
| if routing_weights.dim() == 3: | ||
| routing_weights = routing_weights.reshape(-1, routing_weights.shape[-1]) | ||
|
|
||
| num_experts_plus_dummy = routing_weights.shape[1] | ||
| out = torch.zeros_like(x) | ||
|
|
||
| # GPT-OSS router uses an extra "no expert" bucket at index E | ||
| with torch.no_grad(): | ||
| expert_mask = torch.nn.functional.one_hot( | ||
| router_indices, num_classes=num_experts_plus_dummy | ||
| ).permute(2, 1, 0) | ||
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||
|
|
||
| for idx in expert_hit: | ||
isharif168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| e = idx[0].item() | ||
| if e == self.num_experts: | ||
| # Skip "no expert" bucket | ||
| continue | ||
|
|
||
| _, token_idx = torch.where(expert_mask[e]) | ||
| xi = x[token_idx] | ||
|
|
||
| expert = self.experts[e] | ||
| yi = expert(xi) | ||
|
|
||
| w = routing_weights[token_idx, e, None] | ||
| out.index_add_(0, token_idx, (yi * w).to(out.dtype)) | ||
|
|
||
| return out.view(B, -1, H) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ExpertMeta: | ||
| path: str | ||
| hidden_size: int | ||
| intermediate_size: int | ||
| num_experts: int | ||
| device: torch.device | ||
| dtype: torch.dtype | ||
|
|
||
|
|
||
| def get_module_by_path(root: nn.Module, dotpath: str) -> nn.Module: | ||
isharif168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| m: nn.Module = root | ||
| if not dotpath: | ||
| return root | ||
| for p in dotpath.split("."): | ||
| m = getattr(m, p) | ||
| return m | ||
|
|
||
|
|
||
| def set_module_by_path(root: nn.Module, dotpath: str, new_module: nn.Module) -> None: | ||
isharif168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| parts = dotpath.split(".") | ||
| parent = get_module_by_path(root, ".".join(parts[:-1])) | ||
| setattr(parent, parts[-1], new_module) | ||
|
|
||
|
|
||
| def find_experts(model: nn.Module) -> List[ExpertMeta]: | ||
| """ | ||
| Locate GPT-OSS MoE expert modules under model.model.layers[*].mlp.experts. | ||
| """ | ||
| metas: List[ExpertMeta] = [] | ||
isharif168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for li, layer in enumerate(model.model.layers): | ||
| experts = layer.mlp.experts | ||
| device = next(experts.parameters(), torch.zeros(())).device | ||
| dtype = next(experts.parameters(), torch.zeros(())).dtype | ||
isharif168 marked this conversation as resolved.
Show resolved
Hide resolved
isharif168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| intermediate = getattr(experts, "expert_dim", None) | ||
| if intermediate is None: | ||
| intermediate = getattr(experts, "intermediate_size") | ||
isharif168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| metas.append( | ||
| ExpertMeta( | ||
| path=f"model.layers.{li}.mlp.experts", | ||
| hidden_size=experts.hidden_size, | ||
| intermediate_size=intermediate, | ||
| num_experts=experts.num_experts, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
| ) | ||
| return metas | ||
|
|
||
|
|
||
| def convert_model_for_quantization_gptoss(model: nn.Module) -> None: | ||
| """ | ||
| In-place conversion of a GPT-OSS model: | ||
| - Finds all fused MoE expert blocks (with gate_up_proj/down_proj). | ||
| - Replaces them with LinearExperts that expose plain nn.Linear | ||
| parameters (gate_proj, up_proj, down_proj), which play nicely | ||
| with LLM Compressor W4A8 quantization. | ||
| """ | ||
| metas = find_experts(model) | ||
| for meta in metas: | ||
| legacy = get_module_by_path(model, meta.path) | ||
|
|
||
| # Sanity check that this is the fused layout we expect. | ||
| if not all( | ||
| hasattr(legacy, attr) | ||
| for attr in [ | ||
| "gate_up_proj", | ||
| "gate_up_proj_bias", | ||
| "down_proj", | ||
| "down_proj_bias", | ||
| ] | ||
| ): | ||
| continue | ||
|
|
||
| new_exp = LinearExperts( | ||
| hidden_size=meta.hidden_size, | ||
| intermediate_size=meta.intermediate_size, | ||
| num_experts=meta.num_experts, | ||
| ).to(device=meta.device, dtype=meta.dtype) | ||
|
|
||
| new_exp.copy_from_fused_weights( | ||
| legacy_gate_up_W=legacy.gate_up_proj, | ||
| legacy_gate_up_b=legacy.gate_up_proj_bias, | ||
| legacy_down_W=legacy.down_proj, | ||
| legacy_down_b=legacy.down_proj_bias, | ||
| ) | ||
|
|
||
| set_module_by_path(model, meta.path, new_exp) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.