|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -import copy |
15 | 14 | import math |
16 | 15 | import os |
17 | 16 | import random |
@@ -127,7 +126,7 @@ class DataArguments: |
127 | 126 | split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) |
128 | 127 |
|
129 | 128 | max_seq_length: int = field( |
130 | | - default=1024, |
| 129 | + default=8192, |
131 | 130 | metadata={ |
132 | 131 | "help": "The maximum total input sequence length after tokenization. Sequences longer " |
133 | 132 | "than this will be truncated, sequences shorter will be padded." |
@@ -180,11 +179,15 @@ class ModelArguments: |
180 | 179 | default=None, |
181 | 180 | metadata={"help": "num_hidden_layers."}, |
182 | 181 | ) |
| 182 | + use_global_causal_attn: bool = field( |
| 183 | + default=False, metadata={"help": "Whether to use global causal attention in packing data"} |
| 184 | + ) |
183 | 185 |
|
184 | 186 |
|
185 | 187 | def create_pretrained_dataset( |
186 | 188 | data_args, |
187 | 189 | training_args, |
| 190 | + model_args, |
188 | 191 | data_file, |
189 | 192 | tokenizer, |
190 | 193 | need_data=True, |
@@ -234,16 +237,52 @@ def print_dataset(data, mode="train"): |
234 | 237 |
|
235 | 238 | from paddleformers.data import Stack |
236 | 239 |
|
237 | | - def _collate_data(data, stack_fn=Stack()): |
238 | | - tokens_ = stack_fn([x["text"] for x in data]) |
| 240 | + def _collate_data(batch, stack_fn=Stack()): |
| 241 | + # origin no mask data |
| 242 | + # tokens_ = stack_fn([x["text"] for x in batch]) |
| 243 | + |
| 244 | + # labels = copy.deepcopy(tokens_)[:, 1:] |
| 245 | + # tokens = tokens_[:, :-1] |
| 246 | + |
| 247 | + # return { |
| 248 | + # "input_ids": tokens, |
| 249 | + # "labels": labels, |
| 250 | + # } |
| 251 | + |
| 252 | + # data with attn_mask_startend_row_indices for flashmask |
| 253 | + input_keys = ["input_ids", "labels", "position_ids", "attn_mask_startend_row_indices"] |
| 254 | + return_list = [] |
| 255 | + for batch_sequence in batch: |
| 256 | + # tokens |
| 257 | + padded_token_ids = np.array([batch_sequence["text"][:-1]]) |
| 258 | + # labels |
| 259 | + padded_labels = np.array([batch_sequence["text"][1:]]) |
| 260 | + # position_ids |
| 261 | + padded_position_ids = np.array([sum(batch_sequence["position_ids"], [])[:-1]]) |
| 262 | + return_list.append( |
| 263 | + [ |
| 264 | + padded_token_ids, |
| 265 | + padded_labels, |
| 266 | + padded_position_ids, |
| 267 | + ] |
| 268 | + ) |
| 269 | + # attn mask |
| 270 | + oral_position_ids = batch_sequence["position_ids"] |
| 271 | + from paddleformers.datasets.collate import ( |
| 272 | + gen_attn_mask_startend_row_indices, |
| 273 | + ) |
239 | 274 |
|
240 | | - labels = copy.deepcopy(tokens_)[:, 1:] |
241 | | - tokens = tokens_[:, :-1] |
| 275 | + return_list[-1].append( |
| 276 | + gen_attn_mask_startend_row_indices( |
| 277 | + oral_position_ids, |
| 278 | + data_args.max_seq_length + training_args.num_nextn_predict_layers, |
| 279 | + model_args.use_global_causal_attn, |
| 280 | + )[:, :, :-1, :] |
| 281 | + ) |
242 | 282 |
|
243 | | - return { |
244 | | - "input_ids": tokens, |
245 | | - "labels": labels, |
246 | | - } |
| 283 | + return_list = [np.concatenate(tensor_list) for tensor_list in zip(*return_list)] |
| 284 | + input_dict = dict(zip(input_keys, return_list)) |
| 285 | + return input_dict |
247 | 286 |
|
248 | 287 | if need_data: |
249 | 288 | if training_args.do_train: |
@@ -523,6 +562,7 @@ def main(): |
523 | 562 | train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( |
524 | 563 | data_args, |
525 | 564 | training_args, |
| 565 | + model_args, |
526 | 566 | data_file, |
527 | 567 | tokenizer, |
528 | 568 | need_data=training_args.should_load_dataset, |
|
0 commit comments