|
20 | 20 | from paddle.distributed.fleet.utils import recompute |
21 | 21 | from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp |
22 | 22 |
|
| 23 | +from paddleformers.transformers.conversion_utils import ( |
| 24 | + StateDictNameMapping, |
| 25 | + init_name_mappings, |
| 26 | +) |
| 27 | + |
23 | 28 | from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS |
24 | 29 | from ...nn.criterion.interface import CriterionLayer |
25 | 30 | from ...nn.embedding import Embedding as GeneralEmbedding |
|
28 | 33 | from ...nn.mlp import MLP |
29 | 34 | from ...nn.norm import Norm as GeneralNorm |
30 | 35 | from ...nn.pp_model import GeneralModelForCausalLMPipe |
31 | | -from .auto_dist_config import get_dist_config |
32 | | - |
33 | 36 | from ...utils.log import logger |
34 | 37 | from ..cache_utils import Cache, DynamicCache |
35 | 38 | from ..masking_utils import create_causal_mask_and_row_indices |
36 | 39 | from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
37 | 40 | from ..model_utils import PretrainedModel, register_base_model |
38 | 41 | from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| 42 | +from .auto_dist_config import get_dist_config |
39 | 43 | from .configuration import LlamaConfig |
40 | 44 |
|
41 | 45 |
|
@@ -162,9 +166,9 @@ def forward( |
162 | 166 | q_shape = (batch_size, seq_len, self.num_heads, self.head_dim) |
163 | 167 | kv_shape = (batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
164 | 168 |
|
165 | | - query_states = self.q_proj(hidden_states).view(q_shape).transpose(1, 2) |
166 | | - key_states = self.k_proj(hidden_states).view(kv_shape).transpose(1, 2) |
167 | | - value_states = self.v_proj(hidden_states).view(kv_shape).transpose(1, 2) |
| 169 | + query_states = self.q_proj(hidden_states).reshape(q_shape).transpose(1, 2) |
| 170 | + key_states = self.k_proj(hidden_states).reshape(kv_shape).transpose(1, 2) |
| 171 | + value_states = self.v_proj(hidden_states).reshape(kv_shape).transpose(1, 2) |
168 | 172 |
|
169 | 173 | cos, sin = position_embeddings |
170 | 174 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
@@ -327,8 +331,41 @@ class LlamaPretrainedModel(PretrainedModel): |
327 | 331 | ] |
328 | 332 |
|
329 | 333 | @classmethod |
330 | | - def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): |
| 334 | + def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: |
| 335 | + mappings: list[StateDictNameMapping] = [] |
| 336 | + model_mappings = [ |
| 337 | + ["embed_tokens.weight"], |
| 338 | + ["norm.weight"], |
| 339 | + ] |
| 340 | + for layer_index in range(config.num_hidden_layers): |
| 341 | + layer_mappings = [ |
| 342 | + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], |
| 343 | + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], |
| 344 | + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], |
| 345 | + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], |
| 346 | + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], |
| 347 | + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], |
| 348 | + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], |
| 349 | + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], |
| 350 | + [f"layers.{layer_index}.input_layernorm.weight"], |
| 351 | + [f"layers.{layer_index}.post_attention_layernorm.weight"], |
| 352 | + ] |
| 353 | + model_mappings.extend(layer_mappings) |
| 354 | + |
| 355 | + init_name_mappings(mappings=model_mappings) |
| 356 | + # base-model prefix "LlamaModel" |
| 357 | + if "LlamaModel" not in config.architectures: |
| 358 | + for mapping in model_mappings: |
| 359 | + mapping[0] = "model." + mapping[0] |
| 360 | + mapping[1] = "llama." + mapping[1] |
| 361 | + if not config.tie_word_embeddings: |
| 362 | + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) |
| 363 | + |
| 364 | + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] |
| 365 | + return mappings |
331 | 366 |
|
| 367 | + @classmethod |
| 368 | + def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): |
332 | 369 | from ..conversion_utils import split_or_merge_func |
333 | 370 |
|
334 | 371 | fn = split_or_merge_func( |
|
0 commit comments