Skip to content

Commit d337a0b

Browse files
committed
refactor model in intermediate api mode
1 parent 26d3e2c commit d337a0b

File tree

7 files changed

+147
-2662
lines changed

7 files changed

+147
-2662
lines changed

paddleformers/transformers/__init__.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -214,41 +214,6 @@
214214
"LlamaPretrainingCriterion",
215215
"LlamaNTKScalingRotaryEmbedding",
216216
],
217-
"llama.modeling_auto": [
218-
"enable_fuse_ffn_qkv_pass",
219-
"LlamaDecoderLayerAuto",
220-
"LlamaAttentionAuto",
221-
"LlamaPretrainedModelAuto",
222-
"LlamaLMHeadAuto",
223-
"LlamaModelAuto",
224-
"LlamaForCausalLM3DAuto",
225-
"LlamaMLPAuto",
226-
"get_mesh",
227-
"LlamaRMSNormAuto",
228-
"is_pp_enable",
229-
"LlamaPretrainingCriterion3DAuto",
230-
"global_mesh_starts_with_pp",
231-
"scaled_dot_product_attention",
232-
],
233-
"llama.modeling_network": [
234-
"LlamaPretrainedModelNet",
235-
"layer_input_parallel_row_and_col_hook",
236-
"LlamaModelNet",
237-
"LlamaPretrainingCriterionNet",
238-
"layer_input_replicate_hook",
239-
"LlamaLMHeadNet",
240-
"LlamaForCausalLMNetDPO",
241-
"GlobalOutputNet",
242-
"layer_input_parallel_row_hook",
243-
"LlamaRMSNormNet",
244-
"LlamaAttentionNet",
245-
"scaled_dot_product_attention",
246-
"ReshardLayer",
247-
"LlamaForCausalLMNet",
248-
"enable_fuse_ffn_qkv_pass",
249-
"LlamaMLPNet",
250-
"LlamaDecoderLayerNet",
251-
],
252217
"llama.modeling_pp": ["LlamaForCausalLMPipe"],
253218
"llama.tokenizer": ["LlamaTokenizer", "Llama3Tokenizer"],
254219
"llama.tokenizer_fast": ["LlamaTokenizerFast"],

paddleformers/transformers/configuration_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,9 @@ class PretrainedConfig:
537537
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
538538
model has a output word embedding layer.
539539
540+
run_single_model (`bool`, *optional*, defaults to `False`):
541+
Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled.
542+
540543
dtype (`str`, *optional*):
541544
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
542545
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
@@ -601,6 +604,13 @@ def __init__(self, **kwargs):
601604
self.use_cache = kwargs.pop("use_cache", False)
602605
self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True)
603606

607+
# for run model in single card mode
608+
self.run_single_model = kwargs.pop("run_single_model", False)
609+
if self.run_single_model:
610+
self.tensor_parallel_degree = 1
611+
self.sep_parallel_degree = 1
612+
self.context_parallel_degree = 1
613+
604614
# for transformers fuse
605615
self.fuse_linear = kwargs.pop("fuse_linear", False)
606616
self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False)

paddleformers/transformers/llama/__init__.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -50,41 +50,6 @@
5050
"LlamaPretrainingCriterion",
5151
"LlamaNTKScalingRotaryEmbedding",
5252
],
53-
"modeling_auto": [
54-
"enable_fuse_ffn_qkv_pass",
55-
"LlamaDecoderLayerAuto",
56-
"LlamaAttentionAuto",
57-
"LlamaPretrainedModelAuto",
58-
"LlamaLMHeadAuto",
59-
"LlamaModelAuto",
60-
"LlamaForCausalLM3DAuto",
61-
"LlamaMLPAuto",
62-
"get_mesh",
63-
"LlamaRMSNormAuto",
64-
"is_pp_enable",
65-
"LlamaPretrainingCriterion3DAuto",
66-
"global_mesh_starts_with_pp",
67-
"scaled_dot_product_attention",
68-
],
69-
"modeling_network": [
70-
"LlamaPretrainedModelNet",
71-
"layer_input_parallel_row_and_col_hook",
72-
"LlamaModelNet",
73-
"LlamaPretrainingCriterionNet",
74-
"layer_input_replicate_hook",
75-
"LlamaLMHeadNet",
76-
"LlamaForCausalLMNetDPO",
77-
"GlobalOutputNet",
78-
"layer_input_parallel_row_hook",
79-
"LlamaRMSNormNet",
80-
"LlamaAttentionNet",
81-
"scaled_dot_product_attention",
82-
"ReshardLayer",
83-
"LlamaForCausalLMNet",
84-
"enable_fuse_ffn_qkv_pass",
85-
"LlamaMLPNet",
86-
"LlamaDecoderLayerNet",
87-
],
8853
"modeling_pp": ["LlamaForCausalLMPipe"],
8954
"tokenizer": ["LlamaTokenizer", "Llama3Tokenizer"],
9055
"tokenizer_fast": ["LlamaTokenizerFast"],
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import paddle.distributed as dist
17+
from paddle.distributed.auto_parallel.intermediate.tensor_parallel import (
18+
PrepareLayerInput,
19+
)
20+
21+
22+
def layer_input_parallel_row_hook(process_mesh):
23+
def hook(layer, inputs, output=None):
24+
res_inputs = []
25+
for input in inputs:
26+
if not input.is_dist():
27+
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate()])
28+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate()]))
29+
else:
30+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate()]))
31+
return tuple(res_inputs)
32+
33+
return hook
34+
35+
36+
def layer_input_parallel_row_and_col_hook(process_mesh):
37+
def hook(layer, inputs, output=None):
38+
res_inputs = []
39+
for input in inputs:
40+
if not input.is_dist():
41+
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Shard(1)])
42+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Shard(1)]))
43+
else:
44+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Shard(1)]))
45+
return tuple(res_inputs)
46+
47+
return hook
48+
49+
50+
def layer_input_replicate_hook(process_mesh):
51+
def hook(layer, inputs, output=None):
52+
res_inputs = []
53+
for input in inputs:
54+
if not input.is_dist():
55+
x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate()])
56+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate()]))
57+
else:
58+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Replicate(), dist.Replicate()]))
59+
return tuple(res_inputs)
60+
61+
return hook
62+
63+
64+
def auto_dist_config(self, prefix=""):
65+
if prefix != "":
66+
assert prefix.endswith(".")
67+
config = {
68+
"sp_config": {
69+
"parallelize_plan": {
70+
f"{prefix}llama.embed_tokens": [
71+
dist.ColWiseParallel(),
72+
dist.SequenceParallelBegin(),
73+
],
74+
f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook),
75+
f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook),
76+
f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook),
77+
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
78+
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
79+
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
80+
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
81+
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
82+
f"{prefix}llama.layers.*.self_attn": dist.SequenceParallelDisable(),
83+
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
84+
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
85+
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
86+
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
87+
f"{prefix}llama.layers.*.mlp": dist.SequenceParallelDisable(need_transpose=False),
88+
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
89+
f"{prefix}lm_head": dist.SequenceParallelEnd(),
90+
}
91+
},
92+
"mp_config": {
93+
"parallelize_plan": {
94+
f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True),
95+
f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook),
96+
f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook),
97+
f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook),
98+
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
99+
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
100+
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
101+
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
102+
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
103+
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
104+
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
105+
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
106+
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
107+
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
108+
}
109+
},
110+
"pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"},
111+
}
112+
113+
return config

paddleformers/transformers/llama/modeling.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
get_skip_recompute_ops,
4040
)
4141
from ..refined_recompute import recompute as rr_recompute
42+
from .auto_dist_config import get_dist_config
4243

4344
try:
4445
from paddle.incubate.nn.functional import fused_rotary_position_embedding
@@ -178,15 +179,16 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
178179
return assignment_list
179180

180181

181-
def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True):
182+
def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True, args=None):
182183
is_fleet_init = True
183184
tensor_parallel_degree = 1
184-
try:
185-
hcg = fleet.get_hybrid_communicate_group()
186-
model_parallel_group = hcg.get_model_parallel_group()
187-
tensor_parallel_degree = hcg.get_model_parallel_world_size()
188-
except:
189-
is_fleet_init = False
185+
if args is None or not args.run_single_model:
186+
try:
187+
hcg = fleet.get_hybrid_communicate_group()
188+
model_parallel_group = hcg.get_model_parallel_group()
189+
tensor_parallel_degree = hcg.get_model_parallel_world_size()
190+
except:
191+
is_fleet_init = False
190192

191193
if paddle.in_dynamic_mode():
192194
y_is_distributed = y.is_distributed
@@ -1326,6 +1328,8 @@ def _get_hardware_flops(self):
13261328

13271329
@classmethod
13281330
def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
1331+
if config.run_single_model:
1332+
return cls._get_name_mappings()
13291333
mappings: list[StateDictNameMapping] = []
13301334
model_mappings = [
13311335
["embed_tokens.weight"],
@@ -1360,7 +1364,8 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
13601364

13611365
@classmethod
13621366
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
1363-
1367+
if config.run_single_model:
1368+
return {}
13641369
from ..conversion_utils import split_or_merge_func
13651370

13661371
fn = split_or_merge_func(
@@ -1420,6 +1425,8 @@ def get_tensor_parallel_split_mappings(num_layers):
14201425

14211426
@classmethod
14221427
def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False):
1428+
if config.run_single_model:
1429+
return cls._get_fuse_or_split_param_mappings()
14231430
# return parameter fuse utils
14241431
from ..conversion_utils import split_or_fuse_func
14251432

@@ -1984,7 +1991,11 @@ def forward(self, hidden_states, tensor_parallel_output=None):
19841991
)
19851992
else:
19861993
logits = parallel_matmul(
1987-
hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output
1994+
hidden_states,
1995+
self.weight,
1996+
transpose_y=self.transpose_y,
1997+
tensor_parallel_output=tensor_parallel_output,
1998+
args=self.config,
19881999
)
19892000
return logits
19902001

@@ -2156,3 +2167,7 @@ def forward(
21562167
hidden_states=outputs.hidden_states,
21572168
attentions=outputs.attentions,
21582169
)
2170+
2171+
def auto_dist_config(self, prefix=""):
2172+
assert self.config.run_single_model, "Use `get_dist_config` only in single card mode."
2173+
return get_dist_config(self, prefix)

0 commit comments

Comments
 (0)