Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
gradient_accumulation_steps=1,
learning_rate=0.00001,
weight_decay=0,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
logging_steps=1,
logging_strategy="epoch",
include_tokens_per_second=True,
packing=False,
max_seq_length=4096,
Expand All @@ -91,6 +91,7 @@ def test_resume_training_from_checkpoint():

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)
_validate_num_checkpoints(tempdir, train_args.num_train_epochs)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
Expand All @@ -100,6 +101,7 @@ def test_resume_training_from_checkpoint():
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)
_validate_num_checkpoints(tempdir, train_args.num_train_epochs)

# Get trainer state of latest checkpoint
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
Expand Down Expand Up @@ -242,6 +244,26 @@ def _get_latest_checkpoint_trainer_state(dir_path: str, checkpoint_index: int =
last_checkpoint: The path to the checkpoint directory.
"""
trainer_state = None
last_checkpoint = _get_checkpoint(dir_path, checkpoint_index)
trainer_state_file = os.path.join(last_checkpoint, "trainer_state.json")
with open(trainer_state_file, "r", encoding="utf-8") as f:
trainer_state = json.load(f)
return trainer_state, last_checkpoint


def _get_checkpoint(dir_path: str, checkpoint_index: int = -1):
"""
Get the latest or specified checkpoint directory.

Args:
dir_path (str): The directory path where checkpoint folders are located.
checkpoint_index (int, optional): The index of the checkpoint to retrieve,
based on the checkpoint number. The default
is -1, which returns the latest checkpoint.

Returns:
last_checkpoint: The path to the checkpoint directory.
"""
last_checkpoint = None
checkpoints = [
os.path.join(dir_path, d)
Expand All @@ -252,10 +274,7 @@ def _get_latest_checkpoint_trainer_state(dir_path: str, checkpoint_index: int =
last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[
checkpoint_index
]
trainer_state_file = os.path.join(last_checkpoint, "trainer_state.json")
with open(trainer_state_file, "r", encoding="utf-8") as f:
trainer_state = json.load(f)
return trainer_state, last_checkpoint
return last_checkpoint


def _get_training_logs_by_epoch(dir_path: str, epoch: int = None):
Expand Down Expand Up @@ -398,7 +417,8 @@ def test_run_causallm_pt_and_inference():

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
_validate_num_checkpoints(tempdir, train_args.num_train_epochs)
checkpoint_path = _get_checkpoint(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)

_validate_adapter_config(
Expand Down Expand Up @@ -434,7 +454,7 @@ def test_run_causallm_pt_and_inference_with_formatting_data():

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
checkpoint_path = _get_checkpoint(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
Expand Down Expand Up @@ -467,7 +487,7 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter():

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
checkpoint_path = _get_checkpoint(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
Expand Down Expand Up @@ -499,7 +519,7 @@ def test_run_causallm_pt_init_text():

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
checkpoint_path = _get_checkpoint(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(
adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
Expand Down Expand Up @@ -621,7 +641,8 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected):

# validate lora tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
_validate_num_checkpoints(tempdir, train_args.num_train_epochs)
checkpoint_path = _get_checkpoint(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA")

Expand Down Expand Up @@ -663,7 +684,7 @@ def test_successful_lora_target_modules_default_from_main():
_validate_training(tempdir)
# Calling LoRA tuning from the main results in 'added_tokens_info.json'
assert "added_tokens_info.json" in os.listdir(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
checkpoint_path = _get_checkpoint(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA")

Expand Down Expand Up @@ -692,7 +713,8 @@ def test_run_causallm_ft_and_inference(dataset_path):
data_args.training_data_path = dataset_path

_test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, data_args, tempdir)
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))
_validate_num_checkpoints(tempdir, TRAIN_ARGS.num_train_epochs)
_test_run_inference(checkpoint_path=_get_checkpoint(tempdir))


def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no():
Expand Down Expand Up @@ -741,7 +763,7 @@ def test_run_causallm_ft_pretokenized(dataset_path):

# validate full ft configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
checkpoint_path = _get_checkpoint(tempdir)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
Expand Down Expand Up @@ -797,8 +819,9 @@ def _validate_logfile(log_file_path, check_eval=False):
assert "validation_loss" in train_log_contents


def _get_checkpoint_path(dir_path):
return os.path.join(dir_path, "checkpoint-5")
def _validate_num_checkpoints(dir_path, expected_num):
checkpoints = [d for d in os.listdir(dir_path) if d.startswith("checkpoint")]
assert len(checkpoints) == expected_num


def _get_adapter_config(dir_path):
Expand Down
4 changes: 2 additions & 2 deletions tests/trackers/test_aim_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
DATA_ARGS,
MODEL_ARGS,
TRAIN_ARGS,
_get_checkpoint_path,
_get_checkpoint,
_test_run_inference,
_validate_training,
)
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_e2e_run_with_aim_tracker(aimrepo):
_validate_training(tempdir)

# validate inference
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))
_test_run_inference(checkpoint_path=_get_checkpoint(tempdir))


@pytest.mark.skipif(aim_not_available, reason="Requires aimstack to be installed")
Expand Down
4 changes: 2 additions & 2 deletions tests/trackers/test_file_logging_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
DATA_ARGS,
MODEL_ARGS,
TRAIN_ARGS,
_get_checkpoint_path,
_get_checkpoint,
_test_run_causallm_ft,
_test_run_inference,
_validate_training,
Expand All @@ -45,7 +45,7 @@ def test_run_with_file_logging_tracker():
train_args.trackers = ["file_logger"]

_test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir)
_test_run_inference(_get_checkpoint_path(tempdir))
_test_run_inference(checkpoint_path=_get_checkpoint(tempdir))


def test_sample_run_with_file_logger_updated_filename():
Expand Down
Loading