3333 PipelineLayer ,
3434 RowParallelLinear ,
3535)
36+ from paddlefleet .tensor_parallel import (
37+ ColumnParallelLinear as FleetColumnParallelLinear ,
38+ )
39+ from paddlefleet .tensor_parallel import RowParallelLinear as FleetRowParallelLinear
3640
3741from ...trainer .argparser import strtobool
3842from ...transformers import linear_utils
@@ -92,6 +96,11 @@ def get_lora_layers():
9296 from .lora_layers import (
9397 ColumnParallelLoRALinear ,
9498 ColumnSequenceParallelLoRALinear ,
99+ FleetColumnParallelLoRALinear ,
100+ FleetColumnSequenceParallelLoRALinear ,
101+ FleetLoRALinear ,
102+ FleetRowParallelLoRALinear ,
103+ FleetRowSequenceParallelLoRALinear ,
95104 LoRAConv2D ,
96105 LoRALinear ,
97106 RowParallelLoRALinear ,
@@ -105,6 +114,11 @@ def get_lora_layers():
105114 "LoRALinear" : LoRALinear ,
106115 "RowParallelLoRALinear" : RowParallelLoRALinear ,
107116 "RowSequenceParallelLoRALinear" : RowSequenceParallelLoRALinear ,
117+ "FleetLoRALinear" : FleetLoRALinear ,
118+ "FleetRowParallelLoRALinear" : FleetRowParallelLoRALinear ,
119+ "FleetColumnParallelLoRALinear" : FleetColumnParallelLoRALinear ,
120+ "FleetRowSequenceParallelLoRALinear" : FleetRowSequenceParallelLoRALinear ,
121+ "FleetColumnSequenceParallelLoRALinear" : FleetColumnSequenceParallelLoRALinear ,
108122 }
109123
110124
@@ -115,6 +129,12 @@ def get_lora_layers():
115129LoRALinear = lora_layers ["LoRALinear" ]
116130RowParallelLoRALinear = lora_layers ["RowParallelLoRALinear" ]
117131RowSequenceParallelLoRALinear = lora_layers ["RowSequenceParallelLoRALinear" ]
132+ FleetLoRALinear = lora_layers ["FleetLoRALinear" ]
133+ FleetRowParallelLoRALinear = lora_layers ["FleetRowParallelLoRALinear" ]
134+ FleetColumnParallelLoRALinear = lora_layers ["FleetColumnParallelLoRALinear" ]
135+ FleetRowSequenceParallelLoRALinear = lora_layers ["FleetRowSequenceParallelLoRALinear" ]
136+ FleetColumnSequenceParallelLoRALinear = lora_layers ["FleetColumnSequenceParallelLoRALinear" ]
137+
118138
119139from ...quantization .quantization_linear import (
120140 ColumnParallelQuantizationLinear ,
@@ -167,6 +187,8 @@ def __init__(self, model, lora_config: LoRAConfig) -> None:
167187 self .lora_config .lora_use_mixer or self .lora_config .use_mora
168188 ):
169189 raise NotImplementedError ("lora_use_mixer or mora is not supported in tensor parallel mode." )
190+ if hasattr (self .model .config , "tensor_model_parallel_size" ):
191+ self .model .config .tensor_parallel_degree = self .model .config .tensor_model_parallel_size
170192 if self .lora_config .tensor_parallel_degree != self .model .config .tensor_parallel_degree :
171193 self .lora_config .tensor_parallel_degree = self .model .config .tensor_parallel_degree
172194 logger .warning (
@@ -566,7 +588,10 @@ def replace_name_and_gen_index_lora(path):
566588 if is_main_process :
567589 lora_config_to_save .save_pretrained (save_directory )
568590 if save_model_config :
569- model_config_to_save = copy .deepcopy (self .model .config )
591+ if hasattr (self .model , "config_to_save" ):
592+ model_config_to_save = copy .deepcopy (self .model .config_to_save )
593+ else :
594+ model_config_to_save = copy .deepcopy (self .model .config )
570595 if merge_tensor_parallel :
571596 model_config_to_save .tensor_parallel_degree = - 1
572597 model_config_to_save .save_pretrained (save_directory )
@@ -712,6 +737,115 @@ def _find_and_replace_module(self, model, module_name, lora_config):
712737 self .add_lora_split_mapping (module_name + ".weight_quanter._scale" , is_column = False )
713738 self .add_lora_split_mapping (module_name + ".activation_quanter._scale" , is_column = False )
714739 self .add_lora_split_mapping (module_name + ".activation_quanter.quanter._scale" , is_column = False )
740+ elif isinstance (module , FleetColumnParallelLinear ) or isinstance (module , FleetRowParallelLinear ):
741+ if module .world_size == 1 :
742+ lora_module = FleetLoRALinear (
743+ in_features = module .weight .shape [0 ],
744+ out_features = module .weight .shape [1 ],
745+ skip_bias_add = module .skip_bias_add ,
746+ r = lora_config .r ,
747+ lora_alpha = lora_config .lora_alpha ,
748+ lora_dropout = lora_config .lora_dropout ,
749+ rslora = lora_config .rslora ,
750+ lora_plus_scale = lora_config .lora_plus_scale ,
751+ pissa = lora_config .pissa ,
752+ bias_attr = False if module .bias is None else None ,
753+ use_quick_lora = lora_config .use_quick_lora ,
754+ lora_use_mixer = lora_config .lora_use_mixer ,
755+ use_mora = lora_config .use_mora ,
756+ mp_moe = getattr (module .weight , "mp_moe" , False ),
757+ is_distributed = getattr (module .weight , "is_distributed" , False ),
758+ lorapro = lora_config .lorapro ,
759+ )
760+ elif isinstance (module , FleetRowParallelLinear ):
761+ # recover the original output_features
762+ if module .sequence_parallel :
763+ lora_module = FleetRowSequenceParallelLoRALinear (
764+ in_features = module .weight .shape [0 ] * module .world_size ,
765+ out_features = module .weight .shape [1 ],
766+ skip_bias_add = module .skip_bias_add ,
767+ has_bias = module .bias is not None ,
768+ input_is_parallel = module .input_is_parallel ,
769+ r = lora_config .r ,
770+ lora_alpha = lora_config .lora_alpha ,
771+ lora_dropout = lora_config .lora_dropout ,
772+ rslora = lora_config .rslora ,
773+ lora_plus_scale = lora_config .lora_plus_scale ,
774+ use_quick_lora = lora_config .use_quick_lora ,
775+ )
776+ else :
777+ lora_module = FleetRowParallelLoRALinear (
778+ in_features = module .weight .shape [0 ] * module .world_size ,
779+ out_features = module .weight .shape [1 ],
780+ skip_bias_add = module .skip_bias_add ,
781+ has_bias = module .bias is not None ,
782+ input_is_parallel = module .input_is_parallel ,
783+ r = lora_config .r ,
784+ lora_alpha = lora_config .lora_alpha ,
785+ lora_dropout = lora_config .lora_dropout ,
786+ rslora = lora_config .rslora ,
787+ lora_plus_scale = lora_config .lora_plus_scale ,
788+ pissa = lora_config .pissa ,
789+ use_quick_lora = lora_config .use_quick_lora ,
790+ )
791+ # Lora column parallel will spilt lora A matrix
792+ self .add_lora_split_mapping (module_name + ".lora_A" , is_column = False )
793+
794+ # for lora qat
795+ if self .lora_config .do_qat :
796+ self .add_lora_split_mapping (module_name + ".weight_quanter._scale" , is_column = False )
797+ self .add_lora_split_mapping (module_name + ".activation_quanter._scale" , is_column = False )
798+ self .add_lora_split_mapping (module_name + ".activation_quanter.quanter._scale" , is_column = False )
799+ elif isinstance (module , FleetColumnParallelLinear ):
800+ # recover the original output_features
801+ output_features = module .weight .shape [1 ] * module .world_size
802+ if module .sequence_parallel :
803+ lora_module = FleetColumnSequenceParallelLoRALinear (
804+ in_features = module .weight .shape [0 ],
805+ out_features = output_features ,
806+ skip_bias_add = module .skip_bias_add ,
807+ gather_output = module .gather_output ,
808+ has_bias = module .bias is not None ,
809+ r = lora_config .r ,
810+ lora_alpha = lora_config .lora_alpha ,
811+ lora_dropout = lora_config .lora_dropout ,
812+ rslora = lora_config .rslora ,
813+ lora_plus_scale = lora_config .lora_plus_scale ,
814+ lora_A_weight_attr = paddle .ParamAttr (
815+ initializer = nn .initializer .KaimingUniform (
816+ negative_slope = math .sqrt (5 ), nonlinearity = "leaky_relu"
817+ )
818+ ),
819+ use_quick_lora = lora_config .use_quick_lora ,
820+ )
821+ else :
822+ lora_module = FleetColumnParallelLoRALinear (
823+ in_features = module .weight .shape [0 ],
824+ out_features = output_features ,
825+ skip_bias_add = module .skip_bias_add ,
826+ gather_output = module .gather_output ,
827+ has_bias = module .bias is not None ,
828+ r = lora_config .r ,
829+ lora_alpha = lora_config .lora_alpha ,
830+ lora_dropout = lora_config .lora_dropout ,
831+ rslora = lora_config .rslora ,
832+ lora_plus_scale = lora_config .lora_plus_scale ,
833+ pissa = lora_config .pissa ,
834+ lora_A_weight_attr = paddle .ParamAttr (
835+ initializer = nn .initializer .KaimingUniform (
836+ negative_slope = math .sqrt (5 ), nonlinearity = "leaky_relu"
837+ )
838+ ),
839+ use_quick_lora = lora_config .use_quick_lora ,
840+ )
841+ # Lora column parallel will spilt lora B matrix
842+ self .add_lora_split_mapping (module_name + ".lora_B" , is_column = True )
843+
844+ # for lora qat
845+ if self .lora_config .do_qat :
846+ self .add_lora_split_mapping (module_name + ".weight_quanter._scale" , is_column = True )
847+ self .add_lora_split_mapping (module_name + ".activation_quanter._scale" , is_column = False )
848+ self .add_lora_split_mapping (module_name + ".activation_quanter.quanter._scale" , is_column = False )
715849 elif isinstance (module , QuantizationLinear ):
716850 lora_module = QuantizationLoRALinear (module , lora_config )
717851 elif isinstance (module , ColumnParallelQuantizationLinear ):
0 commit comments