From ef7b36ef791d56c487673344c677e2220b3f080e Mon Sep 17 00:00:00 2001 From: xuexixi Date: Wed, 5 Nov 2025 17:28:29 +0800 Subject: [PATCH 1/6] refactor model in intermediate api mode --- .../transformers/configuration_utils.py | 10 ++ .../transformers/llama/auto_dist_config.py | 113 ++++++++++++++++++ paddleformers/transformers/llama/modeling.py | 8 ++ 3 files changed, 131 insertions(+) create mode 100644 paddleformers/transformers/llama/auto_dist_config.py diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index 4a658e9af97..9b0c29632b8 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -539,6 +539,9 @@ class PretrainedConfig: Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. + run_single_model (`bool`, *optional*, defaults to `False`): + Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled. + dtype (`str`, *optional*): The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype` (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved @@ -609,6 +612,13 @@ def __init__(self, **kwargs): self.use_cache = kwargs.pop("use_cache", False) self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True) + # for run model in single card mode + self.run_single_model = kwargs.pop("run_single_model", False) + if self.run_single_model: + self.tensor_parallel_degree = 1 + self.sep_parallel_degree = 1 + self.context_parallel_degree = 1 + # for transformers fuse self.fuse_linear = kwargs.pop("fuse_linear", False) self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False) diff --git a/paddleformers/transformers/llama/auto_dist_config.py b/paddleformers/transformers/llama/auto_dist_config.py new file mode 100644 index 00000000000..f3a6532fe1e --- /dev/null +++ b/paddleformers/transformers/llama/auto_dist_config.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle.distributed as dist +from paddle.distributed.auto_parallel.intermediate.tensor_parallel import ( + PrepareLayerInput, +) + + +def layer_input_parallel_row_hook(process_mesh): + def hook(layer, inputs, output=None): + res_inputs = [] + for input in inputs: + if not input.is_dist(): + x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate()]) + res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate()])) + else: + res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate()])) + return tuple(res_inputs) + + return hook + + +def layer_input_parallel_row_and_col_hook(process_mesh): + def hook(layer, inputs, output=None): + res_inputs = [] + for input in inputs: + if not input.is_dist(): + x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Shard(1)]) + res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Shard(1)])) + else: + res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Shard(1)])) + return tuple(res_inputs) + + return hook + + +def layer_input_replicate_hook(process_mesh): + def hook(layer, inputs, output=None): + res_inputs = [] + for input in inputs: + if not input.is_dist(): + x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate()]) + res_inputs.append(dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate()])) + else: + res_inputs.append(dist.reshard(input, process_mesh, [dist.Replicate(), dist.Replicate()])) + return tuple(res_inputs) + + return hook + + +def auto_dist_config(self, prefix=""): + if prefix != "": + assert prefix.endswith(".") + config = { + "sp_config": { + "parallelize_plan": { + f"{prefix}llama.embed_tokens": [ + dist.ColWiseParallel(), + dist.SequenceParallelBegin(), + ], + f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook), + f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook), + f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook), + f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + f"{prefix}llama.layers.*.self_attn": dist.SequenceParallelDisable(), + f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), + f"{prefix}llama.layers.*.mlp": dist.SequenceParallelDisable(need_transpose=False), + f"{prefix}lm_head.weight": dist.ColWiseParallel(), + f"{prefix}lm_head": dist.SequenceParallelEnd(), + } + }, + "mp_config": { + "parallelize_plan": { + f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True), + f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook), + f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook), + f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook), + f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), + f"{prefix}lm_head.weight": dist.ColWiseParallel(), + } + }, + "pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"}, + } + + return config diff --git a/paddleformers/transformers/llama/modeling.py b/paddleformers/transformers/llama/modeling.py index a457acbaa4d..75208e2d1ba 100644 --- a/paddleformers/transformers/llama/modeling.py +++ b/paddleformers/transformers/llama/modeling.py @@ -28,6 +28,8 @@ from ...nn.mlp import MLP from ...nn.norm import Norm as GeneralNorm from ...nn.pp_model import GeneralModelForCausalLMPipe +from .auto_dist_config import get_dist_config + from ...utils.log import logger from ..cache_utils import Cache, DynamicCache from ..masking_utils import create_causal_mask_and_row_indices @@ -326,6 +328,8 @@ class LlamaPretrainedModel(PretrainedModel): @classmethod def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): + if config.run_single_model: + return {} from ..conversion_utils import split_or_merge_func fn = split_or_merge_func( @@ -689,6 +693,10 @@ def forward( attentions=outputs.attentions, ) + def auto_dist_config(self, prefix=""): + assert self.config.run_single_model, "Use `get_dist_config` only in single card mode." + return get_dist_config(self, prefix) + class LlamaForCausalLMPipe(GeneralModelForCausalLMPipe): config_class = LlamaConfig From c57e5b1c0f6c88d5e84cf8b046dcac028a5fbeac Mon Sep 17 00:00:00 2001 From: xuexixi Date: Wed, 5 Nov 2025 19:03:20 +0800 Subject: [PATCH 2/6] update auto dist config --- .../transformers/llama/auto_dist_config.py | 104 ++++++++++++++++-- 1 file changed, 94 insertions(+), 10 deletions(-) diff --git a/paddleformers/transformers/llama/auto_dist_config.py b/paddleformers/transformers/llama/auto_dist_config.py index f3a6532fe1e..bcb526a9856 100644 --- a/paddleformers/transformers/llama/auto_dist_config.py +++ b/paddleformers/transformers/llama/auto_dist_config.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import paddle import paddle.distributed as dist from paddle.distributed.auto_parallel.intermediate.tensor_parallel import ( PrepareLayerInput, + PrepareLayerOutput, ) @@ -24,10 +25,12 @@ def hook(layer, inputs, output=None): res_inputs = [] for input in inputs: if not input.is_dist(): - x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate()]) - res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate()])) + x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()]) + res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()])) else: - res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate()])) + res_inputs.append( + dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()]) + ) return tuple(res_inputs) return hook @@ -38,10 +41,10 @@ def hook(layer, inputs, output=None): res_inputs = [] for input in inputs: if not input.is_dist(): - x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Shard(1)]) - res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Shard(1)])) + x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)]) + res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)])) else: - res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Shard(1)])) + res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)])) return tuple(res_inputs) return hook @@ -52,8 +55,10 @@ def hook(layer, inputs, output=None): res_inputs = [] for input in inputs: if not input.is_dist(): - x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate()]) - res_inputs.append(dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate()])) + x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate(), dist.Replicate()]) + res_inputs.append( + dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate(), dist.Replicate()]) + ) else: res_inputs.append(dist.reshard(input, process_mesh, [dist.Replicate(), dist.Replicate()])) return tuple(res_inputs) @@ -61,9 +66,64 @@ def hook(layer, inputs, output=None): return hook -def auto_dist_config(self, prefix=""): +def layer_input_rope_hook(process_mesh): + def hook(layer, inputs, output=None): + res_inputs = [] + batch_size = None + seq_length = None + process_mesh = None + placements = None + for index in range(len(inputs)): + if index == 0: + batch_size, seq_length, _, _ = inputs[index]._local_shape + process_mesh = inputs[index].process_mesh + placements = inputs[index].placements + # process position_ids + if index == len(inputs) - 1: + mesh = dist.auto_parallel.get_mesh() + assert "sep" in mesh.dim_names, f"mesh.dim_names:{mesh.dim_names} must contain sep" + group = mesh._get_group("sep") + chunk_size = seq_length // 2 + chunk_num = group.nranks * 2 + rank = group.rank + first_chunk_ids = paddle.arange(rank * chunk_size, (rank + 1) * chunk_size, dtype="int64") + second_chunk_ids = paddle.arange( + (chunk_num - rank - 1) * chunk_size, (chunk_num - rank) * chunk_size, dtype="int64" + ) + position_ids = paddle.concat([first_chunk_ids, second_chunk_ids]).expand((batch_size, seq_length)) + mp_axis = process_mesh.dim_names.index("mp") + placements[mp_axis] = dist.Replicate() # mp placament shard(2) -> replicate + position_ids = dist.auto_parallel.api.dtensor_from_local(position_ids, process_mesh, placements) + res_inputs.append(position_ids) + else: + res_inputs.append(inputs[index]) + return tuple(res_inputs) + + return hook + + +def layer_output_rope_hook(process_mesh): + def hook(layer, inputs, outputs): + res_outputs = [] + for output in outputs: + process_mesh = output.process_mesh + placements = output.placements + cp_index = process_mesh.dim_names.index("sep") # get the axis for the split + cp_degree = process_mesh.shape[cp_index] + assert cp_degree > 1, f"cp_degree:{cp_degree} must > 1" + placements[cp_index] = dist.Shard(1) # seq_dim:1 + output = dist.reshard(output, process_mesh, placements) + res_outputs.append(output) + return tuple(res_outputs) + + return hook + + +def get_dist_config(model, prefix=""): + """Generate distributed configuration for Llama model""" if prefix != "": assert prefix.endswith(".") + config = { "sp_config": { "parallelize_plan": { @@ -108,6 +168,30 @@ def auto_dist_config(self, prefix=""): } }, "pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"}, + "cp_config": { + "parallelize_plan": { + f"{prefix}llama.layers.*.self_attn.sdpa": dist.ContextParallel( + backend="p2p" if model.config.context_parallel_degree > 1 else "all2all" + ), + } + }, } + if model.config.context_parallel_degree > 1: + config["cp_config"]["parallelize_plan"].update( + { + f"{prefix}llama.layers.*.self_attn.rope_func": [ + PrepareLayerInput(layer_input_rope_hook), + PrepareLayerOutput(layer_output_rope_hook), + ] + } + ) + elif model.config.sep_parallel_degree > 1: + # fuse_rope is not support dtensor spmd yet,thus need to extraly reshard sequence dim + config["cp_config"]["parallelize_plan"].update( + { + f"{prefix}llama.layers.*.self_attn.rope_func": PrepareLayerOutput(layer_output_rope_hook), + } + ) + return config From 1c76af6900bca5497656444a208d72b897ab9d5e Mon Sep 17 00:00:00 2001 From: xuexixi Date: Tue, 11 Nov 2025 13:48:26 +0800 Subject: [PATCH 3/6] fix parallel_matmul --- paddleformers/transformers/llama/modeling.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddleformers/transformers/llama/modeling.py b/paddleformers/transformers/llama/modeling.py index 75208e2d1ba..4daedc9004b 100644 --- a/paddleformers/transformers/llama/modeling.py +++ b/paddleformers/transformers/llama/modeling.py @@ -328,8 +328,6 @@ class LlamaPretrainedModel(PretrainedModel): @classmethod def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): - if config.run_single_model: - return {} from ..conversion_utils import split_or_merge_func fn = split_or_merge_func( From 464668feac84ff95f96765e29331194a23c78653 Mon Sep 17 00:00:00 2001 From: xuexixi Date: Thu, 13 Nov 2025 14:53:10 +0800 Subject: [PATCH 4/6] adapt workflow in auto parallel --- .../cli/train/auto_parallel/workflow.py | 62 +++++-- .../transformers/llama/auto_dist_config.py | 157 ------------------ paddleformers/transformers/llama/modeling.py | 1 + 3 files changed, 50 insertions(+), 170 deletions(-) diff --git a/paddleformers/cli/train/auto_parallel/workflow.py b/paddleformers/cli/train/auto_parallel/workflow.py index f556074d4cb..99b4c9e5258 100644 --- a/paddleformers/cli/train/auto_parallel/workflow.py +++ b/paddleformers/cli/train/auto_parallel/workflow.py @@ -27,6 +27,9 @@ from paddleformers.trainer.trainer import Trainer from paddleformers.trainer.trainer_utils import set_seed from paddleformers.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, AutoTokenizer, CosineAnnealingWithWarmupDecay, LinearAnnealingWithWarmupDecay, @@ -144,7 +147,6 @@ def __init__(self, *args, **kwargs): def run_auto_parallel(model_args, data_args, generating_args, training_args): - do_enable_linear_fused_grad_add = training_args.enable_linear_fused_grad_add # do_enable_mp_async_allreduce = ( # training_args.enable_auto_parallel @@ -202,15 +204,8 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) - # TODO: only support llama model now - config_class = LlamaConfig - model_class = LlamaForCausalLM - - config = config_class.from_pretrained(model_args.model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) - if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id - # config = AutoConfig.from_pretrained(model_args.model_name_or_path) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) LlmMetaConfig.set_llm_config(config, training_args) config.use_fast_layer_norm = model_args.use_fast_layer_norm @@ -272,6 +267,13 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): if training_args.no_recompute_layers is not None: training_args.no_recompute_layers.sort() + if training_args.use_intermediate_api: + config.run_single_model = True + config.tensor_parallel_degree = 1 + config.sharding_parallel_degree = 1 + config.sep_parallel_degree = 1 + config.context_parallel_degree = 1 + print("Final pre-training config:", config) # Set the dtype for loading model @@ -282,9 +284,44 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): if training_args.bf16: dtype = "bfloat16" - with paddle.LazyGuard(): - model = model_class.from_config(config, dtype=dtype) - criterion = model.criterion + model_class = AutoModelForCausalLM + + if not training_args.enable_auto_parallel and training_args.pipeline_parallel_degree > 1: + model_class = AutoModelForCausalLMPipe + if "LLama" in str(config.architectures): + try: + from utils.register_reshard import register_pp_reshard_information + + register_pp_reshard_information(config.num_hidden_layers) + except: + print("Not register llama pp reshard information.") + + architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"} + if ( + any(architecture in str(config.architectures) for architecture in architectures_to_check) + and training_args.data_parallel_degree > 1 + ): + training_args.use_expert_parallel = True + + if model_args.continue_training: + # NOTE(gongenlei): new add + if training_args.autotuner_benchmark: + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + ) + else: + if training_args.enable_auto_parallel: + with paddle.LazyGuard(): + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_config(config, dtype=dtype) + + criterion = model.criterion + if training_args.recompute: @@ -340,7 +377,6 @@ def fn(layer): trainer = PretrainingTrainer( model=model, - criterion=criterion, args=training_args, data_collator=data_collator, train_dataset=train_dataset if training_args.do_train else None, diff --git a/paddleformers/transformers/llama/auto_dist_config.py b/paddleformers/transformers/llama/auto_dist_config.py index bcb526a9856..202e492ce1e 100644 --- a/paddleformers/transformers/llama/auto_dist_config.py +++ b/paddleformers/transformers/llama/auto_dist_config.py @@ -12,111 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle import paddle.distributed as dist -from paddle.distributed.auto_parallel.intermediate.tensor_parallel import ( - PrepareLayerInput, - PrepareLayerOutput, -) - - -def layer_input_parallel_row_hook(process_mesh): - def hook(layer, inputs, output=None): - res_inputs = [] - for input in inputs: - if not input.is_dist(): - x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()]) - res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()])) - else: - res_inputs.append( - dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()]) - ) - return tuple(res_inputs) - - return hook - - -def layer_input_parallel_row_and_col_hook(process_mesh): - def hook(layer, inputs, output=None): - res_inputs = [] - for input in inputs: - if not input.is_dist(): - x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)]) - res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)])) - else: - res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)])) - return tuple(res_inputs) - - return hook - - -def layer_input_replicate_hook(process_mesh): - def hook(layer, inputs, output=None): - res_inputs = [] - for input in inputs: - if not input.is_dist(): - x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate(), dist.Replicate()]) - res_inputs.append( - dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate(), dist.Replicate()]) - ) - else: - res_inputs.append(dist.reshard(input, process_mesh, [dist.Replicate(), dist.Replicate()])) - return tuple(res_inputs) - - return hook - - -def layer_input_rope_hook(process_mesh): - def hook(layer, inputs, output=None): - res_inputs = [] - batch_size = None - seq_length = None - process_mesh = None - placements = None - for index in range(len(inputs)): - if index == 0: - batch_size, seq_length, _, _ = inputs[index]._local_shape - process_mesh = inputs[index].process_mesh - placements = inputs[index].placements - # process position_ids - if index == len(inputs) - 1: - mesh = dist.auto_parallel.get_mesh() - assert "sep" in mesh.dim_names, f"mesh.dim_names:{mesh.dim_names} must contain sep" - group = mesh._get_group("sep") - chunk_size = seq_length // 2 - chunk_num = group.nranks * 2 - rank = group.rank - first_chunk_ids = paddle.arange(rank * chunk_size, (rank + 1) * chunk_size, dtype="int64") - second_chunk_ids = paddle.arange( - (chunk_num - rank - 1) * chunk_size, (chunk_num - rank) * chunk_size, dtype="int64" - ) - position_ids = paddle.concat([first_chunk_ids, second_chunk_ids]).expand((batch_size, seq_length)) - mp_axis = process_mesh.dim_names.index("mp") - placements[mp_axis] = dist.Replicate() # mp placament shard(2) -> replicate - position_ids = dist.auto_parallel.api.dtensor_from_local(position_ids, process_mesh, placements) - res_inputs.append(position_ids) - else: - res_inputs.append(inputs[index]) - return tuple(res_inputs) - - return hook - - -def layer_output_rope_hook(process_mesh): - def hook(layer, inputs, outputs): - res_outputs = [] - for output in outputs: - process_mesh = output.process_mesh - placements = output.placements - cp_index = process_mesh.dim_names.index("sep") # get the axis for the split - cp_degree = process_mesh.shape[cp_index] - assert cp_degree > 1, f"cp_degree:{cp_degree} must > 1" - placements[cp_index] = dist.Shard(1) # seq_dim:1 - output = dist.reshard(output, process_mesh, placements) - res_outputs.append(output) - return tuple(res_outputs) - - return hook def get_dist_config(model, prefix=""): @@ -125,36 +21,9 @@ def get_dist_config(model, prefix=""): assert prefix.endswith(".") config = { - "sp_config": { - "parallelize_plan": { - f"{prefix}llama.embed_tokens": [ - dist.ColWiseParallel(), - dist.SequenceParallelBegin(), - ], - f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook), - f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook), - f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook), - f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(), - f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(), - f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(), - f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(), - f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(), - f"{prefix}llama.layers.*.self_attn": dist.SequenceParallelDisable(), - f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), - f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), - f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(), - f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), - f"{prefix}llama.layers.*.mlp": dist.SequenceParallelDisable(need_transpose=False), - f"{prefix}lm_head.weight": dist.ColWiseParallel(), - f"{prefix}lm_head": dist.SequenceParallelEnd(), - } - }, "mp_config": { "parallelize_plan": { f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True), - f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook), - f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook), - f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook), f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(), f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(), f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(), @@ -167,31 +36,5 @@ def get_dist_config(model, prefix=""): f"{prefix}lm_head.weight": dist.ColWiseParallel(), } }, - "pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"}, - "cp_config": { - "parallelize_plan": { - f"{prefix}llama.layers.*.self_attn.sdpa": dist.ContextParallel( - backend="p2p" if model.config.context_parallel_degree > 1 else "all2all" - ), - } - }, } - - if model.config.context_parallel_degree > 1: - config["cp_config"]["parallelize_plan"].update( - { - f"{prefix}llama.layers.*.self_attn.rope_func": [ - PrepareLayerInput(layer_input_rope_hook), - PrepareLayerOutput(layer_output_rope_hook), - ] - } - ) - elif model.config.sep_parallel_degree > 1: - # fuse_rope is not support dtensor spmd yet,thus need to extraly reshard sequence dim - config["cp_config"]["parallelize_plan"].update( - { - f"{prefix}llama.layers.*.self_attn.rope_func": PrepareLayerOutput(layer_output_rope_hook), - } - ) - return config diff --git a/paddleformers/transformers/llama/modeling.py b/paddleformers/transformers/llama/modeling.py index 4daedc9004b..b437ced81c5 100644 --- a/paddleformers/transformers/llama/modeling.py +++ b/paddleformers/transformers/llama/modeling.py @@ -328,6 +328,7 @@ class LlamaPretrainedModel(PretrainedModel): @classmethod def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): + from ..conversion_utils import split_or_merge_func fn = split_or_merge_func( From d0e3854c4cac6ca677833087610552fcd97ca5aa Mon Sep 17 00:00:00 2001 From: dengsiwei02 Date: Fri, 5 Dec 2025 11:04:36 +0800 Subject: [PATCH 5/6] rename run_single_model and remove redundant code --- paddleformers/cli/train/auto_parallel/workflow.py | 10 +--------- paddleformers/transformers/configuration_utils.py | 6 +++--- paddleformers/transformers/llama/modeling.py | 2 +- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/paddleformers/cli/train/auto_parallel/workflow.py b/paddleformers/cli/train/auto_parallel/workflow.py index 99b4c9e5258..d3fff85791b 100644 --- a/paddleformers/cli/train/auto_parallel/workflow.py +++ b/paddleformers/cli/train/auto_parallel/workflow.py @@ -268,7 +268,7 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): training_args.no_recompute_layers.sort() if training_args.use_intermediate_api: - config.run_single_model = True + config.use_single_model_implementation = True config.tensor_parallel_degree = 1 config.sharding_parallel_degree = 1 config.sep_parallel_degree = 1 @@ -288,13 +288,6 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): if not training_args.enable_auto_parallel and training_args.pipeline_parallel_degree > 1: model_class = AutoModelForCausalLMPipe - if "LLama" in str(config.architectures): - try: - from utils.register_reshard import register_pp_reshard_information - - register_pp_reshard_information(config.num_hidden_layers) - except: - print("Not register llama pp reshard information.") architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"} if ( @@ -304,7 +297,6 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): training_args.use_expert_parallel = True if model_args.continue_training: - # NOTE(gongenlei): new add if training_args.autotuner_benchmark: model = model_class.from_config(config, dtype=dtype) else: diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index 9b0c29632b8..40d3358db20 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -539,7 +539,7 @@ class PretrainedConfig: Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. - run_single_model (`bool`, *optional*, defaults to `False`): + use_single_model_implementation (`bool`, *optional*, defaults to `False`): Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled. dtype (`str`, *optional*): @@ -613,8 +613,8 @@ def __init__(self, **kwargs): self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True) # for run model in single card mode - self.run_single_model = kwargs.pop("run_single_model", False) - if self.run_single_model: + self.use_single_model_implementation = kwargs.pop("use_single_model_implementation", False) + if self.use_single_model_implementation: self.tensor_parallel_degree = 1 self.sep_parallel_degree = 1 self.context_parallel_degree = 1 diff --git a/paddleformers/transformers/llama/modeling.py b/paddleformers/transformers/llama/modeling.py index b437ced81c5..27bbcc0db1a 100644 --- a/paddleformers/transformers/llama/modeling.py +++ b/paddleformers/transformers/llama/modeling.py @@ -693,7 +693,7 @@ def forward( ) def auto_dist_config(self, prefix=""): - assert self.config.run_single_model, "Use `get_dist_config` only in single card mode." + assert self.config.use_single_model_implementation, "Use `get_dist_config` only in single card mode." return get_dist_config(self, prefix) From 05b7388308bcc96b0d7be59385171a5b260a6512 Mon Sep 17 00:00:00 2001 From: dengsiwei02 Date: Mon, 8 Dec 2025 16:45:59 +0800 Subject: [PATCH 6/6] fix conflict --- .../cli/train/auto_parallel/workflow.py | 6 +-- paddleformers/trainer/argparser.py | 1 - paddleformers/trainer/trainer.py | 6 ++- paddleformers/transformers/llama/modeling.py | 49 ++++++++++++++++--- 4 files changed, 49 insertions(+), 13 deletions(-) diff --git a/paddleformers/cli/train/auto_parallel/workflow.py b/paddleformers/cli/train/auto_parallel/workflow.py index d3fff85791b..c2cd46f4ab7 100644 --- a/paddleformers/cli/train/auto_parallel/workflow.py +++ b/paddleformers/cli/train/auto_parallel/workflow.py @@ -33,8 +33,6 @@ AutoTokenizer, CosineAnnealingWithWarmupDecay, LinearAnnealingWithWarmupDecay, - LlamaConfig, - LlamaForCausalLM, ) from paddleformers.transformers.configuration_utils import LlmMetaConfig from paddleformers.utils.log import logger @@ -147,6 +145,7 @@ def __init__(self, *args, **kwargs): def run_auto_parallel(model_args, data_args, generating_args, training_args): + do_enable_linear_fused_grad_add = training_args.enable_linear_fused_grad_add # do_enable_mp_async_allreduce = ( # training_args.enable_auto_parallel @@ -311,9 +310,6 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args): model = model_class.from_config(config, dtype=dtype) else: model = model_class.from_config(config, dtype=dtype) - - criterion = model.criterion - if training_args.recompute: diff --git a/paddleformers/trainer/argparser.py b/paddleformers/trainer/argparser.py index a58df14bda7..795128b2944 100644 --- a/paddleformers/trainer/argparser.py +++ b/paddleformers/trainer/argparser.py @@ -188,7 +188,6 @@ def _add_dataclass_arguments(self, dtype: DataClassType): f"removing line of `from __future__ import annotations` which opts in Postponed " f"Evaluation of Annotations (PEP 563)" ) - for field in dataclasses.fields(dtype): if not field.init: continue diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index b44547181aa..4823055925e 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -369,7 +369,11 @@ def __init__( self._memory_tracker.start() # Seed must be set before instantiating the model when using model - set_random_seed(seed_=self.args.seed) + if not self.args.enable_auto_parallel: + set_random_seed(seed_=self.args.seed) + else: + logger.warning("set_seed not support yet in auto_parallel mode") + set_seed(seed=self.args.seed) self._skip_global_steps = 0 # total skip global steps diff --git a/paddleformers/transformers/llama/modeling.py b/paddleformers/transformers/llama/modeling.py index 27bbcc0db1a..0850071bc96 100644 --- a/paddleformers/transformers/llama/modeling.py +++ b/paddleformers/transformers/llama/modeling.py @@ -20,6 +20,11 @@ from paddle.distributed.fleet.utils import recompute from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from paddleformers.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) + from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS from ...nn.criterion.interface import CriterionLayer from ...nn.embedding import Embedding as GeneralEmbedding @@ -28,14 +33,13 @@ from ...nn.mlp import MLP from ...nn.norm import Norm as GeneralNorm from ...nn.pp_model import GeneralModelForCausalLMPipe -from .auto_dist_config import get_dist_config - from ...utils.log import logger from ..cache_utils import Cache, DynamicCache from ..masking_utils import create_causal_mask_and_row_indices from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel, register_base_model from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from .auto_dist_config import get_dist_config from .configuration import LlamaConfig @@ -162,9 +166,9 @@ def forward( q_shape = (batch_size, seq_len, self.num_heads, self.head_dim) kv_shape = (batch_size, seq_len, self.num_key_value_heads, self.head_dim) - query_states = self.q_proj(hidden_states).view(q_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(kv_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(kv_shape).transpose(1, 2) + query_states = self.q_proj(hidden_states).reshape(q_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).reshape(kv_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).reshape(kv_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -327,8 +331,41 @@ class LlamaPretrainedModel(PretrainedModel): ] @classmethod - def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): + def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "LlamaModel" + if "LlamaModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "llama." + mapping[1] + if not config.tie_word_embeddings: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + @classmethod + def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): from ..conversion_utils import split_or_merge_func fn = split_or_merge_func(