diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 8bc9033d2f0..35857524669 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -20,7 +20,7 @@ import pytest import torch from datasets import Dataset -from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments +from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq, TrainingArguments from trl import SFTTrainer from trl.import_utils import is_peft_available @@ -486,6 +486,44 @@ def test_sft_trainer_with_model(self): assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1") + # Tests for no packing, with no formatting func or dataset_text_field + # If no input/output cols exist, we should throw a KeyError + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + dataloader_drop_last=True, + evaluation_strategy="steps", + max_steps=2, + save_steps=1, + per_device_train_batch_size=2, + ) + with pytest.raises(KeyError): + _ = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + dataset_text_field=None, + formatting_func=None, + max_seq_length=16, + packing=False, + ) + + # if we have input/output cols, then things should work without issue + dataset_with_input_output = self.dummy_dataset.rename_column("question", "input").rename_column("answer", "output") + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=dataset_with_input_output, + dataset_text_field=None, + formatting_func=None, + max_seq_length=16, + packing=False, + ) + assert isinstance(trainer.data_collator, DataCollatorForSeq2Seq) + trainer.train() + assert trainer.state.log_history[(-1)]["train_loss"] is not None + assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1") + def test_sft_trainer_with_multiple_eval_datasets(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c1e085a51e1..f79f2693b9e 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -28,6 +28,7 @@ AutoTokenizer, DataCollator, DataCollatorForLanguageModeling, + DataCollatorForSeq2Seq, PreTrainedModel, PreTrainedTokenizerBase, Trainer, @@ -169,6 +170,12 @@ def __init__( "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." ) + # TODO: think about this error handling and if we want to enforce seq2seq collator + if not packing and formatting_func is None and dataset_text_field is None and data_collator is not None and not isinstance(data_collator, DataCollatorForSeq2Seq): + raise ValueError( + "If no formatting_func / dataset_text_field provided, the data_collator should be a `DataCollatorForSeq2Seq` object" + ) + if is_peft_available() and peft_config is not None: if not isinstance(peft_config, PeftConfig): raise ValueError( @@ -245,14 +252,14 @@ def make_inputs_require_grad(module, input, output): # if not stays #None formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) + requires_input_output_keys = False if not packing: - if dataset_text_field is None and formatting_func is None: - raise ValueError( - "You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument." - ) - + requires_input_output_keys = (dataset_text_field is None and formatting_func is None) if data_collator is None: - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + # Fall back to the appropriate collator type based on the input_output_keys + data_collator = (DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True) + if requires_input_output_keys + else DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)) # Pre-process the datasets only once per node. The remaining processes will use the cache. with PartialState().local_main_process_first(): @@ -269,6 +276,7 @@ def make_inputs_require_grad(module, input, output): num_of_sequences, chars_per_token, remove_unused_columns=args.remove_unused_columns if args is not None else True, + requires_input_output_keys=requires_input_output_keys, **dataset_kwargs, ) if eval_dataset is not None: @@ -365,6 +373,7 @@ def _prepare_dataset( num_of_sequences, chars_per_token, remove_unused_columns=True, + requires_input_output_keys=False, append_concat_token=True, add_special_tokens=True, ): @@ -384,6 +393,7 @@ def _prepare_dataset( formatting_func, add_special_tokens, remove_unused_columns, + requires_input_output_keys, ) else: @@ -408,10 +418,47 @@ def _prepare_non_packed_dataloader( formatting_func=None, add_special_tokens=True, remove_unused_columns=True, + requires_input_output_keys=False, ): use_formatting_func = formatting_func is not None and dataset_text_field is None self._dataset_sanity_checked = False + # TODO : fix how EOS tokens are handled + # Inspired from https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L266 + def tokenize_input_output(element): + + # It is difficult to add special tokens here, as separator / EOS tokens that may be added while tokenizing + # input texts may differ from concatenated text, making masking on input length incorrect. + # EOS and BOS tokens can be added to input / output texts beforehand by user if needed. + # TODO: we may need to change default of add_special_tokens to False. + if add_special_tokens: + warnings.warn( + "Add special tokens is not supported for this type of data format. Hence flag will be ignored." + ) + + new_source = [] + for (input_element, output_element) in zip(element['input'], element['output']): + if not input_element.endswith((' ', '\n', '\t')) and not output_element.startswith((' ', '\n', '\t')): + new_source.append(input_element + ' ' + output_element) + else: + new_source.append(input_element + output_element) + + tokenized_example = tokenizer(new_source, max_length=max_seq_length, truncation=True, padding=False) + input_ids = tokenized_example.input_ids + labels = input_ids + + # mask the prompt part for avoiding loss + tokenized_prompt = tokenizer(element['input'], max_length=max_seq_length, truncation=True, padding=False) + + new_labels = [([-100] * len(tokenized_instance)) + label_instance[len(tokenized_instance):] for tokenized_instance,label_instance in zip(tokenized_prompt.input_ids, labels) ] + attention_mask = tokenized_example.attention_mask + + return { + 'input_ids': input_ids, + 'labels': new_labels, + 'attention_mask': attention_mask, + } + # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt def tokenize(element): outputs = tokenizer( @@ -444,8 +491,20 @@ def tokenize(element): f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns." ) + if requires_input_output_keys: + if "input" in dataset.column_names and "output" in dataset.column_names: + # TODO: if we execute this input path, it is expected that we are using a seq2seq + # collator. If that is the case, the tokenizer should had a pad_token; this is set + # to eos automatically if it's unset and no tokenizer is provided, but we should + # properly handle if a tokenizer with no padding token is given. + tokenize_func = tokenize_input_output + else: + raise KeyError("Missing input / output keys") + else: + tokenize_func = tokenize + tokenized_dataset = dataset.map( - tokenize, + tokenize_func, batched=True, remove_columns=dataset.column_names if remove_unused_columns else None, num_proc=self.dataset_num_proc,