@@ -624,15 +624,29 @@ def get_caller(num_frames=1):
624624 return f"In { file_name } , line { line_number } "
625625
626626
627- def log_rank_0 (msg , include_caller = False , rank = None , to_print = False ):
627+ def log_rank_0 (
628+ msg , include_caller = False , rank = None , to_print = False , level = logging .INFO
629+ ) -> None :
628630 if rank is None :
629631 rank = get_rank () if is_initialized () else 0
630- if rank <= 0 :
631- if include_caller :
632- msg = f"{ get_caller (num_frames = 2 )} : { msg } "
633- if to_print :
634- print (msg )
635- else :
632+ if rank > 0 :
633+ return
634+
635+ if include_caller :
636+ msg = f"{ get_caller (num_frames = 2 )} : { msg } "
637+
638+ if to_print :
639+ print (msg )
640+ return
641+
642+ match level :
643+ case logging .WARNING :
644+ logger .warning (msg )
645+ case logging .ERROR :
646+ logger .error (msg )
647+ case logging .DEBUG :
648+ logger .debug (msg )
649+ case _:
636650 logger .info (msg )
637651
638652
@@ -673,6 +687,13 @@ def skip_precheck_loops():
673687 accelerator .get_state_dict = old_get_state
674688
675689
690+ def _get_checkpoint_dir (args , samples_seen ) -> Path :
691+ subdir = (
692+ "last_epoch" if args .keep_last_checkpoint_only else f"samples_{ samples_seen } "
693+ )
694+ return Path (args .output_dir ) / "hf_format" / subdir
695+
696+
676697def save_hf_format_accelerate (
677698 args ,
678699 model ,
@@ -681,13 +702,11 @@ def save_hf_format_accelerate(
681702 samples_seen ,
682703 is_lora = False ,
683704):
684- # Build the subdirectory name
685- subdir = (
686- "last_epoch" if args .keep_last_checkpoint_only else f"samples_{ samples_seen } "
687- )
705+ # Build the final output directory path
706+ final_output_dir = _get_checkpoint_dir (args , samples_seen )
688707
689708 log_rank_0 (
690- f"\033 [93mSaving model in huggingface format at: { subdir } \033 [0m" ,
709+ f"\033 [93mSaving model in huggingface format at: { final_output_dir } \033 [0m" ,
691710 to_print = True ,
692711 )
693712 start = time .time ()
@@ -697,9 +716,6 @@ def save_hf_format_accelerate(
697716 else :
698717 convert_dolomite = True
699718
700- # Build the final output directory path
701- final_output_dir = Path (args .output_dir ) / "hf_format" / subdir
702-
703719 if args .use_dolomite and convert_dolomite :
704720 tmpdir = TemporaryDirectory ("w" ) # pylint: disable=consider-using-with
705721 output_dir = Path (tmpdir .name )
@@ -797,6 +813,48 @@ def set_random_seed(seed):
797813 torch .cuda .manual_seed_all (seed )
798814
799815
816+ def _get_checkpoint_dir_size (checkpoint_dir ) -> int :
817+ total = 0
818+ for dirpath , _ , filenames in os .walk (checkpoint_dir ):
819+ for f in filenames :
820+ fp = os .path .join (dirpath , f )
821+ if os .path .isfile (fp ):
822+ total += os .path .getsize (fp )
823+ return total
824+
825+
826+ def check_disk_space_for_next_checkpoint (
827+ model : Model , output_dir : Path , warn_steps_ahead : int = 3
828+ ) -> None :
829+ checkpoint_size = model .last_checkpoint_size
830+ if checkpoint_size is None :
831+ # No previous checkpoint size to estimate, do nothing.
832+ return
833+
834+ def _mb_size (num_bytes ):
835+ return f"{ num_bytes / 1024 / 1024 :.2f} MB"
836+
837+ try :
838+ stat = shutil .disk_usage (output_dir )
839+ free_bytes = stat .free
840+ needed_bytes = checkpoint_size * warn_steps_ahead
841+
842+ log_rank_0 (
843+ f"Disk space info: free={ _mb_size (free_bytes )} , last_checkpoint_size={ _mb_size (checkpoint_size )} (output_dir={ output_dir } )"
844+ )
845+ if free_bytes < needed_bytes :
846+ log_rank_0 (
847+ f"Estimated free disk space ({ _mb_size (free_bytes )} ) is less than the estimated size of the next { warn_steps_ahead } checkpoints ({ _mb_size (needed_bytes )} ). "
848+ "The next checkpoint(s) may fail due to insufficient disk space." ,
849+ level = logging .WARNING ,
850+ )
851+ except Exception as e :
852+ log_rank_0 (
853+ f"Could not check disk space after checkpoint: { e } " ,
854+ level = logging .ERROR ,
855+ )
856+
857+
800858def save_checkpoint (
801859 args ,
802860 accelerator : Accelerator ,
@@ -808,6 +866,10 @@ def save_checkpoint(
808866 hf_format : bool = True ,
809867 full_state : bool = False ,
810868) -> None :
869+ # Warn if disk space is low.
870+ output_dir = Path (args .output_dir )
871+ check_disk_space_for_next_checkpoint (model , output_dir , warn_steps_ahead = 3 )
872+
811873 if hf_format :
812874 save_hf_format_accelerate (
813875 args = args ,
@@ -827,6 +889,12 @@ def save_checkpoint(
827889 samples_seen = samples_seen ,
828890 )
829891
892+ # Track last checkpoint size.
893+ if hf_format :
894+ checkpoint_dir = _get_checkpoint_dir (args , samples_seen )
895+ if checkpoint_dir .exists ():
896+ model .last_checkpoint_size = _get_checkpoint_dir_size (checkpoint_dir )
897+
830898
831899def save_full_state (args , accelerator , is_lora : bool , epoch : int , samples_seen : int ):
832900 """
0 commit comments