11# Standard
22from datetime import datetime
3- from typing import Optional , Union , List
3+ from typing import Optional , Union , List , Dict
44import json
55import os , time
66
2626from tuning .data import tokenizer_data_utils
2727from tuning .utils .config_utils import get_hf_peft_config
2828from tuning .utils .data_type_utils import get_torch_dtype
29- from tuning .tracker .tracker import Tracker , get_tracker
30- from tuning .tracker .aimstack_tracker import AimStackTracker
29+ from tuning .tracker .tracker import Tracker , TrackerFactory
3130
3231logger = logging .get_logger ("sft_trainer" )
3332
@@ -92,7 +91,8 @@ def train(
9291 Union [peft_config .LoraConfig , peft_config .PromptTuningConfig ]
9392 ] = None ,
9493 callbacks : Optional [List [TrainerCallback ]] = None ,
95- tracker : Optional [Tracker ] = Tracker () # default tracker is dummy tracker
94+ tracker : Optional [Tracker ] = None ,
95+ exp_metadata : Optional [Dict ] = None
9696):
9797 """Call the SFTTrainer
9898
@@ -123,6 +123,7 @@ def train(
123123 train_args .fsdp_config = {"xla" : False }
124124
125125 task_type = "CAUSAL_LM"
126+ additional_metrics = {}
126127
127128 model_load_time = time .time ()
128129 model = AutoModelForCausalLM .from_pretrained (
@@ -131,8 +132,7 @@ def train(
131132 torch_dtype = get_torch_dtype (model_args .torch_dtype ),
132133 use_flash_attention_2 = model_args .use_flash_attn ,
133134 )
134- model_load_time = time .time () - model_load_time
135- tracker .track (metric = model_load_time , name = 'model_load_time' )
135+ additional_metrics ['model_load_time' ] = time .time () - model_load_time
136136
137137 peft_config = get_hf_peft_config (task_type , peft_config )
138138
@@ -262,6 +262,16 @@ def train(
262262 peft_config = peft_config ,
263263 )
264264
265+ # We track additional metrics and experiment metadata after
266+ # Trainer object creation to ensure that this is not repeated
267+ # multiple times for FSDP runs.
268+ if tracker is not None :
269+ # Currently tracked only on process zero.
270+ if trainer .is_world_process_zero :
271+ for k ,v in additional_metrics .items ():
272+ tracker .track (metric = v , name = k , stage = 'additional_metrics' )
273+ tracker .set_params (params = exp_metadata , name = 'experiment_metadata' )
274+
265275 if run_distributed and peft_config is not None :
266276 trainer .accelerator .state .fsdp_plugin .auto_wrap_policy = fsdp_auto_wrap_policy (
267277 model
@@ -286,7 +296,7 @@ def main(**kwargs):
286296 default = "pt" ,
287297 )
288298 parser .add_argument (
289- "--extra_metadata " ,
299+ "--exp_metadata " ,
290300 type = str ,
291301 default = None ,
292302 )
@@ -315,27 +325,37 @@ def main(**kwargs):
315325 else :
316326 tracker_config = None
317327
318- # Initialize the tracker early so we can calculate custom metrics like model_load_time.
319- tracker = get_tracker (tracker_name , tracker_config )
320-
321328 # Initialize callbacks
322329 file_logger_callback = FileLoggingCallback (logger )
323330 peft_saving_callback = PeftSavingCallback ()
324331 callbacks = [peft_saving_callback , file_logger_callback ]
325332
333+ # Initialize the tracker
334+ tracker = TrackerFactory .get_tracker (tracker_name , tracker_config )
326335 tracker_callback = tracker .get_hf_callback ()
327336 if tracker_callback is not None :
328337 callbacks .append (tracker_callback )
329338
330- # track extra metadata
331- if additional .extra_metadata is not None :
339+ # extra metadata passed via client
340+ metadata = None
341+ if additional .exp_metadata is not None :
332342 try :
333- metadata = json .loads (additional .extra_metadata )
334- tracker .track_metadata (metadata )
343+ metadata = json .loads (additional .exp_metadata )
344+ if metadata is None or not isinstance (metadata , Dict ):
345+ logger .warning ('metadata cannot be converted to simple k:v dict ignoring' )
346+ metadata = None
335347 except :
336348 logger .error ("failed while parsing extra metadata. pass a valid json" )
337349
338- train (model_args , data_args , training_args , tune_config , callbacks , tracker )
350+ train (
351+ model_args = model_args ,
352+ data_args = data_args ,
353+ train_args = training_args ,
354+ peft_config = tune_config ,
355+ callbacks = callbacks ,
356+ tracker = tracker ,
357+ exp_metadata = metadata
358+ )
339359
340360if __name__ == "__main__" :
341361 fire .Fire (main )
0 commit comments