From 6ca76ae2521ccaaaa4980f138d9ba5454afefcd4 Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Fri, 23 Feb 2024 16:14:38 -0700 Subject: [PATCH 1/3] support datasets with input and output and no templates Co-authored-by: Alex-Brooks --- trl/trainer/sft_trainer.py | 69 ++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 7 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c1e085a51e1..c2f96bfc90b 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -28,6 +28,7 @@ AutoTokenizer, DataCollator, DataCollatorForLanguageModeling, + DataCollatorForSeq2Seq, PreTrainedModel, PreTrainedTokenizerBase, Trainer, @@ -169,6 +170,12 @@ def __init__( "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." ) + # TODO: think about this error handling and if we want to enforce seq2seq collator + if not packing and formatting_func is None and dataset_text_field is None and data_collator is not None and not isinstance(data_collator, DataCollatorForSeq2Seq): + raise ValueError( + "If no formatting_func / dataset_text_field provided, the data_collator should be a `DataCollatorForSeq2Seq` object" + ) + if is_peft_available() and peft_config is not None: if not isinstance(peft_config, PeftConfig): raise ValueError( @@ -245,14 +252,14 @@ def make_inputs_require_grad(module, input, output): # if not stays #None formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) + requires_input_output_keys = False if not packing: - if dataset_text_field is None and formatting_func is None: - raise ValueError( - "You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument." - ) - + requires_input_output_keys = (dataset_text_field is None and formatting_func is None) if data_collator is None: - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + # Fall back to the appropriate collator type based on the input_output_keys + data_collator = (DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True) + if requires_input_output_keys + else DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)) # Pre-process the datasets only once per node. The remaining processes will use the cache. with PartialState().local_main_process_first(): @@ -269,6 +276,7 @@ def make_inputs_require_grad(module, input, output): num_of_sequences, chars_per_token, remove_unused_columns=args.remove_unused_columns if args is not None else True, + requires_input_output_keys=requires_input_output_keys, **dataset_kwargs, ) if eval_dataset is not None: @@ -365,6 +373,7 @@ def _prepare_dataset( num_of_sequences, chars_per_token, remove_unused_columns=True, + requires_input_output_keys=False, append_concat_token=True, add_special_tokens=True, ): @@ -384,6 +393,7 @@ def _prepare_dataset( formatting_func, add_special_tokens, remove_unused_columns, + requires_input_output_keys, ) else: @@ -408,10 +418,43 @@ def _prepare_non_packed_dataloader( formatting_func=None, add_special_tokens=True, remove_unused_columns=True, + requires_input_output_keys=False, ): use_formatting_func = formatting_func is not None and dataset_text_field is None self._dataset_sanity_checked = False + # TODO : fix how EOS tokens are handled + # Inspired from https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L266 + def tokenize_input_output(element): + + eos_token = '' + if add_special_tokens: + eos_token = tokenizer.eos_token + tokenizer.eos_token=None + + new_source = [] + for (input_element, output_element) in zip(element['input'], element['output']): + if not input_element.endswith((' ', '\n', '\t')) and not output_element.startswith((' ', '\n', '\t')): + new_source.append(input_element + ' ' + output_element + eos_token) + else: + new_source.append(input_element + output_element + eos_token) + + tokenized_example = tokenizer(new_source, max_length=max_seq_length, truncation=True, padding=False, add_special_tokens=add_special_tokens) + input_ids = tokenized_example.input_ids + labels = input_ids + + # mask the prompt part for avoiding loss + tokenized_prompt = tokenizer(element['input'], max_length=max_seq_length, truncation=True, add_special_tokens=add_special_tokens) + + new_labels = [([-100] * len(tokenized_instance)) + label_instance[len(tokenized_instance):] for tokenized_instance,label_instance in zip(tokenized_prompt.input_ids, labels) ] + attention_mask = tokenized_example.attention_mask + + return { + 'input_ids': input_ids, + 'labels': new_labels, + 'attention_mask': attention_mask, + } + # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt def tokenize(element): outputs = tokenizer( @@ -444,8 +487,20 @@ def tokenize(element): f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns." ) + if requires_input_output_keys: + if "input" in dataset.column_names and "output" in dataset.column_names: + # TODO: if we execute this input path, it is expected that we are using a seq2seq + # collator. If that is the case, the tokenizer should had a pad_token; this is set + # to eos automatically if it's unset and no tokenizer is provided, but we should + # properly handle if a tokenizer with no padding token is given. + tokenize_func = tokenize_input_output + else: + raise KeyError("Missing input / output keys") + else: + tokenize_func = tokenize + tokenized_dataset = dataset.map( - tokenize, + tokenize_func, batched=True, remove_columns=dataset.column_names if remove_unused_columns else None, num_proc=self.dataset_num_proc, From 7712f023b051aa09aab87919ed2b0a704a8acf08 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 23 Feb 2024 17:31:25 -0600 Subject: [PATCH 2/3] Add tests for sft trainer with no packing / format func / data text field Signed-off-by: Alex-Brooks --- tests/test_sft_trainer.py | 40 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 8bc9033d2f0..35857524669 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -20,7 +20,7 @@ import pytest import torch from datasets import Dataset -from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments +from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq, TrainingArguments from trl import SFTTrainer from trl.import_utils import is_peft_available @@ -486,6 +486,44 @@ def test_sft_trainer_with_model(self): assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1") + # Tests for no packing, with no formatting func or dataset_text_field + # If no input/output cols exist, we should throw a KeyError + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + with pytest.raises(KeyError): + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field=None, + formatting_func=None, + max_seq_length=16, + packing=False, + ) + + # if we have input/output cols, then things should work without issue + dataset_with_input_output = self.dummy_dataset.rename_column("question", "input").rename_column("answer", "output") + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=dataset_with_input_output, + dataset_text_field=None, + formatting_func=None, + max_seq_length=16, + packing=False, + ) + assert isinstance(trainer.data_collator, DataCollatorForSeq2Seq) + trainer.train() + assert trainer.state.log_history[(-1)]["train_loss"] is not None + assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1") + def test_sft_trainer_with_multiple_eval_datasets(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( From fb53ce628c1af21ff53185a51924f2288068590d Mon Sep 17 00:00:00 2001 From: Sukriti-Sharma4 Date: Sun, 25 Feb 2024 22:08:07 -0700 Subject: [PATCH 3/3] remove handling special tokens Signed-off-by: Sukriti-Sharma4 --- trl/trainer/sft_trainer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c2f96bfc90b..f79f2693b9e 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -427,24 +427,28 @@ def _prepare_non_packed_dataloader( # Inspired from https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L266 def tokenize_input_output(element): - eos_token = '' + # It is difficult to add special tokens here, as separator / EOS tokens that may be added while tokenizing + # input texts may differ from concatenated text, making masking on input length incorrect. + # EOS and BOS tokens can be added to input / output texts beforehand by user if needed. + # TODO: we may need to change default of add_special_tokens to False. if add_special_tokens: - eos_token = tokenizer.eos_token - tokenizer.eos_token=None + warnings.warn( + "Add special tokens is not supported for this type of data format. Hence flag will be ignored." + ) new_source = [] for (input_element, output_element) in zip(element['input'], element['output']): if not input_element.endswith((' ', '\n', '\t')) and not output_element.startswith((' ', '\n', '\t')): - new_source.append(input_element + ' ' + output_element + eos_token) + new_source.append(input_element + ' ' + output_element) else: - new_source.append(input_element + output_element + eos_token) + new_source.append(input_element + output_element) - tokenized_example = tokenizer(new_source, max_length=max_seq_length, truncation=True, padding=False, add_special_tokens=add_special_tokens) + tokenized_example = tokenizer(new_source, max_length=max_seq_length, truncation=True, padding=False) input_ids = tokenized_example.input_ids labels = input_ids # mask the prompt part for avoiding loss - tokenized_prompt = tokenizer(element['input'], max_length=max_seq_length, truncation=True, add_special_tokens=add_special_tokens) + tokenized_prompt = tokenizer(element['input'], max_length=max_seq_length, truncation=True, padding=False) new_labels = [([-100] * len(tokenized_instance)) + label_instance[len(tokenized_instance):] for tokenized_instance,label_instance in zip(tokenized_prompt.input_ids, labels) ] attention_mask = tokenized_example.attention_mask