Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,12 @@ def test_sft_trainer_with_model(self):
packing=False,
)

# if we have input/output cols, then things should work without issue
dataset_with_input_output = self.dummy_dataset.rename_column("question", "input").rename_column("answer", "output")
# with instruction format, things should work without issue
dataset_with_prompt_collection = self.dummy_instruction_dataset
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=dataset_with_input_output,
train_dataset=dataset_with_prompt_collection,
dataset_text_field=None,
formatting_func=None,
max_seq_length=16,
Expand Down
5 changes: 3 additions & 2 deletions trl/extras/dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def format_dataset(examples):


def get_formatting_func_from_dataset(
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer, apply_chat_instruction_template=True
) -> Optional[Callable]:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
Expand All @@ -83,6 +83,7 @@ def get_formatting_func_from_dataset(
return conversations_formatting_function(tokenizer, "conversations")
elif dataset.features == FORMAT_MAPPING["instruction"]:
logging.info("Formatting dataset with instruction format")
return instructions_formatting_function(tokenizer)
if apply_chat_instruction_template:
return instructions_formatting_function(tokenizer)

return None
30 changes: 15 additions & 15 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
dataset_num_proc: Optional[int] = None,
dataset_batch_size: int = 1000,
neftune_noise_alpha: Optional[float] = None,
apply_chat_instruction_template = True,
model_init_kwargs: Optional[Dict] = None,
dataset_kwargs: Optional[Dict] = None,
):
Expand Down Expand Up @@ -250,15 +251,15 @@ def make_inputs_require_grad(module, input, output):
if formatting_func is None and dataset_text_field is None:
# check if dataset has ChatML format or instruction format and is supported
# if not stays #None
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer, apply_chat_instruction_template)

requires_input_output_keys = False
requires_prompt_completion_keys = False
if not packing:
requires_input_output_keys = (dataset_text_field is None and formatting_func is None)
requires_prompt_completion_keys = (dataset_text_field is None and formatting_func is None)
if data_collator is None:
# Fall back to the appropriate collator type based on the input_output_keys
data_collator = (DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
if requires_input_output_keys
if requires_prompt_completion_keys
else DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False))

# Pre-process the datasets only once per node. The remaining processes will use the cache.
Expand All @@ -276,7 +277,7 @@ def make_inputs_require_grad(module, input, output):
num_of_sequences,
chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
requires_input_output_keys=requires_input_output_keys,
requires_prompt_completion_keys=requires_prompt_completion_keys,
**dataset_kwargs,
)
if eval_dataset is not None:
Expand Down Expand Up @@ -373,7 +374,7 @@ def _prepare_dataset(
num_of_sequences,
chars_per_token,
remove_unused_columns=True,
requires_input_output_keys=False,
requires_prompt_completion_keys=False,
append_concat_token=True,
add_special_tokens=True,
):
Expand All @@ -393,7 +394,7 @@ def _prepare_dataset(
formatting_func,
add_special_tokens,
remove_unused_columns,
requires_input_output_keys,
requires_prompt_completion_keys,
)

else:
Expand All @@ -418,14 +419,13 @@ def _prepare_non_packed_dataloader(
formatting_func=None,
add_special_tokens=True,
remove_unused_columns=True,
requires_input_output_keys=False,
requires_prompt_completion_keys=False,
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False

# TODO : fix how EOS tokens are handled
# Inspired from https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L266
def tokenize_input_output(element):
def tokenize_prompt_completion(element):

# It is difficult to add special tokens here, as separator / EOS tokens that may be added while tokenizing
# input texts may differ from concatenated text, making masking on input length incorrect.
Expand All @@ -437,7 +437,7 @@ def tokenize_input_output(element):
)

new_source = []
for (input_element, output_element) in zip(element['input'], element['output']):
for (input_element, output_element) in zip(element['prompt'], element['completion']):
if not input_element.endswith((' ', '\n', '\t')) and not output_element.startswith((' ', '\n', '\t')):
new_source.append(input_element + ' ' + output_element)
else:
Expand Down Expand Up @@ -491,15 +491,15 @@ def tokenize(element):
f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
)

if requires_input_output_keys:
if "input" in dataset.column_names and "output" in dataset.column_names:
if requires_prompt_completion_keys:
if "prompt" in dataset.column_names and "completion" in dataset.column_names:
# TODO: if we execute this input path, it is expected that we are using a seq2seq
# collator. If that is the case, the tokenizer should had a pad_token; this is set
# to eos automatically if it's unset and no tokenizer is provided, but we should
# properly handle if a tokenizer with no padding token is given.
tokenize_func = tokenize_input_output
tokenize_func = tokenize_prompt_completion
else:
raise KeyError("Missing input / output keys")
raise KeyError("Missing prompt / completion keys")
else:
tokenize_func = tokenize

Expand Down