Skip to content

Commit be9e7df

Browse files
【fleet】fix Fleet lora model (#2997)
1 parent e5b8ac1 commit be9e7df

File tree

3 files changed

+203
-1
lines changed

3 files changed

+203
-1
lines changed

paddleformers/peft/lora/lora_layers.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,19 @@ def extra_repr(self):
300300
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
301301

302302

303+
class FleetLoRALinear(LoRALinear):
304+
def __init__(self, in_features, out_features, skip_bias_add, **kwargs):
305+
super().__init__(in_features, out_features, **kwargs)
306+
self.skip_bias_add = skip_bias_add
307+
308+
def forward(self, input: paddle.Tensor):
309+
out_bias = self.bias if self.skip_bias_add else None
310+
if self.skip_bias_add:
311+
self.bias = None
312+
output = super().forward(input)
313+
return output, out_bias
314+
315+
303316
class RowParallelLoRALinear(RowParallelLinear):
304317
def __init__(
305318
self,
@@ -461,6 +474,19 @@ def extra_repr(self):
461474
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
462475

463476

477+
class FleetRowParallelLoRALinear(RowParallelLoRALinear):
478+
def __init__(self, in_features, out_features, skip_bias_add, **kwargs):
479+
super().__init__(in_features, out_features, **kwargs)
480+
self.skip_bias_add = skip_bias_add
481+
482+
def forward(self, input: paddle.Tensor):
483+
out_bias = self.bias if self.skip_bias_add else None
484+
if self.skip_bias_add:
485+
self.bias = None
486+
output = super().forward(input)
487+
return output, out_bias
488+
489+
464490
class RowSequenceParallelLoRALinear(RowSequenceParallelLinear):
465491
def __init__(
466492
self,
@@ -579,6 +605,19 @@ def extra_repr(self):
579605
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
580606

581607

608+
class FleetRowSequenceParallelLoRALinear(RowSequenceParallelLoRALinear):
609+
def __init__(self, in_features, out_features, skip_bias_add, **kwargs):
610+
super().__init__(in_features, out_features, **kwargs)
611+
self.skip_bias_add = skip_bias_add
612+
613+
def forward(self, input: paddle.Tensor):
614+
out_bias = self.bias if self.skip_bias_add else None
615+
if self.skip_bias_add:
616+
self.bias = None
617+
output = super().forward(input)
618+
return output, out_bias
619+
620+
582621
class ColumnParallelLoRALinear(ColumnParallelLinear):
583622
def __init__(
584623
self,
@@ -722,6 +761,19 @@ def extra_repr(self):
722761
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
723762

724763

764+
class FleetColumnParallelLoRALinear(ColumnParallelLoRALinear):
765+
def __init__(self, in_features, out_features, skip_bias_add, **kwargs):
766+
super().__init__(in_features, out_features, **kwargs)
767+
self.skip_bias_add = skip_bias_add
768+
769+
def forward(self, input: paddle.Tensor):
770+
out_bias = self.bias if self.skip_bias_add else None
771+
if self.skip_bias_add:
772+
self.bias = None
773+
output = super().forward(input)
774+
return output, out_bias
775+
776+
725777
class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear):
726778
def __init__(
727779
self,
@@ -843,6 +895,19 @@ def extra_repr(self):
843895
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"
844896

845897

898+
class FleetColumnSequenceParallelLoRALinear(ColumnSequenceParallelLoRALinear):
899+
def __init__(self, in_features, out_features, skip_bias_add, **kwargs):
900+
super().__init__(in_features, out_features, **kwargs)
901+
self.skip_bias_add = skip_bias_add
902+
903+
def forward(self, input: paddle.Tensor):
904+
out_bias = self.bias if self.skip_bias_add else None
905+
if self.skip_bias_add:
906+
self.bias = None
907+
output = super().forward(input)
908+
return output, out_bias
909+
910+
846911
class LoRAConv2D(nn.Conv2D):
847912
# LoRA implemented in a dense layer
848913
def __init__(

paddleformers/peft/lora/lora_model.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
PipelineLayer,
3434
RowParallelLinear,
3535
)
36+
from paddlefleet.tensor_parallel import (
37+
ColumnParallelLinear as FleetColumnParallelLinear,
38+
)
39+
from paddlefleet.tensor_parallel import RowParallelLinear as FleetRowParallelLinear
3640

3741
from ...trainer.argparser import strtobool
3842
from ...transformers import linear_utils
@@ -92,6 +96,11 @@ def get_lora_layers():
9296
from .lora_layers import (
9397
ColumnParallelLoRALinear,
9498
ColumnSequenceParallelLoRALinear,
99+
FleetColumnParallelLoRALinear,
100+
FleetColumnSequenceParallelLoRALinear,
101+
FleetLoRALinear,
102+
FleetRowParallelLoRALinear,
103+
FleetRowSequenceParallelLoRALinear,
95104
LoRAConv2D,
96105
LoRALinear,
97106
RowParallelLoRALinear,
@@ -105,6 +114,11 @@ def get_lora_layers():
105114
"LoRALinear": LoRALinear,
106115
"RowParallelLoRALinear": RowParallelLoRALinear,
107116
"RowSequenceParallelLoRALinear": RowSequenceParallelLoRALinear,
117+
"FleetLoRALinear": FleetLoRALinear,
118+
"FleetRowParallelLoRALinear": FleetRowParallelLoRALinear,
119+
"FleetColumnParallelLoRALinear": FleetColumnParallelLoRALinear,
120+
"FleetRowSequenceParallelLoRALinear": FleetRowSequenceParallelLoRALinear,
121+
"FleetColumnSequenceParallelLoRALinear": FleetColumnSequenceParallelLoRALinear,
108122
}
109123

110124

@@ -115,6 +129,12 @@ def get_lora_layers():
115129
LoRALinear = lora_layers["LoRALinear"]
116130
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
117131
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]
132+
FleetLoRALinear = lora_layers["FleetLoRALinear"]
133+
FleetRowParallelLoRALinear = lora_layers["FleetRowParallelLoRALinear"]
134+
FleetColumnParallelLoRALinear = lora_layers["FleetColumnParallelLoRALinear"]
135+
FleetRowSequenceParallelLoRALinear = lora_layers["FleetRowSequenceParallelLoRALinear"]
136+
FleetColumnSequenceParallelLoRALinear = lora_layers["FleetColumnSequenceParallelLoRALinear"]
137+
118138

119139
from ...quantization.quantization_linear import (
120140
ColumnParallelQuantizationLinear,
@@ -167,6 +187,8 @@ def __init__(self, model, lora_config: LoRAConfig) -> None:
167187
self.lora_config.lora_use_mixer or self.lora_config.use_mora
168188
):
169189
raise NotImplementedError("lora_use_mixer or mora is not supported in tensor parallel mode.")
190+
if hasattr(self.model.config, "tensor_model_parallel_size"):
191+
self.model.config.tensor_parallel_degree = self.model.config.tensor_model_parallel_size
170192
if self.lora_config.tensor_parallel_degree != self.model.config.tensor_parallel_degree:
171193
self.lora_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree
172194
logger.warning(
@@ -566,7 +588,10 @@ def replace_name_and_gen_index_lora(path):
566588
if is_main_process:
567589
lora_config_to_save.save_pretrained(save_directory)
568590
if save_model_config:
569-
model_config_to_save = copy.deepcopy(self.model.config)
591+
if hasattr(self.model, "config_to_save"):
592+
model_config_to_save = copy.deepcopy(self.model.config_to_save)
593+
else:
594+
model_config_to_save = copy.deepcopy(self.model.config)
570595
if merge_tensor_parallel:
571596
model_config_to_save.tensor_parallel_degree = -1
572597
model_config_to_save.save_pretrained(save_directory)
@@ -712,6 +737,115 @@ def _find_and_replace_module(self, model, module_name, lora_config):
712737
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
713738
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
714739
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
740+
elif isinstance(module, FleetColumnParallelLinear) or isinstance(module, FleetRowParallelLinear):
741+
if module.world_size == 1:
742+
lora_module = FleetLoRALinear(
743+
in_features=module.weight.shape[0],
744+
out_features=module.weight.shape[1],
745+
skip_bias_add=module.skip_bias_add,
746+
r=lora_config.r,
747+
lora_alpha=lora_config.lora_alpha,
748+
lora_dropout=lora_config.lora_dropout,
749+
rslora=lora_config.rslora,
750+
lora_plus_scale=lora_config.lora_plus_scale,
751+
pissa=lora_config.pissa,
752+
bias_attr=False if module.bias is None else None,
753+
use_quick_lora=lora_config.use_quick_lora,
754+
lora_use_mixer=lora_config.lora_use_mixer,
755+
use_mora=lora_config.use_mora,
756+
mp_moe=getattr(module.weight, "mp_moe", False),
757+
is_distributed=getattr(module.weight, "is_distributed", False),
758+
lorapro=lora_config.lorapro,
759+
)
760+
elif isinstance(module, FleetRowParallelLinear):
761+
# recover the original output_features
762+
if module.sequence_parallel:
763+
lora_module = FleetRowSequenceParallelLoRALinear(
764+
in_features=module.weight.shape[0] * module.world_size,
765+
out_features=module.weight.shape[1],
766+
skip_bias_add=module.skip_bias_add,
767+
has_bias=module.bias is not None,
768+
input_is_parallel=module.input_is_parallel,
769+
r=lora_config.r,
770+
lora_alpha=lora_config.lora_alpha,
771+
lora_dropout=lora_config.lora_dropout,
772+
rslora=lora_config.rslora,
773+
lora_plus_scale=lora_config.lora_plus_scale,
774+
use_quick_lora=lora_config.use_quick_lora,
775+
)
776+
else:
777+
lora_module = FleetRowParallelLoRALinear(
778+
in_features=module.weight.shape[0] * module.world_size,
779+
out_features=module.weight.shape[1],
780+
skip_bias_add=module.skip_bias_add,
781+
has_bias=module.bias is not None,
782+
input_is_parallel=module.input_is_parallel,
783+
r=lora_config.r,
784+
lora_alpha=lora_config.lora_alpha,
785+
lora_dropout=lora_config.lora_dropout,
786+
rslora=lora_config.rslora,
787+
lora_plus_scale=lora_config.lora_plus_scale,
788+
pissa=lora_config.pissa,
789+
use_quick_lora=lora_config.use_quick_lora,
790+
)
791+
# Lora column parallel will spilt lora A matrix
792+
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)
793+
794+
# for lora qat
795+
if self.lora_config.do_qat:
796+
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
797+
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
798+
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
799+
elif isinstance(module, FleetColumnParallelLinear):
800+
# recover the original output_features
801+
output_features = module.weight.shape[1] * module.world_size
802+
if module.sequence_parallel:
803+
lora_module = FleetColumnSequenceParallelLoRALinear(
804+
in_features=module.weight.shape[0],
805+
out_features=output_features,
806+
skip_bias_add=module.skip_bias_add,
807+
gather_output=module.gather_output,
808+
has_bias=module.bias is not None,
809+
r=lora_config.r,
810+
lora_alpha=lora_config.lora_alpha,
811+
lora_dropout=lora_config.lora_dropout,
812+
rslora=lora_config.rslora,
813+
lora_plus_scale=lora_config.lora_plus_scale,
814+
lora_A_weight_attr=paddle.ParamAttr(
815+
initializer=nn.initializer.KaimingUniform(
816+
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
817+
)
818+
),
819+
use_quick_lora=lora_config.use_quick_lora,
820+
)
821+
else:
822+
lora_module = FleetColumnParallelLoRALinear(
823+
in_features=module.weight.shape[0],
824+
out_features=output_features,
825+
skip_bias_add=module.skip_bias_add,
826+
gather_output=module.gather_output,
827+
has_bias=module.bias is not None,
828+
r=lora_config.r,
829+
lora_alpha=lora_config.lora_alpha,
830+
lora_dropout=lora_config.lora_dropout,
831+
rslora=lora_config.rslora,
832+
lora_plus_scale=lora_config.lora_plus_scale,
833+
pissa=lora_config.pissa,
834+
lora_A_weight_attr=paddle.ParamAttr(
835+
initializer=nn.initializer.KaimingUniform(
836+
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
837+
)
838+
),
839+
use_quick_lora=lora_config.use_quick_lora,
840+
)
841+
# Lora column parallel will spilt lora B matrix
842+
self.add_lora_split_mapping(module_name + ".lora_B", is_column=True)
843+
844+
# for lora qat
845+
if self.lora_config.do_qat:
846+
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True)
847+
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
848+
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
715849
elif isinstance(module, QuantizationLinear):
716850
lora_module = QuantizationLoRALinear(module, lora_config)
717851
elif isinstance(module, ColumnParallelQuantizationLinear):

paddleformers/transformers/glm4_moe/modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class GLMMoEModelProvider(GPTModelProvider):
5858

5959
bias_activation_fusion: bool = True
6060

61+
transform_rules = {"tensor_parallel_degree": "tensor_model_parallel_size", "dtype": "params_dtype"}
62+
6163

6264
def eager_attention_forward(
6365
module: nn.Layer,
@@ -1494,6 +1496,7 @@ def __new__(cls, config):
14941496
gpt_model = model_provider.provide()
14951497
gpt_model._gen_aoa_config = cls._gen_aoa_config
14961498
gpt_model._gen_inv_aoa_config = cls._gen_inv_aoa_config
1499+
gpt_model._get_tensor_parallel_mappings = cls._get_tensor_parallel_mappings
14971500
gpt_model.config_to_save = config
14981501
return gpt_model
14991502

0 commit comments

Comments
 (0)