Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/MaxText/utils/ckpt_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
scan_layers: (bool) Whether the MaxText model was trained with scanned layers.
This must match the training configuration of the checkpoint.

Optional Flags:
--override_model_architecture: If set, overrides the HF model configuration
with values from the MaxText configuration
(e.g., num_heads, hidden_size) instead of failing.

Environment Variables:
HF_AUTH_TOKEN: (Required) A HuggingFace authentication token. This is needed
to download the correct tokenizer configuration and to upload
Expand Down Expand Up @@ -59,6 +64,7 @@
from transformers import AutoTokenizer, AutoProcessor

from absl import app
from absl import flags

from MaxText import pyconfig
from MaxText.utils.ckpt_conversion.utils.param_mapping import (
Expand All @@ -80,6 +86,15 @@
from maxtext.utils import max_logging
from maxtext.utils import max_utils

flags.DEFINE_bool(
"override_model_architecture",
False,
"If True, overrides Hugging Face config architecture parameters (heads, layers, dims) "
"with values from the MaxText config. If False, raises a ValueError on mismatch.",
)

FLAGS = flags.FLAGS


def _get_model_mappings(
model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: pyconfig.HyperParameters
Expand Down Expand Up @@ -109,6 +124,71 @@ def _get_model_mappings(
}


def _validate_or_update_architecture(hf_config, max_config, override: bool):
"""Validates consistency between HF and MaxText configs or overrides HF config if requested.

Args:
hf_config: The Hugging Face configuration object.
max_config: The MaxText configuration object (HyperParameters).
override: Boolean, if True, update hf_config with max_config values.
If False, raise error on mismatch.
"""
# Mapping from Hugging Face config attribute -> MaxText config attribute
# Note: We use derived MaxText attributes (e.g. emb_dim) which account for scale factors.
attributes_to_check = [
("num_attention_heads", "num_query_heads"),
("num_key_value_heads", "num_kv_heads"),
("head_dim", "head_dim"),
("hidden_size", "emb_dim"),
("intermediate_size", "mlp_dim"),
("num_hidden_layers", "num_decoder_layers"),
("vocab_size", "vocab_size"),
]

mismatches = []

for hf_attr, mt_attr in attributes_to_check:
# Skip checks if the HF config doesn't have this attribute (e.g. layer_norm_eps vs rms_norm_eps)
if not hasattr(hf_config, hf_attr):
continue

# Skip checks if MaxText config doesn't have the attribute (shouldn't happen for valid configs)
if not hasattr(max_config, mt_attr):
continue

hf_value = getattr(hf_config, hf_attr)
mt_value = getattr(max_config, mt_attr)

# Handle None values
if hf_value is None or mt_value is None:
continue

# Compare values (with tolerance for floats)
is_match = False
if isinstance(hf_value, float) or isinstance(mt_value, float):
try:
is_match = abs(float(hf_value) - float(mt_value)) < 1e-6
except (ValueError, TypeError):
is_match = hf_value == mt_value
else:
is_match = hf_value == mt_value

if not is_match:
if override:
max_logging.log(f"⚠️ Overwriting HF Config '{hf_attr}': {hf_value} -> {mt_value} (from MaxText '{mt_attr}')")
setattr(hf_config, hf_attr, mt_value)
else:
mismatches.append(f"{hf_attr} (HF={hf_value} vs MaxText={mt_value})")

if mismatches:
error_msg = (
"Architecture mismatches detected between standard Hugging Face config and provided MaxText config:\n - "
+ "\n - ".join(mismatches)
+ "\n\nAction Required: Pass the flag `--override_model_architecture` to force the conversion using MaxText values."
)
raise ValueError(error_msg)


def main(argv: Sequence[str]) -> None:
"""Main function to convert a MaxText checkpoint to HuggingFace format.

Expand Down Expand Up @@ -151,6 +231,9 @@ def main(argv: Sequence[str]) -> None:
raise ValueError(f"Unsupported model name: {config.model_name}. Supported models are: {list(HF_IDS.keys())}")
hf_config_obj = HF_MODEL_CONFIGS[model_key]

# Validate architecture consistency (raising ValueError on mismatch) or override HF config if specified.
_validate_or_update_architecture(hf_config_obj, config, override=FLAGS.override_model_architecture)

# 2. Load Tokenizer
if model_key not in HF_IDS:
raise ValueError(f"HF Tokenizer ID not found for model key: {model_key}")
Expand Down
Loading