Skip to content

Commit 337ba24

Browse files
LiYuRioFeixLiu
andauthored
glm45 suport pipeline parallel (#3082)
Co-authored-by: YuangLiu <liuyuang@baidu.com>
1 parent 90c5bc9 commit 337ba24

File tree

6 files changed

+85
-86
lines changed

6 files changed

+85
-86
lines changed

examples/experiments/paddlefleet/glm45_provider.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,20 @@
1717

1818
import logging
1919
from dataclasses import dataclass, field
20-
from typing import TYPE_CHECKING, Callable, List, Optional, Union
20+
from typing import Callable, List, Optional, Union
2121

2222
import paddle
2323
import paddle.nn.functional as F
24-
from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
2524

2625
from paddleformers.transformers.gpt_provider import GPTModelProvider
2726

28-
if TYPE_CHECKING:
29-
from paddlefleet.spec_utils import LayerSpec
30-
31-
3227
logger = logging.getLogger(__name__)
3328

3429

3530
@dataclass
3631
class GLMMoEModelProvider(GPTModelProvider):
3732
"""Base provider for GLM MoE Models."""
3833

39-
transformer_layer_spec: Union[
40-
"LayerSpec", Callable[["GPTModelProvider"], "LayerSpec"]
41-
] = get_gpt_decoder_block_spec
42-
4334
normalization: str = "RMSNorm"
4435
hidden_act: Callable = F.silu
4536
gated_linear_unit: bool = True

examples/experiments/paddlefleet/qwen_provider.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,20 @@
1717

1818
import logging
1919
from dataclasses import dataclass
20-
from typing import TYPE_CHECKING, Callable, Optional, Union
20+
from typing import Callable, Optional
2121

2222
import paddle
2323
import paddle.nn.functional as F
24-
from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
2524

2625
from paddleformers.transformers.gpt_provider import GPTModelProvider
2726

28-
if TYPE_CHECKING:
29-
from paddlefleet.spec_utils import LayerSpec
30-
31-
3227
logger = logging.getLogger(__name__)
3328

3429

3530
@dataclass
3631
class Qwen3MoEModelProvider(GPTModelProvider):
3732
"""Base provider for Qwen 3 MoE Models."""
3833

39-
transformer_layer_spec: Union[
40-
"LayerSpec", Callable[["GPTModelProvider"], "LayerSpec"]
41-
] = get_gpt_decoder_block_spec
42-
4334
normalization: str = "RMSNorm"
4435
hidden_act: Callable = F.silu
4536
gated_linear_unit: bool = True

examples/experiments/paddlefleet/run_pretrain.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,7 @@
4040
speed_metrics,
4141
)
4242
from paddleformers.trainer.trainer import Trainer
43-
from paddleformers.transformers import (
44-
AutoConfig,
45-
AutoTokenizer,
46-
CosineAnnealingWithWarmupDecay,
47-
LinearAnnealingWithWarmupDecay,
48-
)
43+
from paddleformers.transformers import AutoConfig, AutoTokenizer
4944
from paddleformers.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
5045
from paddleformers.utils.batch_sampler import DistributedBatchSampler
5146
from paddleformers.utils.log import logger
@@ -522,11 +517,6 @@ def main():
522517
if training_args.decay_steps is None:
523518
training_args.decay_steps = training_args.max_steps
524519

525-
if training_args.warmup_steps > 0:
526-
warmup_steps = training_args.warmup_steps
527-
else:
528-
warmup_steps = training_args.warmup_ratio * training_args.max_steps
529-
530520
lr_scheduler = None
531521

532522
data_file = get_train_data_file(data_args)

paddleformers/trainer/trainer.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@
6161
from paddle.base import core
6262
except:
6363
core = None
64+
try:
65+
import paddlefleet.distributed.model as paddlefleet_dist_model
66+
from paddlefleet.pipeline_parallel import ParallelBase as PaddleFleetParallelBase
67+
from paddlefleet.pipeline_parallel import PipelineLayer as PaddleFleetPipelineLayer
68+
69+
HAS_PADDLEFLEET = True
70+
except:
71+
HAS_PADDLEFLEET = False
72+
6473
from paddle.distributed import fleet
6574
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
6675
HybridParallelOptimizer,
@@ -2999,6 +3008,47 @@ def _wrap_model(self, model, training=True):
29993008

30003009
return model
30013010

3011+
if HAS_PADDLEFLEET and isinstance(model, PaddleFleetPipelineLayer):
3012+
prepare_pipeline_inputs_func = (
3013+
model._prepare_pipeline_inputs_func if hasattr(model, "_prepare_pipeline_inputs_func") else None
3014+
)
3015+
model = paddlefleet_dist_model.distributed_model(model)
3016+
if prepare_pipeline_inputs_func is not None:
3017+
model._prepare_pipeline_inputs_func = prepare_pipeline_inputs_func
3018+
else:
3019+
3020+
def _prepare_pipeline_inputs_func(inputs):
3021+
first_stage_keys = ["input_ids", "attention_mask", "position_ids"]
3022+
last_stage_keys = ["labels"]
3023+
3024+
def get_expected_keys(inputs, keys):
3025+
ret = tuple([inputs.pop(k) for k in keys if k in inputs])
3026+
if len(ret) == 1:
3027+
ret = ret[0]
3028+
return ret
3029+
3030+
if type(inputs) is dict or type(inputs) is OrderedDict:
3031+
return [
3032+
get_expected_keys(inputs, first_stage_keys),
3033+
get_expected_keys(inputs, last_stage_keys),
3034+
]
3035+
3036+
keys = list(inputs[0].keys())
3037+
inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys}
3038+
first_stage_inputs_batch = inputs_batch
3039+
last_stage_inputs = first_stage_inputs_batch.pop("labels")
3040+
outputs = (
3041+
first_stage_inputs_batch,
3042+
last_stage_inputs,
3043+
)
3044+
return outputs
3045+
3046+
logger.warning(
3047+
"Using default prepare pipeline inputs func, only support input_ids and labels as inputs."
3048+
)
3049+
model._prepare_pipeline_inputs_func = _prepare_pipeline_inputs_func
3050+
return model
3051+
30023052
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
30033053
if unwrap_model(model) is not model:
30043054
return model
@@ -3047,7 +3097,10 @@ def _wrap_model(self, model, training=True):
30473097
assert self.optimizer is not None, "optimizer is empty!"
30483098
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
30493099

3050-
in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1
3100+
if HAS_PADDLEFLEET and isinstance(model, PaddleFleetParallelBase):
3101+
in_pipeline_parallel_mode = True
3102+
else:
3103+
in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1
30513104
in_sharding_parallel_mode = self.sharding is not None
30523105
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1
30533106
in_sep_parallel_mode = self.args.sep_parallel_degree > 1
@@ -3382,6 +3435,9 @@ def training_step(
33823435
Return:
33833436
`paddle.Tensor`: The tensor with training loss on this batch.
33843437
"""
3438+
if HAS_PADDLEFLEET and isinstance(model, PaddleFleetParallelBase):
3439+
return self.training_pipeline_step(model, inputs)
3440+
33853441
if self.args.pipeline_parallel_degree > 1:
33863442
return self.training_pipeline_step(model, inputs)
33873443

paddleformers/transformers/glm4_moe/modeling.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from copy import deepcopy
1616
from dataclasses import dataclass
1717
from functools import partial
18-
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
18+
from typing import Optional, Tuple, Union
1919

2020
import paddle
2121
import paddle.distributed as dist
@@ -24,7 +24,6 @@
2424
from paddle.distributed.fleet.utils import recompute
2525
from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp
2626
from paddle.nn import functional as F
27-
from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
2827

2928
from paddleformers.transformers.gpt_provider import GPTModelProvider
3029

@@ -48,18 +47,11 @@
4847
from ..moe_layer import MoEFlexTokenLayer
4948
from .configuration import Glm4MoeConfig
5049

51-
if TYPE_CHECKING:
52-
from paddlefleet.transformer import LayerSpec
53-
5450

5551
@dataclass
5652
class GLMMoEModelProvider(GPTModelProvider):
5753
"""Base provider for GLM MoE Models."""
5854

59-
transformer_layer_spec: Union[
60-
"LayerSpec", Callable[["GPTModelProvider"], "LayerSpec"]
61-
] = get_gpt_decoder_block_spec
62-
6355
moe_router_load_balancing_type: str = "seq_aux_loss"
6456

6557
gated_linear_unit: bool = True

paddleformers/transformers/gpt_provider.py

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,28 @@
2323
from typing import Any, Callable, Literal, Optional, Union
2424

2525
import paddle
26-
from paddlefleet import LayerSpec, parallel_state
26+
from paddlefleet import LayerSpec
2727
from paddlefleet.models.gpt import GPTModel as FleetGPTModel
2828
from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
29-
from paddlefleet.transformer.transformer_config import TransformerConfig
29+
30+
try:
31+
from paddlefleet.models.gpt.gpt_config import GPTConfig
32+
except ImportError:
33+
from paddlefleet.transformer.transformer_config import (
34+
TransformerConfig as GPTConfig,
35+
)
36+
37+
38+
try:
39+
from paddlefleet.gpt_builders import gpt_builder
40+
41+
HAS_PADDLEFLEET = True
42+
except ImportError:
43+
HAS_PADDLEFLEET = False
3044

3145
from paddleformers.transformers.model_utils import PretrainedModel
3246

3347
from .model_provider import ModelProviderMixin
34-
from .vocab_utils import calculate_padded_vocab_size
3548

3649
logger = logging.getLogger(__name__)
3750

@@ -52,6 +65,7 @@ def local_layer_spec(config: "GPTModelProvider") -> LayerSpec:
5265
Returns:
5366
LayerSpec: Module specification for local implementation layers
5467
"""
68+
assert HAS_PADDLEFLEET
5569
return get_gpt_layer_local_spec(
5670
num_experts=config.num_moe_experts,
5771
moe_grouped_gemm=config.moe_grouped_gemm,
@@ -61,7 +75,7 @@ def local_layer_spec(config: "GPTModelProvider") -> LayerSpec:
6175

6276

6377
@dataclass
64-
class GPTModelProvider(TransformerConfig, ModelProviderMixin[GPTModel]):
78+
class GPTModelProvider(GPTConfig, ModelProviderMixin[GPTModel]):
6579
"""Configuration and provider for PaddleFleet GPT models.
6680
6781
This class extends TransformerConfig with GPT-specific parameters and
@@ -78,15 +92,16 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[GPTModel]):
7892
rotary_percent: float = 1.0
7993
seq_len_interpolation_factor: Optional[float] = None
8094
seq_length: int = 1024
95+
96+
max_sequence_length: int = 1024
97+
8198
attention_softmax_in_fp32: bool = False
8299
deallocate_pipeline_outputs: bool = True
83100
scatter_embedding_sequence_parallel: bool = True
84101
tp_only_amax_red: bool = False
85102
tp_comm_overlap_cfg: Optional[Union[str, dict[str, Any]]] = None
86103
"""Config file when tp_comm_overlap is enabled."""
87104

88-
transformer_layer_spec: Union[LayerSpec, Callable[["GPTModelProvider"], LayerSpec]] = local_layer_spec
89-
90105
generation_config: Optional[Any] = None
91106

92107
# This represents the unpadded vocab size
@@ -134,6 +149,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> GPTMode
134149
Returns:
135150
GPTModel: Configured PaddleFleet GPT model instance
136151
"""
152+
assert HAS_PADDLEFLEET
137153
vp_size = self.virtual_pipeline_model_parallel_size
138154
is_pipeline_asymmetric = getattr(self, "account_for_embedding_in_pipeline_split", False) or getattr(
139155
self, "account_for_loss_in_pipeline_split", False
@@ -151,25 +167,6 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> GPTMode
151167
self.num_layers // p_size
152168
) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages."
153169

154-
transformer_layer_spec = self.transformer_layer_spec
155-
print(f"transformer_layer_spec {transformer_layer_spec}")
156-
print(f"param: {inspect.signature(transformer_layer_spec).parameters}")
157-
158-
if not isinstance(transformer_layer_spec, LayerSpec):
159-
# Check if the transformer_layer_spec function accepts vp_stage parameter
160-
if "vp_stage" in inspect.signature(transformer_layer_spec).parameters:
161-
transformer_layer_spec = transformer_layer_spec(self, vp_stage=vp_stage)
162-
else:
163-
transformer_layer_spec = transformer_layer_spec(self)
164-
165-
assert self.vocab_size is not None, "vocab_size must be configured before calling provide()"
166-
if self.should_pad_vocab:
167-
padded_vocab_size = calculate_padded_vocab_size(
168-
self.vocab_size, self.make_vocab_size_divisible_by, self.tensor_model_parallel_size
169-
)
170-
else:
171-
padded_vocab_size = self.vocab_size
172-
173170
# Initialize model as meta data instead of allocating data on a device
174171
model_init_device_context = contextlib.nullcontext
175172
if self.init_model_with_meta_device:
@@ -187,26 +184,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> GPTMode
187184
"""
188185

189186
with model_init_device_context():
190-
model = GPTModel(
191-
self,
192-
transformer_layer_spec=transformer_layer_spec,
193-
vocab_size=padded_vocab_size,
194-
max_sequence_length=self.seq_length,
195-
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
196-
parallel_output=self.parallel_output,
197-
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
198-
position_embedding_type=self.position_embedding_type,
199-
rotary_percent=self.rotary_percent,
200-
rotary_base=self.rotary_base,
201-
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
202-
pre_process=pre_process
203-
or parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage),
204-
post_process=post_process
205-
or parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage),
206-
scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel,
207-
vp_stage=vp_stage,
208-
**kwargs,
209-
)
187+
model = gpt_builder(self, num_stages=1)
210188

211189
return model
212190

@@ -220,6 +198,7 @@ def mtp_block_spec(config: "GPTModelProvider", vp_stage: Optional[int] = None) -
220198
Returns:
221199
LayerSpec: The MTP module specification
222200
"""
201+
assert HAS_PADDLEFLEET
223202
if getattr(config, "mtp_num_layers", None):
224203
from paddlefleet.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
225204

0 commit comments

Comments
 (0)