1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15-
15+ import paddle
1616import paddle .distributed as dist
1717from paddle .distributed .auto_parallel .intermediate .tensor_parallel import (
1818 PrepareLayerInput ,
19+ PrepareLayerOutput ,
1920)
2021
2122
@@ -24,10 +25,12 @@ def hook(layer, inputs, output=None):
2425 res_inputs = []
2526 for input in inputs :
2627 if not input .is_dist ():
27- x = dist .shard_tensor (input , process_mesh , [dist .Shard (0 ), dist .Replicate ()])
28- res_inputs .append (dist .reshard (x , process_mesh , [dist .Shard (0 ), dist .Replicate ()]))
28+ x = dist .shard_tensor (input , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist . Replicate () ])
29+ res_inputs .append (dist .reshard (x , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist . Replicate () ]))
2930 else :
30- res_inputs .append (dist .reshard (input , process_mesh , [dist .Shard (0 ), dist .Replicate ()]))
31+ res_inputs .append (
32+ dist .reshard (input , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist .Replicate ()])
33+ )
3134 return tuple (res_inputs )
3235
3336 return hook
@@ -38,10 +41,10 @@ def hook(layer, inputs, output=None):
3841 res_inputs = []
3942 for input in inputs :
4043 if not input .is_dist ():
41- x = dist .shard_tensor (input , process_mesh , [dist .Shard (0 ), dist .Shard (1 )])
42- res_inputs .append (dist .reshard (x , process_mesh , [dist .Shard (0 ), dist .Shard (1 )]))
44+ x = dist .shard_tensor (input , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist . Shard (1 )])
45+ res_inputs .append (dist .reshard (x , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist . Shard (1 )]))
4346 else :
44- res_inputs .append (dist .reshard (input , process_mesh , [dist .Shard (0 ), dist .Shard (1 )]))
47+ res_inputs .append (dist .reshard (input , process_mesh , [dist .Shard (0 ), dist .Replicate (), dist . Shard (1 )]))
4548 return tuple (res_inputs )
4649
4750 return hook
@@ -52,18 +55,75 @@ def hook(layer, inputs, output=None):
5255 res_inputs = []
5356 for input in inputs :
5457 if not input .is_dist ():
55- x = dist .shard_tensor (input , process_mesh , [dist .Replicate (), dist .Replicate ()])
56- res_inputs .append (dist .reshard (x , process_mesh , [dist .Replicate (), dist .Replicate ()]))
58+ x = dist .shard_tensor (input , process_mesh , [dist .Replicate (), dist .Replicate (), dist .Replicate ()])
59+ res_inputs .append (
60+ dist .reshard (x , process_mesh , [dist .Replicate (), dist .Replicate (), dist .Replicate ()])
61+ )
5762 else :
5863 res_inputs .append (dist .reshard (input , process_mesh , [dist .Replicate (), dist .Replicate ()]))
5964 return tuple (res_inputs )
6065
6166 return hook
6267
6368
64- def auto_dist_config (self , prefix = "" ):
69+ def layer_input_rope_hook (process_mesh ):
70+ def hook (layer , inputs , output = None ):
71+ res_inputs = []
72+ batch_size = None
73+ seq_length = None
74+ process_mesh = None
75+ placements = None
76+ for index in range (len (inputs )):
77+ if index == 0 :
78+ batch_size , seq_length , _ , _ = inputs [index ]._local_shape
79+ process_mesh = inputs [index ].process_mesh
80+ placements = inputs [index ].placements
81+ # process position_ids
82+ if index == len (inputs ) - 1 :
83+ mesh = dist .auto_parallel .get_mesh ()
84+ assert "sep" in mesh .dim_names , f"mesh.dim_names:{ mesh .dim_names } must contain sep"
85+ group = mesh ._get_group ("sep" )
86+ chunk_size = seq_length // 2
87+ chunk_num = group .nranks * 2
88+ rank = group .rank
89+ first_chunk_ids = paddle .arange (rank * chunk_size , (rank + 1 ) * chunk_size , dtype = "int64" )
90+ second_chunk_ids = paddle .arange (
91+ (chunk_num - rank - 1 ) * chunk_size , (chunk_num - rank ) * chunk_size , dtype = "int64"
92+ )
93+ position_ids = paddle .concat ([first_chunk_ids , second_chunk_ids ]).expand ((batch_size , seq_length ))
94+ mp_axis = process_mesh .dim_names .index ("mp" )
95+ placements [mp_axis ] = dist .Replicate () # mp placament shard(2) -> replicate
96+ position_ids = dist .auto_parallel .api .dtensor_from_local (position_ids , process_mesh , placements )
97+ res_inputs .append (position_ids )
98+ else :
99+ res_inputs .append (inputs [index ])
100+ return tuple (res_inputs )
101+
102+ return hook
103+
104+
105+ def layer_output_rope_hook (process_mesh ):
106+ def hook (layer , inputs , outputs ):
107+ res_outputs = []
108+ for output in outputs :
109+ process_mesh = output .process_mesh
110+ placements = output .placements
111+ cp_index = process_mesh .dim_names .index ("sep" ) # get the axis for the split
112+ cp_degree = process_mesh .shape [cp_index ]
113+ assert cp_degree > 1 , f"cp_degree:{ cp_degree } must > 1"
114+ placements [cp_index ] = dist .Shard (1 ) # seq_dim:1
115+ output = dist .reshard (output , process_mesh , placements )
116+ res_outputs .append (output )
117+ return tuple (res_outputs )
118+
119+ return hook
120+
121+
122+ def get_dist_config (model , prefix = "" ):
123+ """Generate distributed configuration for Llama model"""
65124 if prefix != "" :
66125 assert prefix .endswith ("." )
126+
67127 config = {
68128 "sp_config" : {
69129 "parallelize_plan" : {
@@ -108,6 +168,30 @@ def auto_dist_config(self, prefix=""):
108168 }
109169 },
110170 "pp_config" : {"split_spec" : f"{ prefix } llama.layers" , "global_spec" : f"{ prefix } llama.global_layer" },
171+ "cp_config" : {
172+ "parallelize_plan" : {
173+ f"{ prefix } llama.layers.*.self_attn.sdpa" : dist .ContextParallel (
174+ backend = "p2p" if model .config .context_parallel_degree > 1 else "all2all"
175+ ),
176+ }
177+ },
111178 }
112179
180+ if model .config .context_parallel_degree > 1 :
181+ config ["cp_config" ]["parallelize_plan" ].update (
182+ {
183+ f"{ prefix } llama.layers.*.self_attn.rope_func" : [
184+ PrepareLayerInput (layer_input_rope_hook ),
185+ PrepareLayerOutput (layer_output_rope_hook ),
186+ ]
187+ }
188+ )
189+ elif model .config .sep_parallel_degree > 1 :
190+ # fuse_rope is not support dtensor spmd yet,thus need to extraly reshard sequence dim
191+ config ["cp_config" ]["parallelize_plan" ].update (
192+ {
193+ f"{ prefix } llama.layers.*.self_attn.rope_func" : PrepareLayerOutput (layer_output_rope_hook ),
194+ }
195+ )
196+
113197 return config
0 commit comments