diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 35857524669..e6251a7cc64 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -508,12 +508,12 @@ def test_sft_trainer_with_model(self): packing=False, ) - # if we have input/output cols, then things should work without issue - dataset_with_input_output = self.dummy_dataset.rename_column("question", "input").rename_column("answer", "output") + # with instruction format, things should work without issue + dataset_with_prompt_collection = self.dummy_instruction_dataset trainer = SFTTrainer( model=self.model, args=training_args, - train_dataset=dataset_with_input_output, + train_dataset=dataset_with_prompt_collection, dataset_text_field=None, formatting_func=None, max_seq_length=16, diff --git a/trl/extras/dataset_formatting.py b/trl/extras/dataset_formatting.py index 31cf567209a..6534ad63ba3 100644 --- a/trl/extras/dataset_formatting.py +++ b/trl/extras/dataset_formatting.py @@ -58,7 +58,7 @@ def format_dataset(examples): def get_formatting_func_from_dataset( - dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer + dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer, apply_chat_instruction_template=True ) -> Optional[Callable]: r""" Finds the correct formatting function based on the dataset structure. Currently supported datasets are: @@ -83,6 +83,7 @@ def get_formatting_func_from_dataset( return conversations_formatting_function(tokenizer, "conversations") elif dataset.features == FORMAT_MAPPING["instruction"]: logging.info("Formatting dataset with instruction format") - return instructions_formatting_function(tokenizer) + if apply_chat_instruction_template: + return instructions_formatting_function(tokenizer) return None diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index f79f2693b9e..616c4790f64 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -145,6 +145,7 @@ def __init__( dataset_num_proc: Optional[int] = None, dataset_batch_size: int = 1000, neftune_noise_alpha: Optional[float] = None, + apply_chat_instruction_template = True, model_init_kwargs: Optional[Dict] = None, dataset_kwargs: Optional[Dict] = None, ): @@ -250,15 +251,15 @@ def make_inputs_require_grad(module, input, output): if formatting_func is None and dataset_text_field is None: # check if dataset has ChatML format or instruction format and is supported # if not stays #None - formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) + formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer, apply_chat_instruction_template) - requires_input_output_keys = False + requires_prompt_completion_keys = False if not packing: - requires_input_output_keys = (dataset_text_field is None and formatting_func is None) + requires_prompt_completion_keys = (dataset_text_field is None and formatting_func is None) if data_collator is None: # Fall back to the appropriate collator type based on the input_output_keys data_collator = (DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True) - if requires_input_output_keys + if requires_prompt_completion_keys else DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)) # Pre-process the datasets only once per node. The remaining processes will use the cache. @@ -276,7 +277,7 @@ def make_inputs_require_grad(module, input, output): num_of_sequences, chars_per_token, remove_unused_columns=args.remove_unused_columns if args is not None else True, - requires_input_output_keys=requires_input_output_keys, + requires_prompt_completion_keys=requires_prompt_completion_keys, **dataset_kwargs, ) if eval_dataset is not None: @@ -373,7 +374,7 @@ def _prepare_dataset( num_of_sequences, chars_per_token, remove_unused_columns=True, - requires_input_output_keys=False, + requires_prompt_completion_keys=False, append_concat_token=True, add_special_tokens=True, ): @@ -393,7 +394,7 @@ def _prepare_dataset( formatting_func, add_special_tokens, remove_unused_columns, - requires_input_output_keys, + requires_prompt_completion_keys, ) else: @@ -418,14 +419,13 @@ def _prepare_non_packed_dataloader( formatting_func=None, add_special_tokens=True, remove_unused_columns=True, - requires_input_output_keys=False, + requires_prompt_completion_keys=False, ): use_formatting_func = formatting_func is not None and dataset_text_field is None self._dataset_sanity_checked = False - # TODO : fix how EOS tokens are handled # Inspired from https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L266 - def tokenize_input_output(element): + def tokenize_prompt_completion(element): # It is difficult to add special tokens here, as separator / EOS tokens that may be added while tokenizing # input texts may differ from concatenated text, making masking on input length incorrect. @@ -437,7 +437,7 @@ def tokenize_input_output(element): ) new_source = [] - for (input_element, output_element) in zip(element['input'], element['output']): + for (input_element, output_element) in zip(element['prompt'], element['completion']): if not input_element.endswith((' ', '\n', '\t')) and not output_element.startswith((' ', '\n', '\t')): new_source.append(input_element + ' ' + output_element) else: @@ -491,15 +491,15 @@ def tokenize(element): f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns." ) - if requires_input_output_keys: - if "input" in dataset.column_names and "output" in dataset.column_names: + if requires_prompt_completion_keys: + if "prompt" in dataset.column_names and "completion" in dataset.column_names: # TODO: if we execute this input path, it is expected that we are using a seq2seq # collator. If that is the case, the tokenizer should had a pad_token; this is set # to eos automatically if it's unset and no tokenizer is provided, but we should # properly handle if a tokenizer with no padding token is given. - tokenize_func = tokenize_input_output + tokenize_func = tokenize_prompt_completion else: - raise KeyError("Missing input / output keys") + raise KeyError("Missing prompt / completion keys") else: tokenize_func = tokenize