Skip to content

Olmo3 checkpoint conversion and Refactor Olmo3 model to support interleaved RoPE (and attention)#3112

Open
gagika wants to merge 1 commit intomainfrom
agagik-olmo3
Open

Olmo3 checkpoint conversion and Refactor Olmo3 model to support interleaved RoPE (and attention)#3112
gagika wants to merge 1 commit intomainfrom
agagik-olmo3

Conversation

@gagika
Copy link
Collaborator

@gagika gagika commented Feb 8, 2026

Description

This PR adds full support for AllenAI's Olmo3-7B and Olmo3-32B models in MaxText, including checkpoint conversion and forward pass correctness verification.

This implementation addresses several unique architectural features of Olmo3 that required changes to the core layers:

  1. Global QK Normalization: Olmo3 applies RMSNorm across the entire hidden dimension (e.g., 4096) before splitting into heads, whereas MaxText's default RMSNorm applies per-head.

    • Change: Updated src/MaxText/layers/attentions.py to reshape query/key tensors [B, L, H, D] -> [B, L, H*D] before normalization when is_olmo3 is detected.
  2. Mixed RoPE Strategy: Olmo3 uses a hybrid positional embedding strategy where "Sliding Window" layers use standard RoPE, while "Global" layers use YaRN.

    • Change: Updated src/MaxText/layers/olmo3.py to explicitly override the rope_type to "default" for local sliding layers.
    • Change: Updated src/MaxText/layers/attentions.py to accept a rope_type override in __init__, enabling layer-specific RoPE configurations.
  3. Configuration Alignments:

    • Set rope_interleave: False to match Hugging Face's concatenated RoPE.
    • Set rope_truncate: False to prevent frequency drift in YaRN layers.
    • Set normalize_embedding_logits: False as Olmo3 does not normalize output logits.
    • Renamed config files to use hyphens (olmo3-7b.yml) to match standard naming conventions.
  4. Checkpoint Conversion:

    • Added OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING and hooks in param_mapping.py.
    • Implemented identity hooks for Norms to preserve the specific scaling used in Olmo3 checkpoints.

Tests

Tested via checkpoint conversion from Hugging Face and running forward_pass_logit_checker.py to verify KL divergence against the reference HF implementation (BF16 and FP32).

1. Checkpoint Conversion:

python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
    model_name=olmo3-7b \
    hf_access_token=${HF_TOKEN} \
    base_output_directory=${BASE_OUTPUT_DIRECTORY} \
    scan_layers=True

2. Logit Verification (Olmo3-7B):

python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \
    model_name=olmo3-7b \
    load_parameters_path=${CHECKPOINT_PATH} \
    tokenizer_path="allenai/Olmo-3-7B-Instruct" \
    hf_model_path="allenai/Olmo-3-7B-Instruct" \
    run_hf_model=True \
    max_kl_div=0.005 \
    scan_layers=True \
    normalize_embedding_logits=False

Tested both Olmo3-7b and Olmo3-32 logits:
Max KL divergence for a single token in the set: 0.000005

https://paste.googleplex.com/4668070767493120
https://paste.googleplex.com/5281833875013632

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions
Copy link

github-actions bot commented Feb 8, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@codecov
Copy link

codecov bot commented Feb 8, 2026

Codecov Report

❌ Patch coverage is 24.35897% with 59 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...xText/utils/ckpt_conversion/utils/param_mapping.py 0.00% 57 Missing ⚠️
src/MaxText/layers/decoders.py 0.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces support for Olmo3 models, including checkpoint conversion and necessary architectural adjustments for features like global QK normalization and mixed RoPE strategies. The changes are well-structured and the code is clear and maintainable.

🔍 General Feedback

  • The implementation of the Olmo3-specific features within the existing architecture is clean and minimally invasive.
  • The addition of checkpoint conversion utilities for Olmo3 is a valuable contribution.
  • The updates to the testing utilities to better support different dtypes improve the overall quality of the test suite.

Overall, this is a high-quality contribution that is ready for merging.

"""Returns mapping from MaxText to HuggingFace Olmo3 weight paths.

Olmo3 uses an inhomogeneous layer cycle (typically 4 layers: 3 sliding, 1 global).
MaxText handles this by defining sub-layers (layers_0, layers_1...) within a block.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The docstring for OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING is very detailed. For better readability, consider shortening the main docstring and moving the implementation details about sub-layers into a comment within the function body itself.

Suggested change
MaxText handles this by defining sub-layers (layers_0, layers_1...) within a block.
def OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping from MaxText to HuggingFace Olmo3 weight paths."""
# Olmo3 uses an inhomogeneous layer cycle (typically 4 layers: 3 sliding, 1 global).
# MaxText handles this by defining sub-layers (layers_0, layers_1...) within a block.
n_layers = config["num_hidden_layers"]
# Default Olmo3 cycle length is 4 if not specified in config
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval

self.use_mrope = use_mrope
self.mrope_section = mrope_section
self.rngs = rngs
# Use the rope type specified in the arguments if provided, otherwise fall back to the one in the config.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The logic to fall back to config.rope_type is clear. This is a good way to handle the override. The comment is also very helpful.

jnp.float16: torch.float16,
}

# Default to bfloat16 if dtype is unrecognized
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Mapping the jnp dtype to torch dtype is a nice touch for improving the robustness of the logit checker. This is a good defensive programming practice.

@github-actions
Copy link

github-actions bot commented Feb 8, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces support for Olmo3 models, which is a valuable addition. The implementation correctly handles the specific architectural features of Olmo3, such as global QK normalization and the mixed RoPE strategy. The checkpoint conversion utilities and updates to the forward pass checker are also well-implemented.

🔍 General Feedback

  • The use of a rope_type override in the attention layer is a clean way to manage the mixed RoPE requirements of Olmo3.
  • The new GlobalRMSNorm is a good example of extending functionality while maintaining a clear separation of concerns.
  • The parameter mapping for checkpoint conversion is thorough and handles both scanned and unscanned layer configurations correctly.

The overall quality of the code is high, and the changes are well-documented in the PR description. I have one minor suggestion to improve the modularity of the attention layer, but it does not block the merge.

Comment on lines +500 to +502

if self.use_qk_norm and not is_llama4_decoder_block:
self.query_norm = RMSNorm(
num_features=self.head_dim,
# Check if this is Olmo3, which uses a unique "Global" QK Norm strategy.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Using self.config.model_name.startswith("olmo3") to decide whether to use GlobalRMSNorm makes the attention implementation specific to the Olmo3 model. This could make it harder to reuse this logic for future models that might have a similar feature.

A better approach would be to add a new boolean field, for example global_qk_norm, to the model configuration file (olmo3-7b.yml and olmo3-32b.yml). This would make the attention layer more modular and easier to adapt for other models.

Suggested change
if self.use_qk_norm and not is_llama4_decoder_block:
self.query_norm = RMSNorm(
num_features=self.head_dim,
# Check if this is Olmo3, which uses a unique "Global" QK Norm strategy.
# GlobalRMSNorm flattens (Heads, Dim) to normalize across the entire hidden state.
qk_norm_cls = GlobalRMSNorm if self.config.get("global_qk_norm", False) else RMSNorm
# For RMSNorm use `head_dim` (per-head normalization), while for GlobalRMSNorm use `num_heads * head_dim` (global normalization).
q_features = (self.num_query_heads * self.head_dim) if self.config.get("global_qk_norm", False) else self.head_dim
k_features = (self.num_kv_heads * self.head_dim) if self.config.get("global_qk_norm", False) else self.head_dim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants