Skip to content

Commit 6e58a17

Browse files
committed
fix conflict
1 parent facbe9b commit 6e58a17

File tree

6 files changed

+54
-14
lines changed

6 files changed

+54
-14
lines changed

paddleformers/cli/hparams/model_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ class ModelArguments:
8585
default=False,
8686
metadata={"help": "GPT3 model, use fast layernorm"},
8787
)
88+
fuse_attention_qkv: bool = field(
89+
default=None,
90+
metadata={"help": "whether to fuse attention qkv"},
91+
)
92+
fuse_attention_ffn: bool = field(
93+
default=None,
94+
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
95+
)
8896
attn_impl: str = field(default="flashmask", metadata={"help": "Attention implementation"})
8997
fuse_gate_detach_matmul: bool = field(
9098
default=True,

paddleformers/cli/hparams/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -
142142
Returns:
143143
_TRAIN_CLS: _description_
144144
"""
145-
parser = PdArgumentParser(_TRAIN_ARGS)
145+
parser = PdArgumentParser(_TRAIN_ARGS, conflict_handler="resolve")
146146
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
147147
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
148148

paddleformers/cli/train/auto_parallel/workflow.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
AutoTokenizer,
3434
CosineAnnealingWithWarmupDecay,
3535
LinearAnnealingWithWarmupDecay,
36-
LlamaConfig,
37-
LlamaForCausalLM,
3836
)
3937
from paddleformers.transformers.configuration_utils import LlmMetaConfig
4038
from paddleformers.utils.log import logger
@@ -147,6 +145,7 @@ def __init__(self, *args, **kwargs):
147145

148146

149147
def run_auto_parallel(model_args, data_args, generating_args, training_args):
148+
150149
do_enable_linear_fused_grad_add = training_args.enable_linear_fused_grad_add
151150
# do_enable_mp_async_allreduce = (
152151
# training_args.enable_auto_parallel
@@ -311,9 +310,6 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
311310
model = model_class.from_config(config, dtype=dtype)
312311
else:
313312
model = model_class.from_config(config, dtype=dtype)
314-
315-
criterion = model.criterion
316-
317313

318314
if training_args.recompute:
319315

paddleformers/trainer/argparser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
188188
f"removing line of `from __future__ import annotations` which opts in Postponed "
189189
f"Evaluation of Annotations (PEP 563)"
190190
)
191-
192191
for field in dataclasses.fields(dtype):
193192
if not field.init:
194193
continue

paddleformers/transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@
193193
"llama.configuration": [
194194
"LlamaConfig",
195195
],
196-
"llama.modeling": ["LlamaForCausalLM", "LlamaModel", "LlamaForCausalLMPipe", "LlamaRotaryEmbedding"],
196+
"llama.modeling": ["LlamaForCausalLM", "LlamaModel", "LlamaForCausalLMPipe"],
197197
"llama.tokenizer": ["LlamaTokenizer", "Llama3Tokenizer"],
198198
"llama.tokenizer_fast": ["LlamaTokenizerFast"],
199199
"optimization": [

paddleformers/transformers/llama/modeling.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from paddle.distributed.fleet.utils import recompute
2121
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
2222

23+
from paddleformers.transformers.conversion_utils import (
24+
StateDictNameMapping,
25+
init_name_mappings,
26+
)
27+
2328
from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
2429
from ...nn.criterion.interface import CriterionLayer
2530
from ...nn.embedding import Embedding as GeneralEmbedding
@@ -28,14 +33,13 @@
2833
from ...nn.mlp import MLP
2934
from ...nn.norm import Norm as GeneralNorm
3035
from ...nn.pp_model import GeneralModelForCausalLMPipe
31-
from .auto_dist_config import get_dist_config
32-
3336
from ...utils.log import logger
3437
from ..cache_utils import Cache, DynamicCache
3538
from ..masking_utils import create_causal_mask_and_row_indices
3639
from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3740
from ..model_utils import PretrainedModel, register_base_model
3841
from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42+
from .auto_dist_config import get_dist_config
3943
from .configuration import LlamaConfig
4044

4145

@@ -162,9 +166,9 @@ def forward(
162166
q_shape = (batch_size, seq_len, self.num_heads, self.head_dim)
163167
kv_shape = (batch_size, seq_len, self.num_key_value_heads, self.head_dim)
164168

165-
query_states = self.q_proj(hidden_states).view(q_shape).transpose(1, 2)
166-
key_states = self.k_proj(hidden_states).view(kv_shape).transpose(1, 2)
167-
value_states = self.v_proj(hidden_states).view(kv_shape).transpose(1, 2)
169+
query_states = self.q_proj(hidden_states).reshape(q_shape).transpose(1, 2)
170+
key_states = self.k_proj(hidden_states).reshape(kv_shape).transpose(1, 2)
171+
value_states = self.v_proj(hidden_states).reshape(kv_shape).transpose(1, 2)
168172

169173
cos, sin = position_embeddings
170174
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -327,8 +331,41 @@ class LlamaPretrainedModel(PretrainedModel):
327331
]
328332

329333
@classmethod
330-
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
334+
def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
335+
mappings: list[StateDictNameMapping] = []
336+
model_mappings = [
337+
["embed_tokens.weight"],
338+
["norm.weight"],
339+
]
340+
for layer_index in range(config.num_hidden_layers):
341+
layer_mappings = [
342+
[f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"],
343+
[f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"],
344+
[f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"],
345+
[f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"],
346+
[f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"],
347+
[f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"],
348+
[f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"],
349+
[f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"],
350+
[f"layers.{layer_index}.input_layernorm.weight"],
351+
[f"layers.{layer_index}.post_attention_layernorm.weight"],
352+
]
353+
model_mappings.extend(layer_mappings)
354+
355+
init_name_mappings(mappings=model_mappings)
356+
# base-model prefix "LlamaModel"
357+
if "LlamaModel" not in config.architectures:
358+
for mapping in model_mappings:
359+
mapping[0] = "model." + mapping[0]
360+
mapping[1] = "llama." + mapping[1]
361+
if not config.tie_word_embeddings:
362+
model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])
363+
364+
mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
365+
return mappings
331366

367+
@classmethod
368+
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
332369
from ..conversion_utils import split_or_merge_func
333370

334371
fn = split_or_merge_func(

0 commit comments

Comments
 (0)