Skip to content

Commit 0b2113c

Browse files
authored
预训练在线+离线数据流添加attn mask传入 (#3137)
1 parent d352200 commit 0b2113c

File tree

3 files changed

+55
-21
lines changed

3 files changed

+55
-21
lines changed

paddleformers/cli/train/sft/workflow.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
from functools import partial
2121

22+
import numpy as np
2223
import paddle
2324

2425
is_sm90 = (
@@ -74,7 +75,7 @@
7475
)
7576

7677

77-
def create_pretrained_dataset(training_args, data_args):
78+
def create_pretrained_dataset(training_args, data_args, model_args):
7879
assert data_args.input_dir is not None and len(data_args.input_dir.split()) > 1
7980

8081
check_data_split(
@@ -114,16 +115,40 @@ def create_pretrained_dataset(training_args, data_args):
114115

115116
from paddleformers.data import Stack
116117

117-
def _collate_data(data, stack_fn=Stack()):
118-
tokens_ = stack_fn([x["text"] for x in data])
118+
def _collate_data(batch, stack_fn=Stack()):
119+
input_keys = ["input_ids", "labels", "position_ids", "attn_mask_startend_row_indices"]
120+
return_list = []
121+
for batch_sequence in batch:
122+
# tokens
123+
padded_token_ids = np.array([batch_sequence["text"][:-1]])
124+
# labels
125+
padded_labels = np.array([batch_sequence["text"][1:]])
126+
# position_ids
127+
padded_position_ids = np.array([sum(batch_sequence["position_ids"], [])[:-1]])
128+
return_list.append(
129+
[
130+
padded_token_ids,
131+
padded_labels,
132+
padded_position_ids,
133+
]
134+
)
135+
# attn mask
136+
oral_position_ids = batch_sequence["position_ids"]
137+
from paddleformers.datasets.collate import (
138+
gen_attn_mask_startend_row_indices,
139+
)
119140

120-
labels = tokens_[:, 1:]
121-
tokens = tokens_[:, :-1]
141+
return_list[-1].append(
142+
gen_attn_mask_startend_row_indices(
143+
oral_position_ids,
144+
data_args.max_seq_len + training_args.num_nextn_predict_layers,
145+
model_args.use_global_causal_attn,
146+
)[:, :, :-1, :]
147+
)
122148

123-
return {
124-
"input_ids": tokens,
125-
"labels": labels,
126-
}
149+
return_list = [np.concatenate(tensor_list) for tensor_list in zip(*return_list)]
150+
input_dict = dict(zip(input_keys, return_list))
151+
return input_dict
127152

128153
return train_dataset, valid_dataset, test_dataset, _collate_data
129154

@@ -337,7 +362,9 @@ def neft_post_hook(module, input, output):
337362

338363
if data_args.dataset_type == "pretrain":
339364
training_args.test_iters = training_args.eval_iters * 10
340-
train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset(training_args, data_args)
365+
train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset(
366+
training_args, data_args, model_args
367+
)
341368
else:
342369
train_dataset = create_dataset_sft(
343370
task_group=data_args.train_dataset_path,

paddleformers/data/causal_dataset.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,10 @@ def __getitem__(self, idx):
443443
sample, mask = self.indexed_dataset.get(
444444
self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1
445445
)
446+
447+
# position_ids
448+
all_position_ids = []
449+
all_position_ids.append(list(range(len(sample))))
446450
else:
447451
# Otherwise, get the rest of the initial document.
448452
doc_ids.append(self.doc_idx[doc_index_f])
@@ -468,6 +472,10 @@ def __getitem__(self, idx):
468472
sample_list.append(sample)
469473
if append_mask:
470474
mask_list.append(mask)
475+
# position_ids
476+
all_position_ids = []
477+
for item in sample_list:
478+
all_position_ids.append(list(range(len(item))))
471479
sample = np.concatenate(sample_list)
472480
if append_mask:
473481
mask = np.concatenate(mask_list)
@@ -505,6 +513,8 @@ def __getitem__(self, idx):
505513
"CPT": self.CPT,
506514
}
507515

516+
res.update({"position_ids": all_position_ids})
517+
508518
return res
509519

510520

paddleformers/datasets/collate.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,17 +208,14 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, training_args, model_args
208208
padded_nbatch_pack_offset = pad_batch_data([nbatch_pack_offset], pad_idx=0, max_seq_len=max_seq_len)
209209
return_list[-1].append(padded_nbatch_pack_offset)
210210

211-
if not model_args.stage.lower() == "pt":
212-
if model_args.use_attn_mask_startend_row_indices:
213-
return_list[-1].append(
214-
gen_attn_mask_startend_row_indices(
215-
original_token_ids, max_seq_len, model_args.use_global_causal_attn
216-
)
217-
)
218-
else:
219-
return_list[-1].append(
220-
gen_self_attn_mask(original_token_ids, max_seq_len, model_args.use_global_causal_attn)
221-
)
211+
if model_args.use_attn_mask_startend_row_indices:
212+
return_list[-1].append(
213+
gen_attn_mask_startend_row_indices(original_token_ids, max_seq_len, model_args.use_global_causal_attn)
214+
)
215+
else:
216+
return_list[-1].append(
217+
gen_self_attn_mask(original_token_ids, max_seq_len, model_args.use_global_causal_attn)
218+
)
222219

223220
return_list = [np.concatenate(tensor_list) for tensor_list in zip(*return_list)]
224221
input_dict = dict(zip(input_keys, return_list))

0 commit comments

Comments
 (0)