|
20 | 20 | from paddle.distributed.fleet.utils import recompute |
21 | 21 | from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp |
22 | 22 |
|
| 23 | +from paddleformers.transformers.conversion_utils import ( |
| 24 | + StateDictNameMapping, |
| 25 | + init_name_mappings, |
| 26 | +) |
| 27 | + |
23 | 28 | from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS |
24 | 29 | from ...nn.criterion.interface import CriterionLayer |
25 | 30 | from ...nn.embedding import Embedding as GeneralEmbedding |
|
34 | 39 | from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
35 | 40 | from ..model_utils import PretrainedModel, register_base_model |
36 | 41 | from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| 42 | +from .auto_dist_config import get_dist_config |
37 | 43 | from .configuration import LlamaConfig |
38 | 44 |
|
39 | 45 |
|
@@ -160,9 +166,9 @@ def forward( |
160 | 166 | q_shape = (batch_size, seq_len, self.num_heads, self.head_dim) |
161 | 167 | kv_shape = (batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
162 | 168 |
|
163 | | - query_states = self.q_proj(hidden_states).view(q_shape).transpose(1, 2) |
164 | | - key_states = self.k_proj(hidden_states).view(kv_shape).transpose(1, 2) |
165 | | - value_states = self.v_proj(hidden_states).view(kv_shape).transpose(1, 2) |
| 169 | + query_states = self.q_proj(hidden_states).reshape(q_shape).transpose(1, 2) |
| 170 | + key_states = self.k_proj(hidden_states).reshape(kv_shape).transpose(1, 2) |
| 171 | + value_states = self.v_proj(hidden_states).reshape(kv_shape).transpose(1, 2) |
166 | 172 |
|
167 | 173 | cos, sin = position_embeddings |
168 | 174 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
@@ -324,6 +330,40 @@ class LlamaPretrainedModel(PretrainedModel): |
324 | 330 | "down_proj", |
325 | 331 | ] |
326 | 332 |
|
| 333 | + @classmethod |
| 334 | + def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: |
| 335 | + mappings: list[StateDictNameMapping] = [] |
| 336 | + model_mappings = [ |
| 337 | + ["embed_tokens.weight"], |
| 338 | + ["norm.weight"], |
| 339 | + ] |
| 340 | + for layer_index in range(config.num_hidden_layers): |
| 341 | + layer_mappings = [ |
| 342 | + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], |
| 343 | + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], |
| 344 | + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], |
| 345 | + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], |
| 346 | + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], |
| 347 | + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], |
| 348 | + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], |
| 349 | + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], |
| 350 | + [f"layers.{layer_index}.input_layernorm.weight"], |
| 351 | + [f"layers.{layer_index}.post_attention_layernorm.weight"], |
| 352 | + ] |
| 353 | + model_mappings.extend(layer_mappings) |
| 354 | + |
| 355 | + init_name_mappings(mappings=model_mappings) |
| 356 | + # base-model prefix "LlamaModel" |
| 357 | + if "LlamaModel" not in config.architectures: |
| 358 | + for mapping in model_mappings: |
| 359 | + mapping[0] = "model." + mapping[0] |
| 360 | + mapping[1] = "llama." + mapping[1] |
| 361 | + if not config.tie_word_embeddings: |
| 362 | + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) |
| 363 | + |
| 364 | + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] |
| 365 | + return mappings |
| 366 | + |
327 | 367 | @classmethod |
328 | 368 | def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): |
329 | 369 | from ..conversion_utils import split_or_merge_func |
@@ -689,6 +729,10 @@ def forward( |
689 | 729 | attentions=outputs.attentions, |
690 | 730 | ) |
691 | 731 |
|
| 732 | + def auto_dist_config(self, prefix=""): |
| 733 | + assert self.config.use_single_model_implementation, "Use `get_dist_config` only in single card mode." |
| 734 | + return get_dist_config(self, prefix) |
| 735 | + |
692 | 736 |
|
693 | 737 | class LlamaForCausalLMPipe(GeneralModelForCausalLMPipe): |
694 | 738 | config_class = LlamaConfig |
|
0 commit comments