Olmo3 checkpoint conversion and Refactor Olmo3 model to support interleaved RoPE (and attention)#3112
Olmo3 checkpoint conversion and Refactor Olmo3 model to support interleaved RoPE (and attention)#3112
Conversation
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
📋 Review Summary
This pull request introduces support for Olmo3 models, including checkpoint conversion and necessary architectural adjustments for features like global QK normalization and mixed RoPE strategies. The changes are well-structured and the code is clear and maintainable.
🔍 General Feedback
- The implementation of the Olmo3-specific features within the existing architecture is clean and minimally invasive.
- The addition of checkpoint conversion utilities for Olmo3 is a valuable contribution.
- The updates to the testing utilities to better support different dtypes improve the overall quality of the test suite.
Overall, this is a high-quality contribution that is ready for merging.
| """Returns mapping from MaxText to HuggingFace Olmo3 weight paths. | ||
|
|
||
| Olmo3 uses an inhomogeneous layer cycle (typically 4 layers: 3 sliding, 1 global). | ||
| MaxText handles this by defining sub-layers (layers_0, layers_1...) within a block. |
There was a problem hiding this comment.
🟢 The docstring for OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING is very detailed. For better readability, consider shortening the main docstring and moving the implementation details about sub-layers into a comment within the function body itself.
| MaxText handles this by defining sub-layers (layers_0, layers_1...) within a block. | |
| def OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): | |
| """Returns mapping from MaxText to HuggingFace Olmo3 weight paths.""" | |
| # Olmo3 uses an inhomogeneous layer cycle (typically 4 layers: 3 sliding, 1 global). | |
| # MaxText handles this by defining sub-layers (layers_0, layers_1...) within a block. | |
| n_layers = config["num_hidden_layers"] | |
| # Default Olmo3 cycle length is 4 if not specified in config | |
| layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval |
| self.use_mrope = use_mrope | ||
| self.mrope_section = mrope_section | ||
| self.rngs = rngs | ||
| # Use the rope type specified in the arguments if provided, otherwise fall back to the one in the config. |
There was a problem hiding this comment.
🟢 The logic to fall back to config.rope_type is clear. This is a good way to handle the override. The comment is also very helpful.
| jnp.float16: torch.float16, | ||
| } | ||
|
|
||
| # Default to bfloat16 if dtype is unrecognized |
There was a problem hiding this comment.
🟢 Mapping the jnp dtype to torch dtype is a nice touch for improving the robustness of the logit checker. This is a good defensive programming practice.
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request introduces support for Olmo3 models, which is a valuable addition. The implementation correctly handles the specific architectural features of Olmo3, such as global QK normalization and the mixed RoPE strategy. The checkpoint conversion utilities and updates to the forward pass checker are also well-implemented.
🔍 General Feedback
- The use of a
rope_typeoverride in the attention layer is a clean way to manage the mixed RoPE requirements of Olmo3. - The new
GlobalRMSNormis a good example of extending functionality while maintaining a clear separation of concerns. - The parameter mapping for checkpoint conversion is thorough and handles both scanned and unscanned layer configurations correctly.
The overall quality of the code is high, and the changes are well-documented in the PR description. I have one minor suggestion to improve the modularity of the attention layer, but it does not block the merge.
|
|
||
| if self.use_qk_norm and not is_llama4_decoder_block: | ||
| self.query_norm = RMSNorm( | ||
| num_features=self.head_dim, | ||
| # Check if this is Olmo3, which uses a unique "Global" QK Norm strategy. |
There was a problem hiding this comment.
🟡 Using self.config.model_name.startswith("olmo3") to decide whether to use GlobalRMSNorm makes the attention implementation specific to the Olmo3 model. This could make it harder to reuse this logic for future models that might have a similar feature.
A better approach would be to add a new boolean field, for example global_qk_norm, to the model configuration file (olmo3-7b.yml and olmo3-32b.yml). This would make the attention layer more modular and easier to adapt for other models.
| if self.use_qk_norm and not is_llama4_decoder_block: | |
| self.query_norm = RMSNorm( | |
| num_features=self.head_dim, | |
| # Check if this is Olmo3, which uses a unique "Global" QK Norm strategy. | |
| # GlobalRMSNorm flattens (Heads, Dim) to normalize across the entire hidden state. | |
| qk_norm_cls = GlobalRMSNorm if self.config.get("global_qk_norm", False) else RMSNorm | |
| # For RMSNorm use `head_dim` (per-head normalization), while for GlobalRMSNorm use `num_heads * head_dim` (global normalization). | |
| q_features = (self.num_query_heads * self.head_dim) if self.config.get("global_qk_norm", False) else self.head_dim | |
| k_features = (self.num_kv_heads * self.head_dim) if self.config.get("global_qk_norm", False) else self.head_dim |
Description
This PR adds full support for AllenAI's Olmo3-7B and Olmo3-32B models in MaxText, including checkpoint conversion and forward pass correctness verification.
This implementation addresses several unique architectural features of Olmo3 that required changes to the core layers:
Global QK Normalization: Olmo3 applies RMSNorm across the entire hidden dimension (e.g., 4096) before splitting into heads, whereas MaxText's default
RMSNormapplies per-head.src/MaxText/layers/attentions.pyto reshape query/key tensors[B, L, H, D] -> [B, L, H*D]before normalization whenis_olmo3is detected.Mixed RoPE Strategy: Olmo3 uses a hybrid positional embedding strategy where "Sliding Window" layers use standard RoPE, while "Global" layers use YaRN.
src/MaxText/layers/olmo3.pyto explicitly override therope_typeto"default"for local sliding layers.src/MaxText/layers/attentions.pyto accept arope_typeoverride in__init__, enabling layer-specific RoPE configurations.Configuration Alignments:
rope_interleave: Falseto match Hugging Face's concatenated RoPE.rope_truncate: Falseto prevent frequency drift in YaRN layers.normalize_embedding_logits: Falseas Olmo3 does not normalize output logits.olmo3-7b.yml) to match standard naming conventions.Checkpoint Conversion:
OLMO3_MAXTEXT_TO_HF_PARAM_MAPPINGand hooks inparam_mapping.py.Tests
Tested via checkpoint conversion from Hugging Face and running
forward_pass_logit_checker.pyto verify KL divergence against the reference HF implementation (BF16 and FP32).1. Checkpoint Conversion:
python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \ model_name=olmo3-7b \ hf_access_token=${HF_TOKEN} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ scan_layers=True2. Logit Verification (Olmo3-7B):
python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \ model_name=olmo3-7b \ load_parameters_path=${CHECKPOINT_PATH} \ tokenizer_path="allenai/Olmo-3-7B-Instruct" \ hf_model_path="allenai/Olmo-3-7B-Instruct" \ run_hf_model=True \ max_kl_div=0.005 \ scan_layers=True \ normalize_embedding_logits=FalseTested both Olmo3-7b and Olmo3-32 logits:
Max KL divergence for a single token in the set: 0.000005https://paste.googleplex.com/4668070767493120
https://paste.googleplex.com/5281833875013632
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.