Commit 0479bdf
authored
[BugFix] Fix Llama4 Calibration (#2101)
# SUMMARY:
- Applying the router_scores to the hidden states before passing the
hidden states to the experts is resulting in NaNs during calibration.
- I have gone through the forward pass line-by-line, verified that the
dimenions all match / make sense and ensured we are not doing anything
different than the transformers definition. However, this issue
persists.
- Swapping to apply the scores to the expert outputs (as is common for
most MoEs) does not cause this problem and results in high recovery. As
such, enabling this for the time being so that the llama4 pathway does
not produce NaN scales
- We can potentially revisit with another dataset but considering how
good recovery is, I think this is sufficient to unblock release.
- I have left a note about this deviation from the definition in the
modeling code
# Evals:
98% Recovery
```yaml
| Tasks |Version| Filter |n-shot| Metric | |Value| |Stderr|
|-----------|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k_llama| 3|flexible_extract| 8|exact_match|↑ |0.934|± |0.0068|
| | |strict_match | 8|exact_match|↑ |0.931|± |0.0070|
```
Greater than 98% Recovery
```yaml
| Groups |Version| Filter |n-shot| Metric | |Value | |Stderr|
|------------------|------:|------------|------|-----------|---|-----:|---|-----:|
|mmlu_llama | 1|strict_match| |exact_match|↑ |0.7997|± |0.0032|
| - humanities | 1|strict_match| |exact_match|↑ |0.7696|± |0.0059|
| - other | 1|strict_match| |exact_match|↑ |0.8172|± |0.0066|
| - social sciences| 1|strict_match| |exact_match|↑ |0.8781|± |0.0058|
| - stem | 0|strict_match| |exact_match|↑ |0.7510|± |0.0074|
```
Greater than 99% recovery
```yaml
| Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr|
|-------------------|------:|------------|-----:|-----------|---|-----:|---|-----:|
|arc_challenge_llama| 1|strict_match| 0|exact_match|↑ |0.9296|± |0.0075|
```
Greater than 99% recovery
```yaml
| Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr|
|----------|------:|------|-----:|------|---|-----:|---|-----:|
|winogrande| 1|none | 0|acc |↑ |0.6835|± |0.0131|
```
Greater than 98% recovery
```yaml
|truthfulqa_mc2| 3|none | 0|acc |↑ | 0.6177|± |0.0164|
```1 parent 1abfd9e commit 0479bdf
File tree
2 files changed
+28
-18
lines changed- src/llmcompressor/modeling
- tests/llmcompressor/modeling
2 files changed
+28
-18
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
51 | | - | |
52 | | - | |
| 51 | + | |
53 | 52 | | |
54 | 53 | | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
59 | 62 | | |
60 | | - | |
61 | | - | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
62 | 66 | | |
63 | | - | |
64 | 67 | | |
65 | | - | |
| 68 | + | |
| 69 | + | |
66 | 70 | | |
67 | | - | |
68 | | - | |
69 | | - | |
70 | | - | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
71 | 81 | | |
72 | 82 | | |
73 | 83 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
85 | 85 | | |
86 | 86 | | |
87 | 87 | | |
88 | | - | |
89 | | - | |
| 88 | + | |
| 89 | + | |
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
93 | 93 | | |
94 | | - | |
95 | | - | |
| 94 | + | |
| 95 | + | |
0 commit comments