Skip to content

Commit c57e5b1

Browse files
waliwali777sevenan2
authored andcommitted
update auto dist config
1 parent ef7b36e commit c57e5b1

File tree

1 file changed

+94
-10
lines changed

1 file changed

+94
-10
lines changed

paddleformers/transformers/llama/auto_dist_config.py

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
import paddle
1616
import paddle.distributed as dist
1717
from paddle.distributed.auto_parallel.intermediate.tensor_parallel import (
1818
PrepareLayerInput,
19+
PrepareLayerOutput,
1920
)
2021

2122

@@ -24,10 +25,12 @@ def hook(layer, inputs, output=None):
2425
res_inputs = []
2526
for input in inputs:
2627
if not input.is_dist():
27-
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate()])
28-
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate()]))
28+
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()])
29+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()]))
2930
else:
30-
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate()]))
31+
res_inputs.append(
32+
dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Replicate()])
33+
)
3134
return tuple(res_inputs)
3235

3336
return hook
@@ -38,10 +41,10 @@ def hook(layer, inputs, output=None):
3841
res_inputs = []
3942
for input in inputs:
4043
if not input.is_dist():
41-
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Shard(1)])
42-
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Shard(1)]))
44+
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)])
45+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)]))
4346
else:
44-
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Shard(1)]))
47+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate(), dist.Shard(1)]))
4548
return tuple(res_inputs)
4649

4750
return hook
@@ -52,18 +55,75 @@ def hook(layer, inputs, output=None):
5255
res_inputs = []
5356
for input in inputs:
5457
if not input.is_dist():
55-
x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate()])
56-
res_inputs.append(dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate()]))
58+
x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate(), dist.Replicate()])
59+
res_inputs.append(
60+
dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate(), dist.Replicate()])
61+
)
5762
else:
5863
res_inputs.append(dist.reshard(input, process_mesh, [dist.Replicate(), dist.Replicate()]))
5964
return tuple(res_inputs)
6065

6166
return hook
6267

6368

64-
def auto_dist_config(self, prefix=""):
69+
def layer_input_rope_hook(process_mesh):
70+
def hook(layer, inputs, output=None):
71+
res_inputs = []
72+
batch_size = None
73+
seq_length = None
74+
process_mesh = None
75+
placements = None
76+
for index in range(len(inputs)):
77+
if index == 0:
78+
batch_size, seq_length, _, _ = inputs[index]._local_shape
79+
process_mesh = inputs[index].process_mesh
80+
placements = inputs[index].placements
81+
# process position_ids
82+
if index == len(inputs) - 1:
83+
mesh = dist.auto_parallel.get_mesh()
84+
assert "sep" in mesh.dim_names, f"mesh.dim_names:{mesh.dim_names} must contain sep"
85+
group = mesh._get_group("sep")
86+
chunk_size = seq_length // 2
87+
chunk_num = group.nranks * 2
88+
rank = group.rank
89+
first_chunk_ids = paddle.arange(rank * chunk_size, (rank + 1) * chunk_size, dtype="int64")
90+
second_chunk_ids = paddle.arange(
91+
(chunk_num - rank - 1) * chunk_size, (chunk_num - rank) * chunk_size, dtype="int64"
92+
)
93+
position_ids = paddle.concat([first_chunk_ids, second_chunk_ids]).expand((batch_size, seq_length))
94+
mp_axis = process_mesh.dim_names.index("mp")
95+
placements[mp_axis] = dist.Replicate() # mp placament shard(2) -> replicate
96+
position_ids = dist.auto_parallel.api.dtensor_from_local(position_ids, process_mesh, placements)
97+
res_inputs.append(position_ids)
98+
else:
99+
res_inputs.append(inputs[index])
100+
return tuple(res_inputs)
101+
102+
return hook
103+
104+
105+
def layer_output_rope_hook(process_mesh):
106+
def hook(layer, inputs, outputs):
107+
res_outputs = []
108+
for output in outputs:
109+
process_mesh = output.process_mesh
110+
placements = output.placements
111+
cp_index = process_mesh.dim_names.index("sep") # get the axis for the split
112+
cp_degree = process_mesh.shape[cp_index]
113+
assert cp_degree > 1, f"cp_degree:{cp_degree} must > 1"
114+
placements[cp_index] = dist.Shard(1) # seq_dim:1
115+
output = dist.reshard(output, process_mesh, placements)
116+
res_outputs.append(output)
117+
return tuple(res_outputs)
118+
119+
return hook
120+
121+
122+
def get_dist_config(model, prefix=""):
123+
"""Generate distributed configuration for Llama model"""
65124
if prefix != "":
66125
assert prefix.endswith(".")
126+
67127
config = {
68128
"sp_config": {
69129
"parallelize_plan": {
@@ -108,6 +168,30 @@ def auto_dist_config(self, prefix=""):
108168
}
109169
},
110170
"pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"},
171+
"cp_config": {
172+
"parallelize_plan": {
173+
f"{prefix}llama.layers.*.self_attn.sdpa": dist.ContextParallel(
174+
backend="p2p" if model.config.context_parallel_degree > 1 else "all2all"
175+
),
176+
}
177+
},
111178
}
112179

180+
if model.config.context_parallel_degree > 1:
181+
config["cp_config"]["parallelize_plan"].update(
182+
{
183+
f"{prefix}llama.layers.*.self_attn.rope_func": [
184+
PrepareLayerInput(layer_input_rope_hook),
185+
PrepareLayerOutput(layer_output_rope_hook),
186+
]
187+
}
188+
)
189+
elif model.config.sep_parallel_degree > 1:
190+
# fuse_rope is not support dtensor spmd yet,thus need to extraly reshard sequence dim
191+
config["cp_config"]["parallelize_plan"].update(
192+
{
193+
f"{prefix}llama.layers.*.self_attn.rope_func": PrepareLayerOutput(layer_output_rope_hook),
194+
}
195+
)
196+
113197
return config

0 commit comments

Comments
 (0)