diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index a800ed6f6..b8f3721a2 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -59,12 +59,12 @@ num_train_epochs=5, per_device_train_batch_size=4, per_device_eval_batch_size=4, - gradient_accumulation_steps=4, + gradient_accumulation_steps=1, learning_rate=0.00001, weight_decay=0, warmup_ratio=0.03, lr_scheduler_type="cosine", - logging_steps=1, + logging_strategy="epoch", include_tokens_per_second=True, packing=False, max_seq_length=4096, @@ -91,6 +91,7 @@ def test_resume_training_from_checkpoint(): sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) _validate_training(tempdir) + _validate_num_checkpoints(tempdir, train_args.num_train_epochs) # Get trainer state of latest checkpoint init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) @@ -100,6 +101,7 @@ def test_resume_training_from_checkpoint(): train_args.num_train_epochs += 5 sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) _validate_training(tempdir) + _validate_num_checkpoints(tempdir, train_args.num_train_epochs) # Get trainer state of latest checkpoint final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) @@ -242,6 +244,26 @@ def _get_latest_checkpoint_trainer_state(dir_path: str, checkpoint_index: int = last_checkpoint: The path to the checkpoint directory. """ trainer_state = None + last_checkpoint = _get_checkpoint(dir_path, checkpoint_index) + trainer_state_file = os.path.join(last_checkpoint, "trainer_state.json") + with open(trainer_state_file, "r", encoding="utf-8") as f: + trainer_state = json.load(f) + return trainer_state, last_checkpoint + + +def _get_checkpoint(dir_path: str, checkpoint_index: int = -1): + """ + Get the latest or specified checkpoint directory. + + Args: + dir_path (str): The directory path where checkpoint folders are located. + checkpoint_index (int, optional): The index of the checkpoint to retrieve, + based on the checkpoint number. The default + is -1, which returns the latest checkpoint. + + Returns: + last_checkpoint: The path to the checkpoint directory. + """ last_checkpoint = None checkpoints = [ os.path.join(dir_path, d) @@ -252,10 +274,7 @@ def _get_latest_checkpoint_trainer_state(dir_path: str, checkpoint_index: int = last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[ checkpoint_index ] - trainer_state_file = os.path.join(last_checkpoint, "trainer_state.json") - with open(trainer_state_file, "r", encoding="utf-8") as f: - trainer_state = json.load(f) - return trainer_state, last_checkpoint + return last_checkpoint def _get_training_logs_by_epoch(dir_path: str, epoch: int = None): @@ -398,7 +417,8 @@ def test_run_causallm_pt_and_inference(): # validate peft tuning configs _validate_training(tempdir) - checkpoint_path = _get_checkpoint_path(tempdir) + _validate_num_checkpoints(tempdir, train_args.num_train_epochs) + checkpoint_path = _get_checkpoint(tempdir) adapter_config = _get_adapter_config(checkpoint_path) _validate_adapter_config( @@ -434,7 +454,7 @@ def test_run_causallm_pt_and_inference_with_formatting_data(): # validate peft tuning configs _validate_training(tempdir) - checkpoint_path = _get_checkpoint_path(tempdir) + checkpoint_path = _get_checkpoint(tempdir) adapter_config = _get_adapter_config(checkpoint_path) _validate_adapter_config( adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path @@ -467,7 +487,7 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter(): # validate peft tuning configs _validate_training(tempdir) - checkpoint_path = _get_checkpoint_path(tempdir) + checkpoint_path = _get_checkpoint(tempdir) adapter_config = _get_adapter_config(checkpoint_path) _validate_adapter_config( adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path @@ -499,7 +519,7 @@ def test_run_causallm_pt_init_text(): # validate peft tuning configs _validate_training(tempdir) - checkpoint_path = _get_checkpoint_path(tempdir) + checkpoint_path = _get_checkpoint(tempdir) adapter_config = _get_adapter_config(checkpoint_path) _validate_adapter_config( adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path @@ -621,7 +641,8 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected): # validate lora tuning configs _validate_training(tempdir) - checkpoint_path = _get_checkpoint_path(tempdir) + _validate_num_checkpoints(tempdir, train_args.num_train_epochs) + checkpoint_path = _get_checkpoint(tempdir) adapter_config = _get_adapter_config(checkpoint_path) _validate_adapter_config(adapter_config, "LORA") @@ -663,7 +684,7 @@ def test_successful_lora_target_modules_default_from_main(): _validate_training(tempdir) # Calling LoRA tuning from the main results in 'added_tokens_info.json' assert "added_tokens_info.json" in os.listdir(tempdir) - checkpoint_path = _get_checkpoint_path(tempdir) + checkpoint_path = _get_checkpoint(tempdir) adapter_config = _get_adapter_config(checkpoint_path) _validate_adapter_config(adapter_config, "LORA") @@ -692,7 +713,8 @@ def test_run_causallm_ft_and_inference(dataset_path): data_args.training_data_path = dataset_path _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, data_args, tempdir) - _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) + _validate_num_checkpoints(tempdir, TRAIN_ARGS.num_train_epochs) + _test_run_inference(checkpoint_path=_get_checkpoint(tempdir)) def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no(): @@ -741,7 +763,7 @@ def test_run_causallm_ft_pretokenized(dataset_path): # validate full ft configs _validate_training(tempdir) - checkpoint_path = _get_checkpoint_path(tempdir) + checkpoint_path = _get_checkpoint(tempdir) # Load the model loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) @@ -797,8 +819,9 @@ def _validate_logfile(log_file_path, check_eval=False): assert "validation_loss" in train_log_contents -def _get_checkpoint_path(dir_path): - return os.path.join(dir_path, "checkpoint-5") +def _validate_num_checkpoints(dir_path, expected_num): + checkpoints = [d for d in os.listdir(dir_path) if d.startswith("checkpoint")] + assert len(checkpoints) == expected_num def _get_adapter_config(dir_path): diff --git a/tests/trackers/test_aim_tracker.py b/tests/trackers/test_aim_tracker.py index d2aa301b7..a470e632e 100644 --- a/tests/trackers/test_aim_tracker.py +++ b/tests/trackers/test_aim_tracker.py @@ -30,7 +30,7 @@ DATA_ARGS, MODEL_ARGS, TRAIN_ARGS, - _get_checkpoint_path, + _get_checkpoint, _test_run_inference, _validate_training, ) @@ -99,7 +99,7 @@ def test_e2e_run_with_aim_tracker(aimrepo): _validate_training(tempdir) # validate inference - _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) + _test_run_inference(checkpoint_path=_get_checkpoint(tempdir)) @pytest.mark.skipif(aim_not_available, reason="Requires aimstack to be installed") diff --git a/tests/trackers/test_file_logging_tracker.py b/tests/trackers/test_file_logging_tracker.py index e5e62ab8b..8e34dbdea 100644 --- a/tests/trackers/test_file_logging_tracker.py +++ b/tests/trackers/test_file_logging_tracker.py @@ -25,7 +25,7 @@ DATA_ARGS, MODEL_ARGS, TRAIN_ARGS, - _get_checkpoint_path, + _get_checkpoint, _test_run_causallm_ft, _test_run_inference, _validate_training, @@ -45,7 +45,7 @@ def test_run_with_file_logging_tracker(): train_args.trackers = ["file_logger"] _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir) - _test_run_inference(_get_checkpoint_path(tempdir)) + _test_run_inference(checkpoint_path=_get_checkpoint(tempdir)) def test_sample_run_with_file_logger_updated_filename():