Skip to content

Commit aac30fd

Browse files
[AutoParallel] Refactor llama3.1 model in intermediate api (#3116)
Co-authored-by: xuexixi <xuexixi@baidu.com>
1 parent 60d38da commit aac30fd

File tree

6 files changed

+140
-19
lines changed

6 files changed

+140
-19
lines changed

paddleformers/cli/train/auto_parallel/workflow.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727
from paddleformers.trainer.trainer import Trainer
2828
from paddleformers.trainer.trainer_utils import set_seed
2929
from paddleformers.transformers import (
30+
AutoConfig,
31+
AutoModelForCausalLM,
32+
AutoModelForCausalLMPipe,
3033
AutoTokenizer,
3134
CosineAnnealingWithWarmupDecay,
3235
LinearAnnealingWithWarmupDecay,
33-
LlamaConfig,
34-
LlamaForCausalLM,
3536
)
3637
from paddleformers.transformers.configuration_utils import LlmMetaConfig
3738
from paddleformers.utils.log import logger
@@ -202,15 +203,8 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
202203
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
203204
)
204205

205-
# TODO: only support llama model now
206-
config_class = LlamaConfig
207-
model_class = LlamaForCausalLM
208-
209-
config = config_class.from_pretrained(model_args.model_name_or_path)
210206
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
211-
if tokenizer.pad_token_id is None:
212-
tokenizer.pad_token_id = tokenizer.eos_token_id
213-
# config = AutoConfig.from_pretrained(model_args.model_name_or_path)
207+
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
214208
LlmMetaConfig.set_llm_config(config, training_args)
215209
config.use_fast_layer_norm = model_args.use_fast_layer_norm
216210

@@ -272,6 +266,13 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
272266
if training_args.no_recompute_layers is not None:
273267
training_args.no_recompute_layers.sort()
274268

269+
if training_args.use_intermediate_api:
270+
config.use_single_model_implementation = True
271+
config.tensor_parallel_degree = 1
272+
config.sharding_parallel_degree = 1
273+
config.sep_parallel_degree = 1
274+
config.context_parallel_degree = 1
275+
275276
print("Final pre-training config:", config)
276277

277278
# Set the dtype for loading model
@@ -282,9 +283,33 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
282283
if training_args.bf16:
283284
dtype = "bfloat16"
284285

285-
with paddle.LazyGuard():
286-
model = model_class.from_config(config, dtype=dtype)
287-
criterion = model.criterion
286+
model_class = AutoModelForCausalLM
287+
288+
if not training_args.enable_auto_parallel and training_args.pipeline_parallel_degree > 1:
289+
model_class = AutoModelForCausalLMPipe
290+
291+
architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"}
292+
if (
293+
any(architecture in str(config.architectures) for architecture in architectures_to_check)
294+
and training_args.data_parallel_degree > 1
295+
):
296+
training_args.use_expert_parallel = True
297+
298+
if model_args.continue_training:
299+
if training_args.autotuner_benchmark:
300+
model = model_class.from_config(config, dtype=dtype)
301+
else:
302+
model = model_class.from_pretrained(
303+
model_args.model_name_or_path,
304+
config=config,
305+
dtype=dtype,
306+
)
307+
else:
308+
if training_args.enable_auto_parallel:
309+
with paddle.LazyGuard():
310+
model = model_class.from_config(config, dtype=dtype)
311+
else:
312+
model = model_class.from_config(config, dtype=dtype)
288313

289314
if training_args.recompute:
290315

@@ -340,7 +365,6 @@ def fn(layer):
340365

341366
trainer = PretrainingTrainer(
342367
model=model,
343-
criterion=criterion,
344368
args=training_args,
345369
data_collator=data_collator,
346370
train_dataset=train_dataset if training_args.do_train else None,

paddleformers/trainer/argparser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
188188
f"removing line of `from __future__ import annotations` which opts in Postponed "
189189
f"Evaluation of Annotations (PEP 563)"
190190
)
191-
192191
for field in dataclasses.fields(dtype):
193192
if not field.init:
194193
continue

paddleformers/trainer/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,11 @@ def __init__(
369369
self._memory_tracker.start()
370370

371371
# Seed must be set before instantiating the model when using model
372-
set_random_seed(seed_=self.args.seed)
372+
if not self.args.enable_auto_parallel:
373+
set_random_seed(seed_=self.args.seed)
374+
else:
375+
logger.warning("set_seed not support yet in auto_parallel mode")
376+
373377
set_seed(seed=self.args.seed)
374378

375379
self._skip_global_steps = 0 # total skip global steps

paddleformers/transformers/configuration_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,9 @@ class PretrainedConfig:
539539
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
540540
model has a output word embedding layer.
541541
542+
use_single_model_implementation (`bool`, *optional*, defaults to `False`):
543+
Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled.
544+
542545
dtype (`str`, *optional*):
543546
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
544547
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
@@ -609,6 +612,13 @@ def __init__(self, **kwargs):
609612
self.use_cache = kwargs.pop("use_cache", False)
610613
self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True)
611614

615+
# for run model in single card mode
616+
self.use_single_model_implementation = kwargs.pop("use_single_model_implementation", False)
617+
if self.use_single_model_implementation:
618+
self.tensor_parallel_degree = 1
619+
self.sep_parallel_degree = 1
620+
self.context_parallel_degree = 1
621+
612622
# for transformers fuse
613623
self.fuse_linear = kwargs.pop("fuse_linear", False)
614624
self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.distributed as dist
16+
17+
18+
def get_dist_config(model, prefix=""):
19+
"""Generate distributed configuration for Llama model"""
20+
if prefix != "":
21+
assert prefix.endswith(".")
22+
23+
config = {
24+
"mp_config": {
25+
"parallelize_plan": {
26+
f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True),
27+
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
28+
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
29+
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
30+
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
31+
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
32+
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
33+
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
34+
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
35+
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
36+
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
37+
}
38+
},
39+
}
40+
return config

paddleformers/transformers/llama/modeling.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from paddle.distributed.fleet.utils import recompute
2121
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
2222

23+
from paddleformers.transformers.conversion_utils import (
24+
StateDictNameMapping,
25+
init_name_mappings,
26+
)
27+
2328
from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
2429
from ...nn.criterion.interface import CriterionLayer
2530
from ...nn.embedding import Embedding as GeneralEmbedding
@@ -34,6 +39,7 @@
3439
from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3540
from ..model_utils import PretrainedModel, register_base_model
3641
from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42+
from .auto_dist_config import get_dist_config
3743
from .configuration import LlamaConfig
3844

3945

@@ -160,9 +166,9 @@ def forward(
160166
q_shape = (batch_size, seq_len, self.num_heads, self.head_dim)
161167
kv_shape = (batch_size, seq_len, self.num_key_value_heads, self.head_dim)
162168

163-
query_states = self.q_proj(hidden_states).view(q_shape).transpose(1, 2)
164-
key_states = self.k_proj(hidden_states).view(kv_shape).transpose(1, 2)
165-
value_states = self.v_proj(hidden_states).view(kv_shape).transpose(1, 2)
169+
query_states = self.q_proj(hidden_states).reshape(q_shape).transpose(1, 2)
170+
key_states = self.k_proj(hidden_states).reshape(kv_shape).transpose(1, 2)
171+
value_states = self.v_proj(hidden_states).reshape(kv_shape).transpose(1, 2)
166172

167173
cos, sin = position_embeddings
168174
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -324,6 +330,40 @@ class LlamaPretrainedModel(PretrainedModel):
324330
"down_proj",
325331
]
326332

333+
@classmethod
334+
def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
335+
mappings: list[StateDictNameMapping] = []
336+
model_mappings = [
337+
["embed_tokens.weight"],
338+
["norm.weight"],
339+
]
340+
for layer_index in range(config.num_hidden_layers):
341+
layer_mappings = [
342+
[f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"],
343+
[f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"],
344+
[f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"],
345+
[f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"],
346+
[f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"],
347+
[f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"],
348+
[f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"],
349+
[f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"],
350+
[f"layers.{layer_index}.input_layernorm.weight"],
351+
[f"layers.{layer_index}.post_attention_layernorm.weight"],
352+
]
353+
model_mappings.extend(layer_mappings)
354+
355+
init_name_mappings(mappings=model_mappings)
356+
# base-model prefix "LlamaModel"
357+
if "LlamaModel" not in config.architectures:
358+
for mapping in model_mappings:
359+
mapping[0] = "model." + mapping[0]
360+
mapping[1] = "llama." + mapping[1]
361+
if not config.tie_word_embeddings:
362+
model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])
363+
364+
mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
365+
return mappings
366+
327367
@classmethod
328368
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
329369
from ..conversion_utils import split_or_merge_func
@@ -689,6 +729,10 @@ def forward(
689729
attentions=outputs.attentions,
690730
)
691731

732+
def auto_dist_config(self, prefix=""):
733+
assert self.config.use_single_model_implementation, "Use `get_dist_config` only in single card mode."
734+
return get_dist_config(self, prefix)
735+
692736

693737
class LlamaForCausalLMPipe(GeneralModelForCausalLMPipe):
694738
config_class = LlamaConfig

0 commit comments

Comments
 (0)