|
19 | 19 | import os |
20 | 20 | from functools import partial |
21 | 21 |
|
| 22 | +import numpy as np |
22 | 23 | import paddle |
23 | 24 |
|
24 | 25 | is_sm90 = ( |
|
74 | 75 | ) |
75 | 76 |
|
76 | 77 |
|
77 | | -def create_pretrained_dataset(training_args, data_args): |
| 78 | +def create_pretrained_dataset(training_args, data_args, model_args): |
78 | 79 | assert data_args.input_dir is not None and len(data_args.input_dir.split()) > 1 |
79 | 80 |
|
80 | 81 | check_data_split( |
@@ -114,16 +115,40 @@ def create_pretrained_dataset(training_args, data_args): |
114 | 115 |
|
115 | 116 | from paddleformers.data import Stack |
116 | 117 |
|
117 | | - def _collate_data(data, stack_fn=Stack()): |
118 | | - tokens_ = stack_fn([x["text"] for x in data]) |
| 118 | + def _collate_data(batch, stack_fn=Stack()): |
| 119 | + input_keys = ["input_ids", "labels", "position_ids", "attn_mask_startend_row_indices"] |
| 120 | + return_list = [] |
| 121 | + for batch_sequence in batch: |
| 122 | + # tokens |
| 123 | + padded_token_ids = np.array([batch_sequence["text"][:-1]]) |
| 124 | + # labels |
| 125 | + padded_labels = np.array([batch_sequence["text"][1:]]) |
| 126 | + # position_ids |
| 127 | + padded_position_ids = np.array([sum(batch_sequence["position_ids"], [])[:-1]]) |
| 128 | + return_list.append( |
| 129 | + [ |
| 130 | + padded_token_ids, |
| 131 | + padded_labels, |
| 132 | + padded_position_ids, |
| 133 | + ] |
| 134 | + ) |
| 135 | + # attn mask |
| 136 | + oral_position_ids = batch_sequence["position_ids"] |
| 137 | + from paddleformers.datasets.collate import ( |
| 138 | + gen_attn_mask_startend_row_indices, |
| 139 | + ) |
119 | 140 |
|
120 | | - labels = tokens_[:, 1:] |
121 | | - tokens = tokens_[:, :-1] |
| 141 | + return_list[-1].append( |
| 142 | + gen_attn_mask_startend_row_indices( |
| 143 | + oral_position_ids, |
| 144 | + data_args.max_seq_len + training_args.num_nextn_predict_layers, |
| 145 | + model_args.use_global_causal_attn, |
| 146 | + )[:, :, :-1, :] |
| 147 | + ) |
122 | 148 |
|
123 | | - return { |
124 | | - "input_ids": tokens, |
125 | | - "labels": labels, |
126 | | - } |
| 149 | + return_list = [np.concatenate(tensor_list) for tensor_list in zip(*return_list)] |
| 150 | + input_dict = dict(zip(input_keys, return_list)) |
| 151 | + return input_dict |
127 | 152 |
|
128 | 153 | return train_dataset, valid_dataset, test_dataset, _collate_data |
129 | 154 |
|
@@ -337,7 +362,9 @@ def neft_post_hook(module, input, output): |
337 | 362 |
|
338 | 363 | if data_args.dataset_type == "pretrain": |
339 | 364 | training_args.test_iters = training_args.eval_iters * 10 |
340 | | - train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset(training_args, data_args) |
| 365 | + train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( |
| 366 | + training_args, data_args, model_args |
| 367 | + ) |
341 | 368 | else: |
342 | 369 | train_dataset = create_dataset_sft( |
343 | 370 | task_group=data_args.train_dataset_path, |
|
0 commit comments