diff --git a/paddleformers/cli/train/auto_parallel/workflow.py b/paddleformers/cli/train/auto_parallel/workflow.py index f556074d4cb..c2cd46f4ab7 100644 --- a/paddleformers/cli/train/auto_parallel/workflow.py +++ b/paddleformers/cli/train/auto_parallel/workflow.py @@ -27,11 +27,12 @@ from paddleformers.trainer.trainer import Trainer from paddleformers.trainer.trainer_utils import set_seed from paddleformers.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, AutoTokenizer, CosineAnnealingWithWarmupDecay, LinearAnnealingWithWarmupDecay, - LlamaConfig, - LlamaForCausalLM, ) from paddleformers.transformers.configuration_utils import LlmMetaConfig from paddleformers.utils.log import logger @@ -202,15 +203,8 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) - # TODO: only support llama model now - config_class = LlamaConfig - model_class = LlamaForCausalLM - - config = config_class.from_pretrained(model_args.model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) - if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id - # config = AutoConfig.from_pretrained(model_args.model_name_or_path) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) LlmMetaConfig.set_llm_config(config, training_args) config.use_fast_layer_norm = model_args.use_fast_layer_norm @@ -272,6 +266,13 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): if training_args.no_recompute_layers is not None: training_args.no_recompute_layers.sort() + if training_args.use_intermediate_api: + config.use_single_model_implementation = True + config.tensor_parallel_degree = 1 + config.sharding_parallel_degree = 1 + config.sep_parallel_degree = 1 + config.context_parallel_degree = 1 + print("Final pre-training config:", config) # Set the dtype for loading model @@ -282,9 +283,33 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): if training_args.bf16: dtype = "bfloat16" - with paddle.LazyGuard(): - model = model_class.from_config(config, dtype=dtype) - criterion = model.criterion + model_class = AutoModelForCausalLM + + if not training_args.enable_auto_parallel and training_args.pipeline_parallel_degree > 1: + model_class = AutoModelForCausalLMPipe + + architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"} + if ( + any(architecture in str(config.architectures) for architecture in architectures_to_check) + and training_args.data_parallel_degree > 1 + ): + training_args.use_expert_parallel = True + + if model_args.continue_training: + if training_args.autotuner_benchmark: + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + ) + else: + if training_args.enable_auto_parallel: + with paddle.LazyGuard(): + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_config(config, dtype=dtype) if training_args.recompute: @@ -340,7 +365,6 @@ def fn(layer): trainer = PretrainingTrainer( model=model, - criterion=criterion, args=training_args, data_collator=data_collator, train_dataset=train_dataset if training_args.do_train else None, diff --git a/paddleformers/trainer/argparser.py b/paddleformers/trainer/argparser.py index a58df14bda7..795128b2944 100644 --- a/paddleformers/trainer/argparser.py +++ b/paddleformers/trainer/argparser.py @@ -188,7 +188,6 @@ def _add_dataclass_arguments(self, dtype: DataClassType): f"removing line of `from __future__ import annotations` which opts in Postponed " f"Evaluation of Annotations (PEP 563)" ) - for field in dataclasses.fields(dtype): if not field.init: continue diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index b44547181aa..4823055925e 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -369,7 +369,11 @@ def __init__( self._memory_tracker.start() # Seed must be set before instantiating the model when using model - set_random_seed(seed_=self.args.seed) + if not self.args.enable_auto_parallel: + set_random_seed(seed_=self.args.seed) + else: + logger.warning("set_seed not support yet in auto_parallel mode") + set_seed(seed=self.args.seed) self._skip_global_steps = 0 # total skip global steps diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index 4a658e9af97..40d3358db20 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -539,6 +539,9 @@ class PretrainedConfig: Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. + use_single_model_implementation (`bool`, *optional*, defaults to `False`): + Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled. + dtype (`str`, *optional*): The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype` (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved @@ -609,6 +612,13 @@ def __init__(self, **kwargs): self.use_cache = kwargs.pop("use_cache", False) self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True) + # for run model in single card mode + self.use_single_model_implementation = kwargs.pop("use_single_model_implementation", False) + if self.use_single_model_implementation: + self.tensor_parallel_degree = 1 + self.sep_parallel_degree = 1 + self.context_parallel_degree = 1 + # for transformers fuse self.fuse_linear = kwargs.pop("fuse_linear", False) self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False) diff --git a/paddleformers/transformers/llama/auto_dist_config.py b/paddleformers/transformers/llama/auto_dist_config.py new file mode 100644 index 00000000000..202e492ce1e --- /dev/null +++ b/paddleformers/transformers/llama/auto_dist_config.py @@ -0,0 +1,40 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed as dist + + +def get_dist_config(model, prefix=""): + """Generate distributed configuration for Llama model""" + if prefix != "": + assert prefix.endswith(".") + + config = { + "mp_config": { + "parallelize_plan": { + f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True), + f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), + f"{prefix}lm_head.weight": dist.ColWiseParallel(), + } + }, + } + return config diff --git a/paddleformers/transformers/llama/modeling.py b/paddleformers/transformers/llama/modeling.py index a457acbaa4d..0850071bc96 100644 --- a/paddleformers/transformers/llama/modeling.py +++ b/paddleformers/transformers/llama/modeling.py @@ -20,6 +20,11 @@ from paddle.distributed.fleet.utils import recompute from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from paddleformers.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) + from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS from ...nn.criterion.interface import CriterionLayer from ...nn.embedding import Embedding as GeneralEmbedding @@ -34,6 +39,7 @@ from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel, register_base_model from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from .auto_dist_config import get_dist_config from .configuration import LlamaConfig @@ -160,9 +166,9 @@ def forward( q_shape = (batch_size, seq_len, self.num_heads, self.head_dim) kv_shape = (batch_size, seq_len, self.num_key_value_heads, self.head_dim) - query_states = self.q_proj(hidden_states).view(q_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(kv_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(kv_shape).transpose(1, 2) + query_states = self.q_proj(hidden_states).reshape(q_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).reshape(kv_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).reshape(kv_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -324,6 +330,40 @@ class LlamaPretrainedModel(PretrainedModel): "down_proj", ] + @classmethod + def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "LlamaModel" + if "LlamaModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "llama." + mapping[1] + if not config.tie_word_embeddings: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + @classmethod def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): from ..conversion_utils import split_or_merge_func @@ -689,6 +729,10 @@ def forward( attentions=outputs.attentions, ) + def auto_dist_config(self, prefix=""): + assert self.config.use_single_model_implementation, "Use `get_dist_config` only in single card mode." + return get_dist_config(self, prefix) + class LlamaForCausalLMPipe(GeneralModelForCausalLMPipe): config_class = LlamaConfig