Skip to content

Commit f7e8e6b

Browse files
committed
Change to custom aim callback to disable multiple instantiation for
FSDP. Add tracker factory. Signed-off-by: Dushyant Behl <dushyantbehl@users.noreply.github.com>
1 parent 4f93a7c commit f7e8e6b

File tree

3 files changed

+102
-54
lines changed

3 files changed

+102
-54
lines changed

tuning/sft_trainer.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Standard
22
from datetime import datetime
3-
from typing import Optional, Union, List
3+
from typing import Optional, Union, List, Dict
44
import json
55
import os, time
66

@@ -26,8 +26,7 @@
2626
from tuning.data import tokenizer_data_utils
2727
from tuning.utils.config_utils import get_hf_peft_config
2828
from tuning.utils.data_type_utils import get_torch_dtype
29-
from tuning.tracker.tracker import Tracker, get_tracker
30-
from tuning.tracker.aimstack_tracker import AimStackTracker
29+
from tuning.tracker.tracker import Tracker, TrackerFactory
3130

3231
logger = logging.get_logger("sft_trainer")
3332

@@ -92,7 +91,8 @@ def train(
9291
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
9392
] = None,
9493
callbacks: Optional[List[TrainerCallback]] = None,
95-
tracker: Optional[Tracker] = Tracker() # default tracker is dummy tracker
94+
tracker: Optional[Tracker] = None,
95+
exp_metadata: Optional[Dict] = None
9696
):
9797
"""Call the SFTTrainer
9898
@@ -123,6 +123,7 @@ def train(
123123
train_args.fsdp_config = {"xla": False}
124124

125125
task_type = "CAUSAL_LM"
126+
additional_metrics = {}
126127

127128
model_load_time = time.time()
128129
model = AutoModelForCausalLM.from_pretrained(
@@ -131,8 +132,7 @@ def train(
131132
torch_dtype=get_torch_dtype(model_args.torch_dtype),
132133
use_flash_attention_2=model_args.use_flash_attn,
133134
)
134-
model_load_time = time.time() - model_load_time
135-
tracker.track(metric=model_load_time, name='model_load_time')
135+
additional_metrics['model_load_time'] = time.time() - model_load_time
136136

137137
peft_config = get_hf_peft_config(task_type, peft_config)
138138

@@ -262,6 +262,16 @@ def train(
262262
peft_config=peft_config,
263263
)
264264

265+
# We track additional metrics and experiment metadata after
266+
# Trainer object creation to ensure that this is not repeated
267+
# multiple times for FSDP runs.
268+
if tracker is not None:
269+
# Currently tracked only on process zero.
270+
if trainer.is_world_process_zero:
271+
for k,v in additional_metrics.items():
272+
tracker.track(metric=v, name=k, stage='additional_metrics')
273+
tracker.set_params(params=exp_metadata, name='experiment_metadata')
274+
265275
if run_distributed and peft_config is not None:
266276
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
267277
model
@@ -286,7 +296,7 @@ def main(**kwargs):
286296
default="pt",
287297
)
288298
parser.add_argument(
289-
"--extra_metadata",
299+
"--exp_metadata",
290300
type=str,
291301
default=None,
292302
)
@@ -315,27 +325,37 @@ def main(**kwargs):
315325
else:
316326
tracker_config=None
317327

318-
# Initialize the tracker early so we can calculate custom metrics like model_load_time.
319-
tracker = get_tracker(tracker_name, tracker_config)
320-
321328
# Initialize callbacks
322329
file_logger_callback = FileLoggingCallback(logger)
323330
peft_saving_callback = PeftSavingCallback()
324331
callbacks = [peft_saving_callback, file_logger_callback]
325332

333+
# Initialize the tracker
334+
tracker = TrackerFactory.get_tracker(tracker_name, tracker_config)
326335
tracker_callback = tracker.get_hf_callback()
327336
if tracker_callback is not None:
328337
callbacks.append(tracker_callback)
329338

330-
# track extra metadata
331-
if additional.extra_metadata is not None:
339+
# extra metadata passed via client
340+
metadata = None
341+
if additional.exp_metadata is not None:
332342
try:
333-
metadata = json.loads(additional.extra_metadata)
334-
tracker.track_metadata(metadata)
343+
metadata = json.loads(additional.exp_metadata)
344+
if metadata is None or not isinstance(metadata, Dict):
345+
logger.warning('metadata cannot be converted to simple k:v dict ignoring')
346+
metadata = None
335347
except:
336348
logger.error("failed while parsing extra metadata. pass a valid json")
337349

338-
train(model_args, data_args, training_args, tune_config, callbacks, tracker)
350+
train(
351+
model_args=model_args,
352+
data_args=data_args,
353+
train_args=training_args,
354+
peft_config=tune_config,
355+
callbacks=callbacks,
356+
tracker=tracker,
357+
exp_metadata=metadata
358+
)
339359

340360
if __name__ == "__main__":
341361
fire.Fire(main)

tuning/tracker/aimstack_tracker.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,49 @@
11
# Standard
22
import os
33

4-
from tuning.tracker.tracker import Tracker
4+
from .tracker import Tracker
5+
from tuning.config.tracker_configs import AimConfig
56

67
# Third Party
78
from aim.hugging_face import AimCallback
89

10+
class CustomAimCallback(AimCallback):
11+
12+
# A path to export run hash generated by Aim
13+
# This is used to link back to the expriments from outside aimstack
14+
aim_run_hash_export_path = None
15+
16+
def on_init_end(self, args, state, control, **kwargs):
17+
18+
if state and not state.is_world_process_zero:
19+
return
20+
21+
self.setup() # initializes the run_hash
22+
23+
# store the run hash
24+
if self.aim_run_hash_export_path:
25+
with open(self.aim_run_hash_export_path, 'w') as f:
26+
f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n')
27+
28+
def on_train_begin(self, args, state, control, model=None, **kwargs):
29+
# call directly to make sure hyper parameters and model info is recorded.
30+
self.setup(args=args, state=state, model=model)
31+
32+
def track_metrics(self, metric, name, context):
33+
if self._run is not None:
34+
self._run.track(metric, name=name, context=context)
35+
36+
def set_params(self, params, name):
37+
if self._run is not None:
38+
for key, value in params.items():
39+
self._run.set((name, key), value, strict=False)
40+
941
class AimStackTracker(Tracker):
1042

11-
def __init__(self, tracker_config):
12-
super().__init__(tracker_config)
43+
def __init__(self, tracker_config: AimConfig):
44+
super().__init__(name='aim', tracker_config=tracker_config)
45+
46+
def get_hf_callback(self):
1347
c = self.config
1448
exp = c.experiment
1549
ip = c.aim_remote_server_ip
@@ -18,30 +52,24 @@ def __init__(self, tracker_config):
1852
hash_export_path = c.aim_run_hash_export_path
1953

2054
if (ip is not None and port is not None):
21-
aim_callback = AimCallback(
55+
aim_callback = CustomAimCallback(
2256
repo="aim://" + ip +":"+ port + "/",
23-
experiment=exp
24-
)
57+
experiment=exp)
2558
if repo:
26-
aim_callback = AimCallback(repo=repo, experiment=exp)
59+
aim_callback = CustomAimCallback(repo=repo, experiment=exp)
2760
else:
28-
aim_callback = AimCallback(experiment=exp)
61+
aim_callback = CustomAimCallback(experiment=exp)
2962

30-
run = aim_callback.experiment # Initialize Aim run
31-
run_hash = run.hash # Extract the hash
32-
33-
# store the run hash
34-
if hash_export_path:
35-
with open(hash_export_path, 'w') as f:
36-
f.write(str(run_hash)+'\n')
37-
38-
# Save Internal State
63+
aim_callback.aim_run_hash_export_path = hash_export_path
3964
self.hf_callback = aim_callback
40-
self.run = run
41-
42-
def get_hf_callback(self):
4365
return self.hf_callback
4466

4567
def track(self, metric, name, stage='additional_metrics'):
4668
context={'subset' : stage}
47-
self.run.track(metric, name=name, context=context)
69+
self.hf_callback.track_metrics(metric, name=name, context=context)
70+
71+
def set_params(self, params, name='extra_params'):
72+
try:
73+
self.hf_callback.set_params(params, name)
74+
except:
75+
pass

tuning/tracker/tracker.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
# Generic Tracker API
22

3-
from tuning.tracker.aimstack_tracker import AimStackTracker
4-
53
class Tracker:
6-
def __init__(self, tracker_config) -> None:
7-
self.config = tracker_config
4+
def __init__(self, name=None, tracker_config=None) -> None:
5+
if tracker_config is not None:
6+
self.config = tracker_config
7+
if name is None:
8+
self._name = "None"
9+
else:
10+
self._name = name
811

912
def get_hf_callback():
1013
return None
1114

1215
def track(self, metric, name, stage):
1316
pass
1417

15-
# Metadata passed here is supposed to be a KV object
16-
# Key being the name and value being the metric to track.
17-
def track_metadata(self, metadata=None):
18-
if metadata is None or not isinstance(metadata, dict):
19-
return
20-
for k, v in metadata.items():
21-
self.track(name=k, metric=v)
18+
# Object passed here is supposed to be a KV object
19+
# for the parameters to be associated with a run
20+
def set_params(self, params, name):
21+
pass
2222

23-
def get_tracker(tracker_name, tracker_config):
24-
if tracker_name == 'aim':
25-
if tracker_config is not None:
26-
tracker = AimStackTracker(tracker_config)
27-
else:
28-
tracker = Tracker()
29-
return tracker
23+
class TrackerFactory:
24+
def get_tracker(tracker_name, tracker_config):
25+
for T in Tracker.__subclasses__():
26+
if T._name == tracker_name:
27+
return T(tracker_config)
28+
else:
29+
return Tracker()

0 commit comments

Comments
 (0)