2323from typing import Any , Callable , Literal , Optional , Union
2424
2525import paddle
26- from paddlefleet import LayerSpec , parallel_state
26+ from paddlefleet import LayerSpec
2727from paddlefleet .models .gpt import GPTModel as FleetGPTModel
2828from paddlefleet .models .gpt .gpt_layer_specs import get_gpt_layer_local_spec
29- from paddlefleet .transformer .transformer_config import TransformerConfig
29+
30+ try :
31+ from paddlefleet .models .gpt .gpt_config import GPTConfig
32+ except ImportError :
33+ from paddlefleet .transformer .transformer_config import (
34+ TransformerConfig as GPTConfig ,
35+ )
36+
37+
38+ try :
39+ from paddlefleet .gpt_builders import gpt_builder
40+
41+ HAS_PADDLEFLEET = True
42+ except ImportError :
43+ HAS_PADDLEFLEET = False
3044
3145from paddleformers .transformers .model_utils import PretrainedModel
3246
3347from .model_provider import ModelProviderMixin
34- from .vocab_utils import calculate_padded_vocab_size
3548
3649logger = logging .getLogger (__name__ )
3750
@@ -52,6 +65,7 @@ def local_layer_spec(config: "GPTModelProvider") -> LayerSpec:
5265 Returns:
5366 LayerSpec: Module specification for local implementation layers
5467 """
68+ assert HAS_PADDLEFLEET
5569 return get_gpt_layer_local_spec (
5670 num_experts = config .num_moe_experts ,
5771 moe_grouped_gemm = config .moe_grouped_gemm ,
@@ -61,7 +75,7 @@ def local_layer_spec(config: "GPTModelProvider") -> LayerSpec:
6175
6276
6377@dataclass
64- class GPTModelProvider (TransformerConfig , ModelProviderMixin [GPTModel ]):
78+ class GPTModelProvider (GPTConfig , ModelProviderMixin [GPTModel ]):
6579 """Configuration and provider for PaddleFleet GPT models.
6680
6781 This class extends TransformerConfig with GPT-specific parameters and
@@ -78,15 +92,16 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[GPTModel]):
7892 rotary_percent : float = 1.0
7993 seq_len_interpolation_factor : Optional [float ] = None
8094 seq_length : int = 1024
95+
96+ max_sequence_length : int = 1024
97+
8198 attention_softmax_in_fp32 : bool = False
8299 deallocate_pipeline_outputs : bool = True
83100 scatter_embedding_sequence_parallel : bool = True
84101 tp_only_amax_red : bool = False
85102 tp_comm_overlap_cfg : Optional [Union [str , dict [str , Any ]]] = None
86103 """Config file when tp_comm_overlap is enabled."""
87104
88- transformer_layer_spec : Union [LayerSpec , Callable [["GPTModelProvider" ], LayerSpec ]] = local_layer_spec
89-
90105 generation_config : Optional [Any ] = None
91106
92107 # This represents the unpadded vocab size
@@ -134,6 +149,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> GPTMode
134149 Returns:
135150 GPTModel: Configured PaddleFleet GPT model instance
136151 """
152+ assert HAS_PADDLEFLEET
137153 vp_size = self .virtual_pipeline_model_parallel_size
138154 is_pipeline_asymmetric = getattr (self , "account_for_embedding_in_pipeline_split" , False ) or getattr (
139155 self , "account_for_loss_in_pipeline_split" , False
@@ -151,25 +167,6 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> GPTMode
151167 self .num_layers // p_size
152168 ) % vp_size == 0 , "Make sure the number of model chunks is the same across all pipeline stages."
153169
154- transformer_layer_spec = self .transformer_layer_spec
155- print (f"transformer_layer_spec { transformer_layer_spec } " )
156- print (f"param: { inspect .signature (transformer_layer_spec ).parameters } " )
157-
158- if not isinstance (transformer_layer_spec , LayerSpec ):
159- # Check if the transformer_layer_spec function accepts vp_stage parameter
160- if "vp_stage" in inspect .signature (transformer_layer_spec ).parameters :
161- transformer_layer_spec = transformer_layer_spec (self , vp_stage = vp_stage )
162- else :
163- transformer_layer_spec = transformer_layer_spec (self )
164-
165- assert self .vocab_size is not None , "vocab_size must be configured before calling provide()"
166- if self .should_pad_vocab :
167- padded_vocab_size = calculate_padded_vocab_size (
168- self .vocab_size , self .make_vocab_size_divisible_by , self .tensor_model_parallel_size
169- )
170- else :
171- padded_vocab_size = self .vocab_size
172-
173170 # Initialize model as meta data instead of allocating data on a device
174171 model_init_device_context = contextlib .nullcontext
175172 if self .init_model_with_meta_device :
@@ -187,26 +184,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> GPTMode
187184 """
188185
189186 with model_init_device_context ():
190- model = GPTModel (
191- self ,
192- transformer_layer_spec = transformer_layer_spec ,
193- vocab_size = padded_vocab_size ,
194- max_sequence_length = self .seq_length ,
195- fp16_lm_cross_entropy = self .fp16_lm_cross_entropy ,
196- parallel_output = self .parallel_output ,
197- share_embeddings_and_output_weights = self .share_embeddings_and_output_weights ,
198- position_embedding_type = self .position_embedding_type ,
199- rotary_percent = self .rotary_percent ,
200- rotary_base = self .rotary_base ,
201- seq_len_interpolation_factor = self .seq_len_interpolation_factor ,
202- pre_process = pre_process
203- or parallel_state .is_pipeline_first_stage (ignore_virtual = False , vp_stage = vp_stage ),
204- post_process = post_process
205- or parallel_state .is_pipeline_last_stage (ignore_virtual = False , vp_stage = vp_stage ),
206- scatter_embedding_sequence_parallel = self .scatter_embedding_sequence_parallel ,
207- vp_stage = vp_stage ,
208- ** kwargs ,
209- )
187+ model = gpt_builder (self , num_stages = 1 )
210188
211189 return model
212190
@@ -220,6 +198,7 @@ def mtp_block_spec(config: "GPTModelProvider", vp_stage: Optional[int] = None) -
220198 Returns:
221199 LayerSpec: The MTP module specification
222200 """
201+ assert HAS_PADDLEFLEET
223202 if getattr (config , "mtp_num_layers" , None ):
224203 from paddlefleet .models .gpt .gpt_layer_specs import get_gpt_mtp_block_spec
225204
0 commit comments