Skip to content

Commit 17b0de7

Browse files
authored
add mask for run_pretrain.py data (#3152)
1 parent f719f14 commit 17b0de7

File tree

1 file changed

+50
-10
lines changed

1 file changed

+50
-10
lines changed

examples/experiments/paddlefleet/run_pretrain.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import copy
1514
import math
1615
import os
1716
import random
@@ -127,7 +126,7 @@ class DataArguments:
127126
split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."})
128127

129128
max_seq_length: int = field(
130-
default=1024,
129+
default=8192,
131130
metadata={
132131
"help": "The maximum total input sequence length after tokenization. Sequences longer "
133132
"than this will be truncated, sequences shorter will be padded."
@@ -180,11 +179,15 @@ class ModelArguments:
180179
default=None,
181180
metadata={"help": "num_hidden_layers."},
182181
)
182+
use_global_causal_attn: bool = field(
183+
default=False, metadata={"help": "Whether to use global causal attention in packing data"}
184+
)
183185

184186

185187
def create_pretrained_dataset(
186188
data_args,
187189
training_args,
190+
model_args,
188191
data_file,
189192
tokenizer,
190193
need_data=True,
@@ -234,16 +237,52 @@ def print_dataset(data, mode="train"):
234237

235238
from paddleformers.data import Stack
236239

237-
def _collate_data(data, stack_fn=Stack()):
238-
tokens_ = stack_fn([x["text"] for x in data])
240+
def _collate_data(batch, stack_fn=Stack()):
241+
# origin no mask data
242+
# tokens_ = stack_fn([x["text"] for x in batch])
243+
244+
# labels = copy.deepcopy(tokens_)[:, 1:]
245+
# tokens = tokens_[:, :-1]
246+
247+
# return {
248+
# "input_ids": tokens,
249+
# "labels": labels,
250+
# }
251+
252+
# data with attn_mask_startend_row_indices for flashmask
253+
input_keys = ["input_ids", "labels", "position_ids", "attn_mask_startend_row_indices"]
254+
return_list = []
255+
for batch_sequence in batch:
256+
# tokens
257+
padded_token_ids = np.array([batch_sequence["text"][:-1]])
258+
# labels
259+
padded_labels = np.array([batch_sequence["text"][1:]])
260+
# position_ids
261+
padded_position_ids = np.array([sum(batch_sequence["position_ids"], [])[:-1]])
262+
return_list.append(
263+
[
264+
padded_token_ids,
265+
padded_labels,
266+
padded_position_ids,
267+
]
268+
)
269+
# attn mask
270+
oral_position_ids = batch_sequence["position_ids"]
271+
from paddleformers.datasets.collate import (
272+
gen_attn_mask_startend_row_indices,
273+
)
239274

240-
labels = copy.deepcopy(tokens_)[:, 1:]
241-
tokens = tokens_[:, :-1]
275+
return_list[-1].append(
276+
gen_attn_mask_startend_row_indices(
277+
oral_position_ids,
278+
data_args.max_seq_length + training_args.num_nextn_predict_layers,
279+
model_args.use_global_causal_attn,
280+
)[:, :, :-1, :]
281+
)
242282

243-
return {
244-
"input_ids": tokens,
245-
"labels": labels,
246-
}
283+
return_list = [np.concatenate(tensor_list) for tensor_list in zip(*return_list)]
284+
input_dict = dict(zip(input_keys, return_list))
285+
return input_dict
247286

248287
if need_data:
249288
if training_args.do_train:
@@ -523,6 +562,7 @@ def main():
523562
train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset(
524563
data_args,
525564
training_args,
565+
model_args,
526566
data_file,
527567
tokenizer,
528568
need_data=training_args.should_load_dataset,

0 commit comments

Comments
 (0)