Skip to content

Commit 0479bdf

Browse files
authored
[BugFix] Fix Llama4 Calibration (#2101)
# SUMMARY: - Applying the router_scores to the hidden states before passing the hidden states to the experts is resulting in NaNs during calibration. - I have gone through the forward pass line-by-line, verified that the dimenions all match / make sense and ensured we are not doing anything different than the transformers definition. However, this issue persists. - Swapping to apply the scores to the expert outputs (as is common for most MoEs) does not cause this problem and results in high recovery. As such, enabling this for the time being so that the llama4 pathway does not produce NaN scales - We can potentially revisit with another dataset but considering how good recovery is, I think this is sufficient to unblock release. - I have left a note about this deviation from the definition in the modeling code # Evals: 98% Recovery ```yaml | Tasks |Version| Filter |n-shot| Metric | |Value| |Stderr| |-----------|------:|----------------|-----:|-----------|---|----:|---|-----:| |gsm8k_llama| 3|flexible_extract| 8|exact_match|↑ |0.934|± |0.0068| | | |strict_match | 8|exact_match|↑ |0.931|± |0.0070| ``` Greater than 98% Recovery ```yaml | Groups |Version| Filter |n-shot| Metric | |Value | |Stderr| |------------------|------:|------------|------|-----------|---|-----:|---|-----:| |mmlu_llama | 1|strict_match| |exact_match|↑ |0.7997|± |0.0032| | - humanities | 1|strict_match| |exact_match|↑ |0.7696|± |0.0059| | - other | 1|strict_match| |exact_match|↑ |0.8172|± |0.0066| | - social sciences| 1|strict_match| |exact_match|↑ |0.8781|± |0.0058| | - stem | 0|strict_match| |exact_match|↑ |0.7510|± |0.0074| ``` Greater than 99% recovery ```yaml | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| |-------------------|------:|------------|-----:|-----------|---|-----:|---|-----:| |arc_challenge_llama| 1|strict_match| 0|exact_match|↑ |0.9296|± |0.0075| ``` Greater than 99% recovery ```yaml | Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr| |----------|------:|------|-----:|------|---|-----:|---|-----:| |winogrande| 1|none | 0|acc |↑ |0.6835|± |0.0131| ``` Greater than 98% recovery ```yaml |truthfulqa_mc2| 3|none | 0|acc |↑ | 0.6177|± |0.0164| ```
1 parent 1abfd9e commit 0479bdf

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

src/llmcompressor/modeling/llama4.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,36 @@ def __init__(
4848

4949
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5050
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
51-
router_scores, router_logits = self.router(hidden_states) # transformers>=4.54
52-
51+
router_scores, router_logits = self.router(hidden_states)
5352
out = self.shared_expert(hidden_states)
5453

55-
for expert_index in range(self.num_experts):
56-
# find expert scores
57-
expert_score = router_scores[:, expert_index].unsqueeze(-1)
58-
top_token_mask = expert_score[:, 0] > 0
54+
_, router_indices = torch.topk(router_logits, self.top_k, dim=1)
55+
expert_mask = torch.nn.functional.one_hot(
56+
router_indices, num_classes=self.num_experts
57+
).permute(2, 1, 0) # (num_experts, top_k, batch_size * sequence_length)
58+
59+
for i in range(self.num_experts):
60+
# fetch relevant token indices for this expert
61+
token_idx = torch.where(expert_mask[i].squeeze(0))
5962

60-
# llama4 applies scores before expert relu
61-
expert_in = hidden_states * expert_score
63+
# Original Llama4 definition - apply score to hidden states
64+
# before applying to expert this results in NaNs during calibration
65+
# routed_in = hidden_states * router_scores[:, i].reshape(-1, 1)
6266

63-
# calibrate experts
6467
if self.calibrate_all_experts:
65-
expert_out = self.experts[expert_index](expert_in)[top_token_mask]
68+
# all tokens for this expert
69+
expert_out = self.experts[i](hidden_states)[token_idx]
6670
else:
67-
expert_out = self.experts[expert_index](expert_in[top_token_mask])
68-
69-
# accumulate output
70-
out[top_token_mask] += expert_out
71+
# only relevant tokens for this expert
72+
expert_out = self.experts[i](hidden_states[token_idx])
73+
74+
if len(token_idx) > 0:
75+
# Deviation from original Llama4 definition to avoid NaNs
76+
# NaNs during calibration
77+
weighted_output = expert_out * router_scores[:, i][token_idx].reshape(
78+
-1, 1
79+
)
80+
out[token_idx] += weighted_output
7181

7282
return out, router_logits
7383

tests/llmcompressor/modeling/test_calib_llama4.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def test_calib_llama4_module():
8585
module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=True)
8686
with calibration_forward_context(module):
8787
out, router_logits = module(sample)
88-
assert torch.nn.functional.mse_loss(true_out, out) < 1e-10
89-
assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 1e-10
88+
assert torch.nn.functional.mse_loss(true_out, out) < 0.1
89+
assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 0.1
9090

9191
module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=False)
9292
with calibration_forward_context(module):
9393
out, router_logits = module(sample)
94-
assert torch.nn.functional.mse_loss(true_out, out) < 1e-10
95-
assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 1e-10
94+
assert torch.nn.functional.mse_loss(true_out, out) < 0.1
95+
assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 0.1

0 commit comments

Comments
 (0)