Skip to content

Commit 9524073

Browse files
committed
fix conflict
1 parent facbe9b commit 9524073

File tree

4 files changed

+49
-13
lines changed

4 files changed

+49
-13
lines changed

paddleformers/cli/train/auto_parallel/workflow.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
AutoTokenizer,
3434
CosineAnnealingWithWarmupDecay,
3535
LinearAnnealingWithWarmupDecay,
36-
LlamaConfig,
37-
LlamaForCausalLM,
3836
)
3937
from paddleformers.transformers.configuration_utils import LlmMetaConfig
4038
from paddleformers.utils.log import logger
@@ -147,6 +145,7 @@ def __init__(self, *args, **kwargs):
147145

148146

149147
def run_auto_parallel(model_args, data_args, generating_args, training_args):
148+
150149
do_enable_linear_fused_grad_add = training_args.enable_linear_fused_grad_add
151150
# do_enable_mp_async_allreduce = (
152151
# training_args.enable_auto_parallel
@@ -311,9 +310,6 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
311310
model = model_class.from_config(config, dtype=dtype)
312311
else:
313312
model = model_class.from_config(config, dtype=dtype)
314-
315-
criterion = model.criterion
316-
317313

318314
if training_args.recompute:
319315

paddleformers/trainer/argparser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
188188
f"removing line of `from __future__ import annotations` which opts in Postponed "
189189
f"Evaluation of Annotations (PEP 563)"
190190
)
191-
192191
for field in dataclasses.fields(dtype):
193192
if not field.init:
194193
continue

paddleformers/trainer/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,11 @@ def __init__(
369369
self._memory_tracker.start()
370370

371371
# Seed must be set before instantiating the model when using model
372-
set_random_seed(seed_=self.args.seed)
372+
if not self.args.enable_auto_parallel:
373+
set_random_seed(seed_=self.args.seed)
374+
else:
375+
logger.warning("set_seed not support yet in auto_parallel mode")
376+
373377
set_seed(seed=self.args.seed)
374378

375379
self._skip_global_steps = 0 # total skip global steps

paddleformers/transformers/llama/modeling.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from paddle.distributed.fleet.utils import recompute
2121
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
2222

23+
from paddleformers.transformers.conversion_utils import (
24+
StateDictNameMapping,
25+
init_name_mappings,
26+
)
27+
2328
from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
2429
from ...nn.criterion.interface import CriterionLayer
2530
from ...nn.embedding import Embedding as GeneralEmbedding
@@ -28,14 +33,13 @@
2833
from ...nn.mlp import MLP
2934
from ...nn.norm import Norm as GeneralNorm
3035
from ...nn.pp_model import GeneralModelForCausalLMPipe
31-
from .auto_dist_config import get_dist_config
32-
3336
from ...utils.log import logger
3437
from ..cache_utils import Cache, DynamicCache
3538
from ..masking_utils import create_causal_mask_and_row_indices
3639
from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3740
from ..model_utils import PretrainedModel, register_base_model
3841
from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42+
from .auto_dist_config import get_dist_config
3943
from .configuration import LlamaConfig
4044

4145

@@ -162,9 +166,9 @@ def forward(
162166
q_shape = (batch_size, seq_len, self.num_heads, self.head_dim)
163167
kv_shape = (batch_size, seq_len, self.num_key_value_heads, self.head_dim)
164168

165-
query_states = self.q_proj(hidden_states).view(q_shape).transpose(1, 2)
166-
key_states = self.k_proj(hidden_states).view(kv_shape).transpose(1, 2)
167-
value_states = self.v_proj(hidden_states).view(kv_shape).transpose(1, 2)
169+
query_states = self.q_proj(hidden_states).reshape(q_shape).transpose(1, 2)
170+
key_states = self.k_proj(hidden_states).reshape(kv_shape).transpose(1, 2)
171+
value_states = self.v_proj(hidden_states).reshape(kv_shape).transpose(1, 2)
168172

169173
cos, sin = position_embeddings
170174
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -327,8 +331,41 @@ class LlamaPretrainedModel(PretrainedModel):
327331
]
328332

329333
@classmethod
330-
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
334+
def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
335+
mappings: list[StateDictNameMapping] = []
336+
model_mappings = [
337+
["embed_tokens.weight"],
338+
["norm.weight"],
339+
]
340+
for layer_index in range(config.num_hidden_layers):
341+
layer_mappings = [
342+
[f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"],
343+
[f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"],
344+
[f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"],
345+
[f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"],
346+
[f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"],
347+
[f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"],
348+
[f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"],
349+
[f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"],
350+
[f"layers.{layer_index}.input_layernorm.weight"],
351+
[f"layers.{layer_index}.post_attention_layernorm.weight"],
352+
]
353+
model_mappings.extend(layer_mappings)
354+
355+
init_name_mappings(mappings=model_mappings)
356+
# base-model prefix "LlamaModel"
357+
if "LlamaModel" not in config.architectures:
358+
for mapping in model_mappings:
359+
mapping[0] = "model." + mapping[0]
360+
mapping[1] = "llama." + mapping[1]
361+
if not config.tie_word_embeddings:
362+
model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])
363+
364+
mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
365+
return mappings
331366

367+
@classmethod
368+
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
332369
from ..conversion_utils import split_or_merge_func
333370

334371
fn = split_or_merge_func(

0 commit comments

Comments
 (0)