Skip to content

Commit ef7b36e

Browse files
waliwali777sevenan2
authored andcommitted
refactor model in intermediate api mode
1 parent e98380b commit ef7b36e

File tree

3 files changed

+131
-0
lines changed

3 files changed

+131
-0
lines changed

paddleformers/transformers/configuration_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,9 @@ class PretrainedConfig:
539539
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
540540
model has a output word embedding layer.
541541
542+
run_single_model (`bool`, *optional*, defaults to `False`):
543+
Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled.
544+
542545
dtype (`str`, *optional*):
543546
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
544547
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
@@ -609,6 +612,13 @@ def __init__(self, **kwargs):
609612
self.use_cache = kwargs.pop("use_cache", False)
610613
self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True)
611614

615+
# for run model in single card mode
616+
self.run_single_model = kwargs.pop("run_single_model", False)
617+
if self.run_single_model:
618+
self.tensor_parallel_degree = 1
619+
self.sep_parallel_degree = 1
620+
self.context_parallel_degree = 1
621+
612622
# for transformers fuse
613623
self.fuse_linear = kwargs.pop("fuse_linear", False)
614624
self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False)
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import paddle.distributed as dist
17+
from paddle.distributed.auto_parallel.intermediate.tensor_parallel import (
18+
PrepareLayerInput,
19+
)
20+
21+
22+
def layer_input_parallel_row_hook(process_mesh):
23+
def hook(layer, inputs, output=None):
24+
res_inputs = []
25+
for input in inputs:
26+
if not input.is_dist():
27+
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate()])
28+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate()]))
29+
else:
30+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate()]))
31+
return tuple(res_inputs)
32+
33+
return hook
34+
35+
36+
def layer_input_parallel_row_and_col_hook(process_mesh):
37+
def hook(layer, inputs, output=None):
38+
res_inputs = []
39+
for input in inputs:
40+
if not input.is_dist():
41+
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Shard(1)])
42+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Shard(1)]))
43+
else:
44+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Shard(1)]))
45+
return tuple(res_inputs)
46+
47+
return hook
48+
49+
50+
def layer_input_replicate_hook(process_mesh):
51+
def hook(layer, inputs, output=None):
52+
res_inputs = []
53+
for input in inputs:
54+
if not input.is_dist():
55+
x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate()])
56+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate()]))
57+
else:
58+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Replicate(), dist.Replicate()]))
59+
return tuple(res_inputs)
60+
61+
return hook
62+
63+
64+
def auto_dist_config(self, prefix=""):
65+
if prefix != "":
66+
assert prefix.endswith(".")
67+
config = {
68+
"sp_config": {
69+
"parallelize_plan": {
70+
f"{prefix}llama.embed_tokens": [
71+
dist.ColWiseParallel(),
72+
dist.SequenceParallelBegin(),
73+
],
74+
f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook),
75+
f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook),
76+
f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook),
77+
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
78+
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
79+
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
80+
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
81+
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
82+
f"{prefix}llama.layers.*.self_attn": dist.SequenceParallelDisable(),
83+
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
84+
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
85+
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
86+
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
87+
f"{prefix}llama.layers.*.mlp": dist.SequenceParallelDisable(need_transpose=False),
88+
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
89+
f"{prefix}lm_head": dist.SequenceParallelEnd(),
90+
}
91+
},
92+
"mp_config": {
93+
"parallelize_plan": {
94+
f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True),
95+
f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook),
96+
f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook),
97+
f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook),
98+
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
99+
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
100+
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
101+
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
102+
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
103+
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
104+
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
105+
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
106+
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
107+
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
108+
}
109+
},
110+
"pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"},
111+
}
112+
113+
return config

paddleformers/transformers/llama/modeling.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from ...nn.mlp import MLP
2929
from ...nn.norm import Norm as GeneralNorm
3030
from ...nn.pp_model import GeneralModelForCausalLMPipe
31+
from .auto_dist_config import get_dist_config
32+
3133
from ...utils.log import logger
3234
from ..cache_utils import Cache, DynamicCache
3335
from ..masking_utils import create_causal_mask_and_row_indices
@@ -326,6 +328,8 @@ class LlamaPretrainedModel(PretrainedModel):
326328

327329
@classmethod
328330
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
331+
if config.run_single_model:
332+
return {}
329333
from ..conversion_utils import split_or_merge_func
330334

331335
fn = split_or_merge_func(
@@ -689,6 +693,10 @@ def forward(
689693
attentions=outputs.attentions,
690694
)
691695

696+
def auto_dist_config(self, prefix=""):
697+
assert self.config.run_single_model, "Use `get_dist_config` only in single card mode."
698+
return get_dist_config(self, prefix)
699+
692700

693701
class LlamaForCausalLMPipe(GeneralModelForCausalLMPipe):
694702
config_class = LlamaConfig

0 commit comments

Comments
 (0)